REFACTOR: Separation of concerns for embedding generation. (#1027)

In a previous refactor, we moved the responsibility of querying and storing embeddings into the `Schema` class. Now, it's time for embedding generation.

The motivation behind these changes is to isolate vector characteristics in simple objects to later replace them with a DB-backed version, similar to what we did with LLM configs.
This commit is contained in:
Roman Rizzi 2024-12-16 09:55:39 -03:00 committed by GitHub
parent 222e2cf4f9
commit 534b0df391
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 375 additions and 496 deletions

View File

@ -16,9 +16,7 @@ module Jobs
return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
return if post.raw.blank?
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_rep.generate_representation_from(target)
DiscourseAi::Embeddings::Vector.instance.generate_representation_from(target)
end
end
end

View File

@ -8,11 +8,11 @@ module ::Jobs
def execute(args)
return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty?
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector = DiscourseAi::Embeddings::Vector.instance
# generate_representation_from checks compares the digest value to make sure
# the embedding is only generated once per fragment unless something changes.
fragments.map { |fragment| vector_rep.generate_representation_from(fragment) }
fragments.map { |fragment| vector.generate_representation_from(fragment) }
last_fragment = fragments.last
target = last_fragment.target

View File

@ -20,7 +20,8 @@ module Jobs
rebaked = 0
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector = DiscourseAi::Embeddings::Vector.instance
vector_def = vector.vdef
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
topics =
@ -30,19 +31,19 @@ module Jobs
.where(deleted_at: nil)
.order("topics.bumped_at DESC")
rebaked += populate_topic_embeddings(vector_rep, topics.limit(limit - rebaked))
rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked))
return if rebaked >= limit
# Then, we'll try to backfill embeddings for topics that have outdated
# embeddings, be it model or strategy version
relation = topics.where(<<~SQL).limit(limit - rebaked)
#{table_name}.model_version < #{vector_rep.version}
#{table_name}.model_version < #{vector_def.version}
OR
#{table_name}.strategy_version < #{vector_rep.strategy_version}
#{table_name}.strategy_version < #{vector_def.strategy_version}
SQL
rebaked += populate_topic_embeddings(vector_rep, relation)
rebaked += populate_topic_embeddings(vector, relation)
return if rebaked >= limit
@ -54,7 +55,7 @@ module Jobs
.where("#{table_name}.updated_at < topics.updated_at")
.limit((limit - rebaked) / 10)
populate_topic_embeddings(vector_rep, relation, force: true)
populate_topic_embeddings(vector, relation, force: true)
return if rebaked >= limit
@ -76,7 +77,7 @@ module Jobs
.limit(limit - rebaked)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end
@ -86,14 +87,14 @@ module Jobs
# embeddings, be it model or strategy version
posts
.where(<<~SQL)
#{table_name}.model_version < #{vector_rep.version}
#{table_name}.model_version < #{vector_def.version}
OR
#{table_name}.strategy_version < #{vector_rep.strategy_version}
#{table_name}.strategy_version < #{vector_def.strategy_version}
SQL
.limit(limit - rebaked)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end
@ -107,7 +108,7 @@ module Jobs
.limit((limit - rebaked) / 10)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end
@ -116,7 +117,7 @@ module Jobs
private
def populate_topic_embeddings(vector_rep, topics, force: false)
def populate_topic_embeddings(vector, topics, force: false)
done = 0
topics =
@ -126,7 +127,7 @@ module Jobs
batch_size = 1000
ids.each_slice(batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
vector.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
done += batch.length
end

View File

@ -314,10 +314,10 @@ module DiscourseAi
return nil if !consolidated_question
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector = DiscourseAi::Embeddings::Vector.instance
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings
interactions_vector = vector_rep.vector_from(consolidated_question)
interactions_vector = vector.vector_from(consolidated_question)
rag_conversation_chunks = self.class.rag_conversation_chunks
search_limit =
@ -327,7 +327,7 @@ module DiscourseAi
rag_conversation_chunks
end
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep)
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector.vdef)
candidate_fragment_ids =
schema

View File

@ -141,11 +141,10 @@ module DiscourseAi
return [] if upload_refs.empty?
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
query_vector = vector_rep.vector_from(query)
query_vector = DiscourseAi::Embeddings::Vector.instance.vector_from(query)
fragment_ids =
DiscourseAi::Embeddings::Schema
.for(RagDocumentFragment, vector: vector_rep)
.for(RagDocumentFragment)
.asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder|
builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool")
rag_document_fragments ON

