REFACTOR: A Simpler way of interacting with embeddings tables. (#1023)

* REFACTOR: A Simpler way of interacting with embeddings' tables.

This change adds a new abstraction called `Schema`, which acts as a repository that supports the same DB features `VectorRepresentation::Base` has, with the exception that removes the need to have duplicated methods per embeddings table.

It is also a bit more flexible when performing a similarity search because you can pass it a block that gives you access to the builder, allowing you to add multiple joins/where conditions.
This commit is contained in:
Roman Rizzi 2024-12-13 10:15:21 -03:00 committed by GitHub
parent 97ec2c5ff4
commit eae527f99d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 485 additions and 599 deletions

View File

@ -18,9 +18,7 @@ module ::Jobs
target = target_type.constantize.find_by(id: target_id)
return if !target
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
tokenizer = vector_rep.tokenizer
chunk_tokens = target.rag_chunk_tokens

View File

@ -16,9 +16,7 @@ module Jobs
return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
return if post.raw.blank?
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
vector_rep.generate_representation_from(target)
end

View File

@ -8,9 +8,7 @@ module ::Jobs
def execute(args)
return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty?
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
# generate_representation_from checks compares the digest value to make sure
# the embedding is only generated once per fragment unless something changes.

View File

@ -20,10 +20,8 @@ module Jobs
rebaked = 0
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
table_name = vector_rep.topic_table_name
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
topics =
Topic
@ -41,7 +39,7 @@ module Jobs
relation = topics.where(<<~SQL).limit(limit - rebaked)
#{table_name}.model_version < #{vector_rep.version}
OR
#{table_name}.strategy_version < #{strategy.version}
#{table_name}.strategy_version < #{vector_rep.strategy_version}
SQL
rebaked += populate_topic_embeddings(vector_rep, relation)
@ -63,7 +61,7 @@ module Jobs
return unless SiteSetting.ai_embeddings_per_post_enabled
# Now for posts
table_name = vector_rep.post_table_name
table_name = DiscourseAi::Embeddings::Schema::POSTS_TABLE
posts_batch_size = 1000
posts =
@ -121,7 +119,8 @@ module Jobs
def populate_topic_embeddings(vector_rep, topics, force: false)
done = 0
topics = topics.where("#{vector_rep.topic_table_name}.topic_id IS NULL") if !force
topics =
topics.where("#{DiscourseAi::Embeddings::Schema::TOPICS_TABLE}.topic_id IS NULL") if !force
ids = topics.pluck("topics.id")
batch_size = 1000

View File

@ -39,11 +39,7 @@ class RagDocumentFragment < ActiveRecord::Base
end
def indexing_status(persona, uploads)
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
embeddings_table = vector_rep.rag_fragments_table_name
embeddings_table = DiscourseAi::Embeddings::Schema.for(self).table
results =
DB.query(

View File

@ -14,7 +14,7 @@ class MigrateEmbeddingsFromDedicatedDatabase < ActiveRecord::Migration[7.0]
].map { |k| k.new(truncation) }
vector_reps.each do |vector_rep|
new_table_name = vector_rep.topic_table_name
new_table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
old_table_name = "topic_embeddings_#{vector_rep.name.underscore}"
begin

View File

@ -147,9 +147,7 @@ class MoveEmbeddingsToSingleTablePerType < ActiveRecord::Migration[7.0]
SQL
begin
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
rescue StandardError => e
Rails.logger.error("Failed to index embeddings: #{e}")
end

View File

@ -314,30 +314,34 @@ module DiscourseAi
return nil if !consolidated_question
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings
interactions_vector = vector_rep.vector_from(consolidated_question)
rag_conversation_chunks = self.class.rag_conversation_chunks
candidate_fragment_ids =
vector_rep.asymmetric_rag_fragment_similarity_search(
interactions_vector,
target_type: "AiPersona",
target_id: id,
limit:
(
search_limit =
if reranker.reranker_configured?
rag_conversation_chunks * 5
else
rag_conversation_chunks
end
),
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep)
candidate_fragment_ids =
schema
.asymmetric_similarity_search(
interactions_vector,
limit: search_limit,
offset: 0,
)
) { |builder| builder.join(<<~SQL, target_id: id, target_type: "AiPersona") }
rag_document_fragments ON
rag_document_fragments.id = rag_document_fragment_id AND
rag_document_fragments.target_id = :target_id AND
rag_document_fragments.target_type = :target_type
SQL
.map(&:rag_document_fragment_id)
fragments =
RagDocumentFragment.where(upload_id: upload_refs, id: candidate_fragment_ids).pluck(

View File

@ -141,18 +141,20 @@ module DiscourseAi
return [] if upload_refs.empty?
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
query_vector = vector_rep.vector_from(query)
fragment_ids =
vector_rep.asymmetric_rag_fragment_similarity_search(
query_vector,
target_type: "AiTool",
target_id: tool.id,
limit: limit,
offset: 0,
)
DiscourseAi::Embeddings::Schema
.for(RagDocumentFragment, vector: vector_rep)
.asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder|
builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool")
rag_document_fragments ON
rag_document_fragments.id = rag_document_fragment_id AND
rag_document_fragments.target_id = :target_id AND
rag_document_fragments.target_type = :target_type
SQL
end
.map(&:rag_document_fragment_id)
fragments =
RagDocumentFragment.where(id: fragment_ids, upload_id: upload_refs).pluck(

View File

@ -92,9 +92,8 @@ module DiscourseAi
private
def nearest_neighbors(limit: 100)
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
raw_vector = vector_rep.vector_from(@text)
@ -107,13 +106,15 @@ module DiscourseAi
).pluck(:category_id)
end
vector_rep.asymmetric_topics_similarity_search(
raw_vector,
limit: limit,
offset: 0,
return_distance: true,
exclude_category_ids: muted_category_ids,
)
schema
.asymmetric_similarity_search(raw_vector, limit: limit, offset: 0) do |builder|
builder.join("topics t on t.id = topic_id")
builder.where(<<~SQL, exclude_category_ids: muted_category_ids.map(&:to_i))
t.category_id NOT IN (:exclude_category_ids) AND
t.category_id NOT IN (SELECT categories.id FROM categories WHERE categories.parent_category_id IN (:exclude_category_ids))
SQL
end
.map { |r| [r.topic_id, r.distance] }
end
end
end

194
lib/embeddings/schema.rb Normal file
View File

@ -0,0 +1,194 @@
# frozen_string_literal: true
# We don't have AR objects for our embeddings, so this class
# acts as an intermediary between us and the DB.
# It lets us retrieve embeddings either symmetrically and asymmetrically,
# and also store them.
module DiscourseAi
module Embeddings
class Schema
TOPICS_TABLE = "ai_topic_embeddings"
POSTS_TABLE = "ai_post_embeddings"
RAG_DOCS_TABLE = "ai_document_fragment_embeddings"
def self.for(
target_klass,
vector: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
)
case target_klass&.name
when "Topic"
new(TOPICS_TABLE, "topic_id", vector)
when "Post"
new(POSTS_TABLE, "post_id", vector)
when "RagDocumentFragment"
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector)
else
raise ArgumentError, "Invalid target type for embeddings"
end
end
def initialize(table, target_column, vector)
@table = table
@target_column = target_column
@vector = vector
end
attr_reader :table, :target_column, :vector
def find_by_embedding(embedding)
DB.query(<<~SQL, query_embedding: embedding, vid: vector.id, vsid: vector.strategy_id).first
SELECT *
FROM #{table}
WHERE
model_id = :vid AND strategy_id = :vsid
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT 1
SQL
end
def find_by_target(target)
DB.query(<<~SQL, target_id: target.id, vid: vector.id, vsid: vector.strategy_id).first
SELECT *
FROM #{table}
WHERE
model_id = :vid AND
strategy_id = :vsid AND
#{target_column} = :target_id
LIMIT 1
SQL
end
def asymmetric_similarity_search(embedding, limit:, offset:)
builder = DB.build(<<~SQL)
WITH candidates AS (
SELECT
#{target_column},
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{table}
/*join*/
/*where*/
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :candidates_limit
)
SELECT
#{target_column},
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
builder.where(
"model_id = :model_id AND strategy_id = :strategy_id",
model_id: vector.id,
strategy_id: vector.strategy_id,
)
yield(builder) if block_given?
if table == RAG_DOCS_TABLE
# A too low limit exacerbates the the recall loss of binary quantization
candidates_limit = [limit * 2, 100].max
else
candidates_limit = limit * 2
end
builder.query(
query_embedding: embedding,
candidates_limit: candidates_limit,
limit: limit,
offset: offset,
)
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def symmetric_similarity_search(record)
builder = DB.build(<<~SQL)
WITH le_target AS (
SELECT
embeddings
FROM
#{table}
WHERE
model_id = :vid AND
strategy_id = :vsid AND
#{target_column} = :target_id
LIMIT 1
)
SELECT #{target_column} FROM (
SELECT
#{target_column}, embeddings
FROM
#{table}
/*join*/
/*where*/
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> (
SELECT
binary_quantize(embeddings)::bit(#{dimensions})
FROM
le_target
LIMIT 1
)
LIMIT 200
) AS widenet
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} (
SELECT
embeddings::halfvec(#{dimensions})
FROM
le_target
LIMIT 1
)
LIMIT 100;
SQL
builder.where("model_id = :vid AND strategy_id = :vsid")
yield(builder) if block_given?
builder.query(vid: vector.id, vsid: vector.strategy_id, target_id: record.id)
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def store(record, embedding, digest)
DB.exec(
<<~SQL,
INSERT INTO #{table} (#{target_column}, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:target_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (model_id, strategy_id, #{target_column})
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
SQL
target_id: record.id,
model_id: vector.id,
model_version: vector.version,
strategy_id: vector.strategy_id,
strategy_version: vector.strategy_version,
digest: digest,
embeddings: embedding,
now: Time.zone.now,
)
end
private
delegate :dimensions, :pg_function, to: :vector
end
end
end

View File

@ -13,16 +13,16 @@ module DiscourseAi
def related_topic_ids_for(topic)
return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
cache_for = results_ttl(topic)
Discourse
.cache
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
vector_rep
.symmetric_topics_similarity_search(topic)
DiscourseAi::Embeddings::Schema
.for(Topic, vector: vector_rep)
.symmetric_similarity_search(topic)
.map(&:topic_id)
.tap do |candidate_ids|
# Happens when the topic doesn't have any embeddings
# I'd rather not use Exceptions to control the flow, so this should be refactored soon

View File

@ -31,10 +31,7 @@ module DiscourseAi
end
def vector_rep
@vector_rep ||=
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(
DiscourseAi::Embeddings::Strategies::Truncation.new,
)
@vector_rep ||= DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
end
def hyde_embedding(search_term)
@ -87,12 +84,14 @@ module DiscourseAi
over_selection_limit = limit * OVER_SELECTION_FACTOR
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
candidate_topic_ids =
vector_rep.asymmetric_topics_similarity_search(
schema.asymmetric_similarity_search(
search_embedding,
limit: over_selection_limit,
offset: offset,
)
).map(&:topic_id)
semantic_results =
::Post
@ -115,9 +114,7 @@ module DiscourseAi
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
@ -136,11 +133,14 @@ module DiscourseAi
end
candidate_post_ids =
vector_rep.asymmetric_posts_similarity_search(
DiscourseAi::Embeddings::Schema
.for(Post, vector: vector_rep)
.asymmetric_similarity_search(
search_term_embedding,
limit: max_semantic_results_per_page,
offset: 0,
)
.map(&:post_id)
semantic_results =
::Post

View File

@ -20,8 +20,9 @@ module DiscourseAi
].find { _1.name == model_name }
end
def current_representation(strategy)
find_representation(SiteSetting.ai_embeddings_model).new(strategy)
def current_representation
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
find_representation(SiteSetting.ai_embeddings_model).new(truncation)
end
def correctly_configured?
@ -59,6 +60,8 @@ module DiscourseAi
idletime: 30,
)
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector: self)
embedding_gen = inference_client
promised_embeddings =
relation
@ -67,7 +70,7 @@ module DiscourseAi
next if prepared_text.blank?
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
next if find_digest_of(record) == new_digest
next if schema.find_by_target(record)&.digest == new_digest
Concurrent::Promises
.fulfilled_future(
@ -83,7 +86,7 @@ module DiscourseAi
Concurrent::Promises
.zip(*promised_embeddings)
.value!
.each { |e| save_to_db(e[:target], e[:embedding], e[:digest]) }
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) }
pool.shutdown
pool.wait_for_termination
@ -93,265 +96,14 @@ module DiscourseAi
text = prepare_text(target)
return if text.blank?
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector: self)
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
return if find_digest_of(target) == new_digest
return if schema.find_by_target(target)&.digest == new_digest
vector = vector_from(text)
save_to_db(target, vector, new_digest) if persist
end
def topic_id_from_representation(raw_vector)
DB.query_single(<<~SQL, query_embedding: raw_vector).first
SELECT
topic_id
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id}
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT 1
SQL
end
def post_id_from_representation(raw_vector)
DB.query_single(<<~SQL, query_embedding: raw_vector).first
SELECT
post_id
FROM
#{post_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id}
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT 1
SQL
end
def asymmetric_topics_similarity_search(
raw_vector,
limit:,
offset:,
return_distance: false,
exclude_category_ids: nil
)
builder = DB.build(<<~SQL)
WITH candidates AS (
SELECT
topic_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{topic_table_name}
/*join*/
/*where*/
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :limit * 2
)
SELECT
topic_id,
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
builder.where(
"model_id = :model_id AND strategy_id = :strategy_id",
model_id: id,
strategy_id: @strategy.id,
)
if exclude_category_ids.present?
builder.join("topics t on t.id = topic_id")
builder.where(<<~SQL, exclude_category_ids: exclude_category_ids.map(&:to_i))
t.category_id NOT IN (:exclude_category_ids) AND
t.category_id NOT IN (SELECT categories.id FROM categories WHERE categories.parent_category_id IN (:exclude_category_ids))
SQL
end
results = builder.query(query_embedding: raw_vector, limit: limit, offset: offset)
if return_distance
results.map { |r| [r.topic_id, r.distance] }
else
results.map(&:topic_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def asymmetric_posts_similarity_search(raw_vector, limit:, offset:, return_distance: false)
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
WITH candidates AS (
SELECT
post_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{post_table_name}
WHERE
model_id = #{id} AND strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :limit * 2
)
SELECT
post_id,
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
if return_distance
results.map { |r| [r.post_id, r.distance] }
else
results.map(&:post_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def asymmetric_rag_fragment_similarity_search(
raw_vector,
target_id:,
target_type:,
limit:,
offset:,
return_distance: false
)
# A too low limit exacerbates the the recall loss of binary quantization
binary_search_limit = [limit * 2, 100].max
results =
DB.query(
<<~SQL,
WITH candidates AS (
SELECT
rag_document_fragment_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{rag_fragments_table_name}
INNER JOIN
rag_document_fragments ON
rag_document_fragments.id = rag_document_fragment_id AND
rag_document_fragments.target_id = :target_id AND
rag_document_fragments.target_type = :target_type
WHERE
model_id = #{id} AND strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :binary_search_limit
)
SELECT
rag_document_fragment_id,
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
query_embedding: raw_vector,
target_id: target_id,
target_type: target_type,
limit: limit,
offset: offset,
binary_search_limit: binary_search_limit,
)
if return_distance
results.map { |r| [r.rag_document_fragment_id, r.distance] }
else
results.map(&:rag_document_fragment_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def symmetric_topics_similarity_search(topic)
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
WITH le_target AS (
SELECT
embeddings
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id} AND
topic_id = :topic_id
LIMIT 1
)
SELECT topic_id FROM (
SELECT
topic_id, embeddings
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> (
SELECT
binary_quantize(embeddings)::bit(#{dimensions})
FROM
le_target
LIMIT 1
)
LIMIT 200
) AS widenet
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} (
SELECT
embeddings::halfvec(#{dimensions})
FROM
le_target
LIMIT 1
)
LIMIT 100;
SQL
rescue PG::Error => e
Rails.logger.error(
"Error #{e} querying embeddings for topic #{topic.id} and model #{name}",
)
raise MissingEmbeddingError
end
def topic_table_name
"ai_topic_embeddings"
end
def post_table_name
"ai_post_embeddings"
end
def rag_fragments_table_name
"ai_document_fragment_embeddings"
end
def table_name(target)
case target
when Topic
topic_table_name
when Post
post_table_name
when RagDocumentFragment
rag_fragments_table_name
else
raise ArgumentError, "Invalid target type"
end
schema.store(target, vector, new_digest) if persist
end
def index_name(table_name)
@ -390,106 +142,16 @@ module DiscourseAi
raise NotImplementedError
end
def strategy_id
@strategy.id
end
def strategy_version
@strategy.version
end
protected
def find_digest_of(target)
target_column =
case target
when Topic
"topic_id"
when Post
"post_id"
when RagDocumentFragment
"rag_document_fragment_id"
else
raise ArgumentError, "Invalid target type"
end
DB.query_single(<<~SQL, target_id: target.id).first
SELECT
digest
FROM
#{table_name(target)}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id} AND
#{target_column} = :target_id
LIMIT 1
SQL
end
def save_to_db(target, vector, digest)
if target.is_a?(Topic)
DB.exec(
<<~SQL,
INSERT INTO #{topic_table_name} (topic_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:topic_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (strategy_id, model_id, topic_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
SQL
topic_id: target.id,
model_id: id,
model_version: version,
strategy_id: @strategy.id,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
now: Time.zone.now,
)
elsif target.is_a?(Post)
DB.exec(
<<~SQL,
INSERT INTO #{post_table_name} (post_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:post_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (model_id, strategy_id, post_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
SQL
post_id: target.id,
model_id: id,
model_version: version,
strategy_id: @strategy.id,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
now: Time.zone.now,
)
elsif target.is_a?(RagDocumentFragment)
DB.exec(
<<~SQL,
INSERT INTO #{rag_fragments_table_name} (rag_document_fragment_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:fragment_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (model_id, strategy_id, rag_document_fragment_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
SQL
fragment_id: target.id,
model_id: id,
model_version: version,
strategy_id: @strategy.id,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
now: Time.zone.now,
)
else
raise ArgumentError, "Invalid target type"
end
end
def inference_client
raise NotImplementedError
end

View File

@ -4,17 +4,16 @@ desc "Backfill embeddings for all topics and posts"
task "ai:embeddings:backfill", %i[model concurrency] => [:environment] do |_, args|
public_categories = Category.where(read_restricted: false).pluck(:id)
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
if args[:model].present?
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(args[:model]).new(
strategy,
)
else
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
end
table_name = vector_rep.topic_table_name
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
topics =
Topic

View File

@ -6,10 +6,7 @@ RSpec.describe Jobs::DigestRagUpload do
let(:document_file) { StringIO.new("some text" * 200) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
let(:vector_rep) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
end
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }

View File

@ -2,10 +2,7 @@
RSpec.describe Jobs::GenerateRagEmbeddings do
describe "#execute" do
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
let(:vector_rep) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
end
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
@ -30,7 +27,9 @@ RSpec.describe Jobs::GenerateRagEmbeddings do
subject.execute(fragment_ids: [rag_document_fragment_1.id, rag_document_fragment_2.id])
embeddings_count =
DB.query_single("SELECT COUNT(*) from #{vector_rep.rag_fragments_table_name}").first
DB.query_single(
"SELECT COUNT(*) from #{DiscourseAi::Embeddings::Schema::RAG_DOCS_TABLE}",
).first
expect(embeddings_count).to eq(expected_embeddings)
end

View File

@ -19,10 +19,7 @@ RSpec.describe Jobs::EmbeddingsBackfill do
topic
end
let(:vector_rep) do
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
end
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
before do
SiteSetting.ai_embeddings_enabled = true
@ -41,7 +38,8 @@ RSpec.describe Jobs::EmbeddingsBackfill do
Jobs::EmbeddingsBackfill.new.execute({})
topic_ids = DB.query_single("SELECT topic_id from #{vector_rep.topic_table_name}")
topic_ids =
DB.query_single("SELECT topic_id from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE}")
expect(topic_ids).to eq([first_topic.id])
@ -49,7 +47,8 @@ RSpec.describe Jobs::EmbeddingsBackfill do
SiteSetting.ai_embeddings_backfill_batch_size = 100
Jobs::EmbeddingsBackfill.new.execute({})
topic_ids = DB.query_single("SELECT topic_id from #{vector_rep.topic_table_name}")
topic_ids =
DB.query_single("SELECT topic_id from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE}")
expect(topic_ids).to contain_exactly(first_topic.id, second_topic.id, third_topic.id)
@ -62,7 +61,7 @@ RSpec.describe Jobs::EmbeddingsBackfill do
index_date =
DB.query_single(
"SELECT updated_at from #{vector_rep.topic_table_name} WHERE topic_id = ?",
"SELECT updated_at from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE topic_id = ?",
third_topic.id,
).first

View File

@ -326,9 +326,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
fab!(:llm_model) { Fabricate(:fake_model) }
it "will run the question consolidator" do
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
context_embedding = vector_rep.dimensions.times.map { rand(-1.0...1.0) }
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
@ -375,41 +373,44 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
end
context "when a persona has RAG uploads" do
def stub_fragments(limit, expected_limit: nil)
candidate_ids = []
let(:vector_rep) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
end
let(:embedding_value) { 0.04381 }
let(:prompt_cc_embeddings) { [embedding_value] * vector_rep.dimensions }
limit.times do |i|
candidate_ids << Fabricate(
def stub_fragments(fragment_count, persona: ai_persona)
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep)
fragment_count.times do |i|
fragment =
Fabricate(
:rag_document_fragment,
fragment: "fragment-n#{i}",
target_id: ai_persona.id,
target_id: persona.id,
target_type: "AiPersona",
upload: upload,
).id
end
)
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn
.any_instance
.expects(:asymmetric_rag_fragment_similarity_search)
.with { |args, kwargs| kwargs[:limit] == (expected_limit || limit) }
.returns(candidate_ids)
# Similarity is determined left-to-right.
embeddings = [embedding_value + "0.000#{i}".to_f] * vector_rep.dimensions
schema.store(fragment, embeddings, "test")
end
end
before do
stored_ai_persona = AiPersona.find(ai_persona.id)
UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id])
context_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
with_cc.dig(:conversation_context, 0, :content),
context_embedding,
prompt_cc_embeddings,
)
end
context "when persona allows for less fragments" do
before { stub_fragments(3) }
it "will only pick 3 fragments" do
custom_ai_persona =
Fabricate(
@ -419,6 +420,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
)
stub_fragments(3, persona: custom_ai_persona)
UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id])
custom_persona =
@ -438,14 +441,13 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
context "when the reranker is available" do
before do
SiteSetting.ai_hugging_face_tei_reranker_endpoint = "https://test.reranker.com"
# hard coded internal implementation, reranker takes x5 number of chunks
stub_fragments(15, expected_limit: 50) # Mimic limit being more than 10 results
stub_fragments(15)
end
it "uses the re-ranker to reorder the fragments and pick the top 10 candidates" do
skip "This test is flaky needs to be investigated ordering does not come back as expected"
expected_reranked = (0..14).to_a.reverse.map { |idx| { index: idx } }
# The re-ranker reverses the similarity search, but return less results
# to act as a limit for test-purposes.
expected_reranked = (4..14).to_a.reverse.map { |idx| { index: idx } }
WebMock.stub_request(:post, "https://test.reranker.com/rerank").to_return(
status: 200,

View File

@ -110,8 +110,9 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model)
SiteSetting.ai_embeddings_semantic_search_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
hyde_embedding = [0.049382] * vector_rep.dimensions
hyde_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
query,
@ -126,10 +127,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
bot_user: bot_user,
)
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn
.any_instance
.expects(:asymmetric_topics_similarity_search)
.returns([post1.topic_id])
DiscourseAi::Embeddings::Schema.for(Topic).store(post1.topic, hyde_embedding, "digest")
results =
DiscourseAi::Completions::Llm.with_prepared_responses(["<ai>#{query}</ai>"]) do

