REFACTOR: Separation of concerns for embedding generation. (#1027)

In a previous refactor, we moved the responsibility of querying and storing embeddings into the `Schema` class. Now, it's time for embedding generation.

The motivation behind these changes is to isolate vector characteristics in simple objects to later replace them with a DB-backed version, similar to what we did with LLM configs.
This commit is contained in:
Roman Rizzi 2024-12-16 09:55:39 -03:00 committed by GitHub
parent 222e2cf4f9
commit 534b0df391
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 375 additions and 496 deletions

View File

@ -16,9 +16,7 @@ module Jobs
return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
return if post.raw.blank? return if post.raw.blank?
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation DiscourseAi::Embeddings::Vector.instance.generate_representation_from(target)
vector_rep.generate_representation_from(target)
end end
end end
end end

View File

@ -8,11 +8,11 @@ module ::Jobs
def execute(args) def execute(args)
return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty? return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty?
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation vector = DiscourseAi::Embeddings::Vector.instance
# generate_representation_from checks compares the digest value to make sure # generate_representation_from checks compares the digest value to make sure
# the embedding is only generated once per fragment unless something changes. # the embedding is only generated once per fragment unless something changes.
fragments.map { |fragment| vector_rep.generate_representation_from(fragment) } fragments.map { |fragment| vector.generate_representation_from(fragment) }
last_fragment = fragments.last last_fragment = fragments.last
target = last_fragment.target target = last_fragment.target

View File

@ -20,7 +20,8 @@ module Jobs
rebaked = 0 rebaked = 0
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation vector = DiscourseAi::Embeddings::Vector.instance
vector_def = vector.vdef
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
topics = topics =
@ -30,19 +31,19 @@ module Jobs
.where(deleted_at: nil) .where(deleted_at: nil)
.order("topics.bumped_at DESC") .order("topics.bumped_at DESC")
rebaked += populate_topic_embeddings(vector_rep, topics.limit(limit - rebaked)) rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked))
return if rebaked >= limit return if rebaked >= limit
# Then, we'll try to backfill embeddings for topics that have outdated # Then, we'll try to backfill embeddings for topics that have outdated
# embeddings, be it model or strategy version # embeddings, be it model or strategy version
relation = topics.where(<<~SQL).limit(limit - rebaked) relation = topics.where(<<~SQL).limit(limit - rebaked)
#{table_name}.model_version < #{vector_rep.version} #{table_name}.model_version < #{vector_def.version}
OR OR
#{table_name}.strategy_version < #{vector_rep.strategy_version} #{table_name}.strategy_version < #{vector_def.strategy_version}
SQL SQL
rebaked += populate_topic_embeddings(vector_rep, relation) rebaked += populate_topic_embeddings(vector, relation)
return if rebaked >= limit return if rebaked >= limit
@ -54,7 +55,7 @@ module Jobs
.where("#{table_name}.updated_at < topics.updated_at") .where("#{table_name}.updated_at < topics.updated_at")
.limit((limit - rebaked) / 10) .limit((limit - rebaked) / 10)
populate_topic_embeddings(vector_rep, relation, force: true) populate_topic_embeddings(vector, relation, force: true)
return if rebaked >= limit return if rebaked >= limit
@ -76,7 +77,7 @@ module Jobs
.limit(limit - rebaked) .limit(limit - rebaked)
.pluck(:id) .pluck(:id)
.each_slice(posts_batch_size) do |batch| .each_slice(posts_batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Post.where(id: batch)) vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length rebaked += batch.length
end end
@ -86,14 +87,14 @@ module Jobs
# embeddings, be it model or strategy version # embeddings, be it model or strategy version
posts posts
.where(<<~SQL) .where(<<~SQL)
#{table_name}.model_version < #{vector_rep.version} #{table_name}.model_version < #{vector_def.version}
OR OR
#{table_name}.strategy_version < #{vector_rep.strategy_version} #{table_name}.strategy_version < #{vector_def.strategy_version}
SQL SQL
.limit(limit - rebaked) .limit(limit - rebaked)
.pluck(:id) .pluck(:id)
.each_slice(posts_batch_size) do |batch| .each_slice(posts_batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Post.where(id: batch)) vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length rebaked += batch.length
end end
@ -107,7 +108,7 @@ module Jobs
.limit((limit - rebaked) / 10) .limit((limit - rebaked) / 10)
.pluck(:id) .pluck(:id)
.each_slice(posts_batch_size) do |batch| .each_slice(posts_batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Post.where(id: batch)) vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length rebaked += batch.length
end end
@ -116,7 +117,7 @@ module Jobs
private private
def populate_topic_embeddings(vector_rep, topics, force: false) def populate_topic_embeddings(vector, topics, force: false)
done = 0 done = 0
topics = topics =
@ -126,7 +127,7 @@ module Jobs
batch_size = 1000 batch_size = 1000
ids.each_slice(batch_size) do |batch| ids.each_slice(batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC")) vector.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
done += batch.length done += batch.length
end end

View File

@ -314,10 +314,10 @@ module DiscourseAi
return nil if !consolidated_question return nil if !consolidated_question
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation vector = DiscourseAi::Embeddings::Vector.instance
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings
interactions_vector = vector_rep.vector_from(consolidated_question) interactions_vector = vector.vector_from(consolidated_question)
rag_conversation_chunks = self.class.rag_conversation_chunks rag_conversation_chunks = self.class.rag_conversation_chunks
search_limit = search_limit =
@ -327,7 +327,7 @@ module DiscourseAi
rag_conversation_chunks rag_conversation_chunks
end end
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep) schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector.vdef)
candidate_fragment_ids = candidate_fragment_ids =
schema schema