View File

@ -92,10 +92,10 @@ module DiscourseAi
private
def nearest_neighbors(limit: 100)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
vector = DiscourseAi::Embeddings::Vector.instance
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
raw_vector = vector_rep.vector_from(@text)
raw_vector = vector.vector_from(@text)
muted_category_ids = nil
if @user.present?

View File

@ -14,30 +14,31 @@ module DiscourseAi
def self.for(
target_klass,
vector: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_def: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
)
case target_klass&.name
when "Topic"
new(TOPICS_TABLE, "topic_id", vector)
new(TOPICS_TABLE, "topic_id", vector_def)
when "Post"
new(POSTS_TABLE, "post_id", vector)
new(POSTS_TABLE, "post_id", vector_def)
when "RagDocumentFragment"
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector)
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def)
else
raise ArgumentError, "Invalid target type for embeddings"
end
end
def initialize(table, target_column, vector)
def initialize(table, target_column, vector_def)
@table = table
@target_column = target_column
@vector = vector
@vector_def = vector_def
end
attr_reader :table, :target_column, :vector
attr_reader :table, :target_column, :vector_def
def find_by_embedding(embedding)
DB.query(<<~SQL, query_embedding: embedding, vid: vector.id, vsid: vector.strategy_id).first
DB.query(
<<~SQL,
SELECT *
FROM #{table}
WHERE
@ -46,10 +47,15 @@ module DiscourseAi
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT 1
SQL
query_embedding: embedding,
vid: vector_def.id,
vsid: vector_def.strategy_id,
).first
end
def find_by_target(target)
DB.query(<<~SQL, target_id: target.id, vid: vector.id, vsid: vector.strategy_id).first
DB.query(
<<~SQL,
SELECT *
FROM #{table}
WHERE
@ -58,6 +64,10 @@ module DiscourseAi
#{target_column} = :target_id
LIMIT 1
SQL
target_id: target.id,
vid: vector_def.id,
vsid: vector_def.strategy_id,
).first
end
def asymmetric_similarity_search(embedding, limit:, offset:)
@ -87,8 +97,8 @@ module DiscourseAi
builder.where(
"model_id = :model_id AND strategy_id = :strategy_id",
model_id: vector.id,
strategy_id: vector.strategy_id,
model_id: vector_def.id,
strategy_id: vector_def.strategy_id,
)
yield(builder) if block_given?
@ -156,7 +166,7 @@ module DiscourseAi
yield(builder) if block_given?
builder.query(vid: vector.id, vsid: vector.strategy_id, target_id: record.id)
builder.query(vid: vector_def.id, vsid: vector_def.strategy_id, target_id: record.id)
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
@ -176,10 +186,10 @@ module DiscourseAi
updated_at = :now
SQL
target_id: record.id,
model_id: vector.id,
model_version: vector.version,
strategy_id: vector.strategy_id,
strategy_version: vector.strategy_version,
model_id: vector_def.id,
model_version: vector_def.version,
strategy_id: vector_def.strategy_id,
strategy_version: vector_def.strategy_version,
digest: digest,
embeddings: embedding,
now: Time.zone.now,
@ -188,7 +198,7 @@ module DiscourseAi
private
delegate :dimensions, :pg_function, to: :vector
delegate :dimensions, :pg_function, to: :vector_def
end
end
end

View File

@ -13,14 +13,13 @@ module DiscourseAi
def related_topic_ids_for(topic)
return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
cache_for = results_ttl(topic)
Discourse
.cache
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
DiscourseAi::Embeddings::Schema
.for(Topic, vector: vector_rep)
.for(Topic)
.symmetric_similarity_search(topic)
.map(&:topic_id)
.tap do |candidate_ids|

View File