View File

@ -14,10 +14,7 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
fab!(:category)
fab!(:topic) { Fabricate(:topic, category: category) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
let(:vector_rep) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
end
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) }
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }

View File

@ -13,9 +13,9 @@ RSpec.describe Jobs::GenerateEmbeddings do
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
let(:vector_rep) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
end
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) }
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) }
it "works for topics" do
expected_embedding = [0.0038493] * vector_rep.dimensions
@ -30,7 +30,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
job.execute(target_id: topic.id, target_type: "Topic")
expect(vector_rep.topic_id_from_representation(expected_embedding)).to eq(topic.id)
expect(topics_schema.find_by_embedding(expected_embedding).topic_id).to eq(topic.id)
end
it "works for posts" do
@ -42,7 +42,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
job.execute(target_id: post.id, target_type: "Post")
expect(vector_rep.post_id_from_representation(expected_embedding)).to eq(post.id)
expect(posts_schema.find_by_embedding(expected_embedding).post_id).to eq(post.id)
end
end
end

View File

@ -0,0 +1,104 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Embeddings::Schema do
subject(:posts_schema) { described_class.for(Post, vector: vector) }
let(:embeddings) { [0.0038490295] * vector.dimensions }
fab!(:post) { Fabricate(:post, post_number: 1) }
let(:digest) { OpenSSL::Digest.hexdigest("SHA1", "test") }
let(:vector) do
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new(
DiscourseAi::Embeddings::Strategies::Truncation.new,
)
end
before { posts_schema.store(post, embeddings, digest) }
describe "#find_by_target" do
it "gets you the post_id of the record that matches the post" do
embeddings_record = posts_schema.find_by_target(post)
expect(embeddings_record.digest).to eq(digest)
expect(JSON.parse(embeddings_record.embeddings)).to eq(embeddings)
end
end
describe "#find_by_embedding" do
it "gets you the record that matches the embedding" do
embeddings_record = posts_schema.find_by_embedding(embeddings)
expect(embeddings_record.digest).to eq(digest)
expect(embeddings_record.post_id).to eq(post.id)
end
end
describe "similarity searches" do
fab!(:post_2) { Fabricate(:post) }
let(:similar_embeddings) { [0.0038490294] * vector.dimensions }
describe "#symmetric_similarity_search" do
before { posts_schema.store(post_2, similar_embeddings, digest) }
it "returns target_id with similar embeddings" do
similar_records = posts_schema.symmetric_similarity_search(post)
expect(similar_records.map(&:post_id)).to contain_exactly(post.id, post_2.id)
end
it "let's you apply additional scopes to filter results further" do
similar_records =
posts_schema.symmetric_similarity_search(post) do |builder|
builder.where("post_id = ?", post_2.id)
end
expect(similar_records.map(&:post_id)).to contain_exactly(post_2.id)
end
it "let's you join on additional tables and combine with additional scopes" do
similar_records =
posts_schema.symmetric_similarity_search(post) do |builder|
builder.join("posts p on p.id = post_id")
builder.join("topics t on t.id = p.topic_id")
builder.where("t.id = ?", post_2.topic_id)
end
expect(similar_records.map(&:post_id)).to contain_exactly(post_2.id)
end
end
describe "#asymmetric_similarity_search" do
it "returns target_id with similar embeddings" do
similar_records =
posts_schema.asymmetric_similarity_search(similar_embeddings, limit: 1, offset: 0)
expect(similar_records.map(&:post_id)).to contain_exactly(post.id)
end
it "let's you apply additional scopes to filter results further" do
similar_records =
posts_schema.asymmetric_similarity_search(
similar_embeddings,
limit: 1,
offset: 0,
) { |builder| builder.where("post_id <> ?", post.id) }
expect(similar_records.map(&:post_id)).to be_empty
end
it "let's you join on additional tables and combine with additional scopes" do
similar_records =
posts_schema.asymmetric_similarity_search(
similar_embeddings,
limit: 1,
offset: 0,
) do |builder|
builder.join("posts p on p.id = post_id")
builder.join("topics t on t.id = p.topic_id")
builder.where("t.id <> ?", post.topic_id)
end
expect(similar_records.map(&:post_id)).to be_empty
end
end
end
end