View File

@ -141,11 +141,10 @@ module DiscourseAi
return [] if upload_refs.empty? return [] if upload_refs.empty?
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation query_vector = DiscourseAi::Embeddings::Vector.instance.vector_from(query)
query_vector = vector_rep.vector_from(query)
fragment_ids = fragment_ids =
DiscourseAi::Embeddings::Schema DiscourseAi::Embeddings::Schema
.for(RagDocumentFragment, vector: vector_rep) .for(RagDocumentFragment)
.asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder| .asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder|
builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool") builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool")
rag_document_fragments ON rag_document_fragments ON

View File

@ -92,10 +92,10 @@ module DiscourseAi
private private
def nearest_neighbors(limit: 100) def nearest_neighbors(limit: 100)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation vector = DiscourseAi::Embeddings::Vector.instance
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
raw_vector = vector_rep.vector_from(@text) raw_vector = vector.vector_from(@text)
muted_category_ids = nil muted_category_ids = nil
if @user.present? if @user.present?

View File

@ -14,30 +14,31 @@ module DiscourseAi
def self.for( def self.for(
target_klass, target_klass,
vector: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation vector_def: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
) )
case target_klass&.name case target_klass&.name
when "Topic" when "Topic"
new(TOPICS_TABLE, "topic_id", vector) new(TOPICS_TABLE, "topic_id", vector_def)
when "Post" when "Post"
new(POSTS_TABLE, "post_id", vector) new(POSTS_TABLE, "post_id", vector_def)
when "RagDocumentFragment" when "RagDocumentFragment"
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector) new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def)
else else
raise ArgumentError, "Invalid target type for embeddings" raise ArgumentError, "Invalid target type for embeddings"
end end
end end
def initialize(table, target_column, vector) def initialize(table, target_column, vector_def)
@table = table @table = table
@target_column = target_column @target_column = target_column
@vector = vector @vector_def = vector_def
end end
attr_reader :table, :target_column, :vector attr_reader :table, :target_column, :vector_def
def find_by_embedding(embedding) def find_by_embedding(embedding)
DB.query(<<~SQL, query_embedding: embedding, vid: vector.id, vsid: vector.strategy_id).first DB.query(
<<~SQL,
SELECT * SELECT *
FROM #{table} FROM #{table}
WHERE WHERE
@ -46,10 +47,15 @@ module DiscourseAi
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT 1 LIMIT 1
SQL SQL
query_embedding: embedding,
vid: vector_def.id,
vsid: vector_def.strategy_id,
).first
end end
def find_by_target(target) def find_by_target(target)
DB.query(<<~SQL, target_id: target.id, vid: vector.id, vsid: vector.strategy_id).first DB.query(
<<~SQL,
SELECT * SELECT *
FROM #{table} FROM #{table}
WHERE WHERE
@ -58,6 +64,10 @@ module DiscourseAi
#{target_column} = :target_id #{target_column} = :target_id
LIMIT 1 LIMIT 1
SQL SQL
target_id: target.id,
vid: vector_def.id,
vsid: vector_def.strategy_id,
).first
end end
def asymmetric_similarity_search(embedding, limit:, offset:) def asymmetric_similarity_search(embedding, limit:, offset:)
@ -87,8 +97,8 @@ module DiscourseAi
builder.where( builder.where(
"model_id = :model_id AND strategy_id = :strategy_id", "model_id = :model_id AND strategy_id = :strategy_id",
model_id: vector.id, model_id: vector_def.id,
strategy_id: vector.strategy_id, strategy_id: vector_def.strategy_id,
) )
yield(builder) if block_given? yield(builder) if block_given?
@ -156,7 +166,7 @@ module DiscourseAi
yield(builder) if block_given? yield(builder) if block_given?
builder.query(vid: vector.id, vsid: vector.strategy_id, target_id: record.id) builder.query(vid: vector_def.id, vsid: vector_def.strategy_id, target_id: record.id)
rescue PG::Error => e rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}") Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError raise MissingEmbeddingError
@ -176,10 +186,10 @@ module DiscourseAi
updated_at = :now updated_at = :now
SQL SQL
target_id: record.id, target_id: record.id,
model_id: vector.id, model_id: vector_def.id,
model_version: vector.version, model_version: vector_def.version,
strategy_id: vector.strategy_id, strategy_id: vector_def.strategy_id,
strategy_version: vector.strategy_version, strategy_version: vector_def.strategy_version,
digest: digest, digest: digest,
embeddings: embedding, embeddings: embedding,
now: Time.zone.now, now: Time.zone.now,
@ -188,7 +198,7 @@ module DiscourseAi
private private
delegate :dimensions, :pg_function, to: :vector delegate :dimensions, :pg_function, to: :vector_def
end end
end end
end end

