REFACTOR: A Simpler way of interacting with embeddings tables. (#1023)

* REFACTOR: A Simpler way of interacting with embeddings' tables.

This change adds a new abstraction called `Schema`, which acts as a repository that supports the same DB features `VectorRepresentation::Base` has, with the exception that removes the need to have duplicated methods per embeddings table.

It is also a bit more flexible when performing a similarity search because you can pass it a block that gives you access to the builder, allowing you to add multiple joins/where conditions.
This commit is contained in:
Roman Rizzi 2024-12-13 10:15:21 -03:00 committed by GitHub
parent 97ec2c5ff4
commit eae527f99d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 485 additions and 599 deletions

View File

@ -18,9 +18,7 @@ module ::Jobs
target = target_type.constantize.find_by(id: target_id) target = target_type.constantize.find_by(id: target_id)
return if !target return if !target
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
tokenizer = vector_rep.tokenizer tokenizer = vector_rep.tokenizer
chunk_tokens = target.rag_chunk_tokens chunk_tokens = target.rag_chunk_tokens

View File

@ -16,9 +16,7 @@ module Jobs
return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
return if post.raw.blank? return if post.raw.blank?
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
vector_rep.generate_representation_from(target) vector_rep.generate_representation_from(target)
end end

View File

@ -8,9 +8,7 @@ module ::Jobs
def execute(args) def execute(args)
return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty? return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty?
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
# generate_representation_from checks compares the digest value to make sure # generate_representation_from checks compares the digest value to make sure
# the embedding is only generated once per fragment unless something changes. # the embedding is only generated once per fragment unless something changes.

View File

@ -20,10 +20,8 @@ module Jobs
rebaked = 0 rebaked = 0
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_rep = table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
table_name = vector_rep.topic_table_name
topics = topics =
Topic Topic
@ -41,7 +39,7 @@ module Jobs
relation = topics.where(<<~SQL).limit(limit - rebaked) relation = topics.where(<<~SQL).limit(limit - rebaked)
#{table_name}.model_version < #{vector_rep.version} #{table_name}.model_version < #{vector_rep.version}
OR OR
#{table_name}.strategy_version < #{strategy.version} #{table_name}.strategy_version < #{vector_rep.strategy_version}
SQL SQL
rebaked += populate_topic_embeddings(vector_rep, relation) rebaked += populate_topic_embeddings(vector_rep, relation)
@ -63,7 +61,7 @@ module Jobs
return unless SiteSetting.ai_embeddings_per_post_enabled return unless SiteSetting.ai_embeddings_per_post_enabled
# Now for posts # Now for posts
table_name = vector_rep.post_table_name table_name = DiscourseAi::Embeddings::Schema::POSTS_TABLE
posts_batch_size = 1000 posts_batch_size = 1000
posts = posts =
@ -121,7 +119,8 @@ module Jobs
def populate_topic_embeddings(vector_rep, topics, force: false) def populate_topic_embeddings(vector_rep, topics, force: false)
done = 0 done = 0
topics = topics.where("#{vector_rep.topic_table_name}.topic_id IS NULL") if !force topics =
topics.where("#{DiscourseAi::Embeddings::Schema::TOPICS_TABLE}.topic_id IS NULL") if !force
ids = topics.pluck("topics.id") ids = topics.pluck("topics.id")
batch_size = 1000 batch_size = 1000

View File

@ -39,11 +39,7 @@ class RagDocumentFragment < ActiveRecord::Base
end end
def indexing_status(persona, uploads) def indexing_status(persona, uploads)
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new embeddings_table = DiscourseAi::Embeddings::Schema.for(self).table
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
embeddings_table = vector_rep.rag_fragments_table_name
results = results =
DB.query( DB.query(

View File

@ -14,7 +14,7 @@ class MigrateEmbeddingsFromDedicatedDatabase < ActiveRecord::Migration[7.0]
].map { |k| k.new(truncation) } ].map { |k| k.new(truncation) }
vector_reps.each do |vector_rep| vector_reps.each do |vector_rep|
new_table_name = vector_rep.topic_table_name new_table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
old_table_name = "topic_embeddings_#{vector_rep.name.underscore}" old_table_name = "topic_embeddings_#{vector_rep.name.underscore}"
begin begin

View File

@ -147,9 +147,7 @@ class MoveEmbeddingsToSingleTablePerType < ActiveRecord::Migration[7.0]
SQL SQL
begin begin
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
rescue StandardError => e rescue StandardError => e
Rails.logger.error("Failed to index embeddings: #{e}") Rails.logger.error("Failed to index embeddings: #{e}")
end end

View File

@ -314,30 +314,34 @@ module DiscourseAi
return nil if !consolidated_question return nil if !consolidated_question
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings
interactions_vector = vector_rep.vector_from(consolidated_question) interactions_vector = vector_rep.vector_from(consolidated_question)
rag_conversation_chunks = self.class.rag_conversation_chunks rag_conversation_chunks = self.class.rag_conversation_chunks
search_limit =
if reranker.reranker_configured?
rag_conversation_chunks * 5
else
rag_conversation_chunks
end
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep)
candidate_fragment_ids = candidate_fragment_ids =
vector_rep.asymmetric_rag_fragment_similarity_search( schema
interactions_vector, .asymmetric_similarity_search(
target_type: "AiPersona", interactions_vector,
target_id: id, limit: search_limit,
limit: offset: 0,
( ) { |builder| builder.join(<<~SQL, target_id: id, target_type: "AiPersona") }
if reranker.reranker_configured? rag_document_fragments ON
rag_conversation_chunks * 5 rag_document_fragments.id = rag_document_fragment_id AND
else rag_document_fragments.target_id = :target_id AND
rag_conversation_chunks rag_document_fragments.target_type = :target_type
end SQL
), .map(&:rag_document_fragment_id)
offset: 0,
)
fragments = fragments =
RagDocumentFragment.where(upload_id: upload_refs, id: candidate_fragment_ids).pluck( RagDocumentFragment.where(upload_id: upload_refs, id: candidate_fragment_ids).pluck(

View File

@ -141,18 +141,20 @@ module DiscourseAi
return [] if upload_refs.empty? return [] if upload_refs.empty?
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
query_vector = vector_rep.vector_from(query) query_vector = vector_rep.vector_from(query)
fragment_ids = fragment_ids =
vector_rep.asymmetric_rag_fragment_similarity_search( DiscourseAi::Embeddings::Schema
query_vector, .for(RagDocumentFragment, vector: vector_rep)
target_type: "AiTool", .asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder|
target_id: tool.id, builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool")
limit: limit, rag_document_fragments ON
offset: 0, rag_document_fragments.id = rag_document_fragment_id AND
) rag_document_fragments.target_id = :target_id AND
rag_document_fragments.target_type = :target_type
SQL
end
.map(&:rag_document_fragment_id)
fragments = fragments =
RagDocumentFragment.where(id: fragment_ids, upload_id: upload_refs).pluck( RagDocumentFragment.where(id: fragment_ids, upload_id: upload_refs).pluck(

View File

@ -92,9 +92,8 @@ module DiscourseAi
private private
def nearest_neighbors(limit: 100) def nearest_neighbors(limit: 100)
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_rep = schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
raw_vector = vector_rep.vector_from(@text) raw_vector = vector_rep.vector_from(@text)
@ -107,13 +106,15 @@ module DiscourseAi
).pluck(:category_id) ).pluck(:category_id)
end end
vector_rep.asymmetric_topics_similarity_search( schema
raw_vector, .asymmetric_similarity_search(raw_vector, limit: limit, offset: 0) do |builder|
limit: limit, builder.join("topics t on t.id = topic_id")
offset: 0, builder.where(<<~SQL, exclude_category_ids: muted_category_ids.map(&:to_i))
return_distance: true, t.category_id NOT IN (:exclude_category_ids) AND
exclude_category_ids: muted_category_ids, t.category_id NOT IN (SELECT categories.id FROM categories WHERE categories.parent_category_id IN (:exclude_category_ids))
) SQL
end
.map { |r| [r.topic_id, r.distance] }
end end
end end
end end

194
lib/embeddings/schema.rb Normal file
View File

@ -0,0 +1,194 @@
# frozen_string_literal: true
# We don't have AR objects for our embeddings, so this class
# acts as an intermediary between us and the DB.
# It lets us retrieve embeddings either symmetrically and asymmetrically,
# and also store them.
module DiscourseAi
module Embeddings
class Schema
TOPICS_TABLE = "ai_topic_embeddings"
POSTS_TABLE = "ai_post_embeddings"
RAG_DOCS_TABLE = "ai_document_fragment_embeddings"
def self.for(
target_klass,
vector: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
)
case target_klass&.name
when "Topic"
new(TOPICS_TABLE, "topic_id", vector)
when "Post"
new(POSTS_TABLE, "post_id", vector)
when "RagDocumentFragment"
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector)
else
raise ArgumentError, "Invalid target type for embeddings"
end
end
def initialize(table, target_column, vector)
@table = table
@target_column = target_column
@vector = vector
end
attr_reader :table, :target_column, :vector
def find_by_embedding(embedding)
DB.query(<<~SQL, query_embedding: embedding, vid: vector.id, vsid: vector.strategy_id).first
SELECT *
FROM #{table}
WHERE
model_id = :vid AND strategy_id = :vsid
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT 1
SQL
end
def find_by_target(target)
DB.query(<<~SQL, target_id: target.id, vid: vector.id, vsid: vector.strategy_id).first
SELECT *
FROM #{table}
WHERE
model_id = :vid AND
strategy_id = :vsid AND
#{target_column} = :target_id
LIMIT 1
SQL
end
def asymmetric_similarity_search(embedding, limit:, offset:)
builder = DB.build(<<~SQL)
WITH candidates AS (
SELECT
#{target_column},
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{table}
/*join*/
/*where*/
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :candidates_limit
)
SELECT
#{target_column},
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
builder.where(
"model_id = :model_id AND strategy_id = :strategy_id",
model_id: vector.id,
strategy_id: vector.strategy_id,
)
yield(builder) if block_given?
if table == RAG_DOCS_TABLE
# A too low limit exacerbates the the recall loss of binary quantization
candidates_limit = [limit * 2, 100].max
else
candidates_limit = limit * 2
end
builder.query(
query_embedding: embedding,
candidates_limit: candidates_limit,
limit: limit,
offset: offset,
)
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def symmetric_similarity_search(record)
builder = DB.build(<<~SQL)
WITH le_target AS (
SELECT
embeddings
FROM
#{table}
WHERE
model_id = :vid AND
strategy_id = :vsid AND
#{target_column} = :target_id
LIMIT 1
)
SELECT #{target_column} FROM (
SELECT
#{target_column}, embeddings
FROM
#{table}
/*join*/
/*where*/
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> (
SELECT
binary_quantize(embeddings)::bit(#{dimensions})
FROM
le_target
LIMIT 1
)
LIMIT 200
) AS widenet
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} (
SELECT
embeddings::halfvec(#{dimensions})
FROM
le_target
LIMIT 1
)
LIMIT 100;
SQL
builder.where("model_id = :vid AND strategy_id = :vsid")
yield(builder) if block_given?
builder.query(vid: vector.id, vsid: vector.strategy_id, target_id: record.id)
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def store(record, embedding, digest)
DB.exec(
<<~SQL,
INSERT INTO #{table} (#{target_column}, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:target_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (model_id, strategy_id, #{target_column})
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
SQL
target_id: record.id,
model_id: vector.id,
model_version: vector.version,
strategy_id: vector.strategy_id,
strategy_version: vector.strategy_version,
digest: digest,
embeddings: embedding,
now: Time.zone.now,
)
end
private
delegate :dimensions, :pg_function, to: :vector
end
end
end

View File

@ -13,16 +13,16 @@ module DiscourseAi
def related_topic_ids_for(topic) def related_topic_ids_for(topic)
return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1 return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
cache_for = results_ttl(topic) cache_for = results_ttl(topic)
Discourse Discourse
.cache .cache
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do .fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
vector_rep DiscourseAi::Embeddings::Schema
.symmetric_topics_similarity_search(topic) .for(Topic, vector: vector_rep)
.symmetric_similarity_search(topic)
.map(&:topic_id)
.tap do |candidate_ids| .tap do |candidate_ids|
# Happens when the topic doesn't have any embeddings # Happens when the topic doesn't have any embeddings
# I'd rather not use Exceptions to control the flow, so this should be refactored soon # I'd rather not use Exceptions to control the flow, so this should be refactored soon

View File

@ -31,10 +31,7 @@ module DiscourseAi
end end
def vector_rep def vector_rep
@vector_rep ||= @vector_rep ||= DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(
DiscourseAi::Embeddings::Strategies::Truncation.new,
)
end end
def hyde_embedding(search_term) def hyde_embedding(search_term)
@ -87,12 +84,14 @@ module DiscourseAi
over_selection_limit = limit * OVER_SELECTION_FACTOR over_selection_limit = limit * OVER_SELECTION_FACTOR
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
candidate_topic_ids = candidate_topic_ids =
vector_rep.asymmetric_topics_similarity_search( schema.asymmetric_similarity_search(
search_embedding, search_embedding,
limit: over_selection_limit, limit: over_selection_limit,
offset: offset, offset: offset,
) ).map(&:topic_id)
semantic_results = semantic_results =
::Post ::Post
@ -115,9 +114,7 @@ module DiscourseAi
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
digest = OpenSSL::Digest::SHA1.hexdigest(search_term) digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
@ -136,11 +133,14 @@ module DiscourseAi
end end
candidate_post_ids = candidate_post_ids =
vector_rep.asymmetric_posts_similarity_search( DiscourseAi::Embeddings::Schema
search_term_embedding, .for(Post, vector: vector_rep)
limit: max_semantic_results_per_page, .asymmetric_similarity_search(
offset: 0, search_term_embedding,
) limit: max_semantic_results_per_page,
offset: 0,
)
.map(&:post_id)
semantic_results = semantic_results =
::Post ::Post

View File

@ -20,8 +20,9 @@ module DiscourseAi
].find { _1.name == model_name } ].find { _1.name == model_name }
end end
def current_representation(strategy) def current_representation
find_representation(SiteSetting.ai_embeddings_model).new(strategy) truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
find_representation(SiteSetting.ai_embeddings_model).new(truncation)
end end
def correctly_configured? def correctly_configured?
@ -59,6 +60,8 @@ module DiscourseAi
idletime: 30, idletime: 30,
) )
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector: self)
embedding_gen = inference_client embedding_gen = inference_client
promised_embeddings = promised_embeddings =
relation relation
@ -67,7 +70,7 @@ module DiscourseAi
next if prepared_text.blank? next if prepared_text.blank?
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text) new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
next if find_digest_of(record) == new_digest next if schema.find_by_target(record)&.digest == new_digest
Concurrent::Promises Concurrent::Promises
.fulfilled_future( .fulfilled_future(
@ -83,7 +86,7 @@ module DiscourseAi
Concurrent::Promises Concurrent::Promises
.zip(*promised_embeddings) .zip(*promised_embeddings)
.value! .value!
.each { |e| save_to_db(e[:target], e[:embedding], e[:digest]) } .each { |e| schema.store(e[:target], e[:embedding], e[:digest]) }
pool.shutdown pool.shutdown
pool.wait_for_termination pool.wait_for_termination
@ -93,265 +96,14 @@ module DiscourseAi
text = prepare_text(target) text = prepare_text(target)
return if text.blank? return if text.blank?
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector: self)
new_digest = OpenSSL::Digest::SHA1.hexdigest(text) new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
return if find_digest_of(target) == new_digest return if schema.find_by_target(target)&.digest == new_digest
vector = vector_from(text) vector = vector_from(text)
save_to_db(target, vector, new_digest) if persist schema.store(target, vector, new_digest) if persist
end
def topic_id_from_representation(raw_vector)
DB.query_single(<<~SQL, query_embedding: raw_vector).first
SELECT
topic_id
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id}
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT 1
SQL
end
def post_id_from_representation(raw_vector)
DB.query_single(<<~SQL, query_embedding: raw_vector).first
SELECT
post_id
FROM
#{post_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id}
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT 1
SQL
end
def asymmetric_topics_similarity_search(
raw_vector,
limit:,
offset:,
return_distance: false,
exclude_category_ids: nil
)
builder = DB.build(<<~SQL)
WITH candidates AS (
SELECT
topic_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{topic_table_name}
/*join*/
/*where*/
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :limit * 2
)
SELECT
topic_id,
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
builder.where(
"model_id = :model_id AND strategy_id = :strategy_id",
model_id: id,
strategy_id: @strategy.id,
)
if exclude_category_ids.present?
builder.join("topics t on t.id = topic_id")
builder.where(<<~SQL, exclude_category_ids: exclude_category_ids.map(&:to_i))
t.category_id NOT IN (:exclude_category_ids) AND
t.category_id NOT IN (SELECT categories.id FROM categories WHERE categories.parent_category_id IN (:exclude_category_ids))
SQL
end
results = builder.query(query_embedding: raw_vector, limit: limit, offset: offset)
if return_distance
results.map { |r| [r.topic_id, r.distance] }
else
results.map(&:topic_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def asymmetric_posts_similarity_search(raw_vector, limit:, offset:, return_distance: false)
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
WITH candidates AS (
SELECT
post_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{post_table_name}
WHERE
model_id = #{id} AND strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :limit * 2
)
SELECT
post_id,
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
if return_distance
results.map { |r| [r.post_id, r.distance] }
else
results.map(&:post_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def asymmetric_rag_fragment_similarity_search(
raw_vector,
target_id:,
target_type:,
limit:,
offset:,
return_distance: false
)
# A too low limit exacerbates the the recall loss of binary quantization
binary_search_limit = [limit * 2, 100].max
results =
DB.query(
<<~SQL,
WITH candidates AS (
SELECT
rag_document_fragment_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{rag_fragments_table_name}
INNER JOIN
rag_document_fragments ON
rag_document_fragments.id = rag_document_fragment_id AND
rag_document_fragments.target_id = :target_id AND
rag_document_fragments.target_type = :target_type
WHERE
model_id = #{id} AND strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :binary_search_limit
)
SELECT
rag_document_fragment_id,
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
query_embedding: raw_vector,
target_id: target_id,
target_type: target_type,
limit: limit,
offset: offset,
binary_search_limit: binary_search_limit,
)
if return_distance
results.map { |r| [r.rag_document_fragment_id, r.distance] }
else
results.map(&:rag_document_fragment_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def symmetric_topics_similarity_search(topic)
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
WITH le_target AS (
SELECT
embeddings
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id} AND
topic_id = :topic_id
LIMIT 1
)
SELECT topic_id FROM (
SELECT
topic_id, embeddings
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> (
SELECT
binary_quantize(embeddings)::bit(#{dimensions})
FROM
le_target
LIMIT 1
)
LIMIT 200
) AS widenet
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} (
SELECT
embeddings::halfvec(#{dimensions})
FROM
le_target
LIMIT 1
)
LIMIT 100;
SQL
rescue PG::Error => e
Rails.logger.error(
"Error #{e} querying embeddings for topic #{topic.id} and model #{name}",
)
raise MissingEmbeddingError
end
def topic_table_name
"ai_topic_embeddings"
end
def post_table_name
"ai_post_embeddings"
end
def rag_fragments_table_name
"ai_document_fragment_embeddings"
end
def table_name(target)
case target
when Topic
topic_table_name
when Post
post_table_name
when RagDocumentFragment
rag_fragments_table_name
else
raise ArgumentError, "Invalid target type"
end
end end
def index_name(table_name) def index_name(table_name)
@ -390,106 +142,16 @@ module DiscourseAi
raise NotImplementedError raise NotImplementedError
end end
def strategy_id
@strategy.id
end
def strategy_version
@strategy.version
end
protected protected
def find_digest_of(target)
target_column =
case target
when Topic
"topic_id"
when Post
"post_id"
when RagDocumentFragment
"rag_document_fragment_id"
else
raise ArgumentError, "Invalid target type"
end
DB.query_single(<<~SQL, target_id: target.id).first
SELECT
digest
FROM
#{table_name(target)}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id} AND
#{target_column} = :target_id
LIMIT 1
SQL
end
def save_to_db(target, vector, digest)
if target.is_a?(Topic)
DB.exec(
<<~SQL,
INSERT INTO #{topic_table_name} (topic_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:topic_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (strategy_id, model_id, topic_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
SQL
topic_id: target.id,
model_id: id,
model_version: version,
strategy_id: @strategy.id,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
now: Time.zone.now,
)
elsif target.is_a?(Post)
DB.exec(
<<~SQL,
INSERT INTO #{post_table_name} (post_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:post_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (model_id, strategy_id, post_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
SQL
post_id: target.id,
model_id: id,
model_version: version,
strategy_id: @strategy.id,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
now: Time.zone.now,
)
elsif target.is_a?(RagDocumentFragment)
DB.exec(
<<~SQL,
INSERT INTO #{rag_fragments_table_name} (rag_document_fragment_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:fragment_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (model_id, strategy_id, rag_document_fragment_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
SQL
fragment_id: target.id,
model_id: id,
model_version: version,
strategy_id: @strategy.id,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
now: Time.zone.now,
)
else
raise ArgumentError, "Invalid target type"
end
end
def inference_client def inference_client
raise NotImplementedError raise NotImplementedError
end end

View File

@ -4,17 +4,16 @@ desc "Backfill embeddings for all topics and posts"
task "ai:embeddings:backfill", %i[model concurrency] => [:environment] do |_, args| task "ai:embeddings:backfill", %i[model concurrency] => [:environment] do |_, args|
public_categories = Category.where(read_restricted: false).pluck(:id) public_categories = Category.where(read_restricted: false).pluck(:id)
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
if args[:model].present? if args[:model].present?
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep = vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(args[:model]).new( DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(args[:model]).new(
strategy, strategy,
) )
else else
vector_rep = vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
end end
table_name = vector_rep.topic_table_name table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
topics = topics =
Topic Topic

View File

@ -6,10 +6,7 @@ RSpec.describe Jobs::DigestRagUpload do
let(:document_file) { StringIO.new("some text" * 200) } let(:document_file) { StringIO.new("some text" * 200) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:vector_rep) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
end
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions } let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }

View File

@ -2,10 +2,7 @@
RSpec.describe Jobs::GenerateRagEmbeddings do RSpec.describe Jobs::GenerateRagEmbeddings do
describe "#execute" do describe "#execute" do
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:vector_rep) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
end
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions } let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
@ -30,7 +27,9 @@ RSpec.describe Jobs::GenerateRagEmbeddings do
subject.execute(fragment_ids: [rag_document_fragment_1.id, rag_document_fragment_2.id]) subject.execute(fragment_ids: [rag_document_fragment_1.id, rag_document_fragment_2.id])
embeddings_count = embeddings_count =
DB.query_single("SELECT COUNT(*) from #{vector_rep.rag_fragments_table_name}").first DB.query_single(
"SELECT COUNT(*) from #{DiscourseAi::Embeddings::Schema::RAG_DOCS_TABLE}",
).first
expect(embeddings_count).to eq(expected_embeddings) expect(embeddings_count).to eq(expected_embeddings)
end end

