REFACTOR: Separation of concerns for embedding generation. (#1027)
In a previous refactor, we moved the responsibility of querying and storing embeddings into the `Schema` class. Now, it's time for embedding generation. The motivation behind these changes is to isolate vector characteristics in simple objects to later replace them with a DB-backed version, similar to what we did with LLM configs.
This commit is contained in:
parent
222e2cf4f9
commit
534b0df391
|
@ -16,9 +16,7 @@ module Jobs
|
|||
return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
|
||||
return if post.raw.blank?
|
||||
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
|
||||
vector_rep.generate_representation_from(target)
|
||||
DiscourseAi::Embeddings::Vector.instance.generate_representation_from(target)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -8,11 +8,11 @@ module ::Jobs
|
|||
def execute(args)
|
||||
return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty?
|
||||
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
vector = DiscourseAi::Embeddings::Vector.instance
|
||||
|
||||
# generate_representation_from checks compares the digest value to make sure
|
||||
# the embedding is only generated once per fragment unless something changes.
|
||||
fragments.map { |fragment| vector_rep.generate_representation_from(fragment) }
|
||||
fragments.map { |fragment| vector.generate_representation_from(fragment) }
|
||||
|
||||
last_fragment = fragments.last
|
||||
target = last_fragment.target
|
||||
|
|
|
@ -20,7 +20,8 @@ module Jobs
|
|||
|
||||
rebaked = 0
|
||||
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
vector = DiscourseAi::Embeddings::Vector.instance
|
||||
vector_def = vector.vdef
|
||||
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
|
||||
|
||||
topics =
|
||||
|
@ -30,19 +31,19 @@ module Jobs
|
|||
.where(deleted_at: nil)
|
||||
.order("topics.bumped_at DESC")
|
||||
|
||||
rebaked += populate_topic_embeddings(vector_rep, topics.limit(limit - rebaked))
|
||||
rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked))
|
||||
|
||||
return if rebaked >= limit
|
||||
|
||||
# Then, we'll try to backfill embeddings for topics that have outdated
|
||||
# embeddings, be it model or strategy version
|
||||
relation = topics.where(<<~SQL).limit(limit - rebaked)
|
||||
#{table_name}.model_version < #{vector_rep.version}
|
||||
#{table_name}.model_version < #{vector_def.version}
|
||||
OR
|
||||
#{table_name}.strategy_version < #{vector_rep.strategy_version}
|
||||
#{table_name}.strategy_version < #{vector_def.strategy_version}
|
||||
SQL
|
||||
|
||||
rebaked += populate_topic_embeddings(vector_rep, relation)
|
||||
rebaked += populate_topic_embeddings(vector, relation)
|
||||
|
||||
return if rebaked >= limit
|
||||
|
||||
|
@ -54,7 +55,7 @@ module Jobs
|
|||
.where("#{table_name}.updated_at < topics.updated_at")
|
||||
.limit((limit - rebaked) / 10)
|
||||
|
||||
populate_topic_embeddings(vector_rep, relation, force: true)
|
||||
populate_topic_embeddings(vector, relation, force: true)
|
||||
|
||||
return if rebaked >= limit
|
||||
|
||||
|
@ -76,7 +77,7 @@ module Jobs
|
|||
.limit(limit - rebaked)
|
||||
.pluck(:id)
|
||||
.each_slice(posts_batch_size) do |batch|
|
||||
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
|
||||
vector.gen_bulk_reprensentations(Post.where(id: batch))
|
||||
rebaked += batch.length
|
||||
end
|
||||
|
||||
|
@ -86,14 +87,14 @@ module Jobs
|
|||
# embeddings, be it model or strategy version
|
||||
posts
|
||||
.where(<<~SQL)
|
||||
#{table_name}.model_version < #{vector_rep.version}
|
||||
#{table_name}.model_version < #{vector_def.version}
|
||||
OR
|
||||
#{table_name}.strategy_version < #{vector_rep.strategy_version}
|
||||
#{table_name}.strategy_version < #{vector_def.strategy_version}
|
||||
SQL
|
||||
.limit(limit - rebaked)
|
||||
.pluck(:id)
|
||||
.each_slice(posts_batch_size) do |batch|
|
||||
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
|
||||
vector.gen_bulk_reprensentations(Post.where(id: batch))
|
||||
rebaked += batch.length
|
||||
end
|
||||
|
||||
|
@ -107,7 +108,7 @@ module Jobs
|
|||
.limit((limit - rebaked) / 10)
|
||||
.pluck(:id)
|
||||
.each_slice(posts_batch_size) do |batch|
|
||||
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
|
||||
vector.gen_bulk_reprensentations(Post.where(id: batch))
|
||||
rebaked += batch.length
|
||||
end
|
||||
|
||||
|
@ -116,7 +117,7 @@ module Jobs
|
|||
|
||||
private
|
||||
|
||||
def populate_topic_embeddings(vector_rep, topics, force: false)
|
||||
def populate_topic_embeddings(vector, topics, force: false)
|
||||
done = 0
|
||||
|
||||
topics =
|
||||
|
@ -126,7 +127,7 @@ module Jobs
|
|||
batch_size = 1000
|
||||
|
||||
ids.each_slice(batch_size) do |batch|
|
||||
vector_rep.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
|
||||
vector.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
|
||||
done += batch.length
|
||||
end
|
||||
|
||||
|
|
|
@ -314,10 +314,10 @@ module DiscourseAi
|
|||
|
||||
return nil if !consolidated_question
|
||||
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
vector = DiscourseAi::Embeddings::Vector.instance
|
||||
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings
|
||||
|
||||
interactions_vector = vector_rep.vector_from(consolidated_question)
|
||||
interactions_vector = vector.vector_from(consolidated_question)
|
||||
|
||||
rag_conversation_chunks = self.class.rag_conversation_chunks
|
||||
search_limit =
|
||||
|
@ -327,7 +327,7 @@ module DiscourseAi
|
|||
rag_conversation_chunks
|
||||
end
|
||||
|
||||
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep)
|
||||
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector.vdef)
|
||||
|
||||
candidate_fragment_ids =
|
||||
schema
|
||||
|
|
|
@ -141,11 +141,10 @@ module DiscourseAi
|
|||
|
||||
return [] if upload_refs.empty?
|
||||
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
query_vector = vector_rep.vector_from(query)
|
||||
query_vector = DiscourseAi::Embeddings::Vector.instance.vector_from(query)
|
||||
fragment_ids =
|
||||
DiscourseAi::Embeddings::Schema
|
||||
.for(RagDocumentFragment, vector: vector_rep)
|
||||
.for(RagDocumentFragment)
|
||||
.asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder|
|
||||
builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool")
|
||||
rag_document_fragments ON
|
||||
|
|
|
@ -92,10 +92,10 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def nearest_neighbors(limit: 100)
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
|
||||
vector = DiscourseAi::Embeddings::Vector.instance
|
||||
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
|
||||
|
||||
raw_vector = vector_rep.vector_from(@text)
|
||||
raw_vector = vector.vector_from(@text)
|
||||
|
||||
muted_category_ids = nil
|
||||
if @user.present?
|
||||
|
|
|
@ -14,30 +14,31 @@ module DiscourseAi
|
|||
|
||||
def self.for(
|
||||
target_klass,
|
||||
vector: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
vector_def: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
)
|
||||
case target_klass&.name
|
||||
when "Topic"
|
||||
new(TOPICS_TABLE, "topic_id", vector)
|
||||
new(TOPICS_TABLE, "topic_id", vector_def)
|
||||
when "Post"
|
||||
new(POSTS_TABLE, "post_id", vector)
|
||||
new(POSTS_TABLE, "post_id", vector_def)
|
||||
when "RagDocumentFragment"
|
||||
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector)
|
||||
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def)
|
||||
else
|
||||
raise ArgumentError, "Invalid target type for embeddings"
|
||||
end
|
||||
end
|
||||
|
||||
def initialize(table, target_column, vector)
|
||||
def initialize(table, target_column, vector_def)
|
||||
@table = table
|
||||
@target_column = target_column
|
||||
@vector = vector
|
||||
@vector_def = vector_def
|
||||
end
|
||||
|
||||
attr_reader :table, :target_column, :vector
|
||||
attr_reader :table, :target_column, :vector_def
|
||||
|
||||
def find_by_embedding(embedding)
|
||||
DB.query(<<~SQL, query_embedding: embedding, vid: vector.id, vsid: vector.strategy_id).first
|
||||
DB.query(
|
||||
<<~SQL,
|
||||
SELECT *
|
||||
FROM #{table}
|
||||
WHERE
|
||||
|
@ -46,10 +47,15 @@ module DiscourseAi
|
|||
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
||||
LIMIT 1
|
||||
SQL
|
||||
query_embedding: embedding,
|
||||
vid: vector_def.id,
|
||||
vsid: vector_def.strategy_id,
|
||||
).first
|
||||
end
|
||||
|
||||
def find_by_target(target)
|
||||
DB.query(<<~SQL, target_id: target.id, vid: vector.id, vsid: vector.strategy_id).first
|
||||
DB.query(
|
||||
<<~SQL,
|
||||
SELECT *
|
||||
FROM #{table}
|
||||
WHERE
|
||||
|
@ -58,6 +64,10 @@ module DiscourseAi
|
|||
#{target_column} = :target_id
|
||||
LIMIT 1
|
||||
SQL
|
||||
target_id: target.id,
|
||||
vid: vector_def.id,
|
||||
vsid: vector_def.strategy_id,
|
||||
).first
|
||||
end
|
||||
|
||||
def asymmetric_similarity_search(embedding, limit:, offset:)
|
||||
|
@ -87,8 +97,8 @@ module DiscourseAi
|
|||
|
||||
builder.where(
|
||||
"model_id = :model_id AND strategy_id = :strategy_id",
|
||||
model_id: vector.id,
|
||||
strategy_id: vector.strategy_id,
|
||||
model_id: vector_def.id,
|
||||
strategy_id: vector_def.strategy_id,
|
||||
)
|
||||
|
||||
yield(builder) if block_given?
|
||||
|
@ -156,7 +166,7 @@ module DiscourseAi
|
|||
|
||||
yield(builder) if block_given?
|
||||
|
||||
builder.query(vid: vector.id, vsid: vector.strategy_id, target_id: record.id)
|
||||
builder.query(vid: vector_def.id, vsid: vector_def.strategy_id, target_id: record.id)
|
||||
rescue PG::Error => e
|
||||
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
||||
raise MissingEmbeddingError
|
||||
|
@ -176,10 +186,10 @@ module DiscourseAi
|
|||
updated_at = :now
|
||||
SQL
|
||||
target_id: record.id,
|
||||
model_id: vector.id,
|
||||
model_version: vector.version,
|
||||
strategy_id: vector.strategy_id,
|
||||
strategy_version: vector.strategy_version,
|
||||
model_id: vector_def.id,
|
||||
model_version: vector_def.version,
|
||||
strategy_id: vector_def.strategy_id,
|
||||
strategy_version: vector_def.strategy_version,
|
||||
digest: digest,
|
||||
embeddings: embedding,
|
||||
now: Time.zone.now,
|
||||
|
@ -188,7 +198,7 @@ module DiscourseAi
|
|||
|
||||
private
|
||||
|
||||
delegate :dimensions, :pg_function, to: :vector
|
||||
delegate :dimensions, :pg_function, to: :vector_def
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -13,14 +13,13 @@ module DiscourseAi
|
|||
def related_topic_ids_for(topic)
|
||||
return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1
|
||||
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
cache_for = results_ttl(topic)
|
||||
|
||||
Discourse
|
||||
.cache
|
||||
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
|
||||
DiscourseAi::Embeddings::Schema
|
||||
.for(Topic, vector: vector_rep)
|
||||
.for(Topic)
|
||||
.symmetric_similarity_search(topic)
|
||||
.map(&:topic_id)
|
||||
.tap do |candidate_ids|
|
||||
|
|
|
@ -30,8 +30,8 @@ module DiscourseAi
|
|||
Discourse.cache.read(embedding_key).present?
|
||||
end
|
||||
|
||||
def vector_rep
|
||||
@vector_rep ||= DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
def vector
|
||||
@vector ||= DiscourseAi::Embeddings::Vector.instance
|
||||
end
|
||||
|
||||
def hyde_embedding(search_term)
|
||||
|
@ -52,16 +52,14 @@ module DiscourseAi
|
|||
|
||||
Discourse
|
||||
.cache
|
||||
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) }
|
||||
.fetch(embedding_key, expires_in: 1.week) { vector.vector_from(hypothetical_post) }
|
||||
end
|
||||
|
||||
def embedding(search_term)
|
||||
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
|
||||
embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_model)
|
||||
|
||||
Discourse
|
||||
.cache
|
||||
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(search_term) }
|
||||
Discourse.cache.fetch(embedding_key, expires_in: 1.week) { vector.vector_from(search_term) }
|
||||
end
|
||||
|
||||
# this ensures the candidate topics are over selected
|
||||
|
@ -84,7 +82,7 @@ module DiscourseAi
|
|||
|
||||
over_selection_limit = limit * OVER_SELECTION_FACTOR
|
||||
|
||||
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
|
||||
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
|
||||
|
||||
candidate_topic_ids =
|
||||
schema.asymmetric_similarity_search(
|
||||
|
@ -114,7 +112,7 @@ module DiscourseAi
|
|||
|
||||
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
|
||||
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
vector = DiscourseAi::Embeddings::Vector.instance
|
||||
|
||||
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
|
||||
|
||||
|
@ -129,12 +127,12 @@ module DiscourseAi
|
|||
Discourse
|
||||
.cache
|
||||
.fetch(embedding_key, expires_in: 1.week) do
|
||||
vector_rep.vector_from(search_term, asymetric: true)
|
||||
vector.vector_from(search_term, asymetric: true)
|
||||
end
|
||||
|
||||
candidate_post_ids =
|
||||
DiscourseAi::Embeddings::Schema
|
||||
.for(Post, vector: vector_rep)
|
||||
.for(Post, vector_def: vector.vdef)
|
||||
.asymmetric_similarity_search(
|
||||
search_term_embedding,
|
||||
limit: max_semantic_results_per_page,
|
||||
|
|
|
@ -12,19 +12,28 @@ module DiscourseAi
|
|||
1
|
||||
end
|
||||
|
||||
def prepare_text_from(target, tokenizer, max_length)
|
||||
def prepare_target_text(target, vdef)
|
||||
max_length = vdef.max_sequence_length - 2
|
||||
|
||||
case target
|
||||
when Topic
|
||||
topic_truncation(target, tokenizer, max_length)
|
||||
topic_truncation(target, vdef.tokenizer, max_length)
|
||||
when Post
|
||||
post_truncation(target, tokenizer, max_length)
|
||||
post_truncation(target, vdef.tokenizer, max_length)
|
||||
when RagDocumentFragment
|
||||
tokenizer.truncate(target.fragment, max_length)
|
||||
vdef.tokenizer.truncate(target.fragment, max_length)
|
||||
else
|
||||
raise ArgumentError, "Invalid target type"
|
||||
end
|
||||
end
|
||||
|
||||
def prepare_query_text(text, vdef, asymetric: false)
|
||||
qtext = asymetric ? "#{vdef.asymmetric_query_prefix} #{text}" : text
|
||||
max_length = vdef.max_sequence_length - 2
|
||||
|
||||
vdef.tokenizer.truncate(text, max_length)
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def topic_information(topic)
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
class Vector
|
||||
def self.instance
|
||||
new(DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation)
|
||||
end
|
||||
|
||||
def initialize(vector_definition)
|
||||
@vdef = vector_definition
|
||||
end
|
||||
|
||||
def gen_bulk_reprensentations(relation)
|
||||
http_pool_size = 100
|
||||
pool =
|
||||
Concurrent::CachedThreadPool.new(
|
||||
min_threads: 0,
|
||||
max_threads: http_pool_size,
|
||||
idletime: 30,
|
||||
)
|
||||
|
||||
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector_def: vdef)
|
||||
|
||||
embedding_gen = vdef.inference_client
|
||||
promised_embeddings =
|
||||
relation
|
||||
.map do |record|
|
||||
prepared_text = vdef.prepare_target_text(record)
|
||||
next if prepared_text.blank?
|
||||
|
||||
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
|
||||
next if schema.find_by_target(record)&.digest == new_digest
|
||||
|
||||
Concurrent::Promises
|
||||
.fulfilled_future({ target: record, text: prepared_text, digest: new_digest }, pool)
|
||||
.then_on(pool) do |w_prepared_text|
|
||||
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
|
||||
end
|
||||
end
|
||||
.compact
|
||||
|
||||
Concurrent::Promises
|
||||
.zip(*promised_embeddings)
|
||||
.value!
|
||||
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) }
|
||||
|
||||
pool.shutdown
|
||||
pool.wait_for_termination
|
||||
end
|
||||
|
||||
def generate_representation_from(target)
|
||||
text = vdef.prepare_target_text(target)
|
||||
return if text.blank?
|
||||
|
||||
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector_def: vdef)
|
||||
|
||||
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
||||
return if schema.find_by_target(target)&.digest == new_digest
|
||||
|
||||
embeddings = vdef.inference_client.perform!(text)
|
||||
|
||||
schema.store(target, embeddings, new_digest)
|
||||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
prepared_text = vdef.prepare_query_text(text, asymetric: asymetric)
|
||||
return if prepared_text.blank?
|
||||
|
||||
vdef.inference_client.perform!(prepared_text)
|
||||
end
|
||||
|
||||
attr_reader :vdef
|
||||
end
|
||||
end
|
||||
end
|
|
@ -23,10 +23,6 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
inference_client.perform!(text)
|
||||
end
|
||||
|
||||
def dimensions
|
||||
768
|
||||
end
|
||||
|
@ -47,10 +43,6 @@ module DiscourseAi
|
|||
"<#>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"halfvec_ip_ops"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
|
||||
end
|
||||
|
|
|
@ -21,8 +21,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def current_representation
|
||||
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
find_representation(SiteSetting.ai_embeddings_model).new(truncation)
|
||||
find_representation(SiteSetting.ai_embeddings_model).new
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
|
@ -43,73 +42,6 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def initialize(strategy)
|
||||
@strategy = strategy
|
||||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def gen_bulk_reprensentations(relation)
|
||||
http_pool_size = 100
|
||||
pool =
|
||||
Concurrent::CachedThreadPool.new(
|
||||
min_threads: 0,
|
||||
max_threads: http_pool_size,
|
||||
idletime: 30,
|
||||
)
|
||||
|
||||
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector: self)
|
||||
|
||||
embedding_gen = inference_client
|
||||
promised_embeddings =
|
||||
relation
|
||||
.map do |record|
|
||||
prepared_text = prepare_text(record)
|
||||
next if prepared_text.blank?
|
||||
|
||||
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
|
||||
next if schema.find_by_target(record)&.digest == new_digest
|
||||
|
||||
Concurrent::Promises
|
||||
.fulfilled_future(
|
||||
{ target: record, text: prepared_text, digest: new_digest },
|
||||
pool,
|
||||
)
|
||||
.then_on(pool) do |w_prepared_text|
|
||||
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
|
||||
end
|
||||
end
|
||||
.compact
|
||||
|
||||
Concurrent::Promises
|
||||
.zip(*promised_embeddings)
|
||||
.value!
|
||||
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) }
|
||||
|
||||
pool.shutdown
|
||||
pool.wait_for_termination
|
||||
end
|
||||
|
||||
def generate_representation_from(target, persist: true)
|
||||
text = prepare_text(target)
|
||||
return if text.blank?
|
||||
|
||||
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector: self)
|
||||
|
||||
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
||||
return if schema.find_by_target(target)&.digest == new_digest
|
||||
|
||||
vector = vector_from(text)
|
||||
|
||||
schema.store(target, vector, new_digest) if persist
|
||||
end
|
||||
|
||||
def index_name(table_name)
|
||||
"#{table_name}_#{id}_#{@strategy.id}_search"
|
||||
end
|
||||
|
||||
def name
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
@ -139,26 +71,32 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def asymmetric_query_prefix
|
||||
raise NotImplementedError
|
||||
""
|
||||
end
|
||||
|
||||
def strategy_id
|
||||
@strategy.id
|
||||
strategy.id
|
||||
end
|
||||
|
||||
def strategy_version
|
||||
@strategy.version
|
||||
strategy.version
|
||||
end
|
||||
|
||||
protected
|
||||
def prepare_query_text(text, asymetric: false)
|
||||
strategy.prepare_query_text(text, self, asymetric: asymetric)
|
||||
end
|
||||
|
||||
def prepare_target_text(target)
|
||||
strategy.prepare_target_text(target, self)
|
||||
end
|
||||
|
||||
def strategy
|
||||
@strategy ||= DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
end
|
||||
|
||||
def inference_client
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def prepare_text(record)
|
||||
@strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -30,21 +30,6 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
text = "#{asymmetric_query_prefix} #{text}" if asymetric
|
||||
|
||||
client = inference_client
|
||||
|
||||
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
|
||||
text = tokenizer.truncate(text, max_sequence_length - 2) if needs_truncation
|
||||
|
||||
inference_client.perform!(text)
|
||||
end
|
||||
|
||||
def inference_model_name
|
||||
"baai/bge-large-en-v1.5"
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1024
|
||||
end
|
||||
|
@ -65,10 +50,6 @@ module DiscourseAi
|
|||
"<#>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"halfvec_ip_ops"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::BgeLargeEnTokenizer
|
||||
end
|
||||
|
@ -78,6 +59,8 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def inference_client
|
||||
inference_model_name = "baai/bge-large-en-v1.5"
|
||||
|
||||
if SiteSetting.ai_cloudflare_workers_api_token.present?
|
||||
DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name)
|
||||
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
||||
|
|
|
@ -18,11 +18,6 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||
inference_client.perform!(truncated_text)
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1024
|
||||
end
|
||||
|
@ -43,10 +38,6 @@ module DiscourseAi
|
|||
"<#>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"halfvec_ip_ops"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::BgeM3Tokenizer
|
||||
end
|
||||
|
|
|
@ -38,14 +38,6 @@ module DiscourseAi
|
|||
"<=>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"halfvec_cosine_ops"
|
||||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
inference_client.perform!(text)
|
||||
end
|
||||
|
||||
# There is no public tokenizer for Gemini, and from the ones we already ship in the plugin
|
||||
# OpenAI gets the closest results. Gemini Tokenizer results in ~10% less tokens, so it's safe
|
||||
# to use OpenAI tokenizer since it will overestimate the number of tokens.
|
||||
|
|
|
@ -28,19 +28,6 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
client = inference_client
|
||||
|
||||
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
|
||||
if needs_truncation
|
||||
text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||
elsif !text.starts_with?("query:")
|
||||
text = "query: #{text}"
|
||||
end
|
||||
|
||||
client.perform!(text)
|
||||
end
|
||||
|
||||
def id
|
||||
3
|
||||
end
|
||||
|
@ -61,10 +48,6 @@ module DiscourseAi
|
|||
"<=>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"halfvec_cosine_ops"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
|
||||
end
|
||||
|
@ -80,8 +63,18 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def prepare_text(record)
|
||||
prepared_text = super(record)
|
||||
def prepare_text(text, asymetric: false)
|
||||
prepared_text = super(text, asymetric: asymetric)
|
||||
|
||||
if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier")
|
||||
return "query: #{prepared_text}"
|
||||
end
|
||||
|
||||
prepared_text
|
||||
end
|
||||
|
||||
def prepare_target_text(target)
|
||||
prepared_text = super(target)
|
||||
|
||||
if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier")
|
||||
return "query: #{prepared_text}"
|
||||
|
|
|
@ -40,14 +40,6 @@ module DiscourseAi
|
|||
"<=>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"halfvec_cosine_ops"
|
||||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
inference_client.perform!(text)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
|
|
@ -38,14 +38,6 @@ module DiscourseAi
|
|||
"<=>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"halfvec_cosine_ops"
|
||||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
inference_client.perform!(text)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
|
|
@ -38,14 +38,6 @@ module DiscourseAi
|
|||
"<=>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"halfvec_cosine_ops"
|
||||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
inference_client.perform!(text)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
|
|
@ -326,8 +326,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||
fab!(:llm_model) { Fabricate(:fake_model) }
|
||||
|
||||
it "will run the question consolidator" do
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
context_embedding = vector_rep.dimensions.times.map { rand(-1.0...1.0) }
|
||||
vector_def = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
context_embedding = vector_def.dimensions.times.map { rand(-1.0...1.0) }
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
SiteSetting.ai_embeddings_model,
|
||||
consolidated_question,
|
||||
|
@ -373,14 +373,14 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||
end
|
||||
|
||||
context "when a persona has RAG uploads" do
|
||||
let(:vector_rep) do
|
||||
let(:vector_def) do
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
end
|
||||
let(:embedding_value) { 0.04381 }
|
||||
let(:prompt_cc_embeddings) { [embedding_value] * vector_rep.dimensions }
|
||||
let(:prompt_cc_embeddings) { [embedding_value] * vector_def.dimensions }
|
||||
|
||||
def stub_fragments(fragment_count, persona: ai_persona)
|
||||
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep)
|
||||
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector_def)
|
||||
|
||||
fragment_count.times do |i|
|
||||
fragment =
|
||||
|
@ -393,7 +393,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||
)
|
||||
|
||||
# Similarity is determined left-to-right.
|
||||
embeddings = [embedding_value + "0.000#{i}".to_f] * vector_rep.dimensions
|
||||
embeddings = [embedding_value + "0.000#{i}".to_f] * vector_def.dimensions
|
||||
|
||||
schema.store(fragment, embeddings, "test")
|
||||
end
|
||||
|
|
|
@ -14,9 +14,9 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
|
|||
fab!(:category)
|
||||
fab!(:topic) { Fabricate(:topic, category: category) }
|
||||
|
||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
let(:vector) { DiscourseAi::Embeddings::Vector.instance }
|
||||
let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) }
|
||||
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
|
||||
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
|
@ -28,8 +28,8 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
|
|||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
).to_return(status: 200, body: JSON.dump(expected_embedding))
|
||||
|
||||
vector_rep.generate_representation_from(topic)
|
||||
vector_rep.generate_representation_from(muted_topic)
|
||||
vector.generate_representation_from(topic)
|
||||
vector.generate_representation_from(muted_topic)
|
||||
end
|
||||
|
||||
it "respects user muted categories when making suggestions" do
|
||||
|
|
|
@ -12,21 +12,16 @@ RSpec.describe Jobs::GenerateEmbeddings do
|
|||
fab!(:topic)
|
||||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||
|
||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) }
|
||||
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) }
|
||||
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector_def) }
|
||||
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vector_def) }
|
||||
|
||||
it "works for topics" do
|
||||
expected_embedding = [0.0038493] * vector_rep.dimensions
|
||||
expected_embedding = [0.0038493] * vector_def.dimensions
|
||||
|
||||
text =
|
||||
truncation.prepare_text_from(
|
||||
topic,
|
||||
vector_rep.tokenizer,
|
||||
vector_rep.max_sequence_length - 2,
|
||||
)
|
||||
EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding)
|
||||
text = vector_def.prepare_target_text(topic)
|
||||
|
||||
EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding)
|
||||
|
||||
job.execute(target_id: topic.id, target_type: "Topic")
|
||||
|
||||
|
@ -34,11 +29,10 @@ RSpec.describe Jobs::GenerateEmbeddings do
|
|||
end
|
||||
|
||||
it "works for posts" do
|
||||
expected_embedding = [0.0038493] * vector_rep.dimensions
|
||||
expected_embedding = [0.0038493] * vector_def.dimensions
|
||||
|
||||
text =
|
||||
truncation.prepare_text_from(post, vector_rep.tokenizer, vector_rep.max_sequence_length - 2)
|
||||
EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding)
|
||||
text = vector_def.prepare_target_text(post)
|
||||
EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding)
|
||||
|
||||
job.execute(target_id: post.id, target_type: "Post")
|
||||
|
||||
|
|
|
@ -1,16 +1,12 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::Schema do
|
||||
subject(:posts_schema) { described_class.for(Post, vector: vector) }
|
||||
subject(:posts_schema) { described_class.for(Post, vector_def: vector_def) }
|
||||
|
||||
let(:embeddings) { [0.0038490295] * vector.dimensions }
|
||||
let(:embeddings) { [0.0038490295] * vector_def.dimensions }
|
||||
fab!(:post) { Fabricate(:post, post_number: 1) }
|
||||
let(:digest) { OpenSSL::Digest.hexdigest("SHA1", "test") }
|
||||
let(:vector) do
|
||||
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new(
|
||||
DiscourseAi::Embeddings::Strategies::Truncation.new,
|
||||
)
|
||||
end
|
||||
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new }
|
||||
|
||||
before { posts_schema.store(post, embeddings, digest) }
|
||||
|
||||
|
@ -34,7 +30,7 @@ RSpec.describe DiscourseAi::Embeddings::Schema do
|
|||
|
||||
describe "similarity searches" do
|
||||
fab!(:post_2) { Fabricate(:post) }
|
||||
let(:similar_embeddings) { [0.0038490294] * vector.dimensions }
|
||||
let(:similar_embeddings) { [0.0038490294] * vector_def.dimensions }
|
||||
|
||||
describe "#symmetric_similarity_search" do
|
||||
before { posts_schema.store(post_2, similar_embeddings, digest) }
|
||||
|
|
|
@ -11,8 +11,8 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
|||
|
||||
describe "#search_for_topics" do
|
||||
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
|
||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
let(:hyde_embedding) { [0.049382] * vector_rep.dimensions }
|
||||
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
let(:hyde_embedding) { [0.049382] * vector_def.dimensions }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
|
||||
subject(:truncation) { described_class.new }
|
||||
|
||||
describe "#prepare_text_from" do
|
||||
context "when using vector from OpenAI" do
|
||||
describe "#prepare_query_text" do
|
||||
context "when using vector def from OpenAI" do
|
||||
before { SiteSetting.max_post_length = 100_000 }
|
||||
|
||||
fab!(:topic)
|
||||
|
@ -20,15 +20,12 @@ RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
|
|||
end
|
||||
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
|
||||
|
||||
let(:model) do
|
||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new(truncation)
|
||||
end
|
||||
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new }
|
||||
|
||||
it "truncates a topic" do
|
||||
prepared_text =
|
||||
truncation.prepare_text_from(topic, model.tokenizer, model.max_sequence_length)
|
||||
prepared_text = truncation.prepare_target_text(topic, vector_def)
|
||||
|
||||
expect(model.tokenizer.size(prepared_text)).to be <= model.max_sequence_length
|
||||
expect(vector_def.tokenizer.size(prepared_text)).to be <= vector_def.max_sequence_length
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,17 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "vector_rep_shared_examples"
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 do
|
||||
subject(:vector_rep) { described_class.new(truncation) }
|
||||
|
||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||
|
||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.discourse_service(described_class.name, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embedding using with vector representation"
|
||||
end
|
|
@ -1,18 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "vector_rep_shared_examples"
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::Gemini do
|
||||
subject(:vector_rep) { described_class.new(truncation) }
|
||||
|
||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||
let!(:api_key) { "test-123" }
|
||||
|
||||
before { SiteSetting.ai_gemini_api_key = api_key }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.gemini_service(api_key, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embedding using with vector representation"
|
||||
end
|
|
@ -1,21 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "vector_rep_shared_examples"
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large do
|
||||
subject(:vector_rep) { described_class.new(truncation) }
|
||||
|
||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||
|
||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
described_class.name,
|
||||
"query: #{text}",
|
||||
expected_embedding,
|
||||
)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embedding using with vector representation"
|
||||
end
|
|
@ -1,22 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "vector_rep_shared_examples"
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large do
|
||||
subject(:vector_rep) { described_class.new(truncation) }
|
||||
|
||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.openai_service(
|
||||
described_class.name,
|
||||
text,
|
||||
expected_embedding,
|
||||
extra_args: {
|
||||
dimensions: 2000,
|
||||
},
|
||||
)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embedding using with vector representation"
|
||||
end
|
|
@ -1,15 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "vector_rep_shared_examples"
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small do
|
||||
subject(:vector_rep) { described_class.new(truncation) }
|
||||
|
||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.openai_service(described_class.name, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embedding using with vector representation"
|
||||
end
|
|
@ -1,15 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "vector_rep_shared_examples"
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002 do
|
||||
subject(:vector_rep) { described_class.new(truncation) }
|
||||
|
||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.openai_service(described_class.name, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embedding using with vector representation"
|
||||
end
|
|
@ -1,115 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.shared_examples "generates and store embedding using with vector representation" do
|
||||
let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions }
|
||||
let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions }
|
||||
|
||||
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) }
|
||||
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) }
|
||||
|
||||
describe "#vector_from" do
|
||||
it "creates a vector from a given string" do
|
||||
text = "This is a piece of text"
|
||||
stub_vector_mapping(text, expected_embedding_1)
|
||||
|
||||
expect(vector_rep.vector_from(text)).to eq(expected_embedding_1)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#generate_representation_from" do
|
||||
fab!(:topic) { Fabricate(:topic) }
|
||||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
|
||||
|
||||
it "creates a vector from a topic and stores it in the database" do
|
||||
text =
|
||||
truncation.prepare_text_from(
|
||||
topic,
|
||||
vector_rep.tokenizer,
|
||||
vector_rep.max_sequence_length - 2,
|
||||
)
|
||||
stub_vector_mapping(text, expected_embedding_1)
|
||||
|
||||
vector_rep.generate_representation_from(topic)
|
||||
|
||||
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
||||
end
|
||||
|
||||
it "creates a vector from a post and stores it in the database" do
|
||||
text =
|
||||
truncation.prepare_text_from(
|
||||
post2,
|
||||
vector_rep.tokenizer,
|
||||
vector_rep.max_sequence_length - 2,
|
||||
)
|
||||
stub_vector_mapping(text, expected_embedding_1)
|
||||
|
||||
vector_rep.generate_representation_from(post)
|
||||
|
||||
expect(posts_schema.find_by_embedding(expected_embedding_1).post_id).to eq(post.id)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#gen_bulk_reprensentations" do
|
||||
fab!(:topic) { Fabricate(:topic) }
|
||||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
|
||||
|
||||
fab!(:topic_2) { Fabricate(:topic) }
|
||||
fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) }
|
||||
fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) }
|
||||
|
||||
it "creates a vector for each object in the relation" do
|
||||
text =
|
||||
truncation.prepare_text_from(
|
||||
topic,
|
||||
vector_rep.tokenizer,
|
||||
vector_rep.max_sequence_length - 2,
|
||||
)
|
||||
|
||||
text2 =
|
||||
truncation.prepare_text_from(
|
||||
topic_2,
|
||||
vector_rep.tokenizer,
|
||||
vector_rep.max_sequence_length - 2,
|
||||
)
|
||||
|
||||
stub_vector_mapping(text, expected_embedding_1)
|
||||
stub_vector_mapping(text2, expected_embedding_2)
|
||||
|
||||
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
|
||||
|
||||
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
||||
end
|
||||
|
||||
it "does nothing if passed record has no content" do
|
||||
expect { vector_rep.gen_bulk_reprensentations([Topic.new]) }.not_to raise_error
|
||||
end
|
||||
|
||||
it "doesn't ask for a new embedding if digest is the same" do
|
||||
text =
|
||||
truncation.prepare_text_from(
|
||||
topic,
|
||||
vector_rep.tokenizer,
|
||||
vector_rep.max_sequence_length - 2,
|
||||
)
|
||||
stub_vector_mapping(text, expected_embedding_1)
|
||||
|
||||
original_vector_gen = Time.zone.parse("2021-06-04 10:00")
|
||||
|
||||
freeze_time(original_vector_gen) do
|
||||
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
|
||||
end
|
||||
# check vector exists
|
||||
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
||||
|
||||
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
|
||||
last_update =
|
||||
DB.query_single(
|
||||
"SELECT updated_at FROM #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE topic_id = #{topic.id} LIMIT 1",
|
||||
).first
|
||||
|
||||
expect(last_update).to eq(original_vector_gen)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,160 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::Vector do
|
||||
shared_examples "generates and store embeddings using a vector definition" do
|
||||
subject(:vector) { described_class.new(vdef) }
|
||||
|
||||
let(:expected_embedding_1) { [0.0038493] * vdef.dimensions }
|
||||
let(:expected_embedding_2) { [0.0037684] * vdef.dimensions }
|
||||
|
||||
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vdef) }
|
||||
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vdef) }
|
||||
|
||||
fab!(:topic)
|
||||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
|
||||
|
||||
describe "#vector_from" do
|
||||
it "creates a vector from a given string" do
|
||||
text = "This is a piece of text"
|
||||
stub_vector_mapping(text, expected_embedding_1)
|
||||
|
||||
expect(vector.vector_from(text)).to eq(expected_embedding_1)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#generate_representation_from" do
|
||||
it "creates a vector from a topic and stores it in the database" do
|
||||
text = vdef.prepare_target_text(topic)
|
||||
stub_vector_mapping(text, expected_embedding_1)
|
||||
|
||||
vector.generate_representation_from(topic)
|
||||
|
||||
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
||||
end
|
||||
|
||||
it "creates a vector from a post and stores it in the database" do
|
||||
text = vdef.prepare_target_text(post2)
|
||||
stub_vector_mapping(text, expected_embedding_1)
|
||||
|
||||
vector.generate_representation_from(post)
|
||||
|
||||
expect(posts_schema.find_by_embedding(expected_embedding_1).post_id).to eq(post.id)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#gen_bulk_reprensentations" do
|
||||
fab!(:topic_2) { Fabricate(:topic) }
|
||||
fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) }
|
||||
fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) }
|
||||
|
||||
it "creates a vector for each object in the relation" do
|
||||
text = vdef.prepare_target_text(topic)
|
||||
|
||||
text2 = vdef.prepare_target_text(topic_2)
|
||||
|
||||
stub_vector_mapping(text, expected_embedding_1)
|
||||
stub_vector_mapping(text2, expected_embedding_2)
|
||||
|
||||
vector.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
|
||||
|
||||
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
||||
end
|
||||
|
||||
it "does nothing if passed record has no content" do
|
||||
expect { vector.gen_bulk_reprensentations([Topic.new]) }.not_to raise_error
|
||||
end
|
||||
|
||||
it "doesn't ask for a new embedding if digest is the same" do
|
||||
text = vdef.prepare_target_text(topic)
|
||||
stub_vector_mapping(text, expected_embedding_1)
|
||||
|
||||
original_vector_gen = Time.zone.parse("2021-06-04 10:00")
|
||||
|
||||
freeze_time(original_vector_gen) do
|
||||
vector.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
|
||||
end
|
||||
# check vector exists
|
||||
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
||||
|
||||
vector.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
|
||||
|
||||
expect(topics_schema.find_by_target(topic).updated_at).to eq_time(original_vector_gen)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
context "with text-embedding-ada-002" do
|
||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
end
|
||||
|
||||
context "with all all-mpnet-base-v2" do
|
||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new }
|
||||
|
||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
end
|
||||
|
||||
context "with gemini" do
|
||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::Gemini.new }
|
||||
let(:api_key) { "test-123" }
|
||||
|
||||
before { SiteSetting.ai_gemini_api_key = api_key }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.gemini_service(api_key, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
end
|
||||
|
||||
context "with multilingual-e5-large" do
|
||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large.new }
|
||||
|
||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
end
|
||||
|
||||
context "with text-embedding-3-large" do
|
||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large.new }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.openai_service(
|
||||
vdef.class.name,
|
||||
text,
|
||||
expected_embedding,
|
||||
extra_args: {
|
||||
dimensions: 2000,
|
||||
},
|
||||
)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
end
|
||||
|
||||
context "with text-embedding-3-small" do
|
||||
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small.new }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
end
|
||||
end
|
|
@ -74,7 +74,7 @@ RSpec.describe RagDocumentFragment do
|
|||
end
|
||||
|
||||
describe ".indexing_status" do
|
||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
let(:vector) { DiscourseAi::Embeddings::Vector.instance }
|
||||
|
||||
fab!(:rag_document_fragment_1) do
|
||||
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
|
||||
|
@ -84,7 +84,7 @@ RSpec.describe RagDocumentFragment do
|
|||
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
|
||||
end
|
||||
|
||||
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
|
||||
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
|
@ -96,7 +96,7 @@ RSpec.describe RagDocumentFragment do
|
|||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||
).to_return(status: 200, body: JSON.dump(expected_embedding))
|
||||
|
||||
vector_rep.generate_representation_from(rag_document_fragment_1)
|
||||
vector.generate_representation_from(rag_document_fragment_1)
|
||||
end
|
||||
|
||||
it "regenerates all embeddings if ai_embeddings_model changes" do
|
||||
|
|
|
@ -19,14 +19,14 @@ describe DiscourseAi::Embeddings::EmbeddingsController do
|
|||
fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) }
|
||||
|
||||
def index(topic)
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
vector = DiscourseAi::Embeddings::Vector.instance
|
||||
|
||||
stub_request(:post, "https://api.openai.com/v1/embeddings").to_return(
|
||||
status: 200,
|
||||
body: JSON.dump({ data: [{ embedding: [0.1] * 1536 }] }),
|
||||
)
|
||||
|
||||
vector_rep.generate_representation_from(topic)
|
||||
vector.generate_representation_from(topic)
|
||||
end
|
||||
|
||||
def stub_embedding(query)
|
||||
|
|
Loading…
Reference in New Issue