View File

@ -13,14 +13,13 @@ module DiscourseAi
def related_topic_ids_for(topic) def related_topic_ids_for(topic)
return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1 return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
cache_for = results_ttl(topic) cache_for = results_ttl(topic)
Discourse Discourse
.cache .cache
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do .fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
DiscourseAi::Embeddings::Schema DiscourseAi::Embeddings::Schema
.for(Topic, vector: vector_rep) .for(Topic)
.symmetric_similarity_search(topic) .symmetric_similarity_search(topic)
.map(&:topic_id) .map(&:topic_id)
.tap do |candidate_ids| .tap do |candidate_ids|

View File

@ -30,8 +30,8 @@ module DiscourseAi
Discourse.cache.read(embedding_key).present? Discourse.cache.read(embedding_key).present?
end end
def vector_rep def vector
@vector_rep ||= DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation @vector ||= DiscourseAi::Embeddings::Vector.instance
end end
def hyde_embedding(search_term) def hyde_embedding(search_term)
@ -52,16 +52,14 @@ module DiscourseAi
Discourse Discourse
.cache .cache
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) } .fetch(embedding_key, expires_in: 1.week) { vector.vector_from(hypothetical_post) }
end end
def embedding(search_term) def embedding(search_term)
digest = OpenSSL::Digest::SHA1.hexdigest(search_term) digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_model) embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_model)
Discourse Discourse.cache.fetch(embedding_key, expires_in: 1.week) { vector.vector_from(search_term) }
.cache
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(search_term) }
end end
# this ensures the candidate topics are over selected # this ensures the candidate topics are over selected
@ -84,7 +82,7 @@ module DiscourseAi
over_selection_limit = limit * OVER_SELECTION_FACTOR over_selection_limit = limit * OVER_SELECTION_FACTOR
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
candidate_topic_ids = candidate_topic_ids =
schema.asymmetric_similarity_search( schema.asymmetric_similarity_search(
@ -114,7 +112,7 @@ module DiscourseAi
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation vector = DiscourseAi::Embeddings::Vector.instance
digest = OpenSSL::Digest::SHA1.hexdigest(search_term) digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
@ -129,12 +127,12 @@ module DiscourseAi
Discourse Discourse
.cache .cache
.fetch(embedding_key, expires_in: 1.week) do .fetch(embedding_key, expires_in: 1.week) do
vector_rep.vector_from(search_term, asymetric: true) vector.vector_from(search_term, asymetric: true)
end end
candidate_post_ids = candidate_post_ids =
DiscourseAi::Embeddings::Schema DiscourseAi::Embeddings::Schema
.for(Post, vector: vector_rep) .for(Post, vector_def: vector.vdef)
.asymmetric_similarity_search( .asymmetric_similarity_search(
search_term_embedding, search_term_embedding,
limit: max_semantic_results_per_page, limit: max_semantic_results_per_page,

View File

@ -12,19 +12,28 @@ module DiscourseAi
1 1
end end
def prepare_text_from(target, tokenizer, max_length) def prepare_target_text(target, vdef)
max_length = vdef.max_sequence_length - 2
case target case target
when Topic when Topic
topic_truncation(target, tokenizer, max_length) topic_truncation(target, vdef.tokenizer, max_length)
when Post when Post
post_truncation(target, tokenizer, max_length) post_truncation(target, vdef.tokenizer, max_length)
when RagDocumentFragment when RagDocumentFragment
tokenizer.truncate(target.fragment, max_length) vdef.tokenizer.truncate(target.fragment, max_length)
else else
raise ArgumentError, "Invalid target type" raise ArgumentError, "Invalid target type"
end end
end end
def prepare_query_text(text, vdef, asymetric: false)
qtext = asymetric ? "#{vdef.asymmetric_query_prefix} #{text}" : text
max_length = vdef.max_sequence_length - 2
vdef.tokenizer.truncate(text, max_length)
end
private private
def topic_information(topic) def topic_information(topic)

76
lib/embeddings/vector.rb Normal file
View File

@ -0,0 +1,76 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
class Vector
def self.instance
new(DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation)
end
def initialize(vector_definition)
@vdef = vector_definition
end
def gen_bulk_reprensentations(relation)
http_pool_size = 100
pool =
Concurrent::CachedThreadPool.new(
min_threads: 0,
max_threads: http_pool_size,
idletime: 30,
)
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector_def: vdef)
embedding_gen = vdef.inference_client
promised_embeddings =
relation
.map do |record|
prepared_text = vdef.prepare_target_text(record)
next if prepared_text.blank?
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
next if schema.find_by_target(record)&.digest == new_digest
Concurrent::Promises
.fulfilled_future({ target: record, text: prepared_text, digest: new_digest }, pool)
.then_on(pool) do |w_prepared_text|
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
end
end
.compact
Concurrent::Promises
.zip(*promised_embeddings)
.value!
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) }
pool.shutdown
pool.wait_for_termination
end
def generate_representation_from(target)
text = vdef.prepare_target_text(target)
return if text.blank?
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector_def: vdef)
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
return if schema.find_by_target(target)&.digest == new_digest
embeddings = vdef.inference_client.perform!(text)
schema.store(target, embeddings, new_digest)
end
def vector_from(text, asymetric: false)
prepared_text = vdef.prepare_query_text(text, asymetric: asymetric)
return if prepared_text.blank?
vdef.inference_client.perform!(prepared_text)
end
attr_reader :vdef
end
end
end