View File

@ -19,10 +19,7 @@ RSpec.describe Jobs::EmbeddingsBackfill do
topic topic
end end
let(:vector_rep) do let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
end
before do before do
SiteSetting.ai_embeddings_enabled = true SiteSetting.ai_embeddings_enabled = true
@ -41,7 +38,8 @@ RSpec.describe Jobs::EmbeddingsBackfill do
Jobs::EmbeddingsBackfill.new.execute({}) Jobs::EmbeddingsBackfill.new.execute({})
topic_ids = DB.query_single("SELECT topic_id from #{vector_rep.topic_table_name}") topic_ids =
DB.query_single("SELECT topic_id from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE}")
expect(topic_ids).to eq([first_topic.id]) expect(topic_ids).to eq([first_topic.id])
@ -49,7 +47,8 @@ RSpec.describe Jobs::EmbeddingsBackfill do
SiteSetting.ai_embeddings_backfill_batch_size = 100 SiteSetting.ai_embeddings_backfill_batch_size = 100
Jobs::EmbeddingsBackfill.new.execute({}) Jobs::EmbeddingsBackfill.new.execute({})
topic_ids = DB.query_single("SELECT topic_id from #{vector_rep.topic_table_name}") topic_ids =
DB.query_single("SELECT topic_id from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE}")
expect(topic_ids).to contain_exactly(first_topic.id, second_topic.id, third_topic.id) expect(topic_ids).to contain_exactly(first_topic.id, second_topic.id, third_topic.id)
@ -62,7 +61,7 @@ RSpec.describe Jobs::EmbeddingsBackfill do
index_date = index_date =
DB.query_single( DB.query_single(
"SELECT updated_at from #{vector_rep.topic_table_name} WHERE topic_id = ?", "SELECT updated_at from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE topic_id = ?",
third_topic.id, third_topic.id,
).first ).first

