# frozen_string_literal: true module DiscourseAi module Embeddings module VectorRepresentations class Base class << self def find_representation(model_name) # we are explicit here cause the loader may have not # loaded the subclasses yet [ DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2, DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn, DiscourseAi::Embeddings::VectorRepresentations::BgeM3, DiscourseAi::Embeddings::VectorRepresentations::Gemini, DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large, DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large, DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small, DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002, ].find { _1.name == model_name } end def current_representation(strategy) find_representation(SiteSetting.ai_embeddings_model).new(strategy) end def correctly_configured? raise NotImplementedError end def dependant_setting_names raise NotImplementedError end def configuration_hint settings = dependant_setting_names I18n.t( "discourse_ai.embeddings.configuration.hint", settings: settings.join(", "), count: settings.length, ) end end def initialize(strategy) @strategy = strategy end def consider_indexing(memory: "100MB") [topic_table_name, post_table_name].each do |table_name| index_name = index_name(table_name) # Using extension maintainer's recommendation for ivfflat indexes # Results are not as good as without indexes, but it's much faster # Disk usage is ~1x the size of the table, so this doubles table total size count = DB.query_single( "SELECT count(*) FROM #{table_name} WHERE model_id = #{id} AND strategy_id = #{@strategy.id};", ).first lists = [count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i, 10].max probes = [count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i, 1].max Discourse.cache.write("#{table_name}-#{id}-#{@strategy.id}-probes", probes) existing_index = DB.query_single(<<~SQL, index_name: index_name).first SELECT indexdef FROM pg_indexes WHERE indexname = :index_name AND schemaname = 'public' LIMIT 1 SQL if !existing_index.present? Rails.logger.info("Index #{index_name} does not exist, creating...") return create_index!(table_name, memory, lists, probes) end existing_index_age = DB .query_single( "SELECT pg_catalog.obj_description((:index_name)::regclass, 'pg_class');", index_name: index_name, ) .first .to_i || 0 new_rows = DB.query_single( "SELECT count(*) FROM #{table_name} WHERE model_id = #{id} AND strategy_id = #{@strategy.id} AND created_at > '#{Time.at(existing_index_age)}';", ).first existing_lists = existing_index.match(/lists='(\d+)'/)&.captures&.first&.to_i if existing_index_age > 0 && existing_index_age < ( if SiteSetting.ai_embeddings_semantic_related_topics_enabled 1.hour.ago.to_i else 1.day.ago.to_i end ) if new_rows > 10_000 Rails.logger.info( "Index #{index_name} is #{existing_index_age} seconds old, and there are #{new_rows} new rows, updating...", ) return create_index!(table_name, memory, lists, probes) elsif existing_lists != lists Rails.logger.info( "Index #{index_name} already exists, but lists is #{existing_lists} instead of #{lists}, updating...", ) return create_index!(table_name, memory, lists, probes) end end Rails.logger.info( "Index #{index_name} kept. #{Time.now.to_i - existing_index_age} seconds old, #{new_rows} new rows, #{existing_lists} lists, #{probes} probes.", ) end end def create_index!(table_name, memory, lists, probes) tries = 0 index_name = index_name(table_name) DB.exec("SET work_mem TO '#{memory}';") DB.exec("SET maintenance_work_mem TO '#{memory}';") begin DB.exec(<<~SQL) DROP INDEX IF EXISTS #{index_name}; CREATE INDEX IF NOT EXISTS #{index_name} ON #{table_name} USING ivfflat ((embeddings::halfvec(#{dimensions})) #{pg_index_type}) WITH (lists = #{lists}) WHERE model_id = #{id} AND strategy_id = #{@strategy.id}; SQL rescue PG::ProgramLimitExceeded => e parsed_error = e.message.match(/memory required is (\d+ [A-Z]{2}), ([a-z_]+)/) if parsed_error[1].present? && parsed_error[2].present? DB.exec("SET #{parsed_error[2]} TO '#{parsed_error[1].tr(" ", "")}';") tries += 1 retry if tries < 3 else raise e end end DB.exec("COMMENT ON INDEX #{index_name} IS '#{Time.now.to_i}';") DB.exec("RESET work_mem;") DB.exec("RESET maintenance_work_mem;") end def vector_from(text, asymetric: false) raise NotImplementedError end def generate_representation_from(target, persist: true) text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2) return if text.blank? target_column = case target when Topic "topic_id" when Post "post_id" when RagDocumentFragment "rag_document_fragment_id" else raise ArgumentError, "Invalid target type" end new_digest = OpenSSL::Digest::SHA1.hexdigest(text) current_digest = DB.query_single(<<~SQL, target_id: target.id).first SELECT digest FROM #{table_name(target)} WHERE model_id = #{id} AND strategy_id = #{@strategy.id} AND #{target_column} = :target_id LIMIT 1 SQL return if current_digest == new_digest vector = vector_from(text) save_to_db(target, vector, new_digest) if persist end def topic_id_from_representation(raw_vector) DB.query_single(<<~SQL, query_embedding: raw_vector).first SELECT topic_id FROM #{topic_table_name} WHERE model_id = #{id} AND strategy_id = #{@strategy.id} ORDER BY embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) LIMIT 1 SQL end def post_id_from_representation(raw_vector) DB.query_single(<<~SQL, query_embedding: raw_vector).first SELECT post_id FROM #{post_table_name} WHERE model_id = #{id} AND strategy_id = #{@strategy.id} ORDER BY embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) LIMIT 1 SQL end def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false) results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset) #{probes_sql(topic_table_name)} SELECT topic_id, embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance FROM #{topic_table_name} WHERE model_id = #{id} AND strategy_id = #{@strategy.id} ORDER BY embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) LIMIT :limit OFFSET :offset SQL if return_distance results.map { |r| [r.topic_id, r.distance] } else results.map(&:topic_id) end rescue PG::Error => e Rails.logger.error("Error #{e} querying embeddings for model #{name}") raise MissingEmbeddingError end def asymmetric_posts_similarity_search(raw_vector, limit:, offset:, return_distance: false) results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset) #{probes_sql(post_table_name)} SELECT post_id, embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance FROM #{post_table_name} INNER JOIN posts AS p ON p.id = post_id INNER JOIN topics AS t ON t.id = p.topic_id AND t.archetype = 'regular' WHERE model_id = #{id} AND strategy_id = #{@strategy.id} ORDER BY embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) LIMIT :limit OFFSET :offset SQL if return_distance results.map { |r| [r.post_id, r.distance] } else results.map(&:post_id) end rescue PG::Error => e Rails.logger.error("Error #{e} querying embeddings for model #{name}") raise MissingEmbeddingError end def asymmetric_rag_fragment_similarity_search( raw_vector, target_id:, target_type:, limit:, offset:, return_distance: false ) results = DB.query( <<~SQL, #{probes_sql(post_table_name)} SELECT rag_document_fragment_id, embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance FROM #{rag_fragments_table_name} INNER JOIN rag_document_fragments AS rdf ON rdf.id = rag_document_fragment_id WHERE model_id = #{id} AND strategy_id = #{@strategy.id} AND rdf.target_id = :target_id AND rdf.target_type = :target_type ORDER BY embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) LIMIT :limit OFFSET :offset SQL query_embedding: raw_vector, target_id: target_id, target_type: target_type, limit: limit, offset: offset, ) if return_distance results.map { |r| [r.rag_document_fragment_id, r.distance] } else results.map(&:rag_document_fragment_id) end rescue PG::Error => e Rails.logger.error("Error #{e} querying embeddings for model #{name}") raise MissingEmbeddingError end def symmetric_topics_similarity_search(topic) DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id) #{probes_sql(topic_table_name)} SELECT topic_id FROM #{topic_table_name} WHERE model_id = #{id} AND strategy_id = #{@strategy.id} ORDER BY embeddings::halfvec(#{dimensions}) #{pg_function} ( SELECT embeddings FROM #{topic_table_name} WHERE model_id = #{id} AND strategy_id = #{@strategy.id} AND topic_id = :topic_id LIMIT 1 )::halfvec(#{dimensions}) LIMIT 100 SQL rescue PG::Error => e Rails.logger.error( "Error #{e} querying embeddings for topic #{topic.id} and model #{name}", ) raise MissingEmbeddingError end def topic_table_name "ai_topic_embeddings" end def post_table_name "ai_post_embeddings" end def rag_fragments_table_name "ai_document_fragment_embeddings" end def table_name(target) case target when Topic topic_table_name when Post post_table_name when RagDocumentFragment rag_fragments_table_name else raise ArgumentError, "Invalid target type" end end def index_name(table_name) "#{table_name}_#{id}_#{@strategy.id}_search" end def probes_sql(table_name) probes = Discourse.cache.read("#{table_name}-#{id}-#{@strategy.id}-probes") probes.present? ? "SET LOCAL ivfflat.probes TO #{probes};" : "" end def name raise NotImplementedError end def dimensions raise NotImplementedError end def max_sequence_length raise NotImplementedError end def id raise NotImplementedError end def pg_function raise NotImplementedError end def version raise NotImplementedError end def tokenizer raise NotImplementedError end def asymmetric_query_prefix raise NotImplementedError end protected def save_to_db(target, vector, digest) if target.is_a?(Topic) DB.exec( <<~SQL, INSERT INTO #{topic_table_name} (topic_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at) VALUES (:topic_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now) ON CONFLICT (strategy_id, model_id, topic_id) DO UPDATE SET model_version = :model_version, strategy_version = :strategy_version, digest = :digest, embeddings = '[:embeddings]', updated_at = :now SQL topic_id: target.id, model_id: id, model_version: version, strategy_id: @strategy.id, strategy_version: @strategy.version, digest: digest, embeddings: vector, now: Time.zone.now, ) elsif target.is_a?(Post) DB.exec( <<~SQL, INSERT INTO #{post_table_name} (post_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at) VALUES (:post_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now) ON CONFLICT (model_id, strategy_id, post_id) DO UPDATE SET model_version = :model_version, strategy_version = :strategy_version, digest = :digest, embeddings = '[:embeddings]', updated_at = :now SQL post_id: target.id, model_id: id, model_version: version, strategy_id: @strategy.id, strategy_version: @strategy.version, digest: digest, embeddings: vector, now: Time.zone.now, ) elsif target.is_a?(RagDocumentFragment) DB.exec( <<~SQL, INSERT INTO #{rag_fragments_table_name} (rag_document_fragment_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at) VALUES (:fragment_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now) ON CONFLICT (model_id, strategy_id, rag_document_fragment_id) DO UPDATE SET model_version = :model_version, strategy_version = :strategy_version, digest = :digest, embeddings = '[:embeddings]', updated_at = :now SQL fragment_id: target.id, model_id: id, model_version: version, strategy_id: @strategy.id, strategy_version: @strategy.version, digest: digest, embeddings: vector, now: Time.zone.now, ) else raise ArgumentError, "Invalid target type" end end def discourse_embeddings_endpoint if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? service = DiscourseAi::Utils::DnsSrv.lookup( SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv, ) "https://#{service.target}:#{service.port}" else SiteSetting.ai_embeddings_discourse_service_api_endpoint end end end end end end