From eae527f99d276a7e7cd9e99c54d26aa3eff1bd3a Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Fri, 13 Dec 2024 10:15:21 -0300 Subject: [PATCH] REFACTOR: A Simpler way of interacting with embeddings tables. (#1023) * REFACTOR: A Simpler way of interacting with embeddings' tables. This change adds a new abstraction called `Schema`, which acts as a repository that supports the same DB features `VectorRepresentation::Base` has, with the exception that removes the need to have duplicated methods per embeddings table. It is also a bit more flexible when performing a similarity search because you can pass it a block that gives you access to the builder, allowing you to add multiple joins/where conditions. --- app/jobs/regular/digest_rag_upload.rb | 4 +- app/jobs/regular/generate_embeddings.rb | 4 +- app/jobs/regular/generate_rag_embeddings.rb | 4 +- app/jobs/scheduled/embeddings_backfill.rb | 13 +- app/models/rag_document_fragment.rb | 6 +- ...rate_embeddings_from_dedicated_database.rb | 2 +- ...ove_embeddings_to_single_table_per_type.rb | 4 +- lib/ai_bot/personas/persona.rb | 38 +- lib/ai_bot/tool_runner.rb | 22 +- lib/ai_helper/semantic_categorizer.rb | 21 +- lib/embeddings/schema.rb | 194 +++++++++ lib/embeddings/semantic_related.rb | 10 +- lib/embeddings/semantic_search.rb | 28 +- lib/embeddings/vector_representations/base.rb | 376 +----------------- lib/tasks/modules/embeddings/database.rake | 7 +- spec/jobs/regular/digest_rag_upload_spec.rb | 5 +- .../regular/generate_rag_embeddings_spec.rb | 9 +- .../scheduled/embeddings_backfill_spec.rb | 13 +- .../modules/ai_bot/personas/persona_spec.rb | 58 +-- spec/lib/modules/ai_bot/tools/search_spec.rb | 8 +- .../ai_helper/semantic_categorizer_spec.rb | 5 +- .../jobs/generate_embeddings_spec.rb | 10 +- spec/lib/modules/embeddings/schema_spec.rb | 104 +++++ .../embeddings/semantic_related_spec.rb | 4 +- .../embeddings/semantic_search_spec.rb | 24 +- .../embeddings/semantic_topic_query_spec.rb | 34 +- .../vector_rep_shared_examples.rb | 68 +--- spec/models/rag_document_fragment_spec.rb | 5 +- .../embeddings/embeddings_controller_spec.rb | 4 +- 29 files changed, 485 insertions(+), 599 deletions(-) create mode 100644 lib/embeddings/schema.rb create mode 100644 spec/lib/modules/embeddings/schema_spec.rb diff --git a/app/jobs/regular/digest_rag_upload.rb b/app/jobs/regular/digest_rag_upload.rb index 9ea10c06..76b9ee65 100644 --- a/app/jobs/regular/digest_rag_upload.rb +++ b/app/jobs/regular/digest_rag_upload.rb @@ -18,9 +18,7 @@ module ::Jobs target = target_type.constantize.find_by(id: target_id) return if !target - truncation = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) + vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation tokenizer = vector_rep.tokenizer chunk_tokens = target.rag_chunk_tokens diff --git a/app/jobs/regular/generate_embeddings.rb b/app/jobs/regular/generate_embeddings.rb index 70961a2d..4e16dc22 100644 --- a/app/jobs/regular/generate_embeddings.rb +++ b/app/jobs/regular/generate_embeddings.rb @@ -16,9 +16,7 @@ module Jobs return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms return if post.raw.blank? - strategy = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation vector_rep.generate_representation_from(target) end diff --git a/app/jobs/regular/generate_rag_embeddings.rb b/app/jobs/regular/generate_rag_embeddings.rb index a125a21b..836ede8b 100644 --- a/app/jobs/regular/generate_rag_embeddings.rb +++ b/app/jobs/regular/generate_rag_embeddings.rb @@ -8,9 +8,7 @@ module ::Jobs def execute(args) return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty? - truncation = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) + vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation # generate_representation_from checks compares the digest value to make sure # the embedding is only generated once per fragment unless something changes. diff --git a/app/jobs/scheduled/embeddings_backfill.rb b/app/jobs/scheduled/embeddings_backfill.rb index f8ffd448..0c4c0a9a 100644 --- a/app/jobs/scheduled/embeddings_backfill.rb +++ b/app/jobs/scheduled/embeddings_backfill.rb @@ -20,10 +20,8 @@ module Jobs rebaked = 0 - strategy = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) - table_name = vector_rep.topic_table_name + vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation + table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE topics = Topic @@ -41,7 +39,7 @@ module Jobs relation = topics.where(<<~SQL).limit(limit - rebaked) #{table_name}.model_version < #{vector_rep.version} OR - #{table_name}.strategy_version < #{strategy.version} + #{table_name}.strategy_version < #{vector_rep.strategy_version} SQL rebaked += populate_topic_embeddings(vector_rep, relation) @@ -63,7 +61,7 @@ module Jobs return unless SiteSetting.ai_embeddings_per_post_enabled # Now for posts - table_name = vector_rep.post_table_name + table_name = DiscourseAi::Embeddings::Schema::POSTS_TABLE posts_batch_size = 1000 posts = @@ -121,7 +119,8 @@ module Jobs def populate_topic_embeddings(vector_rep, topics, force: false) done = 0 - topics = topics.where("#{vector_rep.topic_table_name}.topic_id IS NULL") if !force + topics = + topics.where("#{DiscourseAi::Embeddings::Schema::TOPICS_TABLE}.topic_id IS NULL") if !force ids = topics.pluck("topics.id") batch_size = 1000 diff --git a/app/models/rag_document_fragment.rb b/app/models/rag_document_fragment.rb index 6c391dbc..c1cd7d9d 100644 --- a/app/models/rag_document_fragment.rb +++ b/app/models/rag_document_fragment.rb @@ -39,11 +39,7 @@ class RagDocumentFragment < ActiveRecord::Base end def indexing_status(persona, uploads) - truncation = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) - - embeddings_table = vector_rep.rag_fragments_table_name + embeddings_table = DiscourseAi::Embeddings::Schema.for(self).table results = DB.query( diff --git a/db/migrate/20230710171143_migrate_embeddings_from_dedicated_database.rb b/db/migrate/20230710171143_migrate_embeddings_from_dedicated_database.rb index aef3b2be..3a60d99c 100644 --- a/db/migrate/20230710171143_migrate_embeddings_from_dedicated_database.rb +++ b/db/migrate/20230710171143_migrate_embeddings_from_dedicated_database.rb @@ -14,7 +14,7 @@ class MigrateEmbeddingsFromDedicatedDatabase < ActiveRecord::Migration[7.0] ].map { |k| k.new(truncation) } vector_reps.each do |vector_rep| - new_table_name = vector_rep.topic_table_name + new_table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE old_table_name = "topic_embeddings_#{vector_rep.name.underscore}" begin diff --git a/db/migrate/20240611170905_move_embeddings_to_single_table_per_type.rb b/db/migrate/20240611170905_move_embeddings_to_single_table_per_type.rb index a21a24b8..e43c0137 100644 --- a/db/migrate/20240611170905_move_embeddings_to_single_table_per_type.rb +++ b/db/migrate/20240611170905_move_embeddings_to_single_table_per_type.rb @@ -147,9 +147,7 @@ class MoveEmbeddingsToSingleTablePerType < ActiveRecord::Migration[7.0] SQL begin - strategy = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation rescue StandardError => e Rails.logger.error("Failed to index embeddings: #{e}") end diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index e818c19f..d3b7cdf4 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -314,30 +314,34 @@ module DiscourseAi return nil if !consolidated_question - strategy = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings interactions_vector = vector_rep.vector_from(consolidated_question) rag_conversation_chunks = self.class.rag_conversation_chunks + search_limit = + if reranker.reranker_configured? + rag_conversation_chunks * 5 + else + rag_conversation_chunks + end + + schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep) candidate_fragment_ids = - vector_rep.asymmetric_rag_fragment_similarity_search( - interactions_vector, - target_type: "AiPersona", - target_id: id, - limit: - ( - if reranker.reranker_configured? - rag_conversation_chunks * 5 - else - rag_conversation_chunks - end - ), - offset: 0, - ) + schema + .asymmetric_similarity_search( + interactions_vector, + limit: search_limit, + offset: 0, + ) { |builder| builder.join(<<~SQL, target_id: id, target_type: "AiPersona") } + rag_document_fragments ON + rag_document_fragments.id = rag_document_fragment_id AND + rag_document_fragments.target_id = :target_id AND + rag_document_fragments.target_type = :target_type + SQL + .map(&:rag_document_fragment_id) fragments = RagDocumentFragment.where(upload_id: upload_refs, id: candidate_fragment_ids).pluck( diff --git a/lib/ai_bot/tool_runner.rb b/lib/ai_bot/tool_runner.rb index 552bccda..74c92767 100644 --- a/lib/ai_bot/tool_runner.rb +++ b/lib/ai_bot/tool_runner.rb @@ -141,18 +141,20 @@ module DiscourseAi return [] if upload_refs.empty? - strategy = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation query_vector = vector_rep.vector_from(query) fragment_ids = - vector_rep.asymmetric_rag_fragment_similarity_search( - query_vector, - target_type: "AiTool", - target_id: tool.id, - limit: limit, - offset: 0, - ) + DiscourseAi::Embeddings::Schema + .for(RagDocumentFragment, vector: vector_rep) + .asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder| + builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool") + rag_document_fragments ON + rag_document_fragments.id = rag_document_fragment_id AND + rag_document_fragments.target_id = :target_id AND + rag_document_fragments.target_type = :target_type + SQL + end + .map(&:rag_document_fragment_id) fragments = RagDocumentFragment.where(id: fragment_ids, upload_id: upload_refs).pluck( diff --git a/lib/ai_helper/semantic_categorizer.rb b/lib/ai_helper/semantic_categorizer.rb index 72759a6c..fc39c264 100644 --- a/lib/ai_helper/semantic_categorizer.rb +++ b/lib/ai_helper/semantic_categorizer.rb @@ -92,9 +92,8 @@ module DiscourseAi private def nearest_neighbors(limit: 100) - strategy = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation + schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) raw_vector = vector_rep.vector_from(@text) @@ -107,13 +106,15 @@ module DiscourseAi ).pluck(:category_id) end - vector_rep.asymmetric_topics_similarity_search( - raw_vector, - limit: limit, - offset: 0, - return_distance: true, - exclude_category_ids: muted_category_ids, - ) + schema + .asymmetric_similarity_search(raw_vector, limit: limit, offset: 0) do |builder| + builder.join("topics t on t.id = topic_id") + builder.where(<<~SQL, exclude_category_ids: muted_category_ids.map(&:to_i)) + t.category_id NOT IN (:exclude_category_ids) AND + t.category_id NOT IN (SELECT categories.id FROM categories WHERE categories.parent_category_id IN (:exclude_category_ids)) + SQL + end + .map { |r| [r.topic_id, r.distance] } end end end diff --git a/lib/embeddings/schema.rb b/lib/embeddings/schema.rb new file mode 100644 index 00000000..bf878fd1 --- /dev/null +++ b/lib/embeddings/schema.rb @@ -0,0 +1,194 @@ +# frozen_string_literal: true + +# We don't have AR objects for our embeddings, so this class +# acts as an intermediary between us and the DB. +# It lets us retrieve embeddings either symmetrically and asymmetrically, +# and also store them. + +module DiscourseAi + module Embeddings + class Schema + TOPICS_TABLE = "ai_topic_embeddings" + POSTS_TABLE = "ai_post_embeddings" + RAG_DOCS_TABLE = "ai_document_fragment_embeddings" + + def self.for( + target_klass, + vector: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation + ) + case target_klass&.name + when "Topic" + new(TOPICS_TABLE, "topic_id", vector) + when "Post" + new(POSTS_TABLE, "post_id", vector) + when "RagDocumentFragment" + new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector) + else + raise ArgumentError, "Invalid target type for embeddings" + end + end + + def initialize(table, target_column, vector) + @table = table + @target_column = target_column + @vector = vector + end + + attr_reader :table, :target_column, :vector + + def find_by_embedding(embedding) + DB.query(<<~SQL, query_embedding: embedding, vid: vector.id, vsid: vector.strategy_id).first + SELECT * + FROM #{table} + WHERE + model_id = :vid AND strategy_id = :vsid + ORDER BY + embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) + LIMIT 1 + SQL + end + + def find_by_target(target) + DB.query(<<~SQL, target_id: target.id, vid: vector.id, vsid: vector.strategy_id).first + SELECT * + FROM #{table} + WHERE + model_id = :vid AND + strategy_id = :vsid AND + #{target_column} = :target_id + LIMIT 1 + SQL + end + + def asymmetric_similarity_search(embedding, limit:, offset:) + builder = DB.build(<<~SQL) + WITH candidates AS ( + SELECT + #{target_column}, + embeddings::halfvec(#{dimensions}) AS embeddings + FROM + #{table} + /*join*/ + /*where*/ + ORDER BY + binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions})) + LIMIT :candidates_limit + ) + SELECT + #{target_column}, + embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance + FROM + candidates + ORDER BY + embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) + LIMIT :limit + OFFSET :offset + SQL + + builder.where( + "model_id = :model_id AND strategy_id = :strategy_id", + model_id: vector.id, + strategy_id: vector.strategy_id, + ) + + yield(builder) if block_given? + + if table == RAG_DOCS_TABLE + # A too low limit exacerbates the the recall loss of binary quantization + candidates_limit = [limit * 2, 100].max + else + candidates_limit = limit * 2 + end + + builder.query( + query_embedding: embedding, + candidates_limit: candidates_limit, + limit: limit, + offset: offset, + ) + rescue PG::Error => e + Rails.logger.error("Error #{e} querying embeddings for model #{name}") + raise MissingEmbeddingError + end + + def symmetric_similarity_search(record) + builder = DB.build(<<~SQL) + WITH le_target AS ( + SELECT + embeddings + FROM + #{table} + WHERE + model_id = :vid AND + strategy_id = :vsid AND + #{target_column} = :target_id + LIMIT 1 + ) + SELECT #{target_column} FROM ( + SELECT + #{target_column}, embeddings + FROM + #{table} + /*join*/ + /*where*/ + ORDER BY + binary_quantize(embeddings)::bit(#{dimensions}) <~> ( + SELECT + binary_quantize(embeddings)::bit(#{dimensions}) + FROM + le_target + LIMIT 1 + ) + LIMIT 200 + ) AS widenet + ORDER BY + embeddings::halfvec(#{dimensions}) #{pg_function} ( + SELECT + embeddings::halfvec(#{dimensions}) + FROM + le_target + LIMIT 1 + ) + LIMIT 100; + SQL + + builder.where("model_id = :vid AND strategy_id = :vsid") + + yield(builder) if block_given? + + builder.query(vid: vector.id, vsid: vector.strategy_id, target_id: record.id) + rescue PG::Error => e + Rails.logger.error("Error #{e} querying embeddings for model #{name}") + raise MissingEmbeddingError + end + + def store(record, embedding, digest) + DB.exec( + <<~SQL, + INSERT INTO #{table} (#{target_column}, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at) + VALUES (:target_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now) + ON CONFLICT (model_id, strategy_id, #{target_column}) + DO UPDATE SET + model_version = :model_version, + strategy_version = :strategy_version, + digest = :digest, + embeddings = '[:embeddings]', + updated_at = :now + SQL + target_id: record.id, + model_id: vector.id, + model_version: vector.version, + strategy_id: vector.strategy_id, + strategy_version: vector.strategy_version, + digest: digest, + embeddings: embedding, + now: Time.zone.now, + ) + end + + private + + delegate :dimensions, :pg_function, to: :vector + end + end +end diff --git a/lib/embeddings/semantic_related.rb b/lib/embeddings/semantic_related.rb index e8b8c517..c28d05fc 100644 --- a/lib/embeddings/semantic_related.rb +++ b/lib/embeddings/semantic_related.rb @@ -13,16 +13,16 @@ module DiscourseAi def related_topic_ids_for(topic) return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1 - strategy = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation cache_for = results_ttl(topic) Discourse .cache .fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do - vector_rep - .symmetric_topics_similarity_search(topic) + DiscourseAi::Embeddings::Schema + .for(Topic, vector: vector_rep) + .symmetric_similarity_search(topic) + .map(&:topic_id) .tap do |candidate_ids| # Happens when the topic doesn't have any embeddings # I'd rather not use Exceptions to control the flow, so this should be refactored soon diff --git a/lib/embeddings/semantic_search.rb b/lib/embeddings/semantic_search.rb index cae93958..108239b9 100644 --- a/lib/embeddings/semantic_search.rb +++ b/lib/embeddings/semantic_search.rb @@ -31,10 +31,7 @@ module DiscourseAi end def vector_rep - @vector_rep ||= - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation( - DiscourseAi::Embeddings::Strategies::Truncation.new, - ) + @vector_rep ||= DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation end def hyde_embedding(search_term) @@ -87,12 +84,14 @@ module DiscourseAi over_selection_limit = limit * OVER_SELECTION_FACTOR + schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) + candidate_topic_ids = - vector_rep.asymmetric_topics_similarity_search( + schema.asymmetric_similarity_search( search_embedding, limit: over_selection_limit, offset: offset, - ) + ).map(&:topic_id) semantic_results = ::Post @@ -115,9 +114,7 @@ module DiscourseAi return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length - strategy = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation digest = OpenSSL::Digest::SHA1.hexdigest(search_term) @@ -136,11 +133,14 @@ module DiscourseAi end candidate_post_ids = - vector_rep.asymmetric_posts_similarity_search( - search_term_embedding, - limit: max_semantic_results_per_page, - offset: 0, - ) + DiscourseAi::Embeddings::Schema + .for(Post, vector: vector_rep) + .asymmetric_similarity_search( + search_term_embedding, + limit: max_semantic_results_per_page, + offset: 0, + ) + .map(&:post_id) semantic_results = ::Post diff --git a/lib/embeddings/vector_representations/base.rb b/lib/embeddings/vector_representations/base.rb index 622bbebb..94e33d8d 100644 --- a/lib/embeddings/vector_representations/base.rb +++ b/lib/embeddings/vector_representations/base.rb @@ -20,8 +20,9 @@ module DiscourseAi ].find { _1.name == model_name } end - def current_representation(strategy) - find_representation(SiteSetting.ai_embeddings_model).new(strategy) + def current_representation + truncation = DiscourseAi::Embeddings::Strategies::Truncation.new + find_representation(SiteSetting.ai_embeddings_model).new(truncation) end def correctly_configured? @@ -59,6 +60,8 @@ module DiscourseAi idletime: 30, ) + schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector: self) + embedding_gen = inference_client promised_embeddings = relation @@ -67,7 +70,7 @@ module DiscourseAi next if prepared_text.blank? new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text) - next if find_digest_of(record) == new_digest + next if schema.find_by_target(record)&.digest == new_digest Concurrent::Promises .fulfilled_future( @@ -83,7 +86,7 @@ module DiscourseAi Concurrent::Promises .zip(*promised_embeddings) .value! - .each { |e| save_to_db(e[:target], e[:embedding], e[:digest]) } + .each { |e| schema.store(e[:target], e[:embedding], e[:digest]) } pool.shutdown pool.wait_for_termination @@ -93,265 +96,14 @@ module DiscourseAi text = prepare_text(target) return if text.blank? + schema = DiscourseAi::Embeddings::Schema.for(target.class, vector: self) + new_digest = OpenSSL::Digest::SHA1.hexdigest(text) - return if find_digest_of(target) == new_digest + return if schema.find_by_target(target)&.digest == new_digest vector = vector_from(text) - save_to_db(target, vector, new_digest) if persist - end - - def topic_id_from_representation(raw_vector) - DB.query_single(<<~SQL, query_embedding: raw_vector).first - SELECT - topic_id - FROM - #{topic_table_name} - WHERE - model_id = #{id} AND - strategy_id = #{@strategy.id} - ORDER BY - embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) - LIMIT 1 - SQL - end - - def post_id_from_representation(raw_vector) - DB.query_single(<<~SQL, query_embedding: raw_vector).first - SELECT - post_id - FROM - #{post_table_name} - WHERE - model_id = #{id} AND - strategy_id = #{@strategy.id} - ORDER BY - embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) - LIMIT 1 - SQL - end - - def asymmetric_topics_similarity_search( - raw_vector, - limit:, - offset:, - return_distance: false, - exclude_category_ids: nil - ) - builder = DB.build(<<~SQL) - WITH candidates AS ( - SELECT - topic_id, - embeddings::halfvec(#{dimensions}) AS embeddings - FROM - #{topic_table_name} - /*join*/ - /*where*/ - ORDER BY - binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions})) - LIMIT :limit * 2 - ) - SELECT - topic_id, - embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance - FROM - candidates - ORDER BY - embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) - LIMIT :limit - OFFSET :offset - SQL - - builder.where( - "model_id = :model_id AND strategy_id = :strategy_id", - model_id: id, - strategy_id: @strategy.id, - ) - - if exclude_category_ids.present? - builder.join("topics t on t.id = topic_id") - builder.where(<<~SQL, exclude_category_ids: exclude_category_ids.map(&:to_i)) - t.category_id NOT IN (:exclude_category_ids) AND - t.category_id NOT IN (SELECT categories.id FROM categories WHERE categories.parent_category_id IN (:exclude_category_ids)) - SQL - end - - results = builder.query(query_embedding: raw_vector, limit: limit, offset: offset) - - if return_distance - results.map { |r| [r.topic_id, r.distance] } - else - results.map(&:topic_id) - end - rescue PG::Error => e - Rails.logger.error("Error #{e} querying embeddings for model #{name}") - raise MissingEmbeddingError - end - - def asymmetric_posts_similarity_search(raw_vector, limit:, offset:, return_distance: false) - results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset) - WITH candidates AS ( - SELECT - post_id, - embeddings::halfvec(#{dimensions}) AS embeddings - FROM - #{post_table_name} - WHERE - model_id = #{id} AND strategy_id = #{@strategy.id} - ORDER BY - binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions})) - LIMIT :limit * 2 - ) - SELECT - post_id, - embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance - FROM - candidates - ORDER BY - embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) - LIMIT :limit - OFFSET :offset - SQL - - if return_distance - results.map { |r| [r.post_id, r.distance] } - else - results.map(&:post_id) - end - rescue PG::Error => e - Rails.logger.error("Error #{e} querying embeddings for model #{name}") - raise MissingEmbeddingError - end - - def asymmetric_rag_fragment_similarity_search( - raw_vector, - target_id:, - target_type:, - limit:, - offset:, - return_distance: false - ) - # A too low limit exacerbates the the recall loss of binary quantization - binary_search_limit = [limit * 2, 100].max - results = - DB.query( - <<~SQL, - WITH candidates AS ( - SELECT - rag_document_fragment_id, - embeddings::halfvec(#{dimensions}) AS embeddings - FROM - #{rag_fragments_table_name} - INNER JOIN - rag_document_fragments ON - rag_document_fragments.id = rag_document_fragment_id AND - rag_document_fragments.target_id = :target_id AND - rag_document_fragments.target_type = :target_type - WHERE - model_id = #{id} AND strategy_id = #{@strategy.id} - ORDER BY - binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions})) - LIMIT :binary_search_limit - ) - SELECT - rag_document_fragment_id, - embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance - FROM - candidates - ORDER BY - embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) - LIMIT :limit - OFFSET :offset - SQL - query_embedding: raw_vector, - target_id: target_id, - target_type: target_type, - limit: limit, - offset: offset, - binary_search_limit: binary_search_limit, - ) - - if return_distance - results.map { |r| [r.rag_document_fragment_id, r.distance] } - else - results.map(&:rag_document_fragment_id) - end - rescue PG::Error => e - Rails.logger.error("Error #{e} querying embeddings for model #{name}") - raise MissingEmbeddingError - end - - def symmetric_topics_similarity_search(topic) - DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id) - WITH le_target AS ( - SELECT - embeddings - FROM - #{topic_table_name} - WHERE - model_id = #{id} AND - strategy_id = #{@strategy.id} AND - topic_id = :topic_id - LIMIT 1 - ) - SELECT topic_id FROM ( - SELECT - topic_id, embeddings - FROM - #{topic_table_name} - WHERE - model_id = #{id} AND - strategy_id = #{@strategy.id} - ORDER BY - binary_quantize(embeddings)::bit(#{dimensions}) <~> ( - SELECT - binary_quantize(embeddings)::bit(#{dimensions}) - FROM - le_target - LIMIT 1 - ) - LIMIT 200 - ) AS widenet - ORDER BY - embeddings::halfvec(#{dimensions}) #{pg_function} ( - SELECT - embeddings::halfvec(#{dimensions}) - FROM - le_target - LIMIT 1 - ) - LIMIT 100; - SQL - rescue PG::Error => e - Rails.logger.error( - "Error #{e} querying embeddings for topic #{topic.id} and model #{name}", - ) - raise MissingEmbeddingError - end - - def topic_table_name - "ai_topic_embeddings" - end - - def post_table_name - "ai_post_embeddings" - end - - def rag_fragments_table_name - "ai_document_fragment_embeddings" - end - - def table_name(target) - case target - when Topic - topic_table_name - when Post - post_table_name - when RagDocumentFragment - rag_fragments_table_name - else - raise ArgumentError, "Invalid target type" - end + schema.store(target, vector, new_digest) if persist end def index_name(table_name) @@ -390,106 +142,16 @@ module DiscourseAi raise NotImplementedError end + def strategy_id + @strategy.id + end + + def strategy_version + @strategy.version + end + protected - def find_digest_of(target) - target_column = - case target - when Topic - "topic_id" - when Post - "post_id" - when RagDocumentFragment - "rag_document_fragment_id" - else - raise ArgumentError, "Invalid target type" - end - - DB.query_single(<<~SQL, target_id: target.id).first - SELECT - digest - FROM - #{table_name(target)} - WHERE - model_id = #{id} AND - strategy_id = #{@strategy.id} AND - #{target_column} = :target_id - LIMIT 1 - SQL - end - - def save_to_db(target, vector, digest) - if target.is_a?(Topic) - DB.exec( - <<~SQL, - INSERT INTO #{topic_table_name} (topic_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at) - VALUES (:topic_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now) - ON CONFLICT (strategy_id, model_id, topic_id) - DO UPDATE SET - model_version = :model_version, - strategy_version = :strategy_version, - digest = :digest, - embeddings = '[:embeddings]', - updated_at = :now - SQL - topic_id: target.id, - model_id: id, - model_version: version, - strategy_id: @strategy.id, - strategy_version: @strategy.version, - digest: digest, - embeddings: vector, - now: Time.zone.now, - ) - elsif target.is_a?(Post) - DB.exec( - <<~SQL, - INSERT INTO #{post_table_name} (post_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at) - VALUES (:post_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now) - ON CONFLICT (model_id, strategy_id, post_id) - DO UPDATE SET - model_version = :model_version, - strategy_version = :strategy_version, - digest = :digest, - embeddings = '[:embeddings]', - updated_at = :now - SQL - post_id: target.id, - model_id: id, - model_version: version, - strategy_id: @strategy.id, - strategy_version: @strategy.version, - digest: digest, - embeddings: vector, - now: Time.zone.now, - ) - elsif target.is_a?(RagDocumentFragment) - DB.exec( - <<~SQL, - INSERT INTO #{rag_fragments_table_name} (rag_document_fragment_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at) - VALUES (:fragment_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now) - ON CONFLICT (model_id, strategy_id, rag_document_fragment_id) - DO UPDATE SET - model_version = :model_version, - strategy_version = :strategy_version, - digest = :digest, - embeddings = '[:embeddings]', - updated_at = :now - SQL - fragment_id: target.id, - model_id: id, - model_version: version, - strategy_id: @strategy.id, - strategy_version: @strategy.version, - digest: digest, - embeddings: vector, - now: Time.zone.now, - ) - else - raise ArgumentError, "Invalid target type" - end - end - def inference_client raise NotImplementedError end diff --git a/lib/tasks/modules/embeddings/database.rake b/lib/tasks/modules/embeddings/database.rake index 83f652a0..6d4ffd0f 100644 --- a/lib/tasks/modules/embeddings/database.rake +++ b/lib/tasks/modules/embeddings/database.rake @@ -4,17 +4,16 @@ desc "Backfill embeddings for all topics and posts" task "ai:embeddings:backfill", %i[model concurrency] => [:environment] do |_, args| public_categories = Category.where(read_restricted: false).pluck(:id) - strategy = DiscourseAi::Embeddings::Strategies::Truncation.new if args[:model].present? + strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(args[:model]).new( strategy, ) else - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation end - table_name = vector_rep.topic_table_name + table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE topics = Topic diff --git a/spec/jobs/regular/digest_rag_upload_spec.rb b/spec/jobs/regular/digest_rag_upload_spec.rb index 063bbf73..eed03fc8 100644 --- a/spec/jobs/regular/digest_rag_upload_spec.rb +++ b/spec/jobs/regular/digest_rag_upload_spec.rb @@ -6,10 +6,7 @@ RSpec.describe Jobs::DigestRagUpload do let(:document_file) { StringIO.new("some text" * 200) } - let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } - let(:vector_rep) do - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) - end + let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } let(:expected_embedding) { [0.0038493] * vector_rep.dimensions } diff --git a/spec/jobs/regular/generate_rag_embeddings_spec.rb b/spec/jobs/regular/generate_rag_embeddings_spec.rb index 0f45f0cc..1cba9d06 100644 --- a/spec/jobs/regular/generate_rag_embeddings_spec.rb +++ b/spec/jobs/regular/generate_rag_embeddings_spec.rb @@ -2,10 +2,7 @@ RSpec.describe Jobs::GenerateRagEmbeddings do describe "#execute" do - let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } - let(:vector_rep) do - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) - end + let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } let(:expected_embedding) { [0.0038493] * vector_rep.dimensions } @@ -30,7 +27,9 @@ RSpec.describe Jobs::GenerateRagEmbeddings do subject.execute(fragment_ids: [rag_document_fragment_1.id, rag_document_fragment_2.id]) embeddings_count = - DB.query_single("SELECT COUNT(*) from #{vector_rep.rag_fragments_table_name}").first + DB.query_single( + "SELECT COUNT(*) from #{DiscourseAi::Embeddings::Schema::RAG_DOCS_TABLE}", + ).first expect(embeddings_count).to eq(expected_embeddings) end diff --git a/spec/jobs/scheduled/embeddings_backfill_spec.rb b/spec/jobs/scheduled/embeddings_backfill_spec.rb index 88cd8976..b54c2c80 100644 --- a/spec/jobs/scheduled/embeddings_backfill_spec.rb +++ b/spec/jobs/scheduled/embeddings_backfill_spec.rb @@ -19,10 +19,7 @@ RSpec.describe Jobs::EmbeddingsBackfill do topic end - let(:vector_rep) do - strategy = DiscourseAi::Embeddings::Strategies::Truncation.new - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) - end + let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } before do SiteSetting.ai_embeddings_enabled = true @@ -41,7 +38,8 @@ RSpec.describe Jobs::EmbeddingsBackfill do Jobs::EmbeddingsBackfill.new.execute({}) - topic_ids = DB.query_single("SELECT topic_id from #{vector_rep.topic_table_name}") + topic_ids = + DB.query_single("SELECT topic_id from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE}") expect(topic_ids).to eq([first_topic.id]) @@ -49,7 +47,8 @@ RSpec.describe Jobs::EmbeddingsBackfill do SiteSetting.ai_embeddings_backfill_batch_size = 100 Jobs::EmbeddingsBackfill.new.execute({}) - topic_ids = DB.query_single("SELECT topic_id from #{vector_rep.topic_table_name}") + topic_ids = + DB.query_single("SELECT topic_id from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE}") expect(topic_ids).to contain_exactly(first_topic.id, second_topic.id, third_topic.id) @@ -62,7 +61,7 @@ RSpec.describe Jobs::EmbeddingsBackfill do index_date = DB.query_single( - "SELECT updated_at from #{vector_rep.topic_table_name} WHERE topic_id = ?", + "SELECT updated_at from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE topic_id = ?", third_topic.id, ).first diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb index 86251d78..7cfcf9fc 100644 --- a/spec/lib/modules/ai_bot/personas/persona_spec.rb +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -326,9 +326,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do fab!(:llm_model) { Fabricate(:fake_model) } it "will run the question consolidator" do - strategy = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation context_embedding = vector_rep.dimensions.times.map { rand(-1.0...1.0) } EmbeddingsGenerationStubs.discourse_service( SiteSetting.ai_embeddings_model, @@ -375,41 +373,44 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do end context "when a persona has RAG uploads" do - def stub_fragments(limit, expected_limit: nil) - candidate_ids = [] + let(:vector_rep) do + DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation + end + let(:embedding_value) { 0.04381 } + let(:prompt_cc_embeddings) { [embedding_value] * vector_rep.dimensions } - limit.times do |i| - candidate_ids << Fabricate( - :rag_document_fragment, - fragment: "fragment-n#{i}", - target_id: ai_persona.id, - target_type: "AiPersona", - upload: upload, - ).id + def stub_fragments(fragment_count, persona: ai_persona) + schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep) + + fragment_count.times do |i| + fragment = + Fabricate( + :rag_document_fragment, + fragment: "fragment-n#{i}", + target_id: persona.id, + target_type: "AiPersona", + upload: upload, + ) + + # Similarity is determined left-to-right. + embeddings = [embedding_value + "0.000#{i}".to_f] * vector_rep.dimensions + + schema.store(fragment, embeddings, "test") end - - DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn - .any_instance - .expects(:asymmetric_rag_fragment_similarity_search) - .with { |args, kwargs| kwargs[:limit] == (expected_limit || limit) } - .returns(candidate_ids) end before do stored_ai_persona = AiPersona.find(ai_persona.id) UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id]) - context_embedding = [0.049382, 0.9999] EmbeddingsGenerationStubs.discourse_service( SiteSetting.ai_embeddings_model, with_cc.dig(:conversation_context, 0, :content), - context_embedding, + prompt_cc_embeddings, ) end context "when persona allows for less fragments" do - before { stub_fragments(3) } - it "will only pick 3 fragments" do custom_ai_persona = Fabricate( @@ -419,6 +420,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], ) + stub_fragments(3, persona: custom_ai_persona) + UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id]) custom_persona = @@ -438,14 +441,13 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do context "when the reranker is available" do before do SiteSetting.ai_hugging_face_tei_reranker_endpoint = "https://test.reranker.com" - - # hard coded internal implementation, reranker takes x5 number of chunks - stub_fragments(15, expected_limit: 50) # Mimic limit being more than 10 results + stub_fragments(15) end it "uses the re-ranker to reorder the fragments and pick the top 10 candidates" do - skip "This test is flaky needs to be investigated ordering does not come back as expected" - expected_reranked = (0..14).to_a.reverse.map { |idx| { index: idx } } + # The re-ranker reverses the similarity search, but return less results + # to act as a limit for test-purposes. + expected_reranked = (4..14).to_a.reverse.map { |idx| { index: idx } } WebMock.stub_request(:post, "https://test.reranker.com/rerank").to_return( status: 200, diff --git a/spec/lib/modules/ai_bot/tools/search_spec.rb b/spec/lib/modules/ai_bot/tools/search_spec.rb index 4f664f1b..04aa4d3f 100644 --- a/spec/lib/modules/ai_bot/tools/search_spec.rb +++ b/spec/lib/modules/ai_bot/tools/search_spec.rb @@ -110,8 +110,9 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model) SiteSetting.ai_embeddings_semantic_search_enabled = true SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" + vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation + hyde_embedding = [0.049382] * vector_rep.dimensions - hyde_embedding = [0.049382, 0.9999] EmbeddingsGenerationStubs.discourse_service( SiteSetting.ai_embeddings_model, query, @@ -126,10 +127,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do bot_user: bot_user, ) - DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn - .any_instance - .expects(:asymmetric_topics_similarity_search) - .returns([post1.topic_id]) + DiscourseAi::Embeddings::Schema.for(Topic).store(post1.topic, hyde_embedding, "digest") results = DiscourseAi::Completions::Llm.with_prepared_responses(["#{query}"]) do diff --git a/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb index 8d40b572..7a246864 100644 --- a/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb +++ b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb @@ -14,10 +14,7 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do fab!(:category) fab!(:topic) { Fabricate(:topic, category: category) } - let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } - let(:vector_rep) do - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) - end + let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) } let(:expected_embedding) { [0.0038493] * vector_rep.dimensions } diff --git a/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb b/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb index 0013f22b..5b9f0fb3 100644 --- a/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb +++ b/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb @@ -13,9 +13,9 @@ RSpec.describe Jobs::GenerateEmbeddings do fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } - let(:vector_rep) do - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) - end + let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } + let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) } + let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) } it "works for topics" do expected_embedding = [0.0038493] * vector_rep.dimensions @@ -30,7 +30,7 @@ RSpec.describe Jobs::GenerateEmbeddings do job.execute(target_id: topic.id, target_type: "Topic") - expect(vector_rep.topic_id_from_representation(expected_embedding)).to eq(topic.id) + expect(topics_schema.find_by_embedding(expected_embedding).topic_id).to eq(topic.id) end it "works for posts" do @@ -42,7 +42,7 @@ RSpec.describe Jobs::GenerateEmbeddings do job.execute(target_id: post.id, target_type: "Post") - expect(vector_rep.post_id_from_representation(expected_embedding)).to eq(post.id) + expect(posts_schema.find_by_embedding(expected_embedding).post_id).to eq(post.id) end end end diff --git a/spec/lib/modules/embeddings/schema_spec.rb b/spec/lib/modules/embeddings/schema_spec.rb new file mode 100644 index 00000000..45c4e243 --- /dev/null +++ b/spec/lib/modules/embeddings/schema_spec.rb @@ -0,0 +1,104 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Embeddings::Schema do + subject(:posts_schema) { described_class.for(Post, vector: vector) } + + let(:embeddings) { [0.0038490295] * vector.dimensions } + fab!(:post) { Fabricate(:post, post_number: 1) } + let(:digest) { OpenSSL::Digest.hexdigest("SHA1", "test") } + let(:vector) do + DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new( + DiscourseAi::Embeddings::Strategies::Truncation.new, + ) + end + + before { posts_schema.store(post, embeddings, digest) } + + describe "#find_by_target" do + it "gets you the post_id of the record that matches the post" do + embeddings_record = posts_schema.find_by_target(post) + + expect(embeddings_record.digest).to eq(digest) + expect(JSON.parse(embeddings_record.embeddings)).to eq(embeddings) + end + end + + describe "#find_by_embedding" do + it "gets you the record that matches the embedding" do + embeddings_record = posts_schema.find_by_embedding(embeddings) + + expect(embeddings_record.digest).to eq(digest) + expect(embeddings_record.post_id).to eq(post.id) + end + end + + describe "similarity searches" do + fab!(:post_2) { Fabricate(:post) } + let(:similar_embeddings) { [0.0038490294] * vector.dimensions } + + describe "#symmetric_similarity_search" do + before { posts_schema.store(post_2, similar_embeddings, digest) } + + it "returns target_id with similar embeddings" do + similar_records = posts_schema.symmetric_similarity_search(post) + + expect(similar_records.map(&:post_id)).to contain_exactly(post.id, post_2.id) + end + + it "let's you apply additional scopes to filter results further" do + similar_records = + posts_schema.symmetric_similarity_search(post) do |builder| + builder.where("post_id = ?", post_2.id) + end + + expect(similar_records.map(&:post_id)).to contain_exactly(post_2.id) + end + + it "let's you join on additional tables and combine with additional scopes" do + similar_records = + posts_schema.symmetric_similarity_search(post) do |builder| + builder.join("posts p on p.id = post_id") + builder.join("topics t on t.id = p.topic_id") + builder.where("t.id = ?", post_2.topic_id) + end + + expect(similar_records.map(&:post_id)).to contain_exactly(post_2.id) + end + end + + describe "#asymmetric_similarity_search" do + it "returns target_id with similar embeddings" do + similar_records = + posts_schema.asymmetric_similarity_search(similar_embeddings, limit: 1, offset: 0) + + expect(similar_records.map(&:post_id)).to contain_exactly(post.id) + end + + it "let's you apply additional scopes to filter results further" do + similar_records = + posts_schema.asymmetric_similarity_search( + similar_embeddings, + limit: 1, + offset: 0, + ) { |builder| builder.where("post_id <> ?", post.id) } + + expect(similar_records.map(&:post_id)).to be_empty + end + + it "let's you join on additional tables and combine with additional scopes" do + similar_records = + posts_schema.asymmetric_similarity_search( + similar_embeddings, + limit: 1, + offset: 0, + ) do |builder| + builder.join("posts p on p.id = post_id") + builder.join("topics t on t.id = p.topic_id") + builder.where("t.id <> ?", post.topic_id) + end + + expect(similar_records.map(&:post_id)).to be_empty + end + end + end +end diff --git a/spec/lib/modules/embeddings/semantic_related_spec.rb b/spec/lib/modules/embeddings/semantic_related_spec.rb index 37965aa1..0349abad 100644 --- a/spec/lib/modules/embeddings/semantic_related_spec.rb +++ b/spec/lib/modules/embeddings/semantic_related_spec.rb @@ -25,9 +25,7 @@ describe DiscourseAi::Embeddings::SemanticRelated do end let(:vector_rep) do - strategy = DiscourseAi::Embeddings::Strategies::Truncation.new - - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation end it "properly generates embeddings if missing" do diff --git a/spec/lib/modules/embeddings/semantic_search_spec.rb b/spec/lib/modules/embeddings/semantic_search_spec.rb index 613c1cf5..67d5c275 100644 --- a/spec/lib/modules/embeddings/semantic_search_spec.rb +++ b/spec/lib/modules/embeddings/semantic_search_spec.rb @@ -11,11 +11,12 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do describe "#search_for_topics" do let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" } + let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } + let(:hyde_embedding) { [0.049382] * vector_rep.dimensions } before do SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" - hyde_embedding = [0.049382, 0.9999] EmbeddingsGenerationStubs.discourse_service( SiteSetting.ai_embeddings_model, hypothetical_post, @@ -25,11 +26,8 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do after { described_class.clear_cache_for(query) } - def stub_candidate_ids(candidate_ids) - DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn - .any_instance - .expects(:asymmetric_topics_similarity_search) - .returns(candidate_ids) + def insert_candidate(candidate) + DiscourseAi::Embeddings::Schema.for(Topic).store(candidate, hyde_embedding, "digest") end def trigger_search(query) @@ -39,7 +37,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do end it "returns the first post of a topic included in the asymmetric search results" do - stub_candidate_ids([post.topic_id]) + insert_candidate(post.topic) posts = trigger_search(query) @@ -50,7 +48,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do context "when the topic is not visible" do it "returns an empty list" do post.topic.update!(visible: false) - stub_candidate_ids([post.topic_id]) + insert_candidate(post.topic) posts = trigger_search(query) @@ -61,7 +59,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do context "when the post is not public" do it "returns an empty list" do pm_post = Fabricate(:private_message_post) - stub_candidate_ids([pm_post.topic_id]) + insert_candidate(pm_post.topic) posts = trigger_search(query) @@ -72,7 +70,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do context "when the post type is not visible" do it "returns an empty list" do post.update!(post_type: Post.types[:whisper]) - stub_candidate_ids([post.topic_id]) + insert_candidate(post.topic) posts = trigger_search(query) @@ -84,7 +82,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do it "returns an empty list" do reply = Fabricate(:reply) reply.topic.first_post.trash! - stub_candidate_ids([reply.topic_id]) + insert_candidate(reply.topic) posts = trigger_search(query) @@ -95,7 +93,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do context "when the post is not a candidate" do it "doesn't include it in the results" do post_2 = Fabricate(:post) - stub_candidate_ids([post.topic_id]) + insert_candidate(post.topic) posts = trigger_search(query) @@ -109,7 +107,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do before do post.topic.update!(category: private_category) - stub_candidate_ids([post.topic_id]) + insert_candidate(post.topic) end it "returns an empty list" do diff --git a/spec/lib/modules/embeddings/semantic_topic_query_spec.rb b/spec/lib/modules/embeddings/semantic_topic_query_spec.rb index ffde6e6a..9c83f42e 100644 --- a/spec/lib/modules/embeddings/semantic_topic_query_spec.rb +++ b/spec/lib/modules/embeddings/semantic_topic_query_spec.rb @@ -9,11 +9,17 @@ describe DiscourseAi::Embeddings::EntryPoint do fab!(:target) { Fabricate(:topic) } - def stub_semantic_search_with(results) - DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn - .any_instance - .expects(:symmetric_topics_similarity_search) - .returns(results.concat([target.id])) + # The Distance gap to target increases for each element of topics. + def seed_embeddings(topics) + schema = DiscourseAi::Embeddings::Schema.for(Topic) + base_value = 1 + + schema.store(target, [base_value] * 1024, "disgest") + + topics.each do |t| + base_value -= 0.01 + schema.store(t, [base_value] * 1024, "digest") + end end after { DiscourseAi::Embeddings::SemanticRelated.clear_cache_for(target) } @@ -21,7 +27,7 @@ describe DiscourseAi::Embeddings::EntryPoint do context "when the semantic search returns an unlisted topic" do fab!(:unlisted_topic) { Fabricate(:topic, visible: false) } - before { stub_semantic_search_with([unlisted_topic.id]) } + before { seed_embeddings([unlisted_topic]) } it "filters it out" do expect(topic_query.list_semantic_related_topics(target).topics).to be_empty @@ -31,7 +37,7 @@ describe DiscourseAi::Embeddings::EntryPoint do context "when the semantic search returns a private topic" do fab!(:private_topic) { Fabricate(:private_message_topic) } - before { stub_semantic_search_with([private_topic.id]) } + before { seed_embeddings([private_topic]) } it "filters it out" do expect(topic_query.list_semantic_related_topics(target).topics).to be_empty @@ -43,7 +49,7 @@ describe DiscourseAi::Embeddings::EntryPoint do fab!(:category) { Fabricate(:private_category, group: group) } fab!(:secured_category_topic) { Fabricate(:topic, category: category) } - before { stub_semantic_search_with([secured_category_topic.id]) } + before { seed_embeddings([secured_category_topic]) } it "filters it out" do expect(topic_query.list_semantic_related_topics(target).topics).to be_empty @@ -63,7 +69,7 @@ describe DiscourseAi::Embeddings::EntryPoint do before do SiteSetting.ai_embeddings_semantic_related_include_closed_topics = false - stub_semantic_search_with([closed_topic.id]) + seed_embeddings([closed_topic]) end it "filters it out" do @@ -80,7 +86,7 @@ describe DiscourseAi::Embeddings::EntryPoint do category_id: category.id, notification_level: CategoryUser.notification_levels[:muted], ) - stub_semantic_search_with([topic.id]) + seed_embeddings([topic]) expect(topic_query.list_semantic_related_topics(target).topics).not_to include(topic) end end @@ -91,11 +97,7 @@ describe DiscourseAi::Embeddings::EntryPoint do fab!(:normal_topic_3) { Fabricate(:topic) } fab!(:closed_topic) { Fabricate(:topic, closed: true) } - before do - stub_semantic_search_with( - [closed_topic.id, normal_topic_1.id, normal_topic_2.id, normal_topic_3.id], - ) - end + before { seed_embeddings([closed_topic, normal_topic_1, normal_topic_2, normal_topic_3]) } it "filters it out" do expect(topic_query.list_semantic_related_topics(target).topics).to eq( @@ -117,7 +119,7 @@ describe DiscourseAi::Embeddings::EntryPoint do fab!(:included_topic) { Fabricate(:topic) } fab!(:excluded_topic) { Fabricate(:topic) } - before { stub_semantic_search_with([included_topic.id, excluded_topic.id]) } + before { seed_embeddings([included_topic, excluded_topic]) } let(:modifier_block) { Proc.new { |query| query.where.not(id: excluded_topic.id) } } diff --git a/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb index 714a9189..0801860e 100644 --- a/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb +++ b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb @@ -4,6 +4,9 @@ RSpec.shared_examples "generates and store embedding using with vector represent let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions } let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions } + let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) } + let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) } + describe "#vector_from" do it "creates a vector from a given string" do text = "This is a piece of text" @@ -29,7 +32,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent vector_rep.generate_representation_from(topic) - expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id) + expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id) end it "creates a vector from a post and stores it in the database" do @@ -43,7 +46,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent vector_rep.generate_representation_from(post) - expect(vector_rep.post_id_from_representation(expected_embedding_1)).to eq(post.id) + expect(posts_schema.find_by_embedding(expected_embedding_1).post_id).to eq(post.id) end end @@ -76,8 +79,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id])) - expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id) - expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id) + expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id) end it "does nothing if passed record has no content" do @@ -99,69 +101,15 @@ RSpec.shared_examples "generates and store embedding using with vector represent vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id])) end # check vector exists - expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id) + expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id) vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id])) last_update = DB.query_single( - "SELECT updated_at FROM #{vector_rep.topic_table_name} WHERE topic_id = #{topic.id} LIMIT 1", + "SELECT updated_at FROM #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE topic_id = #{topic.id} LIMIT 1", ).first expect(last_update).to eq(original_vector_gen) end end - - describe "#asymmetric_topics_similarity_search" do - fab!(:topic) { Fabricate(:topic) } - fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } - - it "finds IDs of similar topics with a given embedding" do - similar_vector = [0.0038494] * vector_rep.dimensions - text = - truncation.prepare_text_from( - topic, - vector_rep.tokenizer, - vector_rep.max_sequence_length - 2, - ) - stub_vector_mapping(text, expected_embedding_1) - vector_rep.generate_representation_from(topic) - - expect( - vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0), - ).to contain_exactly(topic.id) - end - - it "can exclude categories" do - similar_vector = [0.0038494] * vector_rep.dimensions - text = - truncation.prepare_text_from( - topic, - vector_rep.tokenizer, - vector_rep.max_sequence_length - 2, - ) - stub_vector_mapping(text, expected_embedding_1) - vector_rep.generate_representation_from(topic) - - expect( - vector_rep.asymmetric_topics_similarity_search( - similar_vector, - limit: 1, - offset: 0, - exclude_category_ids: [topic.category_id], - ), - ).to be_empty - - child_category = Fabricate(:category, parent_category_id: topic.category_id) - topic.update!(category_id: child_category.id) - - expect( - vector_rep.asymmetric_topics_similarity_search( - similar_vector, - limit: 1, - offset: 0, - exclude_category_ids: [topic.category_id], - ), - ).to be_empty - end - end end diff --git a/spec/models/rag_document_fragment_spec.rb b/spec/models/rag_document_fragment_spec.rb index 9ddbc830..77a0061d 100644 --- a/spec/models/rag_document_fragment_spec.rb +++ b/spec/models/rag_document_fragment_spec.rb @@ -74,10 +74,7 @@ RSpec.describe RagDocumentFragment do end describe ".indexing_status" do - let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } - let(:vector_rep) do - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) - end + let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } fab!(:rag_document_fragment_1) do Fabricate(:rag_document_fragment, upload: upload_1, target: persona) diff --git a/spec/requests/embeddings/embeddings_controller_spec.rb b/spec/requests/embeddings/embeddings_controller_spec.rb index 2a411644..39bcf136 100644 --- a/spec/requests/embeddings/embeddings_controller_spec.rb +++ b/spec/requests/embeddings/embeddings_controller_spec.rb @@ -19,9 +19,7 @@ describe DiscourseAi::Embeddings::EmbeddingsController do fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) } def index(topic) - strategy = DiscourseAi::Embeddings::Strategies::Truncation.new - vector_rep = - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation stub_request(:post, "https://api.openai.com/v1/embeddings").to_return( status: 200,