View File

@ -326,9 +326,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
fab!(:llm_model) { Fabricate(:fake_model) } fab!(:llm_model) { Fabricate(:fake_model) }
it "will run the question consolidator" do it "will run the question consolidator" do
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
context_embedding = vector_rep.dimensions.times.map { rand(-1.0...1.0) } context_embedding = vector_rep.dimensions.times.map { rand(-1.0...1.0) }
EmbeddingsGenerationStubs.discourse_service( EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model, SiteSetting.ai_embeddings_model,
@ -375,41 +373,44 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
end end
context "when a persona has RAG uploads" do context "when a persona has RAG uploads" do
def stub_fragments(limit, expected_limit: nil) let(:vector_rep) do
candidate_ids = [] DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
end
let(:embedding_value) { 0.04381 }
let(:prompt_cc_embeddings) { [embedding_value] * vector_rep.dimensions }
limit.times do |i| def stub_fragments(fragment_count, persona: ai_persona)
candidate_ids << Fabricate( schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep)
:rag_document_fragment,
fragment: "fragment-n#{i}", fragment_count.times do |i|
target_id: ai_persona.id, fragment =
target_type: "AiPersona", Fabricate(
upload: upload, :rag_document_fragment,
).id fragment: "fragment-n#{i}",
target_id: persona.id,
target_type: "AiPersona",
upload: upload,
)
# Similarity is determined left-to-right.
embeddings = [embedding_value + "0.000#{i}".to_f] * vector_rep.dimensions
schema.store(fragment, embeddings, "test")
end end
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn
.any_instance
.expects(:asymmetric_rag_fragment_similarity_search)
.with { |args, kwargs| kwargs[:limit] == (expected_limit || limit) }
.returns(candidate_ids)
end end
before do before do
stored_ai_persona = AiPersona.find(ai_persona.id) stored_ai_persona = AiPersona.find(ai_persona.id)
UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id]) UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id])
context_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service( EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model, SiteSetting.ai_embeddings_model,
with_cc.dig(:conversation_context, 0, :content), with_cc.dig(:conversation_context, 0, :content),
context_embedding, prompt_cc_embeddings,
) )
end end
context "when persona allows for less fragments" do context "when persona allows for less fragments" do
before { stub_fragments(3) }
it "will only pick 3 fragments" do it "will only pick 3 fragments" do
custom_ai_persona = custom_ai_persona =
Fabricate( Fabricate(
@ -419,6 +420,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
) )
stub_fragments(3, persona: custom_ai_persona)
UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id]) UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id])
custom_persona = custom_persona =
@ -438,14 +441,13 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
context "when the reranker is available" do context "when the reranker is available" do
before do before do
SiteSetting.ai_hugging_face_tei_reranker_endpoint = "https://test.reranker.com" SiteSetting.ai_hugging_face_tei_reranker_endpoint = "https://test.reranker.com"
stub_fragments(15)
# hard coded internal implementation, reranker takes x5 number of chunks
stub_fragments(15, expected_limit: 50) # Mimic limit being more than 10 results
end end
it "uses the re-ranker to reorder the fragments and pick the top 10 candidates" do it "uses the re-ranker to reorder the fragments and pick the top 10 candidates" do
skip "This test is flaky needs to be investigated ordering does not come back as expected" # The re-ranker reverses the similarity search, but return less results
expected_reranked = (0..14).to_a.reverse.map { |idx| { index: idx } } # to act as a limit for test-purposes.
expected_reranked = (4..14).to_a.reverse.map { |idx| { index: idx } }
WebMock.stub_request(:post, "https://test.reranker.com/rerank").to_return( WebMock.stub_request(:post, "https://test.reranker.com/rerank").to_return(
status: 200, status: 200,

View File

@ -110,8 +110,9 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model) assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model)
SiteSetting.ai_embeddings_semantic_search_enabled = true SiteSetting.ai_embeddings_semantic_search_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
hyde_embedding = [0.049382] * vector_rep.dimensions
hyde_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service( EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model, SiteSetting.ai_embeddings_model,
query, query,
@ -126,10 +127,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
bot_user: bot_user, bot_user: bot_user,
) )
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn DiscourseAi::Embeddings::Schema.for(Topic).store(post1.topic, hyde_embedding, "digest")
.any_instance
.expects(:asymmetric_topics_similarity_search)
.returns([post1.topic_id])
results = results =
DiscourseAi::Completions::Llm.with_prepared_responses(["<ai>#{query}</ai>"]) do DiscourseAi::Completions::Llm.with_prepared_responses(["<ai>#{query}</ai>"]) do

