| 
									
										
										
										
											2023-03-15 17:21:45 -03:00
										 |  |  | # frozen_string_literal: true | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | module DiscourseAi | 
					
						
							|  |  |  |   module Embeddings | 
					
						
							| 
									
										
										
										
											2023-05-23 10:43:24 +10:00
										 |  |  |     MissingEmbeddingError = Class.new(StandardError) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-15 17:21:45 -03:00
										 |  |  |     class Topic | 
					
						
							| 
									
										
										
										
											2023-03-31 15:29:56 -03:00
										 |  |  |       def generate_and_store_embeddings_for(topic) | 
					
						
							| 
									
										
										
										
											2023-03-15 17:21:45 -03:00
										 |  |  |         return unless SiteSetting.ai_embeddings_enabled | 
					
						
							| 
									
										
										
										
											2023-03-31 15:29:56 -03:00
										 |  |  |         return if topic.blank? || topic.first_post.blank? | 
					
						
							| 
									
										
										
										
											2023-03-15 17:21:45 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-31 15:29:56 -03:00
										 |  |  |         enabled_models = DiscourseAi::Embeddings::Model.enabled_models | 
					
						
							|  |  |  |         return if enabled_models.empty? | 
					
						
							| 
									
										
										
										
											2023-03-15 17:21:45 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-31 15:29:56 -03:00
										 |  |  |         enabled_models.each do |model| | 
					
						
							|  |  |  |           embedding = model.generate_embedding(topic.first_post.raw) | 
					
						
							|  |  |  |           persist_embedding(topic, model, embedding) if embedding | 
					
						
							| 
									
										
										
										
											2023-03-15 17:21:45 -03:00
										 |  |  |         end | 
					
						
							|  |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-31 15:29:56 -03:00
										 |  |  |       def symmetric_semantic_search(model, topic) | 
					
						
							| 
									
										
										
										
											2023-05-23 10:43:24 +10:00
										 |  |  |         candidate_ids = query_symmetric_embeddings(model, topic) | 
					
						
							| 
									
										
										
										
											2023-03-31 15:29:56 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # Happens when the topic doesn't have any embeddings | 
					
						
							|  |  |  |         # I'd rather not use Exceptions to control the flow, so this should be refactored soon | 
					
						
							|  |  |  |         if candidate_ids.empty? || !candidate_ids.include?(topic.id) | 
					
						
							| 
									
										
										
										
											2023-05-23 10:43:24 +10:00
										 |  |  |           raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}" | 
					
						
							| 
									
										
										
										
											2023-03-15 17:21:45 -03:00
										 |  |  |         end | 
					
						
							| 
									
										
										
										
											2023-03-31 15:29:56 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         candidate_ids | 
					
						
							| 
									
										
										
										
											2023-03-15 17:21:45 -03:00
										 |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-31 15:29:56 -03:00
										 |  |  |       def asymmetric_semantic_search(model, query, limit, offset) | 
					
						
							| 
									
										
										
										
											2023-03-31 16:15:10 -03:00
										 |  |  |         embedding = model.generate_embedding(query) | 
					
						
							| 
									
										
										
										
											2023-03-31 15:29:56 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-23 18:57:52 -03:00
										 |  |  |         begin | 
					
						
							|  |  |  |           candidate_ids = | 
					
						
							|  |  |  |             DiscourseAi::Database::Connection | 
					
						
							|  |  |  |               .db | 
					
						
							|  |  |  |               .query(<<~SQL, query_embedding: embedding, limit: limit, offset: offset) | 
					
						
							|  |  |  |                 SELECT | 
					
						
							|  |  |  |                   topic_id | 
					
						
							|  |  |  |                 FROM | 
					
						
							|  |  |  |                   topic_embeddings_#{model.name.underscore} | 
					
						
							|  |  |  |                 ORDER BY | 
					
						
							|  |  |  |                   embedding #{model.pg_function} '[:query_embedding]' | 
					
						
							|  |  |  |                 LIMIT :limit | 
					
						
							|  |  |  |                 OFFSET :offset | 
					
						
							|  |  |  |               SQL | 
					
						
							|  |  |  |               .map(&:topic_id) | 
					
						
							|  |  |  |         rescue PG::Error => e | 
					
						
							|  |  |  |           Rails.logger.error( | 
					
						
							|  |  |  |             "Error #{e} querying embeddings for topic #{topic.id} and model #{model.name}", | 
					
						
							|  |  |  |           ) | 
					
						
							|  |  |  |           raise MissingEmbeddingError | 
					
						
							|  |  |  |         end | 
					
						
							| 
									
										
										
										
											2023-03-31 15:29:56 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         candidate_ids | 
					
						
							| 
									
										
										
										
											2023-03-15 17:21:45 -03:00
										 |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-31 15:29:56 -03:00
										 |  |  |       private | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-23 10:43:24 +10:00
										 |  |  |       def query_symmetric_embeddings(model, topic) | 
					
						
							| 
									
										
										
										
											2023-05-23 18:57:52 -03:00
										 |  |  |         begin | 
					
						
							|  |  |  |           DiscourseAi::Database::Connection.db.query(<<~SQL, topic_id: topic.id).map(&:topic_id) | 
					
						
							|  |  |  |             SELECT | 
					
						
							|  |  |  |               topic_id | 
					
						
							|  |  |  |             FROM | 
					
						
							|  |  |  |               topic_embeddings_#{model.name.underscore} | 
					
						
							|  |  |  |             ORDER BY | 
					
						
							|  |  |  |               embedding #{model.pg_function} ( | 
					
						
							|  |  |  |                 SELECT | 
					
						
							|  |  |  |                   embedding | 
					
						
							|  |  |  |                 FROM | 
					
						
							|  |  |  |                   topic_embeddings_#{model.name.underscore} | 
					
						
							|  |  |  |                 WHERE | 
					
						
							|  |  |  |                   topic_id = :topic_id | 
					
						
							|  |  |  |                 LIMIT 1
 | 
					
						
							|  |  |  |               ) | 
					
						
							|  |  |  |             LIMIT 100
 | 
					
						
							|  |  |  |           SQL | 
					
						
							|  |  |  |         rescue PG::Error => e | 
					
						
							|  |  |  |           Rails.logger.error( | 
					
						
							|  |  |  |             "Error #{e} querying embeddings for topic #{topic.id} and model #{model.name}", | 
					
						
							|  |  |  |           ) | 
					
						
							|  |  |  |           raise MissingEmbeddingError | 
					
						
							|  |  |  |         end | 
					
						
							| 
									
										
										
										
											2023-05-23 10:43:24 +10:00
										 |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-31 15:29:56 -03:00
										 |  |  |       def persist_embedding(topic, model, embedding) | 
					
						
							| 
									
										
										
										
											2023-05-23 18:57:52 -03:00
										 |  |  |         begin | 
					
						
							|  |  |  |           DiscourseAi::Database::Connection.db.exec( | 
					
						
							|  |  |  |             <<~SQL, | 
					
						
							|  |  |  |               INSERT INTO topic_embeddings_#{model.name.underscore} (topic_id, embedding) | 
					
						
							|  |  |  |               VALUES (:topic_id, '[:embedding]') | 
					
						
							|  |  |  |               ON CONFLICT (topic_id) | 
					
						
							|  |  |  |               DO UPDATE SET embedding = '[:embedding]' | 
					
						
							|  |  |  |             SQL | 
					
						
							|  |  |  |             topic_id: topic.id, | 
					
						
							|  |  |  |             embedding: embedding, | 
					
						
							|  |  |  |           ) | 
					
						
							|  |  |  |         rescue PG::Error => e | 
					
						
							|  |  |  |           Rails.logger.error( | 
					
						
							|  |  |  |             "Error #{e} persisting embedding for topic #{topic.id} and model #{model.name}", | 
					
						
							|  |  |  |           ) | 
					
						
							|  |  |  |         end | 
					
						
							| 
									
										
										
										
											2023-03-15 17:21:45 -03:00
										 |  |  |       end | 
					
						
							|  |  |  |     end | 
					
						
							|  |  |  |   end | 
					
						
							|  |  |  | end |