View File

@ -23,10 +23,6 @@ module DiscourseAi
end end
end end
def vector_from(text, asymetric: false)
inference_client.perform!(text)
end
def dimensions def dimensions
768 768
end end
@ -47,10 +43,6 @@ module DiscourseAi
"<#>" "<#>"
end end
def pg_index_type
"halfvec_ip_ops"
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
end end

View File

@ -21,8 +21,7 @@ module DiscourseAi
end end
def current_representation def current_representation
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new find_representation(SiteSetting.ai_embeddings_model).new
find_representation(SiteSetting.ai_embeddings_model).new(truncation)
end end
def correctly_configured? def correctly_configured?
@ -43,73 +42,6 @@ module DiscourseAi
end end
end end
def initialize(strategy)
@strategy = strategy
end
def vector_from(text, asymetric: false)
raise NotImplementedError
end
def gen_bulk_reprensentations(relation)
http_pool_size = 100
pool =
Concurrent::CachedThreadPool.new(
min_threads: 0,
max_threads: http_pool_size,
idletime: 30,
)
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector: self)
embedding_gen = inference_client
promised_embeddings =
relation
.map do |record|
prepared_text = prepare_text(record)
next if prepared_text.blank?
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
next if schema.find_by_target(record)&.digest == new_digest
Concurrent::Promises
.fulfilled_future(
{ target: record, text: prepared_text, digest: new_digest },
pool,
)
.then_on(pool) do |w_prepared_text|
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
end
end
.compact
Concurrent::Promises
.zip(*promised_embeddings)
.value!
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) }
pool.shutdown
pool.wait_for_termination
end
def generate_representation_from(target, persist: true)
text = prepare_text(target)
return if text.blank?
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector: self)
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
return if schema.find_by_target(target)&.digest == new_digest
vector = vector_from(text)
schema.store(target, vector, new_digest) if persist
end
def index_name(table_name)
"#{table_name}_#{id}_#{@strategy.id}_search"
end
def name def name
raise NotImplementedError raise NotImplementedError
end end
@ -139,26 +71,32 @@ module DiscourseAi
end end
def asymmetric_query_prefix def asymmetric_query_prefix
raise NotImplementedError ""
end end
def strategy_id def strategy_id
@strategy.id strategy.id
end end
def strategy_version def strategy_version
@strategy.version strategy.version
end end
protected def prepare_query_text(text, asymetric: false)
strategy.prepare_query_text(text, self, asymetric: asymetric)
end
def prepare_target_text(target)
strategy.prepare_target_text(target, self)
end
def strategy
@strategy ||= DiscourseAi::Embeddings::Strategies::Truncation.new
end
def inference_client def inference_client
raise NotImplementedError raise NotImplementedError
end end
def prepare_text(record)
@strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2)
end
end end
end end
end end

View File