View File

@ -14,10 +14,7 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
fab!(:category) fab!(:category)
fab!(:topic) { Fabricate(:topic, category: category) } fab!(:topic) { Fabricate(:topic, category: category) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:vector_rep) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
end
let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) } let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) }
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions } let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }

View File

@ -13,9 +13,9 @@ RSpec.describe Jobs::GenerateEmbeddings do
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
let(:vector_rep) do let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) }
end let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) }
it "works for topics" do it "works for topics" do
expected_embedding = [0.0038493] * vector_rep.dimensions expected_embedding = [0.0038493] * vector_rep.dimensions
@ -30,7 +30,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
job.execute(target_id: topic.id, target_type: "Topic") job.execute(target_id: topic.id, target_type: "Topic")
expect(vector_rep.topic_id_from_representation(expected_embedding)).to eq(topic.id) expect(topics_schema.find_by_embedding(expected_embedding).topic_id).to eq(topic.id)
end end
it "works for posts" do it "works for posts" do
@ -42,7 +42,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
job.execute(target_id: post.id, target_type: "Post") job.execute(target_id: post.id, target_type: "Post")
expect(vector_rep.post_id_from_representation(expected_embedding)).to eq(post.id) expect(posts_schema.find_by_embedding(expected_embedding).post_id).to eq(post.id)
end end
end end
end end