View File

@ -25,9 +25,7 @@ describe DiscourseAi::Embeddings::SemanticRelated do
end
let(:vector_rep) do
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
end
it "properly generates embeddings if missing" do

View File

@ -11,11 +11,12 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
describe "#search_for_topics" do
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
let(:hyde_embedding) { [0.049382] * vector_rep.dimensions }
before do
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
hyde_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
hypothetical_post,
@ -25,11 +26,8 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
after { described_class.clear_cache_for(query) }
def stub_candidate_ids(candidate_ids)
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn
.any_instance
.expects(:asymmetric_topics_similarity_search)
.returns(candidate_ids)
def insert_candidate(candidate)
DiscourseAi::Embeddings::Schema.for(Topic).store(candidate, hyde_embedding, "digest")
end
def trigger_search(query)
@ -39,7 +37,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
end
it "returns the first post of a topic included in the asymmetric search results" do
stub_candidate_ids([post.topic_id])
insert_candidate(post.topic)
posts = trigger_search(query)
@ -50,7 +48,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
context "when the topic is not visible" do
it "returns an empty list" do
post.topic.update!(visible: false)
stub_candidate_ids([post.topic_id])
insert_candidate(post.topic)
posts = trigger_search(query)
@ -61,7 +59,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
context "when the post is not public" do
it "returns an empty list" do
pm_post = Fabricate(:private_message_post)
stub_candidate_ids([pm_post.topic_id])
insert_candidate(pm_post.topic)
posts = trigger_search(query)
@ -72,7 +70,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
context "when the post type is not visible" do
it "returns an empty list" do
post.update!(post_type: Post.types[:whisper])
stub_candidate_ids([post.topic_id])
insert_candidate(post.topic)
posts = trigger_search(query)
@ -84,7 +82,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
it "returns an empty list" do
reply = Fabricate(:reply)
reply.topic.first_post.trash!
stub_candidate_ids([reply.topic_id])
insert_candidate(reply.topic)
posts = trigger_search(query)
@ -95,7 +93,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
context "when the post is not a candidate" do
it "doesn't include it in the results" do
post_2 = Fabricate(:post)
stub_candidate_ids([post.topic_id])
insert_candidate(post.topic)
posts = trigger_search(query)
@ -109,7 +107,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
before do
post.topic.update!(category: private_category)
stub_candidate_ids([post.topic_id])
insert_candidate(post.topic)
end
it "returns an empty list" do