@ -30,8 +30,8 @@ module DiscourseAi
Discourse.cache.read(embedding_key).present?
end
def vector_rep
@vector_rep ||= DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
def vector
@vector ||= DiscourseAi::Embeddings::Vector.instance
end
def hyde_embedding(search_term)
@ -52,16 +52,14 @@ module DiscourseAi
Discourse
.cache
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) }
.fetch(embedding_key, expires_in: 1.week) { vector.vector_from(hypothetical_post) }
end
def embedding(search_term)
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_model)
Discourse
.cache
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(search_term) }
Discourse.cache.fetch(embedding_key, expires_in: 1.week) { vector.vector_from(search_term) }
end
# this ensures the candidate topics are over selected
@ -84,7 +82,7 @@ module DiscourseAi
over_selection_limit = limit * OVER_SELECTION_FACTOR
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
candidate_topic_ids =
schema.asymmetric_similarity_search(
@ -114,7 +112,7 @@ module DiscourseAi
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector = DiscourseAi::Embeddings::Vector.instance
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
@ -129,12 +127,12 @@ module DiscourseAi
Discourse
.cache
.fetch(embedding_key, expires_in: 1.week) do
vector_rep.vector_from(search_term, asymetric: true)
vector.vector_from(search_term, asymetric: true)
end
candidate_post_ids =
DiscourseAi::Embeddings::Schema
.for(Post, vector: vector_rep)
.for(Post, vector_def: vector.vdef)
.asymmetric_similarity_search(
search_term_embedding,
limit: max_semantic_results_per_page,

View File

@ -12,19 +12,28 @@ module DiscourseAi
1
end
def prepare_text_from(target, tokenizer, max_length)
def prepare_target_text(target, vdef)
max_length = vdef.max_sequence_length - 2
case target
when Topic
topic_truncation(target, tokenizer, max_length)
topic_truncation(target, vdef.tokenizer, max_length)
when Post
post_truncation(target, tokenizer, max_length)
post_truncation(target, vdef.tokenizer, max_length)
when RagDocumentFragment
tokenizer.truncate(target.fragment, max_length)
vdef.tokenizer.truncate(target.fragment, max_length)
else
raise ArgumentError, "Invalid target type"
end
end
def prepare_query_text(text, vdef, asymetric: false)
qtext = asymetric ? "#{vdef.asymmetric_query_prefix} #{text}" : text
max_length = vdef.max_sequence_length - 2
vdef.tokenizer.truncate(text, max_length)
end
private
def topic_information(topic)

76
lib/embeddings/vector.rb Normal file
View File

@ -0,0 +1,76 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
class Vector
def self.instance
new(DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation)
end
def initialize(vector_definition)
@vdef = vector_definition
end
def gen_bulk_reprensentations(relation)
http_pool_size = 100
pool =
Concurrent::CachedThreadPool.new(
min_threads: 0,
max_threads: http_pool_size,
idletime: 30,
)
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector_def: vdef)
embedding_gen = vdef.inference_client
promised_embeddings =
relation
.map do |record|
prepared_text = vdef.prepare_target_text(record)
next if prepared_text.blank?
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
next if schema.find_by_target(record)&.digest == new_digest
Concurrent::Promises
.fulfilled_future({ target: record, text: prepared_text, digest: new_digest }, pool)
.then_on(pool) do |w_prepared_text|
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
end
end
.compact
Concurrent::Promises
.zip(*promised_embeddings)
.value!
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) }
pool.shutdown
pool.wait_for_termination
end
def generate_representation_from(target)
text = vdef.prepare_target_text(target)
return if text.blank?
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector_def: vdef)
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
return if schema.find_by_target(target)&.digest == new_digest
embeddings = vdef.inference_client.perform!(text)
schema.store(target, embeddings, new_digest)
end
def vector_from(text, asymetric: false)
prepared_text = vdef.prepare_query_text(text, asymetric: asymetric)
return if prepared_text.blank?
vdef.inference_client.perform!(prepared_text)
end
attr_reader :vdef
end
end
end

View File

@ -23,10 +23,6 @@ module DiscourseAi
end
end
def vector_from(text, asymetric: false)
inference_client.perform!(text)
end
def dimensions
768
end
@ -47,10 +43,6 @@ module DiscourseAi
"<#>"
end
def pg_index_type
"halfvec_ip_ops"
end
def tokenizer
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
end

View File