View File

@ -0,0 +1,104 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Embeddings::Schema do
subject(:posts_schema) { described_class.for(Post, vector: vector) }
let(:embeddings) { [0.0038490295] * vector.dimensions }
fab!(:post) { Fabricate(:post, post_number: 1) }
let(:digest) { OpenSSL::Digest.hexdigest("SHA1", "test") }
let(:vector) do
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new(
DiscourseAi::Embeddings::Strategies::Truncation.new,
)
end
before { posts_schema.store(post, embeddings, digest) }
describe "#find_by_target" do
it "gets you the post_id of the record that matches the post" do
embeddings_record = posts_schema.find_by_target(post)
expect(embeddings_record.digest).to eq(digest)
expect(JSON.parse(embeddings_record.embeddings)).to eq(embeddings)
end
end
describe "#find_by_embedding" do
it "gets you the record that matches the embedding" do
embeddings_record = posts_schema.find_by_embedding(embeddings)
expect(embeddings_record.digest).to eq(digest)
expect(embeddings_record.post_id).to eq(post.id)
end
end
describe "similarity searches" do
fab!(:post_2) { Fabricate(:post) }
let(:similar_embeddings) { [0.0038490294] * vector.dimensions }
describe "#symmetric_similarity_search" do
before { posts_schema.store(post_2, similar_embeddings, digest) }
it "returns target_id with similar embeddings" do
similar_records = posts_schema.symmetric_similarity_search(post)
expect(similar_records.map(&:post_id)).to contain_exactly(post.id, post_2.id)
end
it "let's you apply additional scopes to filter results further" do
similar_records =
posts_schema.symmetric_similarity_search(post) do |builder|
builder.where("post_id = ?", post_2.id)
end
expect(similar_records.map(&:post_id)).to contain_exactly(post_2.id)
end
it "let's you join on additional tables and combine with additional scopes" do
similar_records =
posts_schema.symmetric_similarity_search(post) do |builder|
builder.join("posts p on p.id = post_id")
builder.join("topics t on t.id = p.topic_id")
builder.where("t.id = ?", post_2.topic_id)
end
expect(similar_records.map(&:post_id)).to contain_exactly(post_2.id)
end
end
describe "#asymmetric_similarity_search" do
it "returns target_id with similar embeddings" do
similar_records =
posts_schema.asymmetric_similarity_search(similar_embeddings, limit: 1, offset: 0)
expect(similar_records.map(&:post_id)).to contain_exactly(post.id)
end
it "let's you apply additional scopes to filter results further" do
similar_records =
posts_schema.asymmetric_similarity_search(
similar_embeddings,
limit: 1,
offset: 0,
) { |builder| builder.where("post_id <> ?", post.id) }
expect(similar_records.map(&:post_id)).to be_empty
end
it "let's you join on additional tables and combine with additional scopes" do
similar_records =
posts_schema.asymmetric_similarity_search(
similar_embeddings,
limit: 1,
offset: 0,
) do |builder|
builder.join("posts p on p.id = post_id")
builder.join("topics t on t.id = p.topic_id")
builder.where("t.id <> ?", post.topic_id)
end
expect(similar_records.map(&:post_id)).to be_empty
end
end
end
end