@ -30,21 +30,6 @@ module DiscourseAi
end end
end end
def vector_from(text, asymetric: false)
text = "#{asymmetric_query_prefix} #{text}" if asymetric
client = inference_client
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
text = tokenizer.truncate(text, max_sequence_length - 2) if needs_truncation
inference_client.perform!(text)
end
def inference_model_name
"baai/bge-large-en-v1.5"
end
def dimensions def dimensions
1024 1024
end end
@ -65,10 +50,6 @@ module DiscourseAi
"<#>" "<#>"
end end
def pg_index_type
"halfvec_ip_ops"
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::BgeLargeEnTokenizer DiscourseAi::Tokenizer::BgeLargeEnTokenizer
end end
@ -78,6 +59,8 @@ module DiscourseAi
end end
def inference_client def inference_client
inference_model_name = "baai/bge-large-en-v1.5"
if SiteSetting.ai_cloudflare_workers_api_token.present? if SiteSetting.ai_cloudflare_workers_api_token.present?
DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name) DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name)
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?

View File

@ -18,11 +18,6 @@ module DiscourseAi
end end
end end
def vector_from(text, asymetric: false)
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
inference_client.perform!(truncated_text)
end
def dimensions def dimensions
1024 1024
end end
@ -43,10 +38,6 @@ module DiscourseAi
"<#>" "<#>"
end end
def pg_index_type
"halfvec_ip_ops"
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::BgeM3Tokenizer DiscourseAi::Tokenizer::BgeM3Tokenizer
end end

View File

@ -38,14 +38,6 @@ module DiscourseAi
"<=>" "<=>"
end end
def pg_index_type
"halfvec_cosine_ops"
end
def vector_from(text, asymetric: false)
inference_client.perform!(text)
end
# There is no public tokenizer for Gemini, and from the ones we already ship in the plugin # There is no public tokenizer for Gemini, and from the ones we already ship in the plugin
# OpenAI gets the closest results. Gemini Tokenizer results in ~10% less tokens, so it's safe # OpenAI gets the closest results. Gemini Tokenizer results in ~10% less tokens, so it's safe
# to use OpenAI tokenizer since it will overestimate the number of tokens. # to use OpenAI tokenizer since it will overestimate the number of tokens.

View File

@ -28,19 +28,6 @@ module DiscourseAi
end end
end end
def vector_from(text, asymetric: false)
client = inference_client
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
if needs_truncation
text = tokenizer.truncate(text, max_sequence_length - 2)
elsif !text.starts_with?("query:")
text = "query: #{text}"
end
client.perform!(text)
end
def id def id
3 3
end end
@ -61,10 +48,6 @@ module DiscourseAi
"<=>" "<=>"
end end
def pg_index_type
"halfvec_cosine_ops"
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
end end
@ -80,8 +63,18 @@ module DiscourseAi
end end
end end
def prepare_text(record) def prepare_text(text, asymetric: false)
prepared_text = super(record) prepared_text = super(text, asymetric: asymetric)
if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier")
return "query: #{prepared_text}"
end
prepared_text
end
def prepare_target_text(target)
prepared_text = super(target)
if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier") if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier")
return "query: #{prepared_text}" return "query: #{prepared_text}"

View File

@ -40,14 +40,6 @@ module DiscourseAi
"<=>" "<=>"
end end
def pg_index_type
"halfvec_cosine_ops"
end
def vector_from(text, asymetric: false)
inference_client.perform!(text)
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer DiscourseAi::Tokenizer::OpenAiTokenizer
end end

View File

@ -38,14 +38,6 @@ module DiscourseAi
"<=>" "<=>"
end end
def pg_index_type
"halfvec_cosine_ops"
end
def vector_from(text, asymetric: false)
inference_client.perform!(text)
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer DiscourseAi::Tokenizer::OpenAiTokenizer
end end

View File

@ -38,14 +38,6 @@ module DiscourseAi
"<=>" "<=>"
end end
def pg_index_type
"halfvec_cosine_ops"
end
def vector_from(text, asymetric: false)
inference_client.perform!(text)
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer DiscourseAi::Tokenizer::OpenAiTokenizer
end end

View File