@ -21,8 +21,7 @@ module DiscourseAi
end
def current_representation
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
find_representation(SiteSetting.ai_embeddings_model).new(truncation)
find_representation(SiteSetting.ai_embeddings_model).new
end
def correctly_configured?
@ -43,73 +42,6 @@ module DiscourseAi
end
end
def initialize(strategy)
@strategy = strategy
end
def vector_from(text, asymetric: false)
raise NotImplementedError
end
def gen_bulk_reprensentations(relation)
http_pool_size = 100
pool =
Concurrent::CachedThreadPool.new(
min_threads: 0,
max_threads: http_pool_size,
idletime: 30,
)
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector: self)
embedding_gen = inference_client
promised_embeddings =
relation
.map do |record|
prepared_text = prepare_text(record)
next if prepared_text.blank?
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
next if schema.find_by_target(record)&.digest == new_digest
Concurrent::Promises
.fulfilled_future(
{ target: record, text: prepared_text, digest: new_digest },
pool,
)
.then_on(pool) do |w_prepared_text|
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
end
end
.compact
Concurrent::Promises
.zip(*promised_embeddings)
.value!
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) }
pool.shutdown
pool.wait_for_termination
end
def generate_representation_from(target, persist: true)
text = prepare_text(target)
return if text.blank?
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector: self)
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
return if schema.find_by_target(target)&.digest == new_digest
vector = vector_from(text)
schema.store(target, vector, new_digest) if persist
end
def index_name(table_name)
"#{table_name}_#{id}_#{@strategy.id}_search"
end
def name
raise NotImplementedError
end
@ -139,26 +71,32 @@ module DiscourseAi
end
def asymmetric_query_prefix
raise NotImplementedError
""
end
def strategy_id
@strategy.id
strategy.id
end
def strategy_version
@strategy.version
strategy.version
end
protected
def prepare_query_text(text, asymetric: false)
strategy.prepare_query_text(text, self, asymetric: asymetric)
end
def prepare_target_text(target)
strategy.prepare_target_text(target, self)
end
def strategy
@strategy ||= DiscourseAi::Embeddings::Strategies::Truncation.new
end
def inference_client
raise NotImplementedError
end
def prepare_text(record)
@strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2)
end
end
end
end

View File

@ -30,21 +30,6 @@ module DiscourseAi
end
end
def vector_from(text, asymetric: false)
text = "#{asymmetric_query_prefix} #{text}" if asymetric
client = inference_client
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
text = tokenizer.truncate(text, max_sequence_length - 2) if needs_truncation
inference_client.perform!(text)
end
def inference_model_name
"baai/bge-large-en-v1.5"
end
def dimensions
1024
end
@ -65,10 +50,6 @@ module DiscourseAi
"<#>"
end
def pg_index_type
"halfvec_ip_ops"
end
def tokenizer
DiscourseAi::Tokenizer::BgeLargeEnTokenizer
end
@ -78,6 +59,8 @@ module DiscourseAi
end
def inference_client
inference_model_name = "baai/bge-large-en-v1.5"
if SiteSetting.ai_cloudflare_workers_api_token.present?
DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name)
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?

View File

@ -18,11 +18,6 @@ module DiscourseAi
end
end
def vector_from(text, asymetric: false)
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
inference_client.perform!(truncated_text)
end
def dimensions
1024
end
@ -43,10 +38,6 @@ module DiscourseAi
"<#>"
end
def pg_index_type
"halfvec_ip_ops"
end
def tokenizer
DiscourseAi::Tokenizer::BgeM3Tokenizer
end

View File

@ -38,14 +38,6 @@ module DiscourseAi
"<=>"
end
def pg_index_type
"halfvec_cosine_ops"
end
def vector_from(text, asymetric: false)
inference_client.perform!(text)
end
# There is no public tokenizer for Gemini, and from the ones we already ship in the plugin
# OpenAI gets the closest results. Gemini Tokenizer results in ~10% less tokens, so it's safe
# to use OpenAI tokenizer since it will overestimate the number of tokens.

View File

@ -28,19 +28,6 @@ module DiscourseAi
end
end
def vector_from(text, asymetric: false)
client = inference_client
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
if needs_truncation
text = tokenizer.truncate(text, max_sequence_length - 2)
elsif !text.starts_with?("query:")
text = "query: #{text}"
end
client.perform!(text)
end
def id
3
end
@ -61,10 +48,6 @@ module DiscourseAi
"<=>"
end
def pg_index_type
"halfvec_cosine_ops"
end
def tokenizer
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
end
@ -80,8 +63,18 @@ module DiscourseAi
end
end
def prepare_text(record)
prepared_text = super(record)
def prepare_text(text, asymetric: false)
prepared_text = super(text, asymetric: asymetric)
if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier")
return "query: #{prepared_text}"
end
prepared_text
end
def prepare_target_text(target)
prepared_text = super(target)
if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier")
return "query: #{prepared_text}"

View File