View File

@ -25,9 +25,7 @@ describe DiscourseAi::Embeddings::SemanticRelated do
end end
let(:vector_rep) do let(:vector_rep) do
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
end end
it "properly generates embeddings if missing" do it "properly generates embeddings if missing" do

View File

@ -11,11 +11,12 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
describe "#search_for_topics" do describe "#search_for_topics" do
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" } let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:hyde_embedding) { [0.049382] * vector_rep.dimensions }
before do before do
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
hyde_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service( EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model, SiteSetting.ai_embeddings_model,
hypothetical_post, hypothetical_post,
@ -25,11 +26,8 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
after { described_class.clear_cache_for(query) } after { described_class.clear_cache_for(query) }
def stub_candidate_ids(candidate_ids) def insert_candidate(candidate)
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn DiscourseAi::Embeddings::Schema.for(Topic).store(candidate, hyde_embedding, "digest")
.any_instance
.expects(:asymmetric_topics_similarity_search)
.returns(candidate_ids)
end end
def trigger_search(query) def trigger_search(query)
@ -39,7 +37,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
end end
it "returns the first post of a topic included in the asymmetric search results" do it "returns the first post of a topic included in the asymmetric search results" do
stub_candidate_ids([post.topic_id]) insert_candidate(post.topic)
posts = trigger_search(query) posts = trigger_search(query)
@ -50,7 +48,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
context "when the topic is not visible" do context "when the topic is not visible" do
it "returns an empty list" do it "returns an empty list" do
post.topic.update!(visible: false) post.topic.update!(visible: false)
stub_candidate_ids([post.topic_id]) insert_candidate(post.topic)
posts = trigger_search(query) posts = trigger_search(query)
@ -61,7 +59,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
context "when the post is not public" do context "when the post is not public" do
it "returns an empty list" do it "returns an empty list" do
pm_post = Fabricate(:private_message_post) pm_post = Fabricate(:private_message_post)
stub_candidate_ids([pm_post.topic_id]) insert_candidate(pm_post.topic)
posts = trigger_search(query) posts = trigger_search(query)
@ -72,7 +70,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
context "when the post type is not visible" do context "when the post type is not visible" do
it "returns an empty list" do it "returns an empty list" do
post.update!(post_type: Post.types[:whisper]) post.update!(post_type: Post.types[:whisper])
stub_candidate_ids([post.topic_id]) insert_candidate(post.topic)
posts = trigger_search(query) posts = trigger_search(query)
@ -84,7 +82,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
it "returns an empty list" do it "returns an empty list" do
reply = Fabricate(:reply) reply = Fabricate(:reply)
reply.topic.first_post.trash! reply.topic.first_post.trash!
stub_candidate_ids([reply.topic_id]) insert_candidate(reply.topic)
posts = trigger_search(query) posts = trigger_search(query)
@ -95,7 +93,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
context "when the post is not a candidate" do context "when the post is not a candidate" do
it "doesn't include it in the results" do it "doesn't include it in the results" do
post_2 = Fabricate(:post) post_2 = Fabricate(:post)
stub_candidate_ids([post.topic_id]) insert_candidate(post.topic)
posts = trigger_search(query) posts = trigger_search(query)
@ -109,7 +107,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
before do before do
post.topic.update!(category: private_category) post.topic.update!(category: private_category)
stub_candidate_ids([post.topic_id]) insert_candidate(post.topic)
end end
it "returns an empty list" do it "returns an empty list" do

View File

@ -9,11 +9,17 @@ describe DiscourseAi::Embeddings::EntryPoint do
fab!(:target) { Fabricate(:topic) } fab!(:target) { Fabricate(:topic) }
def stub_semantic_search_with(results) # The Distance gap to target increases for each element of topics.
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn def seed_embeddings(topics)
.any_instance schema = DiscourseAi::Embeddings::Schema.for(Topic)
.expects(:symmetric_topics_similarity_search) base_value = 1
.returns(results.concat([target.id]))
schema.store(target, [base_value] * 1024, "disgest")
topics.each do |t|
base_value -= 0.01
schema.store(t, [base_value] * 1024, "digest")
end
end end
after { DiscourseAi::Embeddings::SemanticRelated.clear_cache_for(target) } after { DiscourseAi::Embeddings::SemanticRelated.clear_cache_for(target) }
@ -21,7 +27,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
context "when the semantic search returns an unlisted topic" do context "when the semantic search returns an unlisted topic" do
fab!(:unlisted_topic) { Fabricate(:topic, visible: false) } fab!(:unlisted_topic) { Fabricate(:topic, visible: false) }
before { stub_semantic_search_with([unlisted_topic.id]) } before { seed_embeddings([unlisted_topic]) }
it "filters it out" do it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
@ -31,7 +37,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
context "when the semantic search returns a private topic" do context "when the semantic search returns a private topic" do
fab!(:private_topic) { Fabricate(:private_message_topic) } fab!(:private_topic) { Fabricate(:private_message_topic) }
before { stub_semantic_search_with([private_topic.id]) } before { seed_embeddings([private_topic]) }
it "filters it out" do it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
@ -43,7 +49,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
fab!(:category) { Fabricate(:private_category, group: group) } fab!(:category) { Fabricate(:private_category, group: group) }
fab!(:secured_category_topic) { Fabricate(:topic, category: category) } fab!(:secured_category_topic) { Fabricate(:topic, category: category) }
before { stub_semantic_search_with([secured_category_topic.id]) } before { seed_embeddings([secured_category_topic]) }
it "filters it out" do it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
@ -63,7 +69,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
before do before do
SiteSetting.ai_embeddings_semantic_related_include_closed_topics = false SiteSetting.ai_embeddings_semantic_related_include_closed_topics = false
stub_semantic_search_with([closed_topic.id]) seed_embeddings([closed_topic])
end end
it "filters it out" do it "filters it out" do
@ -80,7 +86,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
category_id: category.id, category_id: category.id,
notification_level: CategoryUser.notification_levels[:muted], notification_level: CategoryUser.notification_levels[:muted],
) )
stub_semantic_search_with([topic.id]) seed_embeddings([topic])
expect(topic_query.list_semantic_related_topics(target).topics).not_to include(topic) expect(topic_query.list_semantic_related_topics(target).topics).not_to include(topic)
end end
end end
@ -91,11 +97,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
fab!(:normal_topic_3) { Fabricate(:topic) } fab!(:normal_topic_3) { Fabricate(:topic) }
fab!(:closed_topic) { Fabricate(:topic, closed: true) } fab!(:closed_topic) { Fabricate(:topic, closed: true) }
before do before { seed_embeddings([closed_topic, normal_topic_1, normal_topic_2, normal_topic_3]) }
stub_semantic_search_with(
[closed_topic.id, normal_topic_1.id, normal_topic_2.id, normal_topic_3.id],
)
end
it "filters it out" do it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to eq( expect(topic_query.list_semantic_related_topics(target).topics).to eq(
@ -117,7 +119,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
fab!(:included_topic) { Fabricate(:topic) } fab!(:included_topic) { Fabricate(:topic) }
fab!(:excluded_topic) { Fabricate(:topic) } fab!(:excluded_topic) { Fabricate(:topic) }
before { stub_semantic_search_with([included_topic.id, excluded_topic.id]) } before { seed_embeddings([included_topic, excluded_topic]) }
let(:modifier_block) { Proc.new { |query| query.where.not(id: excluded_topic.id) } } let(:modifier_block) { Proc.new { |query| query.where.not(id: excluded_topic.id) } }

View File

@ -4,6 +4,9 @@ RSpec.shared_examples "generates and store embedding using with vector represent
let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions } let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions }
let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions } let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions }
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) }
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) }
describe "#vector_from" do describe "#vector_from" do
it "creates a vector from a given string" do it "creates a vector from a given string" do
text = "This is a piece of text" text = "This is a piece of text"
@ -29,7 +32,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent
vector_rep.generate_representation_from(topic) vector_rep.generate_representation_from(topic)
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id) expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
end end
it "creates a vector from a post and stores it in the database" do it "creates a vector from a post and stores it in the database" do
@ -43,7 +46,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent
vector_rep.generate_representation_from(post) vector_rep.generate_representation_from(post)
expect(vector_rep.post_id_from_representation(expected_embedding_1)).to eq(post.id) expect(posts_schema.find_by_embedding(expected_embedding_1).post_id).to eq(post.id)
end end
end end
@ -76,8 +79,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id])) vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id) expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
end end
it "does nothing if passed record has no content" do it "does nothing if passed record has no content" do
@ -99,69 +101,15 @@ RSpec.shared_examples "generates and store embedding using with vector represent
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id])) vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
end end
# check vector exists # check vector exists
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id) expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id])) vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
last_update = last_update =
DB.query_single( DB.query_single(
"SELECT updated_at FROM #{vector_rep.topic_table_name} WHERE topic_id = #{topic.id} LIMIT 1", "SELECT updated_at FROM #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE topic_id = #{topic.id} LIMIT 1",
).first ).first
expect(last_update).to eq(original_vector_gen) expect(last_update).to eq(original_vector_gen)
end end
end end
describe "#asymmetric_topics_similarity_search" do
fab!(:topic) { Fabricate(:topic) }
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
it "finds IDs of similar topics with a given embedding" do
similar_vector = [0.0038494] * vector_rep.dimensions
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(topic)
expect(
vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0),
).to contain_exactly(topic.id)
end
it "can exclude categories" do
similar_vector = [0.0038494] * vector_rep.dimensions
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(topic)
expect(
vector_rep.asymmetric_topics_similarity_search(
similar_vector,
limit: 1,
offset: 0,
exclude_category_ids: [topic.category_id],
),
).to be_empty
child_category = Fabricate(:category, parent_category_id: topic.category_id)
topic.update!(category_id: child_category.id)
expect(
vector_rep.asymmetric_topics_similarity_search(
similar_vector,
limit: 1,
offset: 0,
exclude_category_ids: [topic.category_id],
),
).to be_empty
end
end
end end

View File

@ -74,10 +74,7 @@ RSpec.describe RagDocumentFragment do
end end
describe ".indexing_status" do describe ".indexing_status" do
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:vector_rep) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
end
fab!(:rag_document_fragment_1) do fab!(:rag_document_fragment_1) do
Fabricate(:rag_document_fragment, upload: upload_1, target: persona) Fabricate(:rag_document_fragment, upload: upload_1, target: persona)

View File

@ -19,9 +19,7 @@ describe DiscourseAi::Embeddings::EmbeddingsController do
fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) } fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) }
def index(topic) def index(topic)
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
stub_request(:post, "https://api.openai.com/v1/embeddings").to_return( stub_request(:post, "https://api.openai.com/v1/embeddings").to_return(
status: 200, status: 200,