@ -326,8 +326,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
fab!(:llm_model) { Fabricate(:fake_model) } fab!(:llm_model) { Fabricate(:fake_model) }
it "will run the question consolidator" do it "will run the question consolidator" do
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation vector_def = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
context_embedding = vector_rep.dimensions.times.map { rand(-1.0...1.0) } context_embedding = vector_def.dimensions.times.map { rand(-1.0...1.0) }
EmbeddingsGenerationStubs.discourse_service( EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model, SiteSetting.ai_embeddings_model,
consolidated_question, consolidated_question,
@ -373,14 +373,14 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
end end
context "when a persona has RAG uploads" do context "when a persona has RAG uploads" do
let(:vector_rep) do let(:vector_def) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
end end
let(:embedding_value) { 0.04381 } let(:embedding_value) { 0.04381 }
let(:prompt_cc_embeddings) { [embedding_value] * vector_rep.dimensions } let(:prompt_cc_embeddings) { [embedding_value] * vector_def.dimensions }
def stub_fragments(fragment_count, persona: ai_persona) def stub_fragments(fragment_count, persona: ai_persona)
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep) schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector_def)
fragment_count.times do |i| fragment_count.times do |i|
fragment = fragment =
@ -393,7 +393,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
) )
# Similarity is determined left-to-right. # Similarity is determined left-to-right.
embeddings = [embedding_value + "0.000#{i}".to_f] * vector_rep.dimensions embeddings = [embedding_value + "0.000#{i}".to_f] * vector_def.dimensions
schema.store(fragment, embeddings, "test") schema.store(fragment, embeddings, "test")
end end

View File

@ -14,9 +14,9 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
fab!(:category) fab!(:category)
fab!(:topic) { Fabricate(:topic, category: category) } fab!(:topic) { Fabricate(:topic, category: category) }
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } let(:vector) { DiscourseAi::Embeddings::Vector.instance }
let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) } let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) }
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions } let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
before do before do
SiteSetting.ai_embeddings_enabled = true SiteSetting.ai_embeddings_enabled = true
@ -28,8 +28,8 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
).to_return(status: 200, body: JSON.dump(expected_embedding)) ).to_return(status: 200, body: JSON.dump(expected_embedding))
vector_rep.generate_representation_from(topic) vector.generate_representation_from(topic)
vector_rep.generate_representation_from(muted_topic) vector.generate_representation_from(muted_topic)
end end
it "respects user muted categories when making suggestions" do it "respects user muted categories when making suggestions" do

View File

@ -12,21 +12,16 @@ RSpec.describe Jobs::GenerateEmbeddings do
fab!(:topic) fab!(:topic)
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector_def) }
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) } let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vector_def) }
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) }
it "works for topics" do it "works for topics" do
expected_embedding = [0.0038493] * vector_rep.dimensions expected_embedding = [0.0038493] * vector_def.dimensions
text = text = vector_def.prepare_target_text(topic)
truncation.prepare_text_from(
topic, EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding)
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding)
job.execute(target_id: topic.id, target_type: "Topic") job.execute(target_id: topic.id, target_type: "Topic")
@ -34,11 +29,10 @@ RSpec.describe Jobs::GenerateEmbeddings do
end end
it "works for posts" do it "works for posts" do
expected_embedding = [0.0038493] * vector_rep.dimensions expected_embedding = [0.0038493] * vector_def.dimensions
text = text = vector_def.prepare_target_text(post)
truncation.prepare_text_from(post, vector_rep.tokenizer, vector_rep.max_sequence_length - 2) EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding)
job.execute(target_id: post.id, target_type: "Post") job.execute(target_id: post.id, target_type: "Post")

View File

@ -1,16 +1,12 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.describe DiscourseAi::Embeddings::Schema do RSpec.describe DiscourseAi::Embeddings::Schema do
subject(:posts_schema) { described_class.for(Post, vector: vector) } subject(:posts_schema) { described_class.for(Post, vector_def: vector_def) }
let(:embeddings) { [0.0038490295] * vector.dimensions } let(:embeddings) { [0.0038490295] * vector_def.dimensions }
fab!(:post) { Fabricate(:post, post_number: 1) } fab!(:post) { Fabricate(:post, post_number: 1) }
let(:digest) { OpenSSL::Digest.hexdigest("SHA1", "test") } let(:digest) { OpenSSL::Digest.hexdigest("SHA1", "test") }
let(:vector) do let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new }
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new(
DiscourseAi::Embeddings::Strategies::Truncation.new,
)
end
before { posts_schema.store(post, embeddings, digest) } before { posts_schema.store(post, embeddings, digest) }
@ -34,7 +30,7 @@ RSpec.describe DiscourseAi::Embeddings::Schema do
describe "similarity searches" do describe "similarity searches" do
fab!(:post_2) { Fabricate(:post) } fab!(:post_2) { Fabricate(:post) }
let(:similar_embeddings) { [0.0038490294] * vector.dimensions } let(:similar_embeddings) { [0.0038490294] * vector_def.dimensions }
describe "#symmetric_similarity_search" do describe "#symmetric_similarity_search" do
before { posts_schema.store(post_2, similar_embeddings, digest) } before { posts_schema.store(post_2, similar_embeddings, digest) }

View File

@ -11,8 +11,8 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
describe "#search_for_topics" do describe "#search_for_topics" do
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" } let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:hyde_embedding) { [0.049382] * vector_rep.dimensions } let(:hyde_embedding) { [0.049382] * vector_def.dimensions }
before do before do
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"