@ -40,14 +40,6 @@ module DiscourseAi
"<=>"
end
def pg_index_type
"halfvec_cosine_ops"
end
def vector_from(text, asymetric: false)
inference_client.perform!(text)
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end

View File

@ -38,14 +38,6 @@ module DiscourseAi
"<=>"
end
def pg_index_type
"halfvec_cosine_ops"
end
def vector_from(text, asymetric: false)
inference_client.perform!(text)
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end

View File

@ -38,14 +38,6 @@ module DiscourseAi
"<=>"
end
def pg_index_type
"halfvec_cosine_ops"
end
def vector_from(text, asymetric: false)
inference_client.perform!(text)
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end

View File

@ -326,8 +326,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
fab!(:llm_model) { Fabricate(:fake_model) }
it "will run the question consolidator" do
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
context_embedding = vector_rep.dimensions.times.map { rand(-1.0...1.0) }
vector_def = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
context_embedding = vector_def.dimensions.times.map { rand(-1.0...1.0) }
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
consolidated_question,
@ -373,14 +373,14 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
end
context "when a persona has RAG uploads" do
let(:vector_rep) do
let(:vector_def) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
end
let(:embedding_value) { 0.04381 }
let(:prompt_cc_embeddings) { [embedding_value] * vector_rep.dimensions }
let(:prompt_cc_embeddings) { [embedding_value] * vector_def.dimensions }
def stub_fragments(fragment_count, persona: ai_persona)
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep)
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector_def)
fragment_count.times do |i|
fragment =
@ -393,7 +393,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
)
# Similarity is determined left-to-right.
embeddings = [embedding_value + "0.000#{i}".to_f] * vector_rep.dimensions
embeddings = [embedding_value + "0.000#{i}".to_f] * vector_def.dimensions
schema.store(fragment, embeddings, "test")
end

View File

@ -14,9 +14,9 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
fab!(:category)
fab!(:topic) { Fabricate(:topic, category: category) }
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:vector) { DiscourseAi::Embeddings::Vector.instance }
let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) }
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
before do
SiteSetting.ai_embeddings_enabled = true
@ -28,8 +28,8 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
).to_return(status: 200, body: JSON.dump(expected_embedding))
vector_rep.generate_representation_from(topic)
vector_rep.generate_representation_from(muted_topic)
vector.generate_representation_from(topic)
vector.generate_representation_from(muted_topic)
end
it "respects user muted categories when making suggestions" do

View File

@ -12,21 +12,16 @@ RSpec.describe Jobs::GenerateEmbeddings do
fab!(:topic)
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) }
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) }
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector_def) }
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vector_def) }
it "works for topics" do
expected_embedding = [0.0038493] * vector_rep.dimensions
expected_embedding = [0.0038493] * vector_def.dimensions
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding)
text = vector_def.prepare_target_text(topic)
EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding)
job.execute(target_id: topic.id, target_type: "Topic")
@ -34,11 +29,10 @@ RSpec.describe Jobs::GenerateEmbeddings do
end
it "works for posts" do
expected_embedding = [0.0038493] * vector_rep.dimensions
expected_embedding = [0.0038493] * vector_def.dimensions
text =
truncation.prepare_text_from(post, vector_rep.tokenizer, vector_rep.max_sequence_length - 2)
EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding)
text = vector_def.prepare_target_text(post)
EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding)
job.execute(target_id: post.id, target_type: "Post")

View File

@ -1,16 +1,12 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Embeddings::Schema do
subject(:posts_schema) { described_class.for(Post, vector: vector) }
subject(:posts_schema) { described_class.for(Post, vector_def: vector_def) }
let(:embeddings) { [0.0038490295] * vector.dimensions }
let(:embeddings) { [0.0038490295] * vector_def.dimensions }
fab!(:post) { Fabricate(:post, post_number: 1) }
let(:digest) { OpenSSL::Digest.hexdigest("SHA1", "test") }
let(:vector) do
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new(
DiscourseAi::Embeddings::Strategies::Truncation.new,
)
end
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new }
before { posts_schema.store(post, embeddings, digest) }
@ -34,7 +30,7 @@ RSpec.describe DiscourseAi::Embeddings::Schema do
describe "similarity searches" do
fab!(:post_2) { Fabricate(:post) }
let(:similar_embeddings) { [0.0038490294] * vector.dimensions }
let(:similar_embeddings) { [0.0038490294] * vector_def.dimensions }
describe "#symmetric_similarity_search" do
before { posts_schema.store(post_2, similar_embeddings, digest) }