View File

@ -9,11 +9,17 @@ describe DiscourseAi::Embeddings::EntryPoint do
fab!(:target) { Fabricate(:topic) }
def stub_semantic_search_with(results)
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn
.any_instance
.expects(:symmetric_topics_similarity_search)
.returns(results.concat([target.id]))
# The Distance gap to target increases for each element of topics.
def seed_embeddings(topics)
schema = DiscourseAi::Embeddings::Schema.for(Topic)
base_value = 1
schema.store(target, [base_value] * 1024, "disgest")
topics.each do |t|
base_value -= 0.01
schema.store(t, [base_value] * 1024, "digest")
end
end
after { DiscourseAi::Embeddings::SemanticRelated.clear_cache_for(target) }
@ -21,7 +27,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
context "when the semantic search returns an unlisted topic" do
fab!(:unlisted_topic) { Fabricate(:topic, visible: false) }
before { stub_semantic_search_with([unlisted_topic.id]) }
before { seed_embeddings([unlisted_topic]) }
it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
@ -31,7 +37,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
context "when the semantic search returns a private topic" do
fab!(:private_topic) { Fabricate(:private_message_topic) }
before { stub_semantic_search_with([private_topic.id]) }
before { seed_embeddings([private_topic]) }
it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
@ -43,7 +49,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
fab!(:category) { Fabricate(:private_category, group: group) }
fab!(:secured_category_topic) { Fabricate(:topic, category: category) }
before { stub_semantic_search_with([secured_category_topic.id]) }
before { seed_embeddings([secured_category_topic]) }
it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
@ -63,7 +69,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
before do
SiteSetting.ai_embeddings_semantic_related_include_closed_topics = false
stub_semantic_search_with([closed_topic.id])
seed_embeddings([closed_topic])
end
it "filters it out" do
@ -80,7 +86,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
category_id: category.id,
notification_level: CategoryUser.notification_levels[:muted],
)
stub_semantic_search_with([topic.id])
seed_embeddings([topic])
expect(topic_query.list_semantic_related_topics(target).topics).not_to include(topic)
end
end
@ -91,11 +97,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
fab!(:normal_topic_3) { Fabricate(:topic) }
fab!(:closed_topic) { Fabricate(:topic, closed: true) }
before do
stub_semantic_search_with(
[closed_topic.id, normal_topic_1.id, normal_topic_2.id, normal_topic_3.id],
)
end
before { seed_embeddings([closed_topic, normal_topic_1, normal_topic_2, normal_topic_3]) }
it "filters it out" do
expect(topic_query.list_semantic_related_topics(target).topics).to eq(
@ -117,7 +119,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
fab!(:included_topic) { Fabricate(:topic) }
fab!(:excluded_topic) { Fabricate(:topic) }
before { stub_semantic_search_with([included_topic.id, excluded_topic.id]) }
before { seed_embeddings([included_topic, excluded_topic]) }
let(:modifier_block) { Proc.new { |query| query.where.not(id: excluded_topic.id) } }

View File

@ -4,6 +4,9 @@ RSpec.shared_examples "generates and store embedding using with vector represent
let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions }
let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions }
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) }
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) }
describe "#vector_from" do
it "creates a vector from a given string" do
text = "This is a piece of text"
@ -29,7 +32,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent
vector_rep.generate_representation_from(topic)
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
end
it "creates a vector from a post and stores it in the database" do
@ -43,7 +46,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent
vector_rep.generate_representation_from(post)
expect(vector_rep.post_id_from_representation(expected_embedding_1)).to eq(post.id)
expect(posts_schema.find_by_embedding(expected_embedding_1).post_id).to eq(post.id)
end
end
@ -76,8 +79,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
end
it "does nothing if passed record has no content" do
@ -99,69 +101,15 @@ RSpec.shared_examples "generates and store embedding using with vector represent
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
end
# check vector exists
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
last_update =
DB.query_single(
"SELECT updated_at FROM #{vector_rep.topic_table_name} WHERE topic_id = #{topic.id} LIMIT 1",
"SELECT updated_at FROM #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE topic_id = #{topic.id} LIMIT 1",
).first
expect(last_update).to eq(original_vector_gen)
end
end
describe "#asymmetric_topics_similarity_search" do
fab!(:topic) { Fabricate(:topic) }
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
it "finds IDs of similar topics with a given embedding" do
similar_vector = [0.0038494] * vector_rep.dimensions
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(topic)
expect(
vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0),
).to contain_exactly(topic.id)
end
it "can exclude categories" do
similar_vector = [0.0038494] * vector_rep.dimensions
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(topic)
expect(
vector_rep.asymmetric_topics_similarity_search(
similar_vector,
limit: 1,
offset: 0,
exclude_category_ids: [topic.category_id],
),
).to be_empty
child_category = Fabricate(:category, parent_category_id: topic.category_id)
topic.update!(category_id: child_category.id)
expect(
vector_rep.asymmetric_topics_similarity_search(
similar_vector,
limit: 1,
offset: 0,
exclude_category_ids: [topic.category_id],
),
).to be_empty
end
end
end

View File

@ -74,10 +74,7 @@ RSpec.describe RagDocumentFragment do
end
describe ".indexing_status" do
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
let(:vector_rep) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
end
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
fab!(:rag_document_fragment_1) do
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)

View File

@ -19,9 +19,7 @@ describe DiscourseAi::Embeddings::EmbeddingsController do
fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) }
def index(topic)
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
stub_request(:post, "https://api.openai.com/v1/embeddings").to_return(
status: 200,