View File

@ -3,8 +3,8 @@
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
subject(:truncation) { described_class.new } subject(:truncation) { described_class.new }
describe "#prepare_text_from" do describe "#prepare_query_text" do
context "when using vector from OpenAI" do context "when using vector def from OpenAI" do
before { SiteSetting.max_post_length = 100_000 } before { SiteSetting.max_post_length = 100_000 }
fab!(:topic) fab!(:topic)
@ -20,15 +20,12 @@ RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
end end
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) } fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
let(:model) do let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new }
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new(truncation)
end
it "truncates a topic" do it "truncates a topic" do
prepared_text = prepared_text = truncation.prepare_target_text(topic, vector_def)
truncation.prepare_text_from(topic, model.tokenizer, model.max_sequence_length)
expect(model.tokenizer.size(prepared_text)).to be <= model.max_sequence_length expect(vector_def.tokenizer.size(prepared_text)).to be <= vector_def.max_sequence_length
end end
end end
end end

View File

@ -1,17 +0,0 @@
# frozen_string_literal: true
require_relative "vector_rep_shared_examples"
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 do
subject(:vector_rep) { described_class.new(truncation) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(described_class.name, text, expected_embedding)
end
it_behaves_like "generates and store embedding using with vector representation"
end

View File

@ -1,18 +0,0 @@
# frozen_string_literal: true
require_relative "vector_rep_shared_examples"
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::Gemini do
subject(:vector_rep) { described_class.new(truncation) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
let!(:api_key) { "test-123" }
before { SiteSetting.ai_gemini_api_key = api_key }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.gemini_service(api_key, text, expected_embedding)
end
it_behaves_like "generates and store embedding using with vector representation"
end

View File

@ -1,21 +0,0 @@
# frozen_string_literal: true
require_relative "vector_rep_shared_examples"
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large do
subject(:vector_rep) { described_class.new(truncation) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(
described_class.name,
"query: #{text}",
expected_embedding,
)
end
it_behaves_like "generates and store embedding using with vector representation"
end

View File

@ -1,22 +0,0 @@
# frozen_string_literal: true
require_relative "vector_rep_shared_examples"
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large do
subject(:vector_rep) { described_class.new(truncation) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(
described_class.name,
text,
expected_embedding,
extra_args: {
dimensions: 2000,
},
)
end
it_behaves_like "generates and store embedding using with vector representation"
end

View File

@ -1,15 +0,0 @@
# frozen_string_literal: true
require_relative "vector_rep_shared_examples"
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small do
subject(:vector_rep) { described_class.new(truncation) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(described_class.name, text, expected_embedding)
end
it_behaves_like "generates and store embedding using with vector representation"
end

View File

@ -1,15 +0,0 @@
# frozen_string_literal: true
require_relative "vector_rep_shared_examples"
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002 do
subject(:vector_rep) { described_class.new(truncation) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(described_class.name, text, expected_embedding)
end
it_behaves_like "generates and store embedding using with vector representation"
end

View File

@ -1,115 +0,0 @@
# frozen_string_literal: true
RSpec.shared_examples "generates and store embedding using with vector representation" do
let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions }
let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions }
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) }
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) }
describe "#vector_from" do
it "creates a vector from a given string" do
text = "This is a piece of text"
stub_vector_mapping(text, expected_embedding_1)
expect(vector_rep.vector_from(text)).to eq(expected_embedding_1)
end
end
describe "#generate_representation_from" do
fab!(:topic) { Fabricate(:topic) }
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
it "creates a vector from a topic and stores it in the database" do
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(topic)
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
end
it "creates a vector from a post and stores it in the database" do
text =
truncation.prepare_text_from(
post2,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(post)
expect(posts_schema.find_by_embedding(expected_embedding_1).post_id).to eq(post.id)
end
end
describe "#gen_bulk_reprensentations" do
fab!(:topic) { Fabricate(:topic) }
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
fab!(:topic_2) { Fabricate(:topic) }
fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) }
fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) }
it "creates a vector for each object in the relation" do
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
text2 =
truncation.prepare_text_from(
topic_2,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
stub_vector_mapping(text2, expected_embedding_2)
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
end
it "does nothing if passed record has no content" do
expect { vector_rep.gen_bulk_reprensentations([Topic.new]) }.not_to raise_error
end
it "doesn't ask for a new embedding if digest is the same" do
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
original_vector_gen = Time.zone.parse("2021-06-04 10:00")
freeze_time(original_vector_gen) do
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
end
# check vector exists
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
last_update =
DB.query_single(
"SELECT updated_at FROM #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE topic_id = #{topic.id} LIMIT 1",
).first
expect(last_update).to eq(original_vector_gen)
end
end
end

View File

@ -0,0 +1,160 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Embeddings::Vector do
shared_examples "generates and store embeddings using a vector definition" do
subject(:vector) { described_class.new(vdef) }
let(:expected_embedding_1) { [0.0038493] * vdef.dimensions }
let(:expected_embedding_2) { [0.0037684] * vdef.dimensions }
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vdef) }
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vdef) }
fab!(:topic)
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
describe "#vector_from" do
it "creates a vector from a given string" do
text = "This is a piece of text"
stub_vector_mapping(text, expected_embedding_1)
expect(vector.vector_from(text)).to eq(expected_embedding_1)
end
end
describe "#generate_representation_from" do
it "creates a vector from a topic and stores it in the database" do
text = vdef.prepare_target_text(topic)
stub_vector_mapping(text, expected_embedding_1)
vector.generate_representation_from(topic)
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
end
it "creates a vector from a post and stores it in the database" do
text = vdef.prepare_target_text(post2)
stub_vector_mapping(text, expected_embedding_1)
vector.generate_representation_from(post)
expect(posts_schema.find_by_embedding(expected_embedding_1).post_id).to eq(post.id)
end
end
describe "#gen_bulk_reprensentations" do
fab!(:topic_2) { Fabricate(:topic) }
fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) }
fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) }
it "creates a vector for each object in the relation" do
text = vdef.prepare_target_text(topic)
text2 = vdef.prepare_target_text(topic_2)
stub_vector_mapping(text, expected_embedding_1)
stub_vector_mapping(text2, expected_embedding_2)
vector.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
end
it "does nothing if passed record has no content" do
expect { vector.gen_bulk_reprensentations([Topic.new]) }.not_to raise_error
end
it "doesn't ask for a new embedding if digest is the same" do
text = vdef.prepare_target_text(topic)
stub_vector_mapping(text, expected_embedding_1)
original_vector_gen = Time.zone.parse("2021-06-04 10:00")
freeze_time(original_vector_gen) do
vector.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
end
# check vector exists
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
vector.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
expect(topics_schema.find_by_target(topic).updated_at).to eq_time(original_vector_gen)
end
end
end
context "with text-embedding-ada-002" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with all all-mpnet-base-v2" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new }
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with gemini" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::Gemini.new }
let(:api_key) { "test-123" }
before { SiteSetting.ai_gemini_api_key = api_key }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.gemini_service(api_key, text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with multilingual-e5-large" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large.new }
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with text-embedding-3-large" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large.new }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(
vdef.class.name,
text,
expected_embedding,
extra_args: {
dimensions: 2000,
},
)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
context "with text-embedding-3-small" do
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small.new }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding)
end
it_behaves_like "generates and store embeddings using a vector definition"
end
end