View File

@ -11,8 +11,8 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
describe "#search_for_topics" do
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:hyde_embedding) { [0.049382] * vector_rep.dimensions }
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:hyde_embedding) { [0.049382] * vector_def.dimensions }
before do
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"

View File

@ -3,8 +3,8 @@
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
subject(:truncation) { described_class.new }
describe "#prepare_text_from" do
context "when using vector from OpenAI" do
describe "#prepare_query_text" do
context "when using vector def from OpenAI" do
before { SiteSetting.max_post_length = 100_000 }
fab!(:topic)
@ -20,15 +20,12 @@ RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
end
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
let(:model) do
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new(truncation)
end
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new }
it "truncates a topic" do
prepared_text =
truncation.prepare_text_from(topic, model.tokenizer, model.max_sequence_length)
prepared_text = truncation.prepare_target_text(topic, vector_def)
expect(model.tokenizer.size(prepared_text)).to be <= model.max_sequence_length
expect(vector_def.tokenizer.size(prepared_text)).to be <= vector_def.max_sequence_length
end
end
end

View File

@ -1,17 +0,0 @@
# frozen_string_literal: true
require_relative "vector_rep_shared_examples"
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 do
subject(:vector_rep) { described_class.new(truncation) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(described_class.name, text, expected_embedding)
end
it_behaves_like "generates and store embedding using with vector representation"
end

View File

@ -1,18 +0,0 @@
# frozen_string_literal: true
require_relative "vector_rep_shared_examples"
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::Gemini do
subject(:vector_rep) { described_class.new(truncation) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
let!(:api_key) { "test-123" }
before { SiteSetting.ai_gemini_api_key = api_key }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.gemini_service(api_key, text, expected_embedding)
end
it_behaves_like "generates and store embedding using with vector representation"
end

View File

@ -1,21 +0,0 @@
# frozen_string_literal: true
require_relative "vector_rep_shared_examples"
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large do
subject(:vector_rep) { described_class.new(truncation) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(
described_class.name,
"query: #{text}",
expected_embedding,
)
end
it_behaves_like "generates and store embedding using with vector representation"
end

View File

@ -1,22 +0,0 @@
# frozen_string_literal: true
require_relative "vector_rep_shared_examples"
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large do
subject(:vector_rep) { described_class.new(truncation) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(
described_class.name,
text,
expected_embedding,
extra_args: {
dimensions: 2000,
},
)
end
it_behaves_like "generates and store embedding using with vector representation"
end

View File

@ -1,15 +0,0 @@
# frozen_string_literal: true
require_relative "vector_rep_shared_examples"
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small do
subject(:vector_rep) { described_class.new(truncation) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(described_class.name, text, expected_embedding)
end
it_behaves_like "generates and store embedding using with vector representation"
end

View File

@ -1,15 +0,0 @@
# frozen_string_literal: true
require_relative "vector_rep_shared_examples"
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002 do
subject(:vector_rep) { described_class.new(truncation) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(described_class.name, text, expected_embedding)
end
it_behaves_like "generates and store embedding using with vector representation"
end

View File

@ -1,115 +0,0 @@
# frozen_string_literal: true
RSpec.shared_examples "generates and store embedding using with vector representation" do
let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions }
let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions }
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) }
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) }
describe "#vector_from" do
it "creates a vector from a given string" do
text = "This is a piece of text"
stub_vector_mapping(text, expected_embedding_1)
expect(vector_rep.vector_from(text)).to eq(expected_embedding_1)
end
end
describe "#generate_representation_from" do
fab!(:topic) { Fabricate(:topic) }
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
it "creates a vector from a topic and stores it in the database" do
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(topic)
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
end
it "creates a vector from a post and stores it in the database" do
text =
truncation.prepare_text_from(
post2,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(post)
expect(posts_schema.find_by_embedding(expected_embedding_1).post_id).to eq(post.id)
end
end
describe "#gen_bulk_reprensentations" do
fab!(:topic) { Fabricate(:topic) }
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
fab!(:topic_2) { Fabricate(:topic) }
fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) }
fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) }
it "creates a vector for each object in the relation" do
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
text2 =
truncation.prepare_text_from(
topic_2,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
stub_vector_mapping(text2, expected_embedding_2)
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
end
it "does nothing if passed record has no content" do
expect { vector_rep.gen_bulk_reprensentations([Topic.new]) }.not_to raise_error
end
it "doesn't ask for a new embedding if digest is the same" do
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
original_vector_gen = Time.zone.parse("2021-06-04 10:00")
freeze_time(original_vector_gen) do
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
end
# check vector exists
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
last_update =
DB.query_single(
"SELECT updated_at FROM #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE topic_id = #{topic.id} LIMIT 1",
).first
expect(last_update).to eq(original_vector_gen)
end
end
end

View File

@ -0,0 +1,160 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Embeddings::Vector do
shared_examples "generates and store embeddings using a vector definition" do
subject(:vector) { described_class.new(vdef) }
let(:expected_embedding_1) { [0.0038493] * vdef.dimensions }
let(:expected_embedding_2) { [0.0037684] * vdef.dimensions }
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vdef) }
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vdef) }
fab!(:topic)
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
describe "#vector_from" do
it "creates a vector from a given string" do
text = "This is a piece of text"
stub_vector_mapping(text, expected_embedding_1)
expect(vector.vector_from(text)).to eq(expected_embedding_1)
end
end
describe "#generate_representation_from" do
it "creates a vector from a topic and stores it in the database" do
text = vdef.prepare_target_text(topic)
stub_vector_mapping(text, expected_embedding_1)
vector.generate_representation_from(topic)
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
end
it "creates a vector from a post and stores it in the database" do
text = vdef.prepare_target_text(post2)
stub_vector_mapping(text, expected_embedding_1)
vector.generate_representation_from(post)
expect(posts_schema.find_by_embedding(expected_embedding_1).post_id).to eq(post.id)
end
end
describe "#gen_bulk_reprensentations" do
fab!(:topic_2) { Fabricate(:topic) }
fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) }
fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) }
it "creates a vector for each object in the relation" do
text = vdef.prepare_target_text(topic)
text2 = vdef.prepare_target_text(topic_2)
stub_vector_mapping(text, expected_embedding_1)
stub_vector_mapping(text2, expected_embedding_2)
vector.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
end
it "does nothing if passed record has no content" do
expect { vector.gen_bulk_reprensentations([Topic.new]) }.not_to raise_error
end
it "doesn't ask for a new embedding if digest is the same" do
text = vdef.prepare_target_text(topic)
stub_vector_mapping(text, expected_embedding_1)
original_vector_gen = Time.zone.parse("2021-06-04 10:00")
freeze_time(original_vector_gen) do
vector.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
end
# check vector exists
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
vector.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
expect(topics_schema.find_by_target(topic).updated_at).to eq_time(original_vector_gen)
end
end
end
context "with text-embedding-ada-002" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with all all-mpnet-base-v2" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new }
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with gemini" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::Gemini.new }
let(:api_key) { "test-123" }
before { SiteSetting.ai_gemini_api_key = api_key }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.gemini_service(api_key, text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with multilingual-e5-large" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large.new }
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with text-embedding-3-large" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large.new }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(
vdef.class.name,
text,
expected_embedding,
extra_args: {
dimensions: 2000,
},
)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with text-embedding-3-small" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small.new }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
end

View File

@ -74,7 +74,7 @@ RSpec.describe RagDocumentFragment do
end
describe ".indexing_status" do
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:vector) { DiscourseAi::Embeddings::Vector.instance }
fab!(:rag_document_fragment_1) do
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
@ -84,7 +84,7 @@ RSpec.describe RagDocumentFragment do
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
end
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
before do
SiteSetting.ai_embeddings_enabled = true
@ -96,7 +96,7 @@ RSpec.describe RagDocumentFragment do
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
).to_return(status: 200, body: JSON.dump(expected_embedding))
vector_rep.generate_representation_from(rag_document_fragment_1)
vector.generate_representation_from(rag_document_fragment_1)
end
it "regenerates all embeddings if ai_embeddings_model changes" do

View File

@ -19,14 +19,14 @@ describe DiscourseAi::Embeddings::EmbeddingsController do
fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) }
def index(topic)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector = DiscourseAi::Embeddings::Vector.instance
stub_request(:post, "https://api.openai.com/v1/embeddings").to_return(
status: 200,
body: JSON.dump({ data: [{ embedding: [0.1] * 1536 }] }),
)
vector_rep.generate_representation_from(topic)
vector.generate_representation_from(topic)
end
def stub_embedding(query)