REFACTOR: Separation of concerns for embedding generation. (#1027)
In a previous refactor, we moved the responsibility of querying and storing embeddings into the `Schema` class. Now, it's time for embedding generation. The motivation behind these changes is to isolate vector characteristics in simple objects to later replace them with a DB-backed version, similar to what we did with LLM configs.
This commit is contained in:
parent
222e2cf4f9
commit
534b0df391
|
@ -16,9 +16,7 @@ module Jobs
|
||||||
return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
|
return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
|
||||||
return if post.raw.blank?
|
return if post.raw.blank?
|
||||||
|
|
||||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
DiscourseAi::Embeddings::Vector.instance.generate_representation_from(target)
|
||||||
|
|
||||||
vector_rep.generate_representation_from(target)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -8,11 +8,11 @@ module ::Jobs
|
||||||
def execute(args)
|
def execute(args)
|
||||||
return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty?
|
return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty?
|
||||||
|
|
||||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
vector = DiscourseAi::Embeddings::Vector.instance
|
||||||
|
|
||||||
# generate_representation_from checks compares the digest value to make sure
|
# generate_representation_from checks compares the digest value to make sure
|
||||||
# the embedding is only generated once per fragment unless something changes.
|
# the embedding is only generated once per fragment unless something changes.
|
||||||
fragments.map { |fragment| vector_rep.generate_representation_from(fragment) }
|
fragments.map { |fragment| vector.generate_representation_from(fragment) }
|
||||||
|
|
||||||
last_fragment = fragments.last
|
last_fragment = fragments.last
|
||||||
target = last_fragment.target
|
target = last_fragment.target
|
||||||
|
|
|
@ -20,7 +20,8 @@ module Jobs
|
||||||
|
|
||||||
rebaked = 0
|
rebaked = 0
|
||||||
|
|
||||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
vector = DiscourseAi::Embeddings::Vector.instance
|
||||||
|
vector_def = vector.vdef
|
||||||
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
|
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
|
||||||
|
|
||||||
topics =
|
topics =
|
||||||
|
@ -30,19 +31,19 @@ module Jobs
|
||||||
.where(deleted_at: nil)
|
.where(deleted_at: nil)
|
||||||
.order("topics.bumped_at DESC")
|
.order("topics.bumped_at DESC")
|
||||||
|
|
||||||
rebaked += populate_topic_embeddings(vector_rep, topics.limit(limit - rebaked))
|
rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked))
|
||||||
|
|
||||||
return if rebaked >= limit
|
return if rebaked >= limit
|
||||||
|
|
||||||
# Then, we'll try to backfill embeddings for topics that have outdated
|
# Then, we'll try to backfill embeddings for topics that have outdated
|
||||||
# embeddings, be it model or strategy version
|
# embeddings, be it model or strategy version
|
||||||
relation = topics.where(<<~SQL).limit(limit - rebaked)
|
relation = topics.where(<<~SQL).limit(limit - rebaked)
|
||||||
#{table_name}.model_version < #{vector_rep.version}
|
#{table_name}.model_version < #{vector_def.version}
|
||||||
OR
|
OR
|
||||||
#{table_name}.strategy_version < #{vector_rep.strategy_version}
|
#{table_name}.strategy_version < #{vector_def.strategy_version}
|
||||||
SQL
|
SQL
|
||||||
|
|
||||||
rebaked += populate_topic_embeddings(vector_rep, relation)
|
rebaked += populate_topic_embeddings(vector, relation)
|
||||||
|
|
||||||
return if rebaked >= limit
|
return if rebaked >= limit
|
||||||
|
|
||||||
|
@ -54,7 +55,7 @@ module Jobs
|
||||||
.where("#{table_name}.updated_at < topics.updated_at")
|
.where("#{table_name}.updated_at < topics.updated_at")
|
||||||
.limit((limit - rebaked) / 10)
|
.limit((limit - rebaked) / 10)
|
||||||
|
|
||||||
populate_topic_embeddings(vector_rep, relation, force: true)
|
populate_topic_embeddings(vector, relation, force: true)
|
||||||
|
|
||||||
return if rebaked >= limit
|
return if rebaked >= limit
|
||||||
|
|
||||||
|
@ -76,7 +77,7 @@ module Jobs
|
||||||
.limit(limit - rebaked)
|
.limit(limit - rebaked)
|
||||||
.pluck(:id)
|
.pluck(:id)
|
||||||
.each_slice(posts_batch_size) do |batch|
|
.each_slice(posts_batch_size) do |batch|
|
||||||
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
|
vector.gen_bulk_reprensentations(Post.where(id: batch))
|
||||||
rebaked += batch.length
|
rebaked += batch.length
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -86,14 +87,14 @@ module Jobs
|
||||||
# embeddings, be it model or strategy version
|
# embeddings, be it model or strategy version
|
||||||
posts
|
posts
|
||||||
.where(<<~SQL)
|
.where(<<~SQL)
|
||||||
#{table_name}.model_version < #{vector_rep.version}
|
#{table_name}.model_version < #{vector_def.version}
|
||||||
OR
|
OR
|
||||||
#{table_name}.strategy_version < #{vector_rep.strategy_version}
|
#{table_name}.strategy_version < #{vector_def.strategy_version}
|
||||||
SQL
|
SQL
|
||||||
.limit(limit - rebaked)
|
.limit(limit - rebaked)
|
||||||
.pluck(:id)
|
.pluck(:id)
|
||||||
.each_slice(posts_batch_size) do |batch|
|
.each_slice(posts_batch_size) do |batch|
|
||||||
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
|
vector.gen_bulk_reprensentations(Post.where(id: batch))
|
||||||
rebaked += batch.length
|
rebaked += batch.length
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -107,7 +108,7 @@ module Jobs
|
||||||
.limit((limit - rebaked) / 10)
|
.limit((limit - rebaked) / 10)
|
||||||
.pluck(:id)
|
.pluck(:id)
|
||||||
.each_slice(posts_batch_size) do |batch|
|
.each_slice(posts_batch_size) do |batch|
|
||||||
vector_rep.gen_bulk_reprensentations(Post.where(id: batch))
|
vector.gen_bulk_reprensentations(Post.where(id: batch))
|
||||||
rebaked += batch.length
|
rebaked += batch.length
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -116,7 +117,7 @@ module Jobs
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
def populate_topic_embeddings(vector_rep, topics, force: false)
|
def populate_topic_embeddings(vector, topics, force: false)
|
||||||
done = 0
|
done = 0
|
||||||
|
|
||||||
topics =
|
topics =
|
||||||
|
@ -126,7 +127,7 @@ module Jobs
|
||||||
batch_size = 1000
|
batch_size = 1000
|
||||||
|
|
||||||
ids.each_slice(batch_size) do |batch|
|
ids.each_slice(batch_size) do |batch|
|
||||||
vector_rep.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
|
vector.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
|
||||||
done += batch.length
|
done += batch.length
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -314,10 +314,10 @@ module DiscourseAi
|
||||||
|
|
||||||
return nil if !consolidated_question
|
return nil if !consolidated_question
|
||||||
|
|
||||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
vector = DiscourseAi::Embeddings::Vector.instance
|
||||||
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings
|
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings
|
||||||
|
|
||||||
interactions_vector = vector_rep.vector_from(consolidated_question)
|
interactions_vector = vector.vector_from(consolidated_question)
|
||||||
|
|
||||||
rag_conversation_chunks = self.class.rag_conversation_chunks
|
rag_conversation_chunks = self.class.rag_conversation_chunks
|
||||||
search_limit =
|
search_limit =
|
||||||
|
@ -327,7 +327,7 @@ module DiscourseAi
|
||||||
rag_conversation_chunks
|
rag_conversation_chunks
|
||||||
end
|
end
|
||||||
|
|
||||||
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep)
|
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector.vdef)
|
||||||
|
|
||||||
candidate_fragment_ids =
|
candidate_fragment_ids =
|
||||||
schema
|
schema
|
||||||
|
|
|
@ -141,11 +141,10 @@ module DiscourseAi
|
||||||
|
|
||||||
return [] if upload_refs.empty?
|
return [] if upload_refs.empty?
|
||||||
|
|
||||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
query_vector = DiscourseAi::Embeddings::Vector.instance.vector_from(query)
|
||||||
query_vector = vector_rep.vector_from(query)
|
|
||||||
fragment_ids =
|
fragment_ids =
|
||||||
DiscourseAi::Embeddings::Schema
|
DiscourseAi::Embeddings::Schema
|
||||||
.for(RagDocumentFragment, vector: vector_rep)
|
.for(RagDocumentFragment)
|
||||||
.asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder|
|
.asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder|
|
||||||
builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool")
|
builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool")
|
||||||
rag_document_fragments ON
|
rag_document_fragments ON
|
||||||
|
|
|
@ -92,10 +92,10 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def nearest_neighbors(limit: 100)
|
def nearest_neighbors(limit: 100)
|
||||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
vector = DiscourseAi::Embeddings::Vector.instance
|
||||||
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
|
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
|
||||||
|
|
||||||
raw_vector = vector_rep.vector_from(@text)
|
raw_vector = vector.vector_from(@text)
|
||||||
|
|
||||||
muted_category_ids = nil
|
muted_category_ids = nil
|
||||||
if @user.present?
|
if @user.present?
|
||||||
|
|
|
@ -14,30 +14,31 @@ module DiscourseAi
|
||||||
|
|
||||||
def self.for(
|
def self.for(
|
||||||
target_klass,
|
target_klass,
|
||||||
vector: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
vector_def: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||||
)
|
)
|
||||||
case target_klass&.name
|
case target_klass&.name
|
||||||
when "Topic"
|
when "Topic"
|
||||||
new(TOPICS_TABLE, "topic_id", vector)
|
new(TOPICS_TABLE, "topic_id", vector_def)
|
||||||
when "Post"
|
when "Post"
|
||||||
new(POSTS_TABLE, "post_id", vector)
|
new(POSTS_TABLE, "post_id", vector_def)
|
||||||
when "RagDocumentFragment"
|
when "RagDocumentFragment"
|
||||||
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector)
|
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def)
|
||||||
else
|
else
|
||||||
raise ArgumentError, "Invalid target type for embeddings"
|
raise ArgumentError, "Invalid target type for embeddings"
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def initialize(table, target_column, vector)
|
def initialize(table, target_column, vector_def)
|
||||||
@table = table
|
@table = table
|
||||||
@target_column = target_column
|
@target_column = target_column
|
||||||
@vector = vector
|
@vector_def = vector_def
|
||||||
end
|
end
|
||||||
|
|
||||||
attr_reader :table, :target_column, :vector
|
attr_reader :table, :target_column, :vector_def
|
||||||
|
|
||||||
def find_by_embedding(embedding)
|
def find_by_embedding(embedding)
|
||||||
DB.query(<<~SQL, query_embedding: embedding, vid: vector.id, vsid: vector.strategy_id).first
|
DB.query(
|
||||||
|
<<~SQL,
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM #{table}
|
FROM #{table}
|
||||||
WHERE
|
WHERE
|
||||||
|
@ -46,10 +47,15 @@ module DiscourseAi
|
||||||
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
SQL
|
SQL
|
||||||
|
query_embedding: embedding,
|
||||||
|
vid: vector_def.id,
|
||||||
|
vsid: vector_def.strategy_id,
|
||||||
|
).first
|
||||||
end
|
end
|
||||||
|
|
||||||
def find_by_target(target)
|
def find_by_target(target)
|
||||||
DB.query(<<~SQL, target_id: target.id, vid: vector.id, vsid: vector.strategy_id).first
|
DB.query(
|
||||||
|
<<~SQL,
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM #{table}
|
FROM #{table}
|
||||||
WHERE
|
WHERE
|
||||||
|
@ -58,6 +64,10 @@ module DiscourseAi
|
||||||
#{target_column} = :target_id
|
#{target_column} = :target_id
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
SQL
|
SQL
|
||||||
|
target_id: target.id,
|
||||||
|
vid: vector_def.id,
|
||||||
|
vsid: vector_def.strategy_id,
|
||||||
|
).first
|
||||||
end
|
end
|
||||||
|
|
||||||
def asymmetric_similarity_search(embedding, limit:, offset:)
|
def asymmetric_similarity_search(embedding, limit:, offset:)
|
||||||
|
@ -87,8 +97,8 @@ module DiscourseAi
|
||||||
|
|
||||||
builder.where(
|
builder.where(
|
||||||
"model_id = :model_id AND strategy_id = :strategy_id",
|
"model_id = :model_id AND strategy_id = :strategy_id",
|
||||||
model_id: vector.id,
|
model_id: vector_def.id,
|
||||||
strategy_id: vector.strategy_id,
|
strategy_id: vector_def.strategy_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield(builder) if block_given?
|
yield(builder) if block_given?
|
||||||
|
@ -156,7 +166,7 @@ module DiscourseAi
|
||||||
|
|
||||||
yield(builder) if block_given?
|
yield(builder) if block_given?
|
||||||
|
|
||||||
builder.query(vid: vector.id, vsid: vector.strategy_id, target_id: record.id)
|
builder.query(vid: vector_def.id, vsid: vector_def.strategy_id, target_id: record.id)
|
||||||
rescue PG::Error => e
|
rescue PG::Error => e
|
||||||
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
||||||
raise MissingEmbeddingError
|
raise MissingEmbeddingError
|
||||||
|
@ -176,10 +186,10 @@ module DiscourseAi
|
||||||
updated_at = :now
|
updated_at = :now
|
||||||
SQL
|
SQL
|
||||||
target_id: record.id,
|
target_id: record.id,
|
||||||
model_id: vector.id,
|
model_id: vector_def.id,
|
||||||
model_version: vector.version,
|
model_version: vector_def.version,
|
||||||
strategy_id: vector.strategy_id,
|
strategy_id: vector_def.strategy_id,
|
||||||
strategy_version: vector.strategy_version,
|
strategy_version: vector_def.strategy_version,
|
||||||
digest: digest,
|
digest: digest,
|
||||||
embeddings: embedding,
|
embeddings: embedding,
|
||||||
now: Time.zone.now,
|
now: Time.zone.now,
|
||||||
|
@ -188,7 +198,7 @@ module DiscourseAi
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
delegate :dimensions, :pg_function, to: :vector
|
delegate :dimensions, :pg_function, to: :vector_def
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -13,14 +13,13 @@ module DiscourseAi
|
||||||
def related_topic_ids_for(topic)
|
def related_topic_ids_for(topic)
|
||||||
return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1
|
return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1
|
||||||
|
|
||||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
|
||||||
cache_for = results_ttl(topic)
|
cache_for = results_ttl(topic)
|
||||||
|
|
||||||
Discourse
|
Discourse
|
||||||
.cache
|
.cache
|
||||||
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
|
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
|
||||||
DiscourseAi::Embeddings::Schema
|
DiscourseAi::Embeddings::Schema
|
||||||
.for(Topic, vector: vector_rep)
|
.for(Topic)
|
||||||
.symmetric_similarity_search(topic)
|
.symmetric_similarity_search(topic)
|
||||||
.map(&:topic_id)
|
.map(&:topic_id)
|
||||||
.tap do |candidate_ids|
|
.tap do |candidate_ids|
|
||||||
|
|
|
@ -30,8 +30,8 @@ module DiscourseAi
|
||||||
Discourse.cache.read(embedding_key).present?
|
Discourse.cache.read(embedding_key).present?
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_rep
|
def vector
|
||||||
@vector_rep ||= DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
@vector ||= DiscourseAi::Embeddings::Vector.instance
|
||||||
end
|
end
|
||||||
|
|
||||||
def hyde_embedding(search_term)
|
def hyde_embedding(search_term)
|
||||||
|
@ -52,16 +52,14 @@ module DiscourseAi
|
||||||
|
|
||||||
Discourse
|
Discourse
|
||||||
.cache
|
.cache
|
||||||
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(hypothetical_post) }
|
.fetch(embedding_key, expires_in: 1.week) { vector.vector_from(hypothetical_post) }
|
||||||
end
|
end
|
||||||
|
|
||||||
def embedding(search_term)
|
def embedding(search_term)
|
||||||
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
|
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
|
||||||
embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_model)
|
embedding_key = build_embedding_key(digest, "", SiteSetting.ai_embeddings_model)
|
||||||
|
|
||||||
Discourse
|
Discourse.cache.fetch(embedding_key, expires_in: 1.week) { vector.vector_from(search_term) }
|
||||||
.cache
|
|
||||||
.fetch(embedding_key, expires_in: 1.week) { vector_rep.vector_from(search_term) }
|
|
||||||
end
|
end
|
||||||
|
|
||||||
# this ensures the candidate topics are over selected
|
# this ensures the candidate topics are over selected
|
||||||
|
@ -84,7 +82,7 @@ module DiscourseAi
|
||||||
|
|
||||||
over_selection_limit = limit * OVER_SELECTION_FACTOR
|
over_selection_limit = limit * OVER_SELECTION_FACTOR
|
||||||
|
|
||||||
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
|
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef)
|
||||||
|
|
||||||
candidate_topic_ids =
|
candidate_topic_ids =
|
||||||
schema.asymmetric_similarity_search(
|
schema.asymmetric_similarity_search(
|
||||||
|
@ -114,7 +112,7 @@ module DiscourseAi
|
||||||
|
|
||||||
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
|
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
|
||||||
|
|
||||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
vector = DiscourseAi::Embeddings::Vector.instance
|
||||||
|
|
||||||
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
|
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
|
||||||
|
|
||||||
|
@ -129,12 +127,12 @@ module DiscourseAi
|
||||||
Discourse
|
Discourse
|
||||||
.cache
|
.cache
|
||||||
.fetch(embedding_key, expires_in: 1.week) do
|
.fetch(embedding_key, expires_in: 1.week) do
|
||||||
vector_rep.vector_from(search_term, asymetric: true)
|
vector.vector_from(search_term, asymetric: true)
|
||||||
end
|
end
|
||||||
|
|
||||||
candidate_post_ids =
|
candidate_post_ids =
|
||||||
DiscourseAi::Embeddings::Schema
|
DiscourseAi::Embeddings::Schema
|
||||||
.for(Post, vector: vector_rep)
|
.for(Post, vector_def: vector.vdef)
|
||||||
.asymmetric_similarity_search(
|
.asymmetric_similarity_search(
|
||||||
search_term_embedding,
|
search_term_embedding,
|
||||||
limit: max_semantic_results_per_page,
|
limit: max_semantic_results_per_page,
|
||||||
|
|
|
@ -12,19 +12,28 @@ module DiscourseAi
|
||||||
1
|
1
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_text_from(target, tokenizer, max_length)
|
def prepare_target_text(target, vdef)
|
||||||
|
max_length = vdef.max_sequence_length - 2
|
||||||
|
|
||||||
case target
|
case target
|
||||||
when Topic
|
when Topic
|
||||||
topic_truncation(target, tokenizer, max_length)
|
topic_truncation(target, vdef.tokenizer, max_length)
|
||||||
when Post
|
when Post
|
||||||
post_truncation(target, tokenizer, max_length)
|
post_truncation(target, vdef.tokenizer, max_length)
|
||||||
when RagDocumentFragment
|
when RagDocumentFragment
|
||||||
tokenizer.truncate(target.fragment, max_length)
|
vdef.tokenizer.truncate(target.fragment, max_length)
|
||||||
else
|
else
|
||||||
raise ArgumentError, "Invalid target type"
|
raise ArgumentError, "Invalid target type"
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def prepare_query_text(text, vdef, asymetric: false)
|
||||||
|
qtext = asymetric ? "#{vdef.asymmetric_query_prefix} #{text}" : text
|
||||||
|
max_length = vdef.max_sequence_length - 2
|
||||||
|
|
||||||
|
vdef.tokenizer.truncate(text, max_length)
|
||||||
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
def topic_information(topic)
|
def topic_information(topic)
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
class Vector
|
||||||
|
def self.instance
|
||||||
|
new(DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation)
|
||||||
|
end
|
||||||
|
|
||||||
|
def initialize(vector_definition)
|
||||||
|
@vdef = vector_definition
|
||||||
|
end
|
||||||
|
|
||||||
|
def gen_bulk_reprensentations(relation)
|
||||||
|
http_pool_size = 100
|
||||||
|
pool =
|
||||||
|
Concurrent::CachedThreadPool.new(
|
||||||
|
min_threads: 0,
|
||||||
|
max_threads: http_pool_size,
|
||||||
|
idletime: 30,
|
||||||
|
)
|
||||||
|
|
||||||
|
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector_def: vdef)
|
||||||
|
|
||||||
|
embedding_gen = vdef.inference_client
|
||||||
|
promised_embeddings =
|
||||||
|
relation
|
||||||
|
.map do |record|
|
||||||
|
prepared_text = vdef.prepare_target_text(record)
|
||||||
|
next if prepared_text.blank?
|
||||||
|
|
||||||
|
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
|
||||||
|
next if schema.find_by_target(record)&.digest == new_digest
|
||||||
|
|
||||||
|
Concurrent::Promises
|
||||||
|
.fulfilled_future({ target: record, text: prepared_text, digest: new_digest }, pool)
|
||||||
|
.then_on(pool) do |w_prepared_text|
|
||||||
|
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
.compact
|
||||||
|
|
||||||
|
Concurrent::Promises
|
||||||
|
.zip(*promised_embeddings)
|
||||||
|
.value!
|
||||||
|
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) }
|
||||||
|
|
||||||
|
pool.shutdown
|
||||||
|
pool.wait_for_termination
|
||||||
|
end
|
||||||
|
|
||||||
|
def generate_representation_from(target)
|
||||||
|
text = vdef.prepare_target_text(target)
|
||||||
|
return if text.blank?
|
||||||
|
|
||||||
|
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector_def: vdef)
|
||||||
|
|
||||||
|
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
||||||
|
return if schema.find_by_target(target)&.digest == new_digest
|
||||||
|
|
||||||
|
embeddings = vdef.inference_client.perform!(text)
|
||||||
|
|
||||||
|
schema.store(target, embeddings, new_digest)
|
||||||
|
end
|
||||||
|
|
||||||
|
def vector_from(text, asymetric: false)
|
||||||
|
prepared_text = vdef.prepare_query_text(text, asymetric: asymetric)
|
||||||
|
return if prepared_text.blank?
|
||||||
|
|
||||||
|
vdef.inference_client.perform!(prepared_text)
|
||||||
|
end
|
||||||
|
|
||||||
|
attr_reader :vdef
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -23,10 +23,6 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text, asymetric: false)
|
|
||||||
inference_client.perform!(text)
|
|
||||||
end
|
|
||||||
|
|
||||||
def dimensions
|
def dimensions
|
||||||
768
|
768
|
||||||
end
|
end
|
||||||
|
@ -47,10 +43,6 @@ module DiscourseAi
|
||||||
"<#>"
|
"<#>"
|
||||||
end
|
end
|
||||||
|
|
||||||
def pg_index_type
|
|
||||||
"halfvec_ip_ops"
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
|
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
|
||||||
end
|
end
|
||||||
|
|
|
@ -21,8 +21,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def current_representation
|
def current_representation
|
||||||
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
|
find_representation(SiteSetting.ai_embeddings_model).new
|
||||||
find_representation(SiteSetting.ai_embeddings_model).new(truncation)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def correctly_configured?
|
def correctly_configured?
|
||||||
|
@ -43,73 +42,6 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def initialize(strategy)
|
|
||||||
@strategy = strategy
|
|
||||||
end
|
|
||||||
|
|
||||||
def vector_from(text, asymetric: false)
|
|
||||||
raise NotImplementedError
|
|
||||||
end
|
|
||||||
|
|
||||||
def gen_bulk_reprensentations(relation)
|
|
||||||
http_pool_size = 100
|
|
||||||
pool =
|
|
||||||
Concurrent::CachedThreadPool.new(
|
|
||||||
min_threads: 0,
|
|
||||||
max_threads: http_pool_size,
|
|
||||||
idletime: 30,
|
|
||||||
)
|
|
||||||
|
|
||||||
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector: self)
|
|
||||||
|
|
||||||
embedding_gen = inference_client
|
|
||||||
promised_embeddings =
|
|
||||||
relation
|
|
||||||
.map do |record|
|
|
||||||
prepared_text = prepare_text(record)
|
|
||||||
next if prepared_text.blank?
|
|
||||||
|
|
||||||
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
|
|
||||||
next if schema.find_by_target(record)&.digest == new_digest
|
|
||||||
|
|
||||||
Concurrent::Promises
|
|
||||||
.fulfilled_future(
|
|
||||||
{ target: record, text: prepared_text, digest: new_digest },
|
|
||||||
pool,
|
|
||||||
)
|
|
||||||
.then_on(pool) do |w_prepared_text|
|
|
||||||
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
.compact
|
|
||||||
|
|
||||||
Concurrent::Promises
|
|
||||||
.zip(*promised_embeddings)
|
|
||||||
.value!
|
|
||||||
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) }
|
|
||||||
|
|
||||||
pool.shutdown
|
|
||||||
pool.wait_for_termination
|
|
||||||
end
|
|
||||||
|
|
||||||
def generate_representation_from(target, persist: true)
|
|
||||||
text = prepare_text(target)
|
|
||||||
return if text.blank?
|
|
||||||
|
|
||||||
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector: self)
|
|
||||||
|
|
||||||
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
|
||||||
return if schema.find_by_target(target)&.digest == new_digest
|
|
||||||
|
|
||||||
vector = vector_from(text)
|
|
||||||
|
|
||||||
schema.store(target, vector, new_digest) if persist
|
|
||||||
end
|
|
||||||
|
|
||||||
def index_name(table_name)
|
|
||||||
"#{table_name}_#{id}_#{@strategy.id}_search"
|
|
||||||
end
|
|
||||||
|
|
||||||
def name
|
def name
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
end
|
end
|
||||||
|
@ -139,26 +71,32 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def asymmetric_query_prefix
|
def asymmetric_query_prefix
|
||||||
raise NotImplementedError
|
""
|
||||||
end
|
end
|
||||||
|
|
||||||
def strategy_id
|
def strategy_id
|
||||||
@strategy.id
|
strategy.id
|
||||||
end
|
end
|
||||||
|
|
||||||
def strategy_version
|
def strategy_version
|
||||||
@strategy.version
|
strategy.version
|
||||||
end
|
end
|
||||||
|
|
||||||
protected
|
def prepare_query_text(text, asymetric: false)
|
||||||
|
strategy.prepare_query_text(text, self, asymetric: asymetric)
|
||||||
|
end
|
||||||
|
|
||||||
|
def prepare_target_text(target)
|
||||||
|
strategy.prepare_target_text(target, self)
|
||||||
|
end
|
||||||
|
|
||||||
|
def strategy
|
||||||
|
@strategy ||= DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
|
end
|
||||||
|
|
||||||
def inference_client
|
def inference_client
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_text(record)
|
|
||||||
@strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -30,21 +30,6 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text, asymetric: false)
|
|
||||||
text = "#{asymmetric_query_prefix} #{text}" if asymetric
|
|
||||||
|
|
||||||
client = inference_client
|
|
||||||
|
|
||||||
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
|
|
||||||
text = tokenizer.truncate(text, max_sequence_length - 2) if needs_truncation
|
|
||||||
|
|
||||||
inference_client.perform!(text)
|
|
||||||
end
|
|
||||||
|
|
||||||
def inference_model_name
|
|
||||||
"baai/bge-large-en-v1.5"
|
|
||||||
end
|
|
||||||
|
|
||||||
def dimensions
|
def dimensions
|
||||||
1024
|
1024
|
||||||
end
|
end
|
||||||
|
@ -65,10 +50,6 @@ module DiscourseAi
|
||||||
"<#>"
|
"<#>"
|
||||||
end
|
end
|
||||||
|
|
||||||
def pg_index_type
|
|
||||||
"halfvec_ip_ops"
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::BgeLargeEnTokenizer
|
DiscourseAi::Tokenizer::BgeLargeEnTokenizer
|
||||||
end
|
end
|
||||||
|
@ -78,6 +59,8 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def inference_client
|
def inference_client
|
||||||
|
inference_model_name = "baai/bge-large-en-v1.5"
|
||||||
|
|
||||||
if SiteSetting.ai_cloudflare_workers_api_token.present?
|
if SiteSetting.ai_cloudflare_workers_api_token.present?
|
||||||
DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name)
|
DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name)
|
||||||
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
||||||
|
|
|
@ -18,11 +18,6 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text, asymetric: false)
|
|
||||||
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
|
||||||
inference_client.perform!(truncated_text)
|
|
||||||
end
|
|
||||||
|
|
||||||
def dimensions
|
def dimensions
|
||||||
1024
|
1024
|
||||||
end
|
end
|
||||||
|
@ -43,10 +38,6 @@ module DiscourseAi
|
||||||
"<#>"
|
"<#>"
|
||||||
end
|
end
|
||||||
|
|
||||||
def pg_index_type
|
|
||||||
"halfvec_ip_ops"
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::BgeM3Tokenizer
|
DiscourseAi::Tokenizer::BgeM3Tokenizer
|
||||||
end
|
end
|
||||||
|
|
|
@ -38,14 +38,6 @@ module DiscourseAi
|
||||||
"<=>"
|
"<=>"
|
||||||
end
|
end
|
||||||
|
|
||||||
def pg_index_type
|
|
||||||
"halfvec_cosine_ops"
|
|
||||||
end
|
|
||||||
|
|
||||||
def vector_from(text, asymetric: false)
|
|
||||||
inference_client.perform!(text)
|
|
||||||
end
|
|
||||||
|
|
||||||
# There is no public tokenizer for Gemini, and from the ones we already ship in the plugin
|
# There is no public tokenizer for Gemini, and from the ones we already ship in the plugin
|
||||||
# OpenAI gets the closest results. Gemini Tokenizer results in ~10% less tokens, so it's safe
|
# OpenAI gets the closest results. Gemini Tokenizer results in ~10% less tokens, so it's safe
|
||||||
# to use OpenAI tokenizer since it will overestimate the number of tokens.
|
# to use OpenAI tokenizer since it will overestimate the number of tokens.
|
||||||
|
|
|
@ -28,19 +28,6 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text, asymetric: false)
|
|
||||||
client = inference_client
|
|
||||||
|
|
||||||
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
|
|
||||||
if needs_truncation
|
|
||||||
text = tokenizer.truncate(text, max_sequence_length - 2)
|
|
||||||
elsif !text.starts_with?("query:")
|
|
||||||
text = "query: #{text}"
|
|
||||||
end
|
|
||||||
|
|
||||||
client.perform!(text)
|
|
||||||
end
|
|
||||||
|
|
||||||
def id
|
def id
|
||||||
3
|
3
|
||||||
end
|
end
|
||||||
|
@ -61,10 +48,6 @@ module DiscourseAi
|
||||||
"<=>"
|
"<=>"
|
||||||
end
|
end
|
||||||
|
|
||||||
def pg_index_type
|
|
||||||
"halfvec_cosine_ops"
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
|
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
|
||||||
end
|
end
|
||||||
|
@ -80,8 +63,18 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_text(record)
|
def prepare_text(text, asymetric: false)
|
||||||
prepared_text = super(record)
|
prepared_text = super(text, asymetric: asymetric)
|
||||||
|
|
||||||
|
if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier")
|
||||||
|
return "query: #{prepared_text}"
|
||||||
|
end
|
||||||
|
|
||||||
|
prepared_text
|
||||||
|
end
|
||||||
|
|
||||||
|
def prepare_target_text(target)
|
||||||
|
prepared_text = super(target)
|
||||||
|
|
||||||
if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier")
|
if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier")
|
||||||
return "query: #{prepared_text}"
|
return "query: #{prepared_text}"
|
||||||
|
|
|
@ -40,14 +40,6 @@ module DiscourseAi
|
||||||
"<=>"
|
"<=>"
|
||||||
end
|
end
|
||||||
|
|
||||||
def pg_index_type
|
|
||||||
"halfvec_cosine_ops"
|
|
||||||
end
|
|
||||||
|
|
||||||
def vector_from(text, asymetric: false)
|
|
||||||
inference_client.perform!(text)
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
end
|
end
|
||||||
|
|
|
@ -38,14 +38,6 @@ module DiscourseAi
|
||||||
"<=>"
|
"<=>"
|
||||||
end
|
end
|
||||||
|
|
||||||
def pg_index_type
|
|
||||||
"halfvec_cosine_ops"
|
|
||||||
end
|
|
||||||
|
|
||||||
def vector_from(text, asymetric: false)
|
|
||||||
inference_client.perform!(text)
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
end
|
end
|
||||||
|
|
|
@ -38,14 +38,6 @@ module DiscourseAi
|
||||||
"<=>"
|
"<=>"
|
||||||
end
|
end
|
||||||
|
|
||||||
def pg_index_type
|
|
||||||
"halfvec_cosine_ops"
|
|
||||||
end
|
|
||||||
|
|
||||||
def vector_from(text, asymetric: false)
|
|
||||||
inference_client.perform!(text)
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
end
|
end
|
||||||
|
|
|
@ -326,8 +326,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||||
fab!(:llm_model) { Fabricate(:fake_model) }
|
fab!(:llm_model) { Fabricate(:fake_model) }
|
||||||
|
|
||||||
it "will run the question consolidator" do
|
it "will run the question consolidator" do
|
||||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
vector_def = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||||
context_embedding = vector_rep.dimensions.times.map { rand(-1.0...1.0) }
|
context_embedding = vector_def.dimensions.times.map { rand(-1.0...1.0) }
|
||||||
EmbeddingsGenerationStubs.discourse_service(
|
EmbeddingsGenerationStubs.discourse_service(
|
||||||
SiteSetting.ai_embeddings_model,
|
SiteSetting.ai_embeddings_model,
|
||||||
consolidated_question,
|
consolidated_question,
|
||||||
|
@ -373,14 +373,14 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||||
end
|
end
|
||||||
|
|
||||||
context "when a persona has RAG uploads" do
|
context "when a persona has RAG uploads" do
|
||||||
let(:vector_rep) do
|
let(:vector_def) do
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||||
end
|
end
|
||||||
let(:embedding_value) { 0.04381 }
|
let(:embedding_value) { 0.04381 }
|
||||||
let(:prompt_cc_embeddings) { [embedding_value] * vector_rep.dimensions }
|
let(:prompt_cc_embeddings) { [embedding_value] * vector_def.dimensions }
|
||||||
|
|
||||||
def stub_fragments(fragment_count, persona: ai_persona)
|
def stub_fragments(fragment_count, persona: ai_persona)
|
||||||
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep)
|
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector_def)
|
||||||
|
|
||||||
fragment_count.times do |i|
|
fragment_count.times do |i|
|
||||||
fragment =
|
fragment =
|
||||||
|
@ -393,7 +393,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||||
)
|
)
|
||||||
|
|
||||||
# Similarity is determined left-to-right.
|
# Similarity is determined left-to-right.
|
||||||
embeddings = [embedding_value + "0.000#{i}".to_f] * vector_rep.dimensions
|
embeddings = [embedding_value + "0.000#{i}".to_f] * vector_def.dimensions
|
||||||
|
|
||||||
schema.store(fragment, embeddings, "test")
|
schema.store(fragment, embeddings, "test")
|
||||||
end
|
end
|
||||||
|
|
|
@ -14,9 +14,9 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
|
||||||
fab!(:category)
|
fab!(:category)
|
||||||
fab!(:topic) { Fabricate(:topic, category: category) }
|
fab!(:topic) { Fabricate(:topic, category: category) }
|
||||||
|
|
||||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
let(:vector) { DiscourseAi::Embeddings::Vector.instance }
|
||||||
let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) }
|
let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) }
|
||||||
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
|
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
SiteSetting.ai_embeddings_enabled = true
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
|
@ -28,8 +28,8 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
|
||||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||||
).to_return(status: 200, body: JSON.dump(expected_embedding))
|
).to_return(status: 200, body: JSON.dump(expected_embedding))
|
||||||
|
|
||||||
vector_rep.generate_representation_from(topic)
|
vector.generate_representation_from(topic)
|
||||||
vector_rep.generate_representation_from(muted_topic)
|
vector.generate_representation_from(muted_topic)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "respects user muted categories when making suggestions" do
|
it "respects user muted categories when making suggestions" do
|
||||||
|
|
|
@ -12,21 +12,16 @@ RSpec.describe Jobs::GenerateEmbeddings do
|
||||||
fab!(:topic)
|
fab!(:topic)
|
||||||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||||
|
|
||||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector_def) }
|
||||||
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) }
|
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vector_def) }
|
||||||
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) }
|
|
||||||
|
|
||||||
it "works for topics" do
|
it "works for topics" do
|
||||||
expected_embedding = [0.0038493] * vector_rep.dimensions
|
expected_embedding = [0.0038493] * vector_def.dimensions
|
||||||
|
|
||||||
text =
|
text = vector_def.prepare_target_text(topic)
|
||||||
truncation.prepare_text_from(
|
|
||||||
topic,
|
EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding)
|
||||||
vector_rep.tokenizer,
|
|
||||||
vector_rep.max_sequence_length - 2,
|
|
||||||
)
|
|
||||||
EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding)
|
|
||||||
|
|
||||||
job.execute(target_id: topic.id, target_type: "Topic")
|
job.execute(target_id: topic.id, target_type: "Topic")
|
||||||
|
|
||||||
|
@ -34,11 +29,10 @@ RSpec.describe Jobs::GenerateEmbeddings do
|
||||||
end
|
end
|
||||||
|
|
||||||
it "works for posts" do
|
it "works for posts" do
|
||||||
expected_embedding = [0.0038493] * vector_rep.dimensions
|
expected_embedding = [0.0038493] * vector_def.dimensions
|
||||||
|
|
||||||
text =
|
text = vector_def.prepare_target_text(post)
|
||||||
truncation.prepare_text_from(post, vector_rep.tokenizer, vector_rep.max_sequence_length - 2)
|
EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding)
|
||||||
EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding)
|
|
||||||
|
|
||||||
job.execute(target_id: post.id, target_type: "Post")
|
job.execute(target_id: post.id, target_type: "Post")
|
||||||
|
|
||||||
|
|
|
@ -1,16 +1,12 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Embeddings::Schema do
|
RSpec.describe DiscourseAi::Embeddings::Schema do
|
||||||
subject(:posts_schema) { described_class.for(Post, vector: vector) }
|
subject(:posts_schema) { described_class.for(Post, vector_def: vector_def) }
|
||||||
|
|
||||||
let(:embeddings) { [0.0038490295] * vector.dimensions }
|
let(:embeddings) { [0.0038490295] * vector_def.dimensions }
|
||||||
fab!(:post) { Fabricate(:post, post_number: 1) }
|
fab!(:post) { Fabricate(:post, post_number: 1) }
|
||||||
let(:digest) { OpenSSL::Digest.hexdigest("SHA1", "test") }
|
let(:digest) { OpenSSL::Digest.hexdigest("SHA1", "test") }
|
||||||
let(:vector) do
|
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new }
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new(
|
|
||||||
DiscourseAi::Embeddings::Strategies::Truncation.new,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
before { posts_schema.store(post, embeddings, digest) }
|
before { posts_schema.store(post, embeddings, digest) }
|
||||||
|
|
||||||
|
@ -34,7 +30,7 @@ RSpec.describe DiscourseAi::Embeddings::Schema do
|
||||||
|
|
||||||
describe "similarity searches" do
|
describe "similarity searches" do
|
||||||
fab!(:post_2) { Fabricate(:post) }
|
fab!(:post_2) { Fabricate(:post) }
|
||||||
let(:similar_embeddings) { [0.0038490294] * vector.dimensions }
|
let(:similar_embeddings) { [0.0038490294] * vector_def.dimensions }
|
||||||
|
|
||||||
describe "#symmetric_similarity_search" do
|
describe "#symmetric_similarity_search" do
|
||||||
before { posts_schema.store(post_2, similar_embeddings, digest) }
|
before { posts_schema.store(post_2, similar_embeddings, digest) }
|
||||||
|
|
|
@ -11,8 +11,8 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
||||||
|
|
||||||
describe "#search_for_topics" do
|
describe "#search_for_topics" do
|
||||||
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
|
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
|
||||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||||
let(:hyde_embedding) { [0.049382] * vector_rep.dimensions }
|
let(:hyde_embedding) { [0.049382] * vector_def.dimensions }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||||
|
|
|
@ -3,8 +3,8 @@
|
||||||
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
|
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
|
||||||
subject(:truncation) { described_class.new }
|
subject(:truncation) { described_class.new }
|
||||||
|
|
||||||
describe "#prepare_text_from" do
|
describe "#prepare_query_text" do
|
||||||
context "when using vector from OpenAI" do
|
context "when using vector def from OpenAI" do
|
||||||
before { SiteSetting.max_post_length = 100_000 }
|
before { SiteSetting.max_post_length = 100_000 }
|
||||||
|
|
||||||
fab!(:topic)
|
fab!(:topic)
|
||||||
|
@ -20,15 +20,12 @@ RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
|
||||||
end
|
end
|
||||||
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
|
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
|
||||||
|
|
||||||
let(:model) do
|
let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new }
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new(truncation)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "truncates a topic" do
|
it "truncates a topic" do
|
||||||
prepared_text =
|
prepared_text = truncation.prepare_target_text(topic, vector_def)
|
||||||
truncation.prepare_text_from(topic, model.tokenizer, model.max_sequence_length)
|
|
||||||
|
|
||||||
expect(model.tokenizer.size(prepared_text)).to be <= model.max_sequence_length
|
expect(vector_def.tokenizer.size(prepared_text)).to be <= vector_def.max_sequence_length
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,17 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
require_relative "vector_rep_shared_examples"
|
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 do
|
|
||||||
subject(:vector_rep) { described_class.new(truncation) }
|
|
||||||
|
|
||||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
|
||||||
|
|
||||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
|
||||||
|
|
||||||
def stub_vector_mapping(text, expected_embedding)
|
|
||||||
EmbeddingsGenerationStubs.discourse_service(described_class.name, text, expected_embedding)
|
|
||||||
end
|
|
||||||
|
|
||||||
it_behaves_like "generates and store embedding using with vector representation"
|
|
||||||
end
|
|
|
@ -1,18 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
require_relative "vector_rep_shared_examples"
|
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::Gemini do
|
|
||||||
subject(:vector_rep) { described_class.new(truncation) }
|
|
||||||
|
|
||||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
|
||||||
let!(:api_key) { "test-123" }
|
|
||||||
|
|
||||||
before { SiteSetting.ai_gemini_api_key = api_key }
|
|
||||||
|
|
||||||
def stub_vector_mapping(text, expected_embedding)
|
|
||||||
EmbeddingsGenerationStubs.gemini_service(api_key, text, expected_embedding)
|
|
||||||
end
|
|
||||||
|
|
||||||
it_behaves_like "generates and store embedding using with vector representation"
|
|
||||||
end
|
|
|
@ -1,21 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
require_relative "vector_rep_shared_examples"
|
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large do
|
|
||||||
subject(:vector_rep) { described_class.new(truncation) }
|
|
||||||
|
|
||||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
|
||||||
|
|
||||||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
|
||||||
|
|
||||||
def stub_vector_mapping(text, expected_embedding)
|
|
||||||
EmbeddingsGenerationStubs.discourse_service(
|
|
||||||
described_class.name,
|
|
||||||
"query: #{text}",
|
|
||||||
expected_embedding,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
it_behaves_like "generates and store embedding using with vector representation"
|
|
||||||
end
|
|
|
@ -1,22 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
require_relative "vector_rep_shared_examples"
|
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large do
|
|
||||||
subject(:vector_rep) { described_class.new(truncation) }
|
|
||||||
|
|
||||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
|
||||||
|
|
||||||
def stub_vector_mapping(text, expected_embedding)
|
|
||||||
EmbeddingsGenerationStubs.openai_service(
|
|
||||||
described_class.name,
|
|
||||||
text,
|
|
||||||
expected_embedding,
|
|
||||||
extra_args: {
|
|
||||||
dimensions: 2000,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
it_behaves_like "generates and store embedding using with vector representation"
|
|
||||||
end
|
|
|
@ -1,15 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
require_relative "vector_rep_shared_examples"
|
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small do
|
|
||||||
subject(:vector_rep) { described_class.new(truncation) }
|
|
||||||
|
|
||||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
|
||||||
|
|
||||||
def stub_vector_mapping(text, expected_embedding)
|
|
||||||
EmbeddingsGenerationStubs.openai_service(described_class.name, text, expected_embedding)
|
|
||||||
end
|
|
||||||
|
|
||||||
it_behaves_like "generates and store embedding using with vector representation"
|
|
||||||
end
|
|
|
@ -1,15 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
require_relative "vector_rep_shared_examples"
|
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002 do
|
|
||||||
subject(:vector_rep) { described_class.new(truncation) }
|
|
||||||
|
|
||||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
|
||||||
|
|
||||||
def stub_vector_mapping(text, expected_embedding)
|
|
||||||
EmbeddingsGenerationStubs.openai_service(described_class.name, text, expected_embedding)
|
|
||||||
end
|
|
||||||
|
|
||||||
it_behaves_like "generates and store embedding using with vector representation"
|
|
||||||
end
|
|
|
@ -1,115 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
RSpec.shared_examples "generates and store embedding using with vector representation" do
|
|
||||||
let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions }
|
|
||||||
let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions }
|
|
||||||
|
|
||||||
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) }
|
|
||||||
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) }
|
|
||||||
|
|
||||||
describe "#vector_from" do
|
|
||||||
it "creates a vector from a given string" do
|
|
||||||
text = "This is a piece of text"
|
|
||||||
stub_vector_mapping(text, expected_embedding_1)
|
|
||||||
|
|
||||||
expect(vector_rep.vector_from(text)).to eq(expected_embedding_1)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
describe "#generate_representation_from" do
|
|
||||||
fab!(:topic) { Fabricate(:topic) }
|
|
||||||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
|
||||||
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
|
|
||||||
|
|
||||||
it "creates a vector from a topic and stores it in the database" do
|
|
||||||
text =
|
|
||||||
truncation.prepare_text_from(
|
|
||||||
topic,
|
|
||||||
vector_rep.tokenizer,
|
|
||||||
vector_rep.max_sequence_length - 2,
|
|
||||||
)
|
|
||||||
stub_vector_mapping(text, expected_embedding_1)
|
|
||||||
|
|
||||||
vector_rep.generate_representation_from(topic)
|
|
||||||
|
|
||||||
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "creates a vector from a post and stores it in the database" do
|
|
||||||
text =
|
|
||||||
truncation.prepare_text_from(
|
|
||||||
post2,
|
|
||||||
vector_rep.tokenizer,
|
|
||||||
vector_rep.max_sequence_length - 2,
|
|
||||||
)
|
|
||||||
stub_vector_mapping(text, expected_embedding_1)
|
|
||||||
|
|
||||||
vector_rep.generate_representation_from(post)
|
|
||||||
|
|
||||||
expect(posts_schema.find_by_embedding(expected_embedding_1).post_id).to eq(post.id)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
describe "#gen_bulk_reprensentations" do
|
|
||||||
fab!(:topic) { Fabricate(:topic) }
|
|
||||||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
|
||||||
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
|
|
||||||
|
|
||||||
fab!(:topic_2) { Fabricate(:topic) }
|
|
||||||
fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) }
|
|
||||||
fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) }
|
|
||||||
|
|
||||||
it "creates a vector for each object in the relation" do
|
|
||||||
text =
|
|
||||||
truncation.prepare_text_from(
|
|
||||||
topic,
|
|
||||||
vector_rep.tokenizer,
|
|
||||||
vector_rep.max_sequence_length - 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
text2 =
|
|
||||||
truncation.prepare_text_from(
|
|
||||||
topic_2,
|
|
||||||
vector_rep.tokenizer,
|
|
||||||
vector_rep.max_sequence_length - 2,
|
|
||||||
)
|
|
||||||
|
|
||||||
stub_vector_mapping(text, expected_embedding_1)
|
|
||||||
stub_vector_mapping(text2, expected_embedding_2)
|
|
||||||
|
|
||||||
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
|
|
||||||
|
|
||||||
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "does nothing if passed record has no content" do
|
|
||||||
expect { vector_rep.gen_bulk_reprensentations([Topic.new]) }.not_to raise_error
|
|
||||||
end
|
|
||||||
|
|
||||||
it "doesn't ask for a new embedding if digest is the same" do
|
|
||||||
text =
|
|
||||||
truncation.prepare_text_from(
|
|
||||||
topic,
|
|
||||||
vector_rep.tokenizer,
|
|
||||||
vector_rep.max_sequence_length - 2,
|
|
||||||
)
|
|
||||||
stub_vector_mapping(text, expected_embedding_1)
|
|
||||||
|
|
||||||
original_vector_gen = Time.zone.parse("2021-06-04 10:00")
|
|
||||||
|
|
||||||
freeze_time(original_vector_gen) do
|
|
||||||
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
|
|
||||||
end
|
|
||||||
# check vector exists
|
|
||||||
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
|
||||||
|
|
||||||
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
|
|
||||||
last_update =
|
|
||||||
DB.query_single(
|
|
||||||
"SELECT updated_at FROM #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE topic_id = #{topic.id} LIMIT 1",
|
|
||||||
).first
|
|
||||||
|
|
||||||
expect(last_update).to eq(original_vector_gen)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -0,0 +1,160 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Embeddings::Vector do
|
||||||
|
shared_examples "generates and store embeddings using a vector definition" do
|
||||||
|
subject(:vector) { described_class.new(vdef) }
|
||||||
|
|
||||||
|
let(:expected_embedding_1) { [0.0038493] * vdef.dimensions }
|
||||||
|
let(:expected_embedding_2) { [0.0037684] * vdef.dimensions }
|
||||||
|
|
||||||
|
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vdef) }
|
||||||
|
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vdef) }
|
||||||
|
|
||||||
|
fab!(:topic)
|
||||||
|
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||||
|
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
|
||||||
|
|
||||||
|
describe "#vector_from" do
|
||||||
|
it "creates a vector from a given string" do
|
||||||
|
text = "This is a piece of text"
|
||||||
|
stub_vector_mapping(text, expected_embedding_1)
|
||||||
|
|
||||||
|
expect(vector.vector_from(text)).to eq(expected_embedding_1)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#generate_representation_from" do
|
||||||
|
it "creates a vector from a topic and stores it in the database" do
|
||||||
|
text = vdef.prepare_target_text(topic)
|
||||||
|
stub_vector_mapping(text, expected_embedding_1)
|
||||||
|
|
||||||
|
vector.generate_representation_from(topic)
|
||||||
|
|
||||||
|
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "creates a vector from a post and stores it in the database" do
|
||||||
|
text = vdef.prepare_target_text(post2)
|
||||||
|
stub_vector_mapping(text, expected_embedding_1)
|
||||||
|
|
||||||
|
vector.generate_representation_from(post)
|
||||||
|
|
||||||
|
expect(posts_schema.find_by_embedding(expected_embedding_1).post_id).to eq(post.id)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#gen_bulk_reprensentations" do
|
||||||
|
fab!(:topic_2) { Fabricate(:topic) }
|
||||||
|
fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) }
|
||||||
|
fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) }
|
||||||
|
|
||||||
|
it "creates a vector for each object in the relation" do
|
||||||
|
text = vdef.prepare_target_text(topic)
|
||||||
|
|
||||||
|
text2 = vdef.prepare_target_text(topic_2)
|
||||||
|
|
||||||
|
stub_vector_mapping(text, expected_embedding_1)
|
||||||
|
stub_vector_mapping(text2, expected_embedding_2)
|
||||||
|
|
||||||
|
vector.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
|
||||||
|
|
||||||
|
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "does nothing if passed record has no content" do
|
||||||
|
expect { vector.gen_bulk_reprensentations([Topic.new]) }.not_to raise_error
|
||||||
|
end
|
||||||
|
|
||||||
|
it "doesn't ask for a new embedding if digest is the same" do
|
||||||
|
text = vdef.prepare_target_text(topic)
|
||||||
|
stub_vector_mapping(text, expected_embedding_1)
|
||||||
|
|
||||||
|
original_vector_gen = Time.zone.parse("2021-06-04 10:00")
|
||||||
|
|
||||||
|
freeze_time(original_vector_gen) do
|
||||||
|
vector.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
|
||||||
|
end
|
||||||
|
# check vector exists
|
||||||
|
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
||||||
|
|
||||||
|
vector.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
|
||||||
|
|
||||||
|
expect(topics_schema.find_by_target(topic).updated_at).to eq_time(original_vector_gen)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with text-embedding-ada-002" do
|
||||||
|
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new }
|
||||||
|
|
||||||
|
def stub_vector_mapping(text, expected_embedding)
|
||||||
|
EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding)
|
||||||
|
end
|
||||||
|
|
||||||
|
it_behaves_like "generates and store embeddings using a vector definition"
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with all all-mpnet-base-v2" do
|
||||||
|
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new }
|
||||||
|
|
||||||
|
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||||
|
|
||||||
|
def stub_vector_mapping(text, expected_embedding)
|
||||||
|
EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding)
|
||||||
|
end
|
||||||
|
|
||||||
|
it_behaves_like "generates and store embeddings using a vector definition"
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with gemini" do
|
||||||
|
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::Gemini.new }
|
||||||
|
let(:api_key) { "test-123" }
|
||||||
|
|
||||||
|
before { SiteSetting.ai_gemini_api_key = api_key }
|
||||||
|
|
||||||
|
def stub_vector_mapping(text, expected_embedding)
|
||||||
|
EmbeddingsGenerationStubs.gemini_service(api_key, text, expected_embedding)
|
||||||
|
end
|
||||||
|
|
||||||
|
it_behaves_like "generates and store embeddings using a vector definition"
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with multilingual-e5-large" do
|
||||||
|
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large.new }
|
||||||
|
|
||||||
|
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||||
|
|
||||||
|
def stub_vector_mapping(text, expected_embedding)
|
||||||
|
EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding)
|
||||||
|
end
|
||||||
|
|
||||||
|
it_behaves_like "generates and store embeddings using a vector definition"
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with text-embedding-3-large" do
|
||||||
|
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large.new }
|
||||||
|
|
||||||
|
def stub_vector_mapping(text, expected_embedding)
|
||||||
|
EmbeddingsGenerationStubs.openai_service(
|
||||||
|
vdef.class.name,
|
||||||
|
text,
|
||||||
|
expected_embedding,
|
||||||
|
extra_args: {
|
||||||
|
dimensions: 2000,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
it_behaves_like "generates and store embeddings using a vector definition"
|
||||||
|
end
|
||||||
|
|
||||||
|
context "with text-embedding-3-small" do
|
||||||
|
let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small.new }
|
||||||
|
|
||||||
|
def stub_vector_mapping(text, expected_embedding)
|
||||||
|
EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding)
|
||||||
|
end
|
||||||
|
|
||||||
|
it_behaves_like "generates and store embeddings using a vector definition"
|
||||||
|
end
|
||||||
|
end
|
|
@ -74,7 +74,7 @@ RSpec.describe RagDocumentFragment do
|
||||||
end
|
end
|
||||||
|
|
||||||
describe ".indexing_status" do
|
describe ".indexing_status" do
|
||||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
let(:vector) { DiscourseAi::Embeddings::Vector.instance }
|
||||||
|
|
||||||
fab!(:rag_document_fragment_1) do
|
fab!(:rag_document_fragment_1) do
|
||||||
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
|
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
|
||||||
|
@ -84,7 +84,7 @@ RSpec.describe RagDocumentFragment do
|
||||||
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
|
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
|
let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
SiteSetting.ai_embeddings_enabled = true
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
|
@ -96,7 +96,7 @@ RSpec.describe RagDocumentFragment do
|
||||||
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||||
).to_return(status: 200, body: JSON.dump(expected_embedding))
|
).to_return(status: 200, body: JSON.dump(expected_embedding))
|
||||||
|
|
||||||
vector_rep.generate_representation_from(rag_document_fragment_1)
|
vector.generate_representation_from(rag_document_fragment_1)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "regenerates all embeddings if ai_embeddings_model changes" do
|
it "regenerates all embeddings if ai_embeddings_model changes" do
|
||||||
|
|
|
@ -19,14 +19,14 @@ describe DiscourseAi::Embeddings::EmbeddingsController do
|
||||||
fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) }
|
fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) }
|
||||||
|
|
||||||
def index(topic)
|
def index(topic)
|
||||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
vector = DiscourseAi::Embeddings::Vector.instance
|
||||||
|
|
||||||
stub_request(:post, "https://api.openai.com/v1/embeddings").to_return(
|
stub_request(:post, "https://api.openai.com/v1/embeddings").to_return(
|
||||||
status: 200,
|
status: 200,
|
||||||
body: JSON.dump({ data: [{ embedding: [0.1] * 1536 }] }),
|
body: JSON.dump({ data: [{ embedding: [0.1] * 1536 }] }),
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_rep.generate_representation_from(topic)
|
vector.generate_representation_from(topic)
|
||||||
end
|
end
|
||||||
|
|
||||||
def stub_embedding(query)
|
def stub_embedding(query)
|
||||||
|
|
Loading…
Reference in New Issue