From 534b0df391ad6d1ae54c6674453aec5a6c69203f Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Mon, 16 Dec 2024 09:55:39 -0300 Subject: [PATCH] REFACTOR: Separation of concerns for embedding generation. (#1027) In a previous refactor, we moved the responsibility of querying and storing embeddings into the `Schema` class. Now, it's time for embedding generation. The motivation behind these changes is to isolate vector characteristics in simple objects to later replace them with a DB-backed version, similar to what we did with LLM configs. --- app/jobs/regular/generate_embeddings.rb | 4 +- app/jobs/regular/generate_rag_embeddings.rb | 4 +- app/jobs/scheduled/embeddings_backfill.rb | 27 +-- lib/ai_bot/personas/persona.rb | 6 +- lib/ai_bot/tool_runner.rb | 5 +- lib/ai_helper/semantic_categorizer.rb | 6 +- lib/embeddings/schema.rb | 44 +++-- lib/embeddings/semantic_related.rb | 3 +- lib/embeddings/semantic_search.rb | 18 +- lib/embeddings/strategies/truncation.rb | 17 +- lib/embeddings/vector.rb | 76 +++++++++ .../all_mpnet_base_v2.rb | 8 - lib/embeddings/vector_representations/base.rb | 92 ++-------- .../vector_representations/bge_large_en.rb | 21 +-- .../vector_representations/bge_m3.rb | 9 - .../vector_representations/gemini.rb | 8 - .../multilingual_e5_large.rb | 31 ++-- .../text_embedding_3_large.rb | 8 - .../text_embedding_3_small.rb | 8 - .../text_embedding_ada_002.rb | 8 - .../modules/ai_bot/personas/persona_spec.rb | 12 +- .../ai_helper/semantic_categorizer_spec.rb | 8 +- .../jobs/generate_embeddings_spec.rb | 26 ++- spec/lib/modules/embeddings/schema_spec.rb | 12 +- .../embeddings/semantic_search_spec.rb | 4 +- .../embeddings/strategies/truncation_spec.rb | 13 +- .../all_mpnet_base_v2_spec.rb | 17 -- .../vector_representations/gemini_spec.rb | 18 -- .../multilingual_e5_large_spec.rb | 21 --- .../text_embedding_3_large_spec.rb | 22 --- .../text_embedding_3_small_spec.rb | 15 -- .../text_embedding_ada_002_spec.rb | 15 -- .../vector_rep_shared_examples.rb | 115 ------------- spec/lib/modules/embeddings/vector_spec.rb | 160 ++++++++++++++++++ spec/models/rag_document_fragment_spec.rb | 6 +- .../embeddings/embeddings_controller_spec.rb | 4 +- 36 files changed, 375 insertions(+), 496 deletions(-) create mode 100644 lib/embeddings/vector.rb delete mode 100644 spec/lib/modules/embeddings/vector_representations/all_mpnet_base_v2_spec.rb delete mode 100644 spec/lib/modules/embeddings/vector_representations/gemini_spec.rb delete mode 100644 spec/lib/modules/embeddings/vector_representations/multilingual_e5_large_spec.rb delete mode 100644 spec/lib/modules/embeddings/vector_representations/text_embedding_3_large_spec.rb delete mode 100644 spec/lib/modules/embeddings/vector_representations/text_embedding_3_small_spec.rb delete mode 100644 spec/lib/modules/embeddings/vector_representations/text_embedding_ada_002_spec.rb delete mode 100644 spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb create mode 100644 spec/lib/modules/embeddings/vector_spec.rb diff --git a/app/jobs/regular/generate_embeddings.rb b/app/jobs/regular/generate_embeddings.rb index 4e16dc22..6cc71fed 100644 --- a/app/jobs/regular/generate_embeddings.rb +++ b/app/jobs/regular/generate_embeddings.rb @@ -16,9 +16,7 @@ module Jobs return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms return if post.raw.blank? - vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation - - vector_rep.generate_representation_from(target) + DiscourseAi::Embeddings::Vector.instance.generate_representation_from(target) end end end diff --git a/app/jobs/regular/generate_rag_embeddings.rb b/app/jobs/regular/generate_rag_embeddings.rb index 836ede8b..775b849f 100644 --- a/app/jobs/regular/generate_rag_embeddings.rb +++ b/app/jobs/regular/generate_rag_embeddings.rb @@ -8,11 +8,11 @@ module ::Jobs def execute(args) return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty? - vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation + vector = DiscourseAi::Embeddings::Vector.instance # generate_representation_from checks compares the digest value to make sure # the embedding is only generated once per fragment unless something changes. - fragments.map { |fragment| vector_rep.generate_representation_from(fragment) } + fragments.map { |fragment| vector.generate_representation_from(fragment) } last_fragment = fragments.last target = last_fragment.target diff --git a/app/jobs/scheduled/embeddings_backfill.rb b/app/jobs/scheduled/embeddings_backfill.rb index 383c4c05..4fdbe5bd 100644 --- a/app/jobs/scheduled/embeddings_backfill.rb +++ b/app/jobs/scheduled/embeddings_backfill.rb @@ -20,7 +20,8 @@ module Jobs rebaked = 0 - vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation + vector = DiscourseAi::Embeddings::Vector.instance + vector_def = vector.vdef table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE topics = @@ -30,19 +31,19 @@ module Jobs .where(deleted_at: nil) .order("topics.bumped_at DESC") - rebaked += populate_topic_embeddings(vector_rep, topics.limit(limit - rebaked)) + rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked)) return if rebaked >= limit # Then, we'll try to backfill embeddings for topics that have outdated # embeddings, be it model or strategy version relation = topics.where(<<~SQL).limit(limit - rebaked) - #{table_name}.model_version < #{vector_rep.version} + #{table_name}.model_version < #{vector_def.version} OR - #{table_name}.strategy_version < #{vector_rep.strategy_version} + #{table_name}.strategy_version < #{vector_def.strategy_version} SQL - rebaked += populate_topic_embeddings(vector_rep, relation) + rebaked += populate_topic_embeddings(vector, relation) return if rebaked >= limit @@ -54,7 +55,7 @@ module Jobs .where("#{table_name}.updated_at < topics.updated_at") .limit((limit - rebaked) / 10) - populate_topic_embeddings(vector_rep, relation, force: true) + populate_topic_embeddings(vector, relation, force: true) return if rebaked >= limit @@ -76,7 +77,7 @@ module Jobs .limit(limit - rebaked) .pluck(:id) .each_slice(posts_batch_size) do |batch| - vector_rep.gen_bulk_reprensentations(Post.where(id: batch)) + vector.gen_bulk_reprensentations(Post.where(id: batch)) rebaked += batch.length end @@ -86,14 +87,14 @@ module Jobs # embeddings, be it model or strategy version posts .where(<<~SQL) - #{table_name}.model_version < #{vector_rep.version} + #{table_name}.model_version < #{vector_def.version} OR - #{table_name}.strategy_version < #{vector_rep.strategy_version} + #{table_name}.strategy_version < #{vector_def.strategy_version} SQL .limit(limit - rebaked) .pluck(:id) .each_slice(posts_batch_size) do |batch| - vector_rep.gen_bulk_reprensentations(Post.where(id: batch)) + vector.gen_bulk_reprensentations(Post.where(id: batch)) rebaked += batch.length end @@ -107,7 +108,7 @@ module Jobs .limit((limit - rebaked) / 10) .pluck(:id) .each_slice(posts_batch_size) do |batch| - vector_rep.gen_bulk_reprensentations(Post.where(id: batch)) + vector.gen_bulk_reprensentations(Post.where(id: batch)) rebaked += batch.length end @@ -116,7 +117,7 @@ module Jobs private - def populate_topic_embeddings(vector_rep, topics, force: false) + def populate_topic_embeddings(vector, topics, force: false) done = 0 topics = @@ -126,7 +127,7 @@ module Jobs batch_size = 1000 ids.each_slice(batch_size) do |batch| - vector_rep.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC")) + vector.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC")) done += batch.length end diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index d3b7cdf4..1569ce45 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -314,10 +314,10 @@ module DiscourseAi return nil if !consolidated_question - vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation + vector = DiscourseAi::Embeddings::Vector.instance reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings - interactions_vector = vector_rep.vector_from(consolidated_question) + interactions_vector = vector.vector_from(consolidated_question) rag_conversation_chunks = self.class.rag_conversation_chunks search_limit = @@ -327,7 +327,7 @@ module DiscourseAi rag_conversation_chunks end - schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep) + schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector.vdef) candidate_fragment_ids = schema diff --git a/lib/ai_bot/tool_runner.rb b/lib/ai_bot/tool_runner.rb index 74c92767..17440ead 100644 --- a/lib/ai_bot/tool_runner.rb +++ b/lib/ai_bot/tool_runner.rb @@ -141,11 +141,10 @@ module DiscourseAi return [] if upload_refs.empty? - vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation - query_vector = vector_rep.vector_from(query) + query_vector = DiscourseAi::Embeddings::Vector.instance.vector_from(query) fragment_ids = DiscourseAi::Embeddings::Schema - .for(RagDocumentFragment, vector: vector_rep) + .for(RagDocumentFragment) .asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder| builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool") rag_document_fragments ON diff --git a/lib/ai_helper/semantic_categorizer.rb b/lib/ai_helper/semantic_categorizer.rb index fc39c264..b226b01b 100644 --- a/lib/ai_helper/semantic_categorizer.rb +++ b/lib/ai_helper/semantic_categorizer.rb @@ -92,10 +92,10 @@ module DiscourseAi private def nearest_neighbors(limit: 100) - vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation - schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) + vector = DiscourseAi::Embeddings::Vector.instance + schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef) - raw_vector = vector_rep.vector_from(@text) + raw_vector = vector.vector_from(@text) muted_category_ids = nil if @user.present? diff --git a/lib/embeddings/schema.rb b/lib/embeddings/schema.rb index bf878fd1..3ae51c7b 100644 --- a/lib/embeddings/schema.rb +++ b/lib/embeddings/schema.rb @@ -14,30 +14,31 @@ module DiscourseAi def self.for( target_klass, - vector: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation + vector_def: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation ) case target_klass&.name when "Topic" - new(TOPICS_TABLE, "topic_id", vector) + new(TOPICS_TABLE, "topic_id", vector_def) when "Post" - new(POSTS_TABLE, "post_id", vector) + new(POSTS_TABLE, "post_id", vector_def) when "RagDocumentFragment" - new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector) + new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def) else raise ArgumentError, "Invalid target type for embeddings" end end - def initialize(table, target_column, vector) + def initialize(table, target_column, vector_def) @table = table @target_column = target_column - @vector = vector + @vector_def = vector_def end - attr_reader :table, :target_column, :vector + attr_reader :table, :target_column, :vector_def def find_by_embedding(embedding) - DB.query(<<~SQL, query_embedding: embedding, vid: vector.id, vsid: vector.strategy_id).first + DB.query( + <<~SQL, SELECT * FROM #{table} WHERE @@ -46,10 +47,15 @@ module DiscourseAi embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) LIMIT 1 SQL + query_embedding: embedding, + vid: vector_def.id, + vsid: vector_def.strategy_id, + ).first end def find_by_target(target) - DB.query(<<~SQL, target_id: target.id, vid: vector.id, vsid: vector.strategy_id).first + DB.query( + <<~SQL, SELECT * FROM #{table} WHERE @@ -58,6 +64,10 @@ module DiscourseAi #{target_column} = :target_id LIMIT 1 SQL + target_id: target.id, + vid: vector_def.id, + vsid: vector_def.strategy_id, + ).first end def asymmetric_similarity_search(embedding, limit:, offset:) @@ -87,8 +97,8 @@ module DiscourseAi builder.where( "model_id = :model_id AND strategy_id = :strategy_id", - model_id: vector.id, - strategy_id: vector.strategy_id, + model_id: vector_def.id, + strategy_id: vector_def.strategy_id, ) yield(builder) if block_given? @@ -156,7 +166,7 @@ module DiscourseAi yield(builder) if block_given? - builder.query(vid: vector.id, vsid: vector.strategy_id, target_id: record.id) + builder.query(vid: vector_def.id, vsid: vector_def.strategy_id, target_id: record.id) rescue PG::Error => e Rails.logger.error("Error #{e} querying embeddings for model #{name}") raise MissingEmbeddingError @@ -176,10 +186,10 @@ module DiscourseAi updated_at = :now SQL target_id: record.id, - model_id: vector.id, - model_version: vector.version, - strategy_id: vector.strategy_id, - strategy_version: vector.strategy_version, + model_id: vector_def.id, + model_version: vector_def.version, + strategy_id: vector_def.strategy_id, + strategy_version: vector_def.strategy_version, digest: digest, embeddings: embedding, now: Time.zone.now, @@ -188,7 +198,7 @@ module DiscourseAi private - delegate :dimensions, :pg_function, to: :vector + delegate :dimensions, :pg_function, to: :vector_def end end end diff --git a/lib/embeddings/semantic_related.rb b/lib/embeddings/semantic_related.rb index c28d05fc..54c8e572 100644 --- a/lib/embeddings/semantic_related.rb +++ b/lib/embeddings/semantic_related.rb @@ -13,14 +13,13 @@ module DiscourseAi def related_topic_ids_for(topic) return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1 - vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation cache_for = results_ttl(topic) Discourse .cache .fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do DiscourseAi::Embeddings::Schema - .for(Topic, vector: vector_rep) + .for(Topic) .symmetric_similarity_search(topic) .map(&:topic_id) .tap do |candidate_ids| diff --git a/lib/embeddings/semantic_search.rb b/lib/embeddings/semantic_search.rb index 108239b9..490772fa 100644 --- a/lib/embeddings/semantic_search.rb +++ b/lib/embeddings/semantic_search.rb @@ -30,8 +30,8 @@ module DiscourseAi Discourse.cache.read(embedding_key).present? end - def vector_rep - @vector_rep ||= DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation + def vector + @vector ||= DiscourseAi::Embeddings::Vector.instance end def hyde_embedding(search_term) @@ -52,16 +52,14 @@ module DiscourseAi Discourse .cache - .fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) } + .fetch(embedding_key, expires_in: 1.week) { vector.vector_from(hypothetical_post) } end def embedding(search_term) digest = OpenSSL::Digest::SHA1.hexdigest(search_term) embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_model) - Discourse - .cache - .fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(search_term) } + Discourse.cache.fetch(embedding_key, expires_in: 1.week) { vector.vector_from(search_term) } end # this ensures the candidate topics are over selected @@ -84,7 +82,7 @@ module DiscourseAi over_selection_limit = limit * OVER_SELECTION_FACTOR - schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) + schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef) candidate_topic_ids = schema.asymmetric_similarity_search( @@ -114,7 +112,7 @@ module DiscourseAi return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length - vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation + vector = DiscourseAi::Embeddings::Vector.instance digest = OpenSSL::Digest::SHA1.hexdigest(search_term) @@ -129,12 +127,12 @@ module DiscourseAi Discourse .cache .fetch(embedding_key, expires_in: 1.week) do - vector_rep.vector_from(search_term, asymetric: true) + vector.vector_from(search_term, asymetric: true) end candidate_post_ids = DiscourseAi::Embeddings::Schema - .for(Post, vector: vector_rep) + .for(Post, vector_def: vector.vdef) .asymmetric_similarity_search( search_term_embedding, limit: max_semantic_results_per_page, diff --git a/lib/embeddings/strategies/truncation.rb b/lib/embeddings/strategies/truncation.rb index 1ae82d80..cbe463c0 100644 --- a/lib/embeddings/strategies/truncation.rb +++ b/lib/embeddings/strategies/truncation.rb @@ -12,19 +12,28 @@ module DiscourseAi 1 end - def prepare_text_from(target, tokenizer, max_length) + def prepare_target_text(target, vdef) + max_length = vdef.max_sequence_length - 2 + case target when Topic - topic_truncation(target, tokenizer, max_length) + topic_truncation(target, vdef.tokenizer, max_length) when Post - post_truncation(target, tokenizer, max_length) + post_truncation(target, vdef.tokenizer, max_length) when RagDocumentFragment - tokenizer.truncate(target.fragment, max_length) + vdef.tokenizer.truncate(target.fragment, max_length) else raise ArgumentError, "Invalid target type" end end + def prepare_query_text(text, vdef, asymetric: false) + qtext = asymetric ? "#{vdef.asymmetric_query_prefix} #{text}" : text + max_length = vdef.max_sequence_length - 2 + + vdef.tokenizer.truncate(text, max_length) + end + private def topic_information(topic) diff --git a/lib/embeddings/vector.rb b/lib/embeddings/vector.rb new file mode 100644 index 00000000..2fe8c72c --- /dev/null +++ b/lib/embeddings/vector.rb @@ -0,0 +1,76 @@ +# frozen_string_literal: true + +module DiscourseAi + module Embeddings + class Vector + def self.instance + new(DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation) + end + + def initialize(vector_definition) + @vdef = vector_definition + end + + def gen_bulk_reprensentations(relation) + http_pool_size = 100 + pool = + Concurrent::CachedThreadPool.new( + min_threads: 0, + max_threads: http_pool_size, + idletime: 30, + ) + + schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector_def: vdef) + + embedding_gen = vdef.inference_client + promised_embeddings = + relation + .map do |record| + prepared_text = vdef.prepare_target_text(record) + next if prepared_text.blank? + + new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text) + next if schema.find_by_target(record)&.digest == new_digest + + Concurrent::Promises + .fulfilled_future({ target: record, text: prepared_text, digest: new_digest }, pool) + .then_on(pool) do |w_prepared_text| + w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text])) + end + end + .compact + + Concurrent::Promises + .zip(*promised_embeddings) + .value! + .each { |e| schema.store(e[:target], e[:embedding], e[:digest]) } + + pool.shutdown + pool.wait_for_termination + end + + def generate_representation_from(target) + text = vdef.prepare_target_text(target) + return if text.blank? + + schema = DiscourseAi::Embeddings::Schema.for(target.class, vector_def: vdef) + + new_digest = OpenSSL::Digest::SHA1.hexdigest(text) + return if schema.find_by_target(target)&.digest == new_digest + + embeddings = vdef.inference_client.perform!(text) + + schema.store(target, embeddings, new_digest) + end + + def vector_from(text, asymetric: false) + prepared_text = vdef.prepare_query_text(text, asymetric: asymetric) + return if prepared_text.blank? + + vdef.inference_client.perform!(prepared_text) + end + + attr_reader :vdef + end + end +end diff --git a/lib/embeddings/vector_representations/all_mpnet_base_v2.rb b/lib/embeddings/vector_representations/all_mpnet_base_v2.rb index 7e4a2ad7..a89aab8b 100644 --- a/lib/embeddings/vector_representations/all_mpnet_base_v2.rb +++ b/lib/embeddings/vector_representations/all_mpnet_base_v2.rb @@ -23,10 +23,6 @@ module DiscourseAi end end - def vector_from(text, asymetric: false) - inference_client.perform!(text) - end - def dimensions 768 end @@ -47,10 +43,6 @@ module DiscourseAi "<#>" end - def pg_index_type - "halfvec_ip_ops" - end - def tokenizer DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer end diff --git a/lib/embeddings/vector_representations/base.rb b/lib/embeddings/vector_representations/base.rb index 94e33d8d..8670404b 100644 --- a/lib/embeddings/vector_representations/base.rb +++ b/lib/embeddings/vector_representations/base.rb @@ -21,8 +21,7 @@ module DiscourseAi end def current_representation - truncation = DiscourseAi::Embeddings::Strategies::Truncation.new - find_representation(SiteSetting.ai_embeddings_model).new(truncation) + find_representation(SiteSetting.ai_embeddings_model).new end def correctly_configured? @@ -43,73 +42,6 @@ module DiscourseAi end end - def initialize(strategy) - @strategy = strategy - end - - def vector_from(text, asymetric: false) - raise NotImplementedError - end - - def gen_bulk_reprensentations(relation) - http_pool_size = 100 - pool = - Concurrent::CachedThreadPool.new( - min_threads: 0, - max_threads: http_pool_size, - idletime: 30, - ) - - schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector: self) - - embedding_gen = inference_client - promised_embeddings = - relation - .map do |record| - prepared_text = prepare_text(record) - next if prepared_text.blank? - - new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text) - next if schema.find_by_target(record)&.digest == new_digest - - Concurrent::Promises - .fulfilled_future( - { target: record, text: prepared_text, digest: new_digest }, - pool, - ) - .then_on(pool) do |w_prepared_text| - w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text])) - end - end - .compact - - Concurrent::Promises - .zip(*promised_embeddings) - .value! - .each { |e| schema.store(e[:target], e[:embedding], e[:digest]) } - - pool.shutdown - pool.wait_for_termination - end - - def generate_representation_from(target, persist: true) - text = prepare_text(target) - return if text.blank? - - schema = DiscourseAi::Embeddings::Schema.for(target.class, vector: self) - - new_digest = OpenSSL::Digest::SHA1.hexdigest(text) - return if schema.find_by_target(target)&.digest == new_digest - - vector = vector_from(text) - - schema.store(target, vector, new_digest) if persist - end - - def index_name(table_name) - "#{table_name}_#{id}_#{@strategy.id}_search" - end - def name raise NotImplementedError end @@ -139,26 +71,32 @@ module DiscourseAi end def asymmetric_query_prefix - raise NotImplementedError + "" end def strategy_id - @strategy.id + strategy.id end def strategy_version - @strategy.version + strategy.version end - protected + def prepare_query_text(text, asymetric: false) + strategy.prepare_query_text(text, self, asymetric: asymetric) + end + + def prepare_target_text(target) + strategy.prepare_target_text(target, self) + end + + def strategy + @strategy ||= DiscourseAi::Embeddings::Strategies::Truncation.new + end def inference_client raise NotImplementedError end - - def prepare_text(record) - @strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2) - end end end end diff --git a/lib/embeddings/vector_representations/bge_large_en.rb b/lib/embeddings/vector_representations/bge_large_en.rb index 923ee19e..9006ebbe 100644 --- a/lib/embeddings/vector_representations/bge_large_en.rb +++ b/lib/embeddings/vector_representations/bge_large_en.rb @@ -30,21 +30,6 @@ module DiscourseAi end end - def vector_from(text, asymetric: false) - text = "#{asymmetric_query_prefix} #{text}" if asymetric - - client = inference_client - - needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings") - text = tokenizer.truncate(text, max_sequence_length - 2) if needs_truncation - - inference_client.perform!(text) - end - - def inference_model_name - "baai/bge-large-en-v1.5" - end - def dimensions 1024 end @@ -65,10 +50,6 @@ module DiscourseAi "<#>" end - def pg_index_type - "halfvec_ip_ops" - end - def tokenizer DiscourseAi::Tokenizer::BgeLargeEnTokenizer end @@ -78,6 +59,8 @@ module DiscourseAi end def inference_client + inference_model_name = "baai/bge-large-en-v1.5" + if SiteSetting.ai_cloudflare_workers_api_token.present? DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name) elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? diff --git a/lib/embeddings/vector_representations/bge_m3.rb b/lib/embeddings/vector_representations/bge_m3.rb index d7e963fc..d3625e41 100644 --- a/lib/embeddings/vector_representations/bge_m3.rb +++ b/lib/embeddings/vector_representations/bge_m3.rb @@ -18,11 +18,6 @@ module DiscourseAi end end - def vector_from(text, asymetric: false) - truncated_text = tokenizer.truncate(text, max_sequence_length - 2) - inference_client.perform!(truncated_text) - end - def dimensions 1024 end @@ -43,10 +38,6 @@ module DiscourseAi "<#>" end - def pg_index_type - "halfvec_ip_ops" - end - def tokenizer DiscourseAi::Tokenizer::BgeM3Tokenizer end diff --git a/lib/embeddings/vector_representations/gemini.rb b/lib/embeddings/vector_representations/gemini.rb index 5ac1cba7..110b8b4c 100644 --- a/lib/embeddings/vector_representations/gemini.rb +++ b/lib/embeddings/vector_representations/gemini.rb @@ -38,14 +38,6 @@ module DiscourseAi "<=>" end - def pg_index_type - "halfvec_cosine_ops" - end - - def vector_from(text, asymetric: false) - inference_client.perform!(text) - end - # There is no public tokenizer for Gemini, and from the ones we already ship in the plugin # OpenAI gets the closest results. Gemini Tokenizer results in ~10% less tokens, so it's safe # to use OpenAI tokenizer since it will overestimate the number of tokens. diff --git a/lib/embeddings/vector_representations/multilingual_e5_large.rb b/lib/embeddings/vector_representations/multilingual_e5_large.rb index fe611ec9..7d6894a8 100644 --- a/lib/embeddings/vector_representations/multilingual_e5_large.rb +++ b/lib/embeddings/vector_representations/multilingual_e5_large.rb @@ -28,19 +28,6 @@ module DiscourseAi end end - def vector_from(text, asymetric: false) - client = inference_client - - needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings") - if needs_truncation - text = tokenizer.truncate(text, max_sequence_length - 2) - elsif !text.starts_with?("query:") - text = "query: #{text}" - end - - client.perform!(text) - end - def id 3 end @@ -61,10 +48,6 @@ module DiscourseAi "<=>" end - def pg_index_type - "halfvec_cosine_ops" - end - def tokenizer DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer end @@ -80,8 +63,18 @@ module DiscourseAi end end - def prepare_text(record) - prepared_text = super(record) + def prepare_text(text, asymetric: false) + prepared_text = super(text, asymetric: asymetric) + + if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier") + return "query: #{prepared_text}" + end + + prepared_text + end + + def prepare_target_text(target) + prepared_text = super(target) if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier") return "query: #{prepared_text}" diff --git a/lib/embeddings/vector_representations/text_embedding_3_large.rb b/lib/embeddings/vector_representations/text_embedding_3_large.rb index 202d66de..d73f4cee 100644 --- a/lib/embeddings/vector_representations/text_embedding_3_large.rb +++ b/lib/embeddings/vector_representations/text_embedding_3_large.rb @@ -40,14 +40,6 @@ module DiscourseAi "<=>" end - def pg_index_type - "halfvec_cosine_ops" - end - - def vector_from(text, asymetric: false) - inference_client.perform!(text) - end - def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end diff --git a/lib/embeddings/vector_representations/text_embedding_3_small.rb b/lib/embeddings/vector_representations/text_embedding_3_small.rb index 87f31185..90a3f790 100644 --- a/lib/embeddings/vector_representations/text_embedding_3_small.rb +++ b/lib/embeddings/vector_representations/text_embedding_3_small.rb @@ -38,14 +38,6 @@ module DiscourseAi "<=>" end - def pg_index_type - "halfvec_cosine_ops" - end - - def vector_from(text, asymetric: false) - inference_client.perform!(text) - end - def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end diff --git a/lib/embeddings/vector_representations/text_embedding_ada_002.rb b/lib/embeddings/vector_representations/text_embedding_ada_002.rb index 1e570b98..f5340918 100644 --- a/lib/embeddings/vector_representations/text_embedding_ada_002.rb +++ b/lib/embeddings/vector_representations/text_embedding_ada_002.rb @@ -38,14 +38,6 @@ module DiscourseAi "<=>" end - def pg_index_type - "halfvec_cosine_ops" - end - - def vector_from(text, asymetric: false) - inference_client.perform!(text) - end - def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb index 0ff4bd53..c58158d3 100644 --- a/spec/lib/modules/ai_bot/personas/persona_spec.rb +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -326,8 +326,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do fab!(:llm_model) { Fabricate(:fake_model) } it "will run the question consolidator" do - vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation - context_embedding = vector_rep.dimensions.times.map { rand(-1.0...1.0) } + vector_def = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation + context_embedding = vector_def.dimensions.times.map { rand(-1.0...1.0) } EmbeddingsGenerationStubs.discourse_service( SiteSetting.ai_embeddings_model, consolidated_question, @@ -373,14 +373,14 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do end context "when a persona has RAG uploads" do - let(:vector_rep) do + let(:vector_def) do DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation end let(:embedding_value) { 0.04381 } - let(:prompt_cc_embeddings) { [embedding_value] * vector_rep.dimensions } + let(:prompt_cc_embeddings) { [embedding_value] * vector_def.dimensions } def stub_fragments(fragment_count, persona: ai_persona) - schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep) + schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector_def) fragment_count.times do |i| fragment = @@ -393,7 +393,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do ) # Similarity is determined left-to-right. - embeddings = [embedding_value + "0.000#{i}".to_f] * vector_rep.dimensions + embeddings = [embedding_value + "0.000#{i}".to_f] * vector_def.dimensions schema.store(fragment, embeddings, "test") end diff --git a/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb index 7a246864..8b89d800 100644 --- a/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb +++ b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb @@ -14,9 +14,9 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do fab!(:category) fab!(:topic) { Fabricate(:topic, category: category) } - let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } + let(:vector) { DiscourseAi::Embeddings::Vector.instance } let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) } - let(:expected_embedding) { [0.0038493] * vector_rep.dimensions } + let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions } before do SiteSetting.ai_embeddings_enabled = true @@ -28,8 +28,8 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", ).to_return(status: 200, body: JSON.dump(expected_embedding)) - vector_rep.generate_representation_from(topic) - vector_rep.generate_representation_from(muted_topic) + vector.generate_representation_from(topic) + vector.generate_representation_from(muted_topic) end it "respects user muted categories when making suggestions" do diff --git a/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb b/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb index 5b9f0fb3..56384d59 100644 --- a/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb +++ b/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb @@ -12,21 +12,16 @@ RSpec.describe Jobs::GenerateEmbeddings do fab!(:topic) fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } - let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } - let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } - let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) } - let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) } + let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } + let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector_def) } + let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vector_def) } it "works for topics" do - expected_embedding = [0.0038493] * vector_rep.dimensions + expected_embedding = [0.0038493] * vector_def.dimensions - text = - truncation.prepare_text_from( - topic, - vector_rep.tokenizer, - vector_rep.max_sequence_length - 2, - ) - EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding) + text = vector_def.prepare_target_text(topic) + + EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding) job.execute(target_id: topic.id, target_type: "Topic") @@ -34,11 +29,10 @@ RSpec.describe Jobs::GenerateEmbeddings do end it "works for posts" do - expected_embedding = [0.0038493] * vector_rep.dimensions + expected_embedding = [0.0038493] * vector_def.dimensions - text = - truncation.prepare_text_from(post, vector_rep.tokenizer, vector_rep.max_sequence_length - 2) - EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding) + text = vector_def.prepare_target_text(post) + EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding) job.execute(target_id: post.id, target_type: "Post") diff --git a/spec/lib/modules/embeddings/schema_spec.rb b/spec/lib/modules/embeddings/schema_spec.rb index 45c4e243..107428aa 100644 --- a/spec/lib/modules/embeddings/schema_spec.rb +++ b/spec/lib/modules/embeddings/schema_spec.rb @@ -1,16 +1,12 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::Embeddings::Schema do - subject(:posts_schema) { described_class.for(Post, vector: vector) } + subject(:posts_schema) { described_class.for(Post, vector_def: vector_def) } - let(:embeddings) { [0.0038490295] * vector.dimensions } + let(:embeddings) { [0.0038490295] * vector_def.dimensions } fab!(:post) { Fabricate(:post, post_number: 1) } let(:digest) { OpenSSL::Digest.hexdigest("SHA1", "test") } - let(:vector) do - DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new( - DiscourseAi::Embeddings::Strategies::Truncation.new, - ) - end + let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new } before { posts_schema.store(post, embeddings, digest) } @@ -34,7 +30,7 @@ RSpec.describe DiscourseAi::Embeddings::Schema do describe "similarity searches" do fab!(:post_2) { Fabricate(:post) } - let(:similar_embeddings) { [0.0038490294] * vector.dimensions } + let(:similar_embeddings) { [0.0038490294] * vector_def.dimensions } describe "#symmetric_similarity_search" do before { posts_schema.store(post_2, similar_embeddings, digest) } diff --git a/spec/lib/modules/embeddings/semantic_search_spec.rb b/spec/lib/modules/embeddings/semantic_search_spec.rb index 67d5c275..2ed0bae9 100644 --- a/spec/lib/modules/embeddings/semantic_search_spec.rb +++ b/spec/lib/modules/embeddings/semantic_search_spec.rb @@ -11,8 +11,8 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do describe "#search_for_topics" do let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" } - let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } - let(:hyde_embedding) { [0.049382] * vector_rep.dimensions } + let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } + let(:hyde_embedding) { [0.049382] * vector_def.dimensions } before do SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" diff --git a/spec/lib/modules/embeddings/strategies/truncation_spec.rb b/spec/lib/modules/embeddings/strategies/truncation_spec.rb index b0498023..0b31fb11 100644 --- a/spec/lib/modules/embeddings/strategies/truncation_spec.rb +++ b/spec/lib/modules/embeddings/strategies/truncation_spec.rb @@ -3,8 +3,8 @@ RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do subject(:truncation) { described_class.new } - describe "#prepare_text_from" do - context "when using vector from OpenAI" do + describe "#prepare_query_text" do + context "when using vector def from OpenAI" do before { SiteSetting.max_post_length = 100_000 } fab!(:topic) @@ -20,15 +20,12 @@ RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do end fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) } - let(:model) do - DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new(truncation) - end + let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new } it "truncates a topic" do - prepared_text = - truncation.prepare_text_from(topic, model.tokenizer, model.max_sequence_length) + prepared_text = truncation.prepare_target_text(topic, vector_def) - expect(model.tokenizer.size(prepared_text)).to be <= model.max_sequence_length + expect(vector_def.tokenizer.size(prepared_text)).to be <= vector_def.max_sequence_length end end end diff --git a/spec/lib/modules/embeddings/vector_representations/all_mpnet_base_v2_spec.rb b/spec/lib/modules/embeddings/vector_representations/all_mpnet_base_v2_spec.rb deleted file mode 100644 index c1cefe04..00000000 --- a/spec/lib/modules/embeddings/vector_representations/all_mpnet_base_v2_spec.rb +++ /dev/null @@ -1,17 +0,0 @@ -# frozen_string_literal: true - -require_relative "vector_rep_shared_examples" - -RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 do - subject(:vector_rep) { described_class.new(truncation) } - - let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } - - before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" } - - def stub_vector_mapping(text, expected_embedding) - EmbeddingsGenerationStubs.discourse_service(described_class.name, text, expected_embedding) - end - - it_behaves_like "generates and store embedding using with vector representation" -end diff --git a/spec/lib/modules/embeddings/vector_representations/gemini_spec.rb b/spec/lib/modules/embeddings/vector_representations/gemini_spec.rb deleted file mode 100644 index 80f021f7..00000000 --- a/spec/lib/modules/embeddings/vector_representations/gemini_spec.rb +++ /dev/null @@ -1,18 +0,0 @@ -# frozen_string_literal: true - -require_relative "vector_rep_shared_examples" - -RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::Gemini do - subject(:vector_rep) { described_class.new(truncation) } - - let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } - let!(:api_key) { "test-123" } - - before { SiteSetting.ai_gemini_api_key = api_key } - - def stub_vector_mapping(text, expected_embedding) - EmbeddingsGenerationStubs.gemini_service(api_key, text, expected_embedding) - end - - it_behaves_like "generates and store embedding using with vector representation" -end diff --git a/spec/lib/modules/embeddings/vector_representations/multilingual_e5_large_spec.rb b/spec/lib/modules/embeddings/vector_representations/multilingual_e5_large_spec.rb deleted file mode 100644 index e7af5eba..00000000 --- a/spec/lib/modules/embeddings/vector_representations/multilingual_e5_large_spec.rb +++ /dev/null @@ -1,21 +0,0 @@ -# frozen_string_literal: true - -require_relative "vector_rep_shared_examples" - -RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large do - subject(:vector_rep) { described_class.new(truncation) } - - let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } - - before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" } - - def stub_vector_mapping(text, expected_embedding) - EmbeddingsGenerationStubs.discourse_service( - described_class.name, - "query: #{text}", - expected_embedding, - ) - end - - it_behaves_like "generates and store embedding using with vector representation" -end diff --git a/spec/lib/modules/embeddings/vector_representations/text_embedding_3_large_spec.rb b/spec/lib/modules/embeddings/vector_representations/text_embedding_3_large_spec.rb deleted file mode 100644 index 5bed2863..00000000 --- a/spec/lib/modules/embeddings/vector_representations/text_embedding_3_large_spec.rb +++ /dev/null @@ -1,22 +0,0 @@ -# frozen_string_literal: true - -require_relative "vector_rep_shared_examples" - -RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large do - subject(:vector_rep) { described_class.new(truncation) } - - let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } - - def stub_vector_mapping(text, expected_embedding) - EmbeddingsGenerationStubs.openai_service( - described_class.name, - text, - expected_embedding, - extra_args: { - dimensions: 2000, - }, - ) - end - - it_behaves_like "generates and store embedding using with vector representation" -end diff --git a/spec/lib/modules/embeddings/vector_representations/text_embedding_3_small_spec.rb b/spec/lib/modules/embeddings/vector_representations/text_embedding_3_small_spec.rb deleted file mode 100644 index 1f4f01c2..00000000 --- a/spec/lib/modules/embeddings/vector_representations/text_embedding_3_small_spec.rb +++ /dev/null @@ -1,15 +0,0 @@ -# frozen_string_literal: true - -require_relative "vector_rep_shared_examples" - -RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small do - subject(:vector_rep) { described_class.new(truncation) } - - let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } - - def stub_vector_mapping(text, expected_embedding) - EmbeddingsGenerationStubs.openai_service(described_class.name, text, expected_embedding) - end - - it_behaves_like "generates and store embedding using with vector representation" -end diff --git a/spec/lib/modules/embeddings/vector_representations/text_embedding_ada_002_spec.rb b/spec/lib/modules/embeddings/vector_representations/text_embedding_ada_002_spec.rb deleted file mode 100644 index ed5a80fc..00000000 --- a/spec/lib/modules/embeddings/vector_representations/text_embedding_ada_002_spec.rb +++ /dev/null @@ -1,15 +0,0 @@ -# frozen_string_literal: true - -require_relative "vector_rep_shared_examples" - -RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002 do - subject(:vector_rep) { described_class.new(truncation) } - - let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } - - def stub_vector_mapping(text, expected_embedding) - EmbeddingsGenerationStubs.openai_service(described_class.name, text, expected_embedding) - end - - it_behaves_like "generates and store embedding using with vector representation" -end diff --git a/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb deleted file mode 100644 index 0801860e..00000000 --- a/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb +++ /dev/null @@ -1,115 +0,0 @@ -# frozen_string_literal: true - -RSpec.shared_examples "generates and store embedding using with vector representation" do - let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions } - let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions } - - let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) } - let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) } - - describe "#vector_from" do - it "creates a vector from a given string" do - text = "This is a piece of text" - stub_vector_mapping(text, expected_embedding_1) - - expect(vector_rep.vector_from(text)).to eq(expected_embedding_1) - end - end - - describe "#generate_representation_from" do - fab!(:topic) { Fabricate(:topic) } - fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } - fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) } - - it "creates a vector from a topic and stores it in the database" do - text = - truncation.prepare_text_from( - topic, - vector_rep.tokenizer, - vector_rep.max_sequence_length - 2, - ) - stub_vector_mapping(text, expected_embedding_1) - - vector_rep.generate_representation_from(topic) - - expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id) - end - - it "creates a vector from a post and stores it in the database" do - text = - truncation.prepare_text_from( - post2, - vector_rep.tokenizer, - vector_rep.max_sequence_length - 2, - ) - stub_vector_mapping(text, expected_embedding_1) - - vector_rep.generate_representation_from(post) - - expect(posts_schema.find_by_embedding(expected_embedding_1).post_id).to eq(post.id) - end - end - - describe "#gen_bulk_reprensentations" do - fab!(:topic) { Fabricate(:topic) } - fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } - fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) } - - fab!(:topic_2) { Fabricate(:topic) } - fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) } - fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) } - - it "creates a vector for each object in the relation" do - text = - truncation.prepare_text_from( - topic, - vector_rep.tokenizer, - vector_rep.max_sequence_length - 2, - ) - - text2 = - truncation.prepare_text_from( - topic_2, - vector_rep.tokenizer, - vector_rep.max_sequence_length - 2, - ) - - stub_vector_mapping(text, expected_embedding_1) - stub_vector_mapping(text2, expected_embedding_2) - - vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id])) - - expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id) - end - - it "does nothing if passed record has no content" do - expect { vector_rep.gen_bulk_reprensentations([Topic.new]) }.not_to raise_error - end - - it "doesn't ask for a new embedding if digest is the same" do - text = - truncation.prepare_text_from( - topic, - vector_rep.tokenizer, - vector_rep.max_sequence_length - 2, - ) - stub_vector_mapping(text, expected_embedding_1) - - original_vector_gen = Time.zone.parse("2021-06-04 10:00") - - freeze_time(original_vector_gen) do - vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id])) - end - # check vector exists - expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id) - - vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id])) - last_update = - DB.query_single( - "SELECT updated_at FROM #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE topic_id = #{topic.id} LIMIT 1", - ).first - - expect(last_update).to eq(original_vector_gen) - end - end -end diff --git a/spec/lib/modules/embeddings/vector_spec.rb b/spec/lib/modules/embeddings/vector_spec.rb new file mode 100644 index 00000000..6542323d --- /dev/null +++ b/spec/lib/modules/embeddings/vector_spec.rb @@ -0,0 +1,160 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Embeddings::Vector do + shared_examples "generates and store embeddings using a vector definition" do + subject(:vector) { described_class.new(vdef) } + + let(:expected_embedding_1) { [0.0038493] * vdef.dimensions } + let(:expected_embedding_2) { [0.0037684] * vdef.dimensions } + + let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vdef) } + let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vdef) } + + fab!(:topic) + fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } + fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) } + + describe "#vector_from" do + it "creates a vector from a given string" do + text = "This is a piece of text" + stub_vector_mapping(text, expected_embedding_1) + + expect(vector.vector_from(text)).to eq(expected_embedding_1) + end + end + + describe "#generate_representation_from" do + it "creates a vector from a topic and stores it in the database" do + text = vdef.prepare_target_text(topic) + stub_vector_mapping(text, expected_embedding_1) + + vector.generate_representation_from(topic) + + expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id) + end + + it "creates a vector from a post and stores it in the database" do + text = vdef.prepare_target_text(post2) + stub_vector_mapping(text, expected_embedding_1) + + vector.generate_representation_from(post) + + expect(posts_schema.find_by_embedding(expected_embedding_1).post_id).to eq(post.id) + end + end + + describe "#gen_bulk_reprensentations" do + fab!(:topic_2) { Fabricate(:topic) } + fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) } + fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) } + + it "creates a vector for each object in the relation" do + text = vdef.prepare_target_text(topic) + + text2 = vdef.prepare_target_text(topic_2) + + stub_vector_mapping(text, expected_embedding_1) + stub_vector_mapping(text2, expected_embedding_2) + + vector.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id])) + + expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id) + end + + it "does nothing if passed record has no content" do + expect { vector.gen_bulk_reprensentations([Topic.new]) }.not_to raise_error + end + + it "doesn't ask for a new embedding if digest is the same" do + text = vdef.prepare_target_text(topic) + stub_vector_mapping(text, expected_embedding_1) + + original_vector_gen = Time.zone.parse("2021-06-04 10:00") + + freeze_time(original_vector_gen) do + vector.gen_bulk_reprensentations(Topic.where(id: [topic.id])) + end + # check vector exists + expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id) + + vector.gen_bulk_reprensentations(Topic.where(id: [topic.id])) + + expect(topics_schema.find_by_target(topic).updated_at).to eq_time(original_vector_gen) + end + end + end + + context "with text-embedding-ada-002" do + let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new } + + def stub_vector_mapping(text, expected_embedding) + EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding) + end + + it_behaves_like "generates and store embeddings using a vector definition" + end + + context "with all all-mpnet-base-v2" do + let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new } + + before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" } + + def stub_vector_mapping(text, expected_embedding) + EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding) + end + + it_behaves_like "generates and store embeddings using a vector definition" + end + + context "with gemini" do + let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::Gemini.new } + let(:api_key) { "test-123" } + + before { SiteSetting.ai_gemini_api_key = api_key } + + def stub_vector_mapping(text, expected_embedding) + EmbeddingsGenerationStubs.gemini_service(api_key, text, expected_embedding) + end + + it_behaves_like "generates and store embeddings using a vector definition" + end + + context "with multilingual-e5-large" do + let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large.new } + + before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" } + + def stub_vector_mapping(text, expected_embedding) + EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding) + end + + it_behaves_like "generates and store embeddings using a vector definition" + end + + context "with text-embedding-3-large" do + let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large.new } + + def stub_vector_mapping(text, expected_embedding) + EmbeddingsGenerationStubs.openai_service( + vdef.class.name, + text, + expected_embedding, + extra_args: { + dimensions: 2000, + }, + ) + end + + it_behaves_like "generates and store embeddings using a vector definition" + end + + context "with text-embedding-3-small" do + let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small.new } + + def stub_vector_mapping(text, expected_embedding) + EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding) + end + + it_behaves_like "generates and store embeddings using a vector definition" + end +end diff --git a/spec/models/rag_document_fragment_spec.rb b/spec/models/rag_document_fragment_spec.rb index 77a0061d..afb95d83 100644 --- a/spec/models/rag_document_fragment_spec.rb +++ b/spec/models/rag_document_fragment_spec.rb @@ -74,7 +74,7 @@ RSpec.describe RagDocumentFragment do end describe ".indexing_status" do - let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } + let(:vector) { DiscourseAi::Embeddings::Vector.instance } fab!(:rag_document_fragment_1) do Fabricate(:rag_document_fragment, upload: upload_1, target: persona) @@ -84,7 +84,7 @@ RSpec.describe RagDocumentFragment do Fabricate(:rag_document_fragment, upload: upload_1, target: persona) end - let(:expected_embedding) { [0.0038493] * vector_rep.dimensions } + let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions } before do SiteSetting.ai_embeddings_enabled = true @@ -96,7 +96,7 @@ RSpec.describe RagDocumentFragment do "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", ).to_return(status: 200, body: JSON.dump(expected_embedding)) - vector_rep.generate_representation_from(rag_document_fragment_1) + vector.generate_representation_from(rag_document_fragment_1) end it "regenerates all embeddings if ai_embeddings_model changes" do diff --git a/spec/requests/embeddings/embeddings_controller_spec.rb b/spec/requests/embeddings/embeddings_controller_spec.rb index 39bcf136..e79039fb 100644 --- a/spec/requests/embeddings/embeddings_controller_spec.rb +++ b/spec/requests/embeddings/embeddings_controller_spec.rb @@ -19,14 +19,14 @@ describe DiscourseAi::Embeddings::EmbeddingsController do fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) } def index(topic) - vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation + vector = DiscourseAi::Embeddings::Vector.instance stub_request(:post, "https://api.openai.com/v1/embeddings").to_return( status: 200, body: JSON.dump({ data: [{ embedding: [0.1] * 1536 }] }), ) - vector_rep.generate_representation_from(topic) + vector.generate_representation_from(topic) end def stub_embedding(query)