View File

@ -74,7 +74,7 @@ RSpec.describe RagDocumentFragment do
end end
describe ".indexing_status" do describe ".indexing_status" do
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } let(:vector) { DiscourseAi::Embeddings::Vector.instance }
fab!(:rag_document_fragment_1) do fab!(:rag_document_fragment_1) do
Fabricate(:rag_document_fragment, upload: upload_1, target: persona) Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
@ -84,7 +84,7 @@ RSpec.describe RagDocumentFragment do
Fabricate(:rag_document_fragment, upload: upload_1, target: persona) Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
end end
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions } let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
before do before do
SiteSetting.ai_embeddings_enabled = true SiteSetting.ai_embeddings_enabled = true
@ -96,7 +96,7 @@ RSpec.describe RagDocumentFragment do
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
).to_return(status: 200, body: JSON.dump(expected_embedding)) ).to_return(status: 200, body: JSON.dump(expected_embedding))
vector_rep.generate_representation_from(rag_document_fragment_1) vector.generate_representation_from(rag_document_fragment_1)
end end
it "regenerates all embeddings if ai_embeddings_model changes" do it "regenerates all embeddings if ai_embeddings_model changes" do

View File

@ -19,14 +19,14 @@ describe DiscourseAi::Embeddings::EmbeddingsController do
fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) } fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) }
def index(topic) def index(topic)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation vector = DiscourseAi::Embeddings::Vector.instance
stub_request(:post, "https://api.openai.com/v1/embeddings").to_return( stub_request(:post, "https://api.openai.com/v1/embeddings").to_return(
status: 200, status: 200,
body: JSON.dump({ data: [{ embedding: [0.1] * 1536 }] }), body: JSON.dump({ data: [{ embedding: [0.1] * 1536 }] }),
) )
vector_rep.generate_representation_from(topic) vector.generate_representation_from(topic)
end end
def stub_embedding(query) def stub_embedding(query)