REFACTOR: A Simpler way of interacting with embeddings tables. (#1023)
* REFACTOR: A Simpler way of interacting with embeddings' tables. This change adds a new abstraction called `Schema`, which acts as a repository that supports the same DB features `VectorRepresentation::Base` has, with the exception that removes the need to have duplicated methods per embeddings table. It is also a bit more flexible when performing a similarity search because you can pass it a block that gives you access to the builder, allowing you to add multiple joins/where conditions.
This commit is contained in:
parent
97ec2c5ff4
commit
eae527f99d
|
@ -18,9 +18,7 @@ module ::Jobs
|
|||
target = target_type.constantize.find_by(id: target_id)
|
||||
return if !target
|
||||
|
||||
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
|
||||
tokenizer = vector_rep.tokenizer
|
||||
chunk_tokens = target.rag_chunk_tokens
|
||||
|
|
|
@ -16,9 +16,7 @@ module Jobs
|
|||
return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
|
||||
return if post.raw.blank?
|
||||
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
|
||||
vector_rep.generate_representation_from(target)
|
||||
end
|
||||
|
|
|
@ -8,9 +8,7 @@ module ::Jobs
|
|||
def execute(args)
|
||||
return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty?
|
||||
|
||||
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
|
||||
# generate_representation_from checks compares the digest value to make sure
|
||||
# the embedding is only generated once per fragment unless something changes.
|
||||
|
|
|
@ -20,10 +20,8 @@ module Jobs
|
|||
|
||||
rebaked = 0
|
||||
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
table_name = vector_rep.topic_table_name
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
|
||||
|
||||
topics =
|
||||
Topic
|
||||
|
@ -41,7 +39,7 @@ module Jobs
|
|||
relation = topics.where(<<~SQL).limit(limit - rebaked)
|
||||
#{table_name}.model_version < #{vector_rep.version}
|
||||
OR
|
||||
#{table_name}.strategy_version < #{strategy.version}
|
||||
#{table_name}.strategy_version < #{vector_rep.strategy_version}
|
||||
SQL
|
||||
|
||||
rebaked += populate_topic_embeddings(vector_rep, relation)
|
||||
|
@ -63,7 +61,7 @@ module Jobs
|
|||
return unless SiteSetting.ai_embeddings_per_post_enabled
|
||||
|
||||
# Now for posts
|
||||
table_name = vector_rep.post_table_name
|
||||
table_name = DiscourseAi::Embeddings::Schema::POSTS_TABLE
|
||||
posts_batch_size = 1000
|
||||
|
||||
posts =
|
||||
|
@ -121,7 +119,8 @@ module Jobs
|
|||
def populate_topic_embeddings(vector_rep, topics, force: false)
|
||||
done = 0
|
||||
|
||||
topics = topics.where("#{vector_rep.topic_table_name}.topic_id IS NULL") if !force
|
||||
topics =
|
||||
topics.where("#{DiscourseAi::Embeddings::Schema::TOPICS_TABLE}.topic_id IS NULL") if !force
|
||||
|
||||
ids = topics.pluck("topics.id")
|
||||
batch_size = 1000
|
||||
|
|
|
@ -39,11 +39,7 @@ class RagDocumentFragment < ActiveRecord::Base
|
|||
end
|
||||
|
||||
def indexing_status(persona, uploads)
|
||||
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
|
||||
|
||||
embeddings_table = vector_rep.rag_fragments_table_name
|
||||
embeddings_table = DiscourseAi::Embeddings::Schema.for(self).table
|
||||
|
||||
results =
|
||||
DB.query(
|
||||
|
|
|
@ -14,7 +14,7 @@ class MigrateEmbeddingsFromDedicatedDatabase < ActiveRecord::Migration[7.0]
|
|||
].map { |k| k.new(truncation) }
|
||||
|
||||
vector_reps.each do |vector_rep|
|
||||
new_table_name = vector_rep.topic_table_name
|
||||
new_table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
|
||||
old_table_name = "topic_embeddings_#{vector_rep.name.underscore}"
|
||||
|
||||
begin
|
||||
|
|
|
@ -147,9 +147,7 @@ class MoveEmbeddingsToSingleTablePerType < ActiveRecord::Migration[7.0]
|
|||
SQL
|
||||
|
||||
begin
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
rescue StandardError => e
|
||||
Rails.logger.error("Failed to index embeddings: #{e}")
|
||||
end
|
||||
|
|
|
@ -314,30 +314,34 @@ module DiscourseAi
|
|||
|
||||
return nil if !consolidated_question
|
||||
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings
|
||||
|
||||
interactions_vector = vector_rep.vector_from(consolidated_question)
|
||||
|
||||
rag_conversation_chunks = self.class.rag_conversation_chunks
|
||||
search_limit =
|
||||
if reranker.reranker_configured?
|
||||
rag_conversation_chunks * 5
|
||||
else
|
||||
rag_conversation_chunks
|
||||
end
|
||||
|
||||
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep)
|
||||
|
||||
candidate_fragment_ids =
|
||||
vector_rep.asymmetric_rag_fragment_similarity_search(
|
||||
interactions_vector,
|
||||
target_type: "AiPersona",
|
||||
target_id: id,
|
||||
limit:
|
||||
(
|
||||
if reranker.reranker_configured?
|
||||
rag_conversation_chunks * 5
|
||||
else
|
||||
rag_conversation_chunks
|
||||
end
|
||||
),
|
||||
offset: 0,
|
||||
)
|
||||
schema
|
||||
.asymmetric_similarity_search(
|
||||
interactions_vector,
|
||||
limit: search_limit,
|
||||
offset: 0,
|
||||
) { |builder| builder.join(<<~SQL, target_id: id, target_type: "AiPersona") }
|
||||
rag_document_fragments ON
|
||||
rag_document_fragments.id = rag_document_fragment_id AND
|
||||
rag_document_fragments.target_id = :target_id AND
|
||||
rag_document_fragments.target_type = :target_type
|
||||
SQL
|
||||
.map(&:rag_document_fragment_id)
|
||||
|
||||
fragments =
|
||||
RagDocumentFragment.where(upload_id: upload_refs, id: candidate_fragment_ids).pluck(
|
||||
|
|
|
@ -141,18 +141,20 @@ module DiscourseAi
|
|||
|
||||
return [] if upload_refs.empty?
|
||||
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
query_vector = vector_rep.vector_from(query)
|
||||
fragment_ids =
|
||||
vector_rep.asymmetric_rag_fragment_similarity_search(
|
||||
query_vector,
|
||||
target_type: "AiTool",
|
||||
target_id: tool.id,
|
||||
limit: limit,
|
||||
offset: 0,
|
||||
)
|
||||
DiscourseAi::Embeddings::Schema
|
||||
.for(RagDocumentFragment, vector: vector_rep)
|
||||
.asymmetric_similarity_search(query_vector, limit: limit, offset: 0) do |builder|
|
||||
builder.join(<<~SQL, target_id: tool.id, target_type: "AiTool")
|
||||
rag_document_fragments ON
|
||||
rag_document_fragments.id = rag_document_fragment_id AND
|
||||
rag_document_fragments.target_id = :target_id AND
|
||||
rag_document_fragments.target_type = :target_type
|
||||
SQL
|
||||
end
|
||||
.map(&:rag_document_fragment_id)
|
||||
|
||||
fragments =
|
||||
RagDocumentFragment.where(id: fragment_ids, upload_id: upload_refs).pluck(
|
||||
|
|
|
@ -92,9 +92,8 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def nearest_neighbors(limit: 100)
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
|
||||
|
||||
raw_vector = vector_rep.vector_from(@text)
|
||||
|
||||
|
@ -107,13 +106,15 @@ module DiscourseAi
|
|||
).pluck(:category_id)
|
||||
end
|
||||
|
||||
vector_rep.asymmetric_topics_similarity_search(
|
||||
raw_vector,
|
||||
limit: limit,
|
||||
offset: 0,
|
||||
return_distance: true,
|
||||
exclude_category_ids: muted_category_ids,
|
||||
)
|
||||
schema
|
||||
.asymmetric_similarity_search(raw_vector, limit: limit, offset: 0) do |builder|
|
||||
builder.join("topics t on t.id = topic_id")
|
||||
builder.where(<<~SQL, exclude_category_ids: muted_category_ids.map(&:to_i))
|
||||
t.category_id NOT IN (:exclude_category_ids) AND
|
||||
t.category_id NOT IN (SELECT categories.id FROM categories WHERE categories.parent_category_id IN (:exclude_category_ids))
|
||||
SQL
|
||||
end
|
||||
.map { |r| [r.topic_id, r.distance] }
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,194 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
# We don't have AR objects for our embeddings, so this class
|
||||
# acts as an intermediary between us and the DB.
|
||||
# It lets us retrieve embeddings either symmetrically and asymmetrically,
|
||||
# and also store them.
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
class Schema
|
||||
TOPICS_TABLE = "ai_topic_embeddings"
|
||||
POSTS_TABLE = "ai_post_embeddings"
|
||||
RAG_DOCS_TABLE = "ai_document_fragment_embeddings"
|
||||
|
||||
def self.for(
|
||||
target_klass,
|
||||
vector: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
)
|
||||
case target_klass&.name
|
||||
when "Topic"
|
||||
new(TOPICS_TABLE, "topic_id", vector)
|
||||
when "Post"
|
||||
new(POSTS_TABLE, "post_id", vector)
|
||||
when "RagDocumentFragment"
|
||||
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector)
|
||||
else
|
||||
raise ArgumentError, "Invalid target type for embeddings"
|
||||
end
|
||||
end
|
||||
|
||||
def initialize(table, target_column, vector)
|
||||
@table = table
|
||||
@target_column = target_column
|
||||
@vector = vector
|
||||
end
|
||||
|
||||
attr_reader :table, :target_column, :vector
|
||||
|
||||
def find_by_embedding(embedding)
|
||||
DB.query(<<~SQL, query_embedding: embedding, vid: vector.id, vsid: vector.strategy_id).first
|
||||
SELECT *
|
||||
FROM #{table}
|
||||
WHERE
|
||||
model_id = :vid AND strategy_id = :vsid
|
||||
ORDER BY
|
||||
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
||||
LIMIT 1
|
||||
SQL
|
||||
end
|
||||
|
||||
def find_by_target(target)
|
||||
DB.query(<<~SQL, target_id: target.id, vid: vector.id, vsid: vector.strategy_id).first
|
||||
SELECT *
|
||||
FROM #{table}
|
||||
WHERE
|
||||
model_id = :vid AND
|
||||
strategy_id = :vsid AND
|
||||
#{target_column} = :target_id
|
||||
LIMIT 1
|
||||
SQL
|
||||
end
|
||||
|
||||
def asymmetric_similarity_search(embedding, limit:, offset:)
|
||||
builder = DB.build(<<~SQL)
|
||||
WITH candidates AS (
|
||||
SELECT
|
||||
#{target_column},
|
||||
embeddings::halfvec(#{dimensions}) AS embeddings
|
||||
FROM
|
||||
#{table}
|
||||
/*join*/
|
||||
/*where*/
|
||||
ORDER BY
|
||||
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
|
||||
LIMIT :candidates_limit
|
||||
)
|
||||
SELECT
|
||||
#{target_column},
|
||||
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
|
||||
FROM
|
||||
candidates
|
||||
ORDER BY
|
||||
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
||||
LIMIT :limit
|
||||
OFFSET :offset
|
||||
SQL
|
||||
|
||||
builder.where(
|
||||
"model_id = :model_id AND strategy_id = :strategy_id",
|
||||
model_id: vector.id,
|
||||
strategy_id: vector.strategy_id,
|
||||
)
|
||||
|
||||
yield(builder) if block_given?
|
||||
|
||||
if table == RAG_DOCS_TABLE
|
||||
# A too low limit exacerbates the the recall loss of binary quantization
|
||||
candidates_limit = [limit * 2, 100].max
|
||||
else
|
||||
candidates_limit = limit * 2
|
||||
end
|
||||
|
||||
builder.query(
|
||||
query_embedding: embedding,
|
||||
candidates_limit: candidates_limit,
|
||||
limit: limit,
|
||||
offset: offset,
|
||||
)
|
||||
rescue PG::Error => e
|
||||
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
||||
raise MissingEmbeddingError
|
||||
end
|
||||
|
||||
def symmetric_similarity_search(record)
|
||||
builder = DB.build(<<~SQL)
|
||||
WITH le_target AS (
|
||||
SELECT
|
||||
embeddings
|
||||
FROM
|
||||
#{table}
|
||||
WHERE
|
||||
model_id = :vid AND
|
||||
strategy_id = :vsid AND
|
||||
#{target_column} = :target_id
|
||||
LIMIT 1
|
||||
)
|
||||
SELECT #{target_column} FROM (
|
||||
SELECT
|
||||
#{target_column}, embeddings
|
||||
FROM
|
||||
#{table}
|
||||
/*join*/
|
||||
/*where*/
|
||||
ORDER BY
|
||||
binary_quantize(embeddings)::bit(#{dimensions}) <~> (
|
||||
SELECT
|
||||
binary_quantize(embeddings)::bit(#{dimensions})
|
||||
FROM
|
||||
le_target
|
||||
LIMIT 1
|
||||
)
|
||||
LIMIT 200
|
||||
) AS widenet
|
||||
ORDER BY
|
||||
embeddings::halfvec(#{dimensions}) #{pg_function} (
|
||||
SELECT
|
||||
embeddings::halfvec(#{dimensions})
|
||||
FROM
|
||||
le_target
|
||||
LIMIT 1
|
||||
)
|
||||
LIMIT 100;
|
||||
SQL
|
||||
|
||||
builder.where("model_id = :vid AND strategy_id = :vsid")
|
||||
|
||||
yield(builder) if block_given?
|
||||
|
||||
builder.query(vid: vector.id, vsid: vector.strategy_id, target_id: record.id)
|
||||
rescue PG::Error => e
|
||||
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
||||
raise MissingEmbeddingError
|
||||
end
|
||||
|
||||
def store(record, embedding, digest)
|
||||
DB.exec(
|
||||
<<~SQL,
|
||||
INSERT INTO #{table} (#{target_column}, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
|
||||
VALUES (:target_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
|
||||
ON CONFLICT (model_id, strategy_id, #{target_column})
|
||||
DO UPDATE SET
|
||||
model_version = :model_version,
|
||||
strategy_version = :strategy_version,
|
||||
digest = :digest,
|
||||
embeddings = '[:embeddings]',
|
||||
updated_at = :now
|
||||
SQL
|
||||
target_id: record.id,
|
||||
model_id: vector.id,
|
||||
model_version: vector.version,
|
||||
strategy_id: vector.strategy_id,
|
||||
strategy_version: vector.strategy_version,
|
||||
digest: digest,
|
||||
embeddings: embedding,
|
||||
now: Time.zone.now,
|
||||
)
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
delegate :dimensions, :pg_function, to: :vector
|
||||
end
|
||||
end
|
||||
end
|
|
@ -13,16 +13,16 @@ module DiscourseAi
|
|||
def related_topic_ids_for(topic)
|
||||
return [] if SiteSetting.ai_embeddings_semantic_related_topics < 1
|
||||
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
cache_for = results_ttl(topic)
|
||||
|
||||
Discourse
|
||||
.cache
|
||||
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
|
||||
vector_rep
|
||||
.symmetric_topics_similarity_search(topic)
|
||||
DiscourseAi::Embeddings::Schema
|
||||
.for(Topic, vector: vector_rep)
|
||||
.symmetric_similarity_search(topic)
|
||||
.map(&:topic_id)
|
||||
.tap do |candidate_ids|
|
||||
# Happens when the topic doesn't have any embeddings
|
||||
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
|
||||
|
|
|
@ -31,10 +31,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def vector_rep
|
||||
@vector_rep ||=
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(
|
||||
DiscourseAi::Embeddings::Strategies::Truncation.new,
|
||||
)
|
||||
@vector_rep ||= DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
end
|
||||
|
||||
def hyde_embedding(search_term)
|
||||
|
@ -87,12 +84,14 @@ module DiscourseAi
|
|||
|
||||
over_selection_limit = limit * OVER_SELECTION_FACTOR
|
||||
|
||||
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
|
||||
|
||||
candidate_topic_ids =
|
||||
vector_rep.asymmetric_topics_similarity_search(
|
||||
schema.asymmetric_similarity_search(
|
||||
search_embedding,
|
||||
limit: over_selection_limit,
|
||||
offset: offset,
|
||||
)
|
||||
).map(&:topic_id)
|
||||
|
||||
semantic_results =
|
||||
::Post
|
||||
|
@ -115,9 +114,7 @@ module DiscourseAi
|
|||
|
||||
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length
|
||||
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
|
||||
digest = OpenSSL::Digest::SHA1.hexdigest(search_term)
|
||||
|
||||
|
@ -136,11 +133,14 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
candidate_post_ids =
|
||||
vector_rep.asymmetric_posts_similarity_search(
|
||||
search_term_embedding,
|
||||
limit: max_semantic_results_per_page,
|
||||
offset: 0,
|
||||
)
|
||||
DiscourseAi::Embeddings::Schema
|
||||
.for(Post, vector: vector_rep)
|
||||
.asymmetric_similarity_search(
|
||||
search_term_embedding,
|
||||
limit: max_semantic_results_per_page,
|
||||
offset: 0,
|
||||
)
|
||||
.map(&:post_id)
|
||||
|
||||
semantic_results =
|
||||
::Post
|
||||
|
|
|
@ -20,8 +20,9 @@ module DiscourseAi
|
|||
].find { _1.name == model_name }
|
||||
end
|
||||
|
||||
def current_representation(strategy)
|
||||
find_representation(SiteSetting.ai_embeddings_model).new(strategy)
|
||||
def current_representation
|
||||
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
find_representation(SiteSetting.ai_embeddings_model).new(truncation)
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
|
@ -59,6 +60,8 @@ module DiscourseAi
|
|||
idletime: 30,
|
||||
)
|
||||
|
||||
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector: self)
|
||||
|
||||
embedding_gen = inference_client
|
||||
promised_embeddings =
|
||||
relation
|
||||
|
@ -67,7 +70,7 @@ module DiscourseAi
|
|||
next if prepared_text.blank?
|
||||
|
||||
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
|
||||
next if find_digest_of(record) == new_digest
|
||||
next if schema.find_by_target(record)&.digest == new_digest
|
||||
|
||||
Concurrent::Promises
|
||||
.fulfilled_future(
|
||||
|
@ -83,7 +86,7 @@ module DiscourseAi
|
|||
Concurrent::Promises
|
||||
.zip(*promised_embeddings)
|
||||
.value!
|
||||
.each { |e| save_to_db(e[:target], e[:embedding], e[:digest]) }
|
||||
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) }
|
||||
|
||||
pool.shutdown
|
||||
pool.wait_for_termination
|
||||
|
@ -93,265 +96,14 @@ module DiscourseAi
|
|||
text = prepare_text(target)
|
||||
return if text.blank?
|
||||
|
||||
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector: self)
|
||||
|
||||
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
||||
return if find_digest_of(target) == new_digest
|
||||
return if schema.find_by_target(target)&.digest == new_digest
|
||||
|
||||
vector = vector_from(text)
|
||||
|
||||
save_to_db(target, vector, new_digest) if persist
|
||||
end
|
||||
|
||||
def topic_id_from_representation(raw_vector)
|
||||
DB.query_single(<<~SQL, query_embedding: raw_vector).first
|
||||
SELECT
|
||||
topic_id
|
||||
FROM
|
||||
#{topic_table_name}
|
||||
WHERE
|
||||
model_id = #{id} AND
|
||||
strategy_id = #{@strategy.id}
|
||||
ORDER BY
|
||||
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
||||
LIMIT 1
|
||||
SQL
|
||||
end
|
||||
|
||||
def post_id_from_representation(raw_vector)
|
||||
DB.query_single(<<~SQL, query_embedding: raw_vector).first
|
||||
SELECT
|
||||
post_id
|
||||
FROM
|
||||
#{post_table_name}
|
||||
WHERE
|
||||
model_id = #{id} AND
|
||||
strategy_id = #{@strategy.id}
|
||||
ORDER BY
|
||||
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
||||
LIMIT 1
|
||||
SQL
|
||||
end
|
||||
|
||||
def asymmetric_topics_similarity_search(
|
||||
raw_vector,
|
||||
limit:,
|
||||
offset:,
|
||||
return_distance: false,
|
||||
exclude_category_ids: nil
|
||||
)
|
||||
builder = DB.build(<<~SQL)
|
||||
WITH candidates AS (
|
||||
SELECT
|
||||
topic_id,
|
||||
embeddings::halfvec(#{dimensions}) AS embeddings
|
||||
FROM
|
||||
#{topic_table_name}
|
||||
/*join*/
|
||||
/*where*/
|
||||
ORDER BY
|
||||
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
|
||||
LIMIT :limit * 2
|
||||
)
|
||||
SELECT
|
||||
topic_id,
|
||||
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
|
||||
FROM
|
||||
candidates
|
||||
ORDER BY
|
||||
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
||||
LIMIT :limit
|
||||
OFFSET :offset
|
||||
SQL
|
||||
|
||||
builder.where(
|
||||
"model_id = :model_id AND strategy_id = :strategy_id",
|
||||
model_id: id,
|
||||
strategy_id: @strategy.id,
|
||||
)
|
||||
|
||||
if exclude_category_ids.present?
|
||||
builder.join("topics t on t.id = topic_id")
|
||||
builder.where(<<~SQL, exclude_category_ids: exclude_category_ids.map(&:to_i))
|
||||
t.category_id NOT IN (:exclude_category_ids) AND
|
||||
t.category_id NOT IN (SELECT categories.id FROM categories WHERE categories.parent_category_id IN (:exclude_category_ids))
|
||||
SQL
|
||||
end
|
||||
|
||||
results = builder.query(query_embedding: raw_vector, limit: limit, offset: offset)
|
||||
|
||||
if return_distance
|
||||
results.map { |r| [r.topic_id, r.distance] }
|
||||
else
|
||||
results.map(&:topic_id)
|
||||
end
|
||||
rescue PG::Error => e
|
||||
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
||||
raise MissingEmbeddingError
|
||||
end
|
||||
|
||||
def asymmetric_posts_similarity_search(raw_vector, limit:, offset:, return_distance: false)
|
||||
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
|
||||
WITH candidates AS (
|
||||
SELECT
|
||||
post_id,
|
||||
embeddings::halfvec(#{dimensions}) AS embeddings
|
||||
FROM
|
||||
#{post_table_name}
|
||||
WHERE
|
||||
model_id = #{id} AND strategy_id = #{@strategy.id}
|
||||
ORDER BY
|
||||
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
|
||||
LIMIT :limit * 2
|
||||
)
|
||||
SELECT
|
||||
post_id,
|
||||
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
|
||||
FROM
|
||||
candidates
|
||||
ORDER BY
|
||||
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
||||
LIMIT :limit
|
||||
OFFSET :offset
|
||||
SQL
|
||||
|
||||
if return_distance
|
||||
results.map { |r| [r.post_id, r.distance] }
|
||||
else
|
||||
results.map(&:post_id)
|
||||
end
|
||||
rescue PG::Error => e
|
||||
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
||||
raise MissingEmbeddingError
|
||||
end
|
||||
|
||||
def asymmetric_rag_fragment_similarity_search(
|
||||
raw_vector,
|
||||
target_id:,
|
||||
target_type:,
|
||||
limit:,
|
||||
offset:,
|
||||
return_distance: false
|
||||
)
|
||||
# A too low limit exacerbates the the recall loss of binary quantization
|
||||
binary_search_limit = [limit * 2, 100].max
|
||||
results =
|
||||
DB.query(
|
||||
<<~SQL,
|
||||
WITH candidates AS (
|
||||
SELECT
|
||||
rag_document_fragment_id,
|
||||
embeddings::halfvec(#{dimensions}) AS embeddings
|
||||
FROM
|
||||
#{rag_fragments_table_name}
|
||||
INNER JOIN
|
||||
rag_document_fragments ON
|
||||
rag_document_fragments.id = rag_document_fragment_id AND
|
||||
rag_document_fragments.target_id = :target_id AND
|
||||
rag_document_fragments.target_type = :target_type
|
||||
WHERE
|
||||
model_id = #{id} AND strategy_id = #{@strategy.id}
|
||||
ORDER BY
|
||||
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
|
||||
LIMIT :binary_search_limit
|
||||
)
|
||||
SELECT
|
||||
rag_document_fragment_id,
|
||||
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
|
||||
FROM
|
||||
candidates
|
||||
ORDER BY
|
||||
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
||||
LIMIT :limit
|
||||
OFFSET :offset
|
||||
SQL
|
||||
query_embedding: raw_vector,
|
||||
target_id: target_id,
|
||||
target_type: target_type,
|
||||
limit: limit,
|
||||
offset: offset,
|
||||
binary_search_limit: binary_search_limit,
|
||||
)
|
||||
|
||||
if return_distance
|
||||
results.map { |r| [r.rag_document_fragment_id, r.distance] }
|
||||
else
|
||||
results.map(&:rag_document_fragment_id)
|
||||
end
|
||||
rescue PG::Error => e
|
||||
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
||||
raise MissingEmbeddingError
|
||||
end
|
||||
|
||||
def symmetric_topics_similarity_search(topic)
|
||||
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
||||
WITH le_target AS (
|
||||
SELECT
|
||||
embeddings
|
||||
FROM
|
||||
#{topic_table_name}
|
||||
WHERE
|
||||
model_id = #{id} AND
|
||||
strategy_id = #{@strategy.id} AND
|
||||
topic_id = :topic_id
|
||||
LIMIT 1
|
||||
)
|
||||
SELECT topic_id FROM (
|
||||
SELECT
|
||||
topic_id, embeddings
|
||||
FROM
|
||||
#{topic_table_name}
|
||||
WHERE
|
||||
model_id = #{id} AND
|
||||
strategy_id = #{@strategy.id}
|
||||
ORDER BY
|
||||
binary_quantize(embeddings)::bit(#{dimensions}) <~> (
|
||||
SELECT
|
||||
binary_quantize(embeddings)::bit(#{dimensions})
|
||||
FROM
|
||||
le_target
|
||||
LIMIT 1
|
||||
)
|
||||
LIMIT 200
|
||||
) AS widenet
|
||||
ORDER BY
|
||||
embeddings::halfvec(#{dimensions}) #{pg_function} (
|
||||
SELECT
|
||||
embeddings::halfvec(#{dimensions})
|
||||
FROM
|
||||
le_target
|
||||
LIMIT 1
|
||||
)
|
||||
LIMIT 100;
|
||||
SQL
|
||||
rescue PG::Error => e
|
||||
Rails.logger.error(
|
||||
"Error #{e} querying embeddings for topic #{topic.id} and model #{name}",
|
||||
)
|
||||
raise MissingEmbeddingError
|
||||
end
|
||||
|
||||
def topic_table_name
|
||||
"ai_topic_embeddings"
|
||||
end
|
||||
|
||||
def post_table_name
|
||||
"ai_post_embeddings"
|
||||
end
|
||||
|
||||
def rag_fragments_table_name
|
||||
"ai_document_fragment_embeddings"
|
||||
end
|
||||
|
||||
def table_name(target)
|
||||
case target
|
||||
when Topic
|
||||
topic_table_name
|
||||
when Post
|
||||
post_table_name
|
||||
when RagDocumentFragment
|
||||
rag_fragments_table_name
|
||||
else
|
||||
raise ArgumentError, "Invalid target type"
|
||||
end
|
||||
schema.store(target, vector, new_digest) if persist
|
||||
end
|
||||
|
||||
def index_name(table_name)
|
||||
|
@ -390,106 +142,16 @@ module DiscourseAi
|
|||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def strategy_id
|
||||
@strategy.id
|
||||
end
|
||||
|
||||
def strategy_version
|
||||
@strategy.version
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def find_digest_of(target)
|
||||
target_column =
|
||||
case target
|
||||
when Topic
|
||||
"topic_id"
|
||||
when Post
|
||||
"post_id"
|
||||
when RagDocumentFragment
|
||||
"rag_document_fragment_id"
|
||||
else
|
||||
raise ArgumentError, "Invalid target type"
|
||||
end
|
||||
|
||||
DB.query_single(<<~SQL, target_id: target.id).first
|
||||
SELECT
|
||||
digest
|
||||
FROM
|
||||
#{table_name(target)}
|
||||
WHERE
|
||||
model_id = #{id} AND
|
||||
strategy_id = #{@strategy.id} AND
|
||||
#{target_column} = :target_id
|
||||
LIMIT 1
|
||||
SQL
|
||||
end
|
||||
|
||||
def save_to_db(target, vector, digest)
|
||||
if target.is_a?(Topic)
|
||||
DB.exec(
|
||||
<<~SQL,
|
||||
INSERT INTO #{topic_table_name} (topic_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
|
||||
VALUES (:topic_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
|
||||
ON CONFLICT (strategy_id, model_id, topic_id)
|
||||
DO UPDATE SET
|
||||
model_version = :model_version,
|
||||
strategy_version = :strategy_version,
|
||||
digest = :digest,
|
||||
embeddings = '[:embeddings]',
|
||||
updated_at = :now
|
||||
SQL
|
||||
topic_id: target.id,
|
||||
model_id: id,
|
||||
model_version: version,
|
||||
strategy_id: @strategy.id,
|
||||
strategy_version: @strategy.version,
|
||||
digest: digest,
|
||||
embeddings: vector,
|
||||
now: Time.zone.now,
|
||||
)
|
||||
elsif target.is_a?(Post)
|
||||
DB.exec(
|
||||
<<~SQL,
|
||||
INSERT INTO #{post_table_name} (post_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
|
||||
VALUES (:post_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
|
||||
ON CONFLICT (model_id, strategy_id, post_id)
|
||||
DO UPDATE SET
|
||||
model_version = :model_version,
|
||||
strategy_version = :strategy_version,
|
||||
digest = :digest,
|
||||
embeddings = '[:embeddings]',
|
||||
updated_at = :now
|
||||
SQL
|
||||
post_id: target.id,
|
||||
model_id: id,
|
||||
model_version: version,
|
||||
strategy_id: @strategy.id,
|
||||
strategy_version: @strategy.version,
|
||||
digest: digest,
|
||||
embeddings: vector,
|
||||
now: Time.zone.now,
|
||||
)
|
||||
elsif target.is_a?(RagDocumentFragment)
|
||||
DB.exec(
|
||||
<<~SQL,
|
||||
INSERT INTO #{rag_fragments_table_name} (rag_document_fragment_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
|
||||
VALUES (:fragment_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
|
||||
ON CONFLICT (model_id, strategy_id, rag_document_fragment_id)
|
||||
DO UPDATE SET
|
||||
model_version = :model_version,
|
||||
strategy_version = :strategy_version,
|
||||
digest = :digest,
|
||||
embeddings = '[:embeddings]',
|
||||
updated_at = :now
|
||||
SQL
|
||||
fragment_id: target.id,
|
||||
model_id: id,
|
||||
model_version: version,
|
||||
strategy_id: @strategy.id,
|
||||
strategy_version: @strategy.version,
|
||||
digest: digest,
|
||||
embeddings: vector,
|
||||
now: Time.zone.now,
|
||||
)
|
||||
else
|
||||
raise ArgumentError, "Invalid target type"
|
||||
end
|
||||
end
|
||||
|
||||
def inference_client
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
|
|
@ -4,17 +4,16 @@ desc "Backfill embeddings for all topics and posts"
|
|||
task "ai:embeddings:backfill", %i[model concurrency] => [:environment] do |_, args|
|
||||
public_categories = Category.where(read_restricted: false).pluck(:id)
|
||||
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
if args[:model].present?
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(args[:model]).new(
|
||||
strategy,
|
||||
)
|
||||
else
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
end
|
||||
table_name = vector_rep.topic_table_name
|
||||
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
|
||||
|
||||
topics =
|
||||
Topic
|
||||
|
|
|
@ -6,10 +6,7 @@ RSpec.describe Jobs::DigestRagUpload do
|
|||
|
||||
let(:document_file) { StringIO.new("some text" * 200) }
|
||||
|
||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||
let(:vector_rep) do
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
|
||||
end
|
||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
|
||||
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
|
||||
|
||||
|
|
|
@ -2,10 +2,7 @@
|
|||
|
||||
RSpec.describe Jobs::GenerateRagEmbeddings do
|
||||
describe "#execute" do
|
||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||
let(:vector_rep) do
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
|
||||
end
|
||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
|
||||
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
|
||||
|
||||
|
@ -30,7 +27,9 @@ RSpec.describe Jobs::GenerateRagEmbeddings do
|
|||
subject.execute(fragment_ids: [rag_document_fragment_1.id, rag_document_fragment_2.id])
|
||||
|
||||
embeddings_count =
|
||||
DB.query_single("SELECT COUNT(*) from #{vector_rep.rag_fragments_table_name}").first
|
||||
DB.query_single(
|
||||
"SELECT COUNT(*) from #{DiscourseAi::Embeddings::Schema::RAG_DOCS_TABLE}",
|
||||
).first
|
||||
|
||||
expect(embeddings_count).to eq(expected_embeddings)
|
||||
end
|
||||
|
|
|
@ -19,10 +19,7 @@ RSpec.describe Jobs::EmbeddingsBackfill do
|
|||
topic
|
||||
end
|
||||
|
||||
let(:vector_rep) do
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
end
|
||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
|
@ -41,7 +38,8 @@ RSpec.describe Jobs::EmbeddingsBackfill do
|
|||
|
||||
Jobs::EmbeddingsBackfill.new.execute({})
|
||||
|
||||
topic_ids = DB.query_single("SELECT topic_id from #{vector_rep.topic_table_name}")
|
||||
topic_ids =
|
||||
DB.query_single("SELECT topic_id from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE}")
|
||||
|
||||
expect(topic_ids).to eq([first_topic.id])
|
||||
|
||||
|
@ -49,7 +47,8 @@ RSpec.describe Jobs::EmbeddingsBackfill do
|
|||
SiteSetting.ai_embeddings_backfill_batch_size = 100
|
||||
Jobs::EmbeddingsBackfill.new.execute({})
|
||||
|
||||
topic_ids = DB.query_single("SELECT topic_id from #{vector_rep.topic_table_name}")
|
||||
topic_ids =
|
||||
DB.query_single("SELECT topic_id from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE}")
|
||||
|
||||
expect(topic_ids).to contain_exactly(first_topic.id, second_topic.id, third_topic.id)
|
||||
|
||||
|
@ -62,7 +61,7 @@ RSpec.describe Jobs::EmbeddingsBackfill do
|
|||
|
||||
index_date =
|
||||
DB.query_single(
|
||||
"SELECT updated_at from #{vector_rep.topic_table_name} WHERE topic_id = ?",
|
||||
"SELECT updated_at from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE topic_id = ?",
|
||||
third_topic.id,
|
||||
).first
|
||||
|
||||
|
|
|
@ -326,9 +326,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||
fab!(:llm_model) { Fabricate(:fake_model) }
|
||||
|
||||
it "will run the question consolidator" do
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
context_embedding = vector_rep.dimensions.times.map { rand(-1.0...1.0) }
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
SiteSetting.ai_embeddings_model,
|
||||
|
@ -375,41 +373,44 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||
end
|
||||
|
||||
context "when a persona has RAG uploads" do
|
||||
def stub_fragments(limit, expected_limit: nil)
|
||||
candidate_ids = []
|
||||
let(:vector_rep) do
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
end
|
||||
let(:embedding_value) { 0.04381 }
|
||||
let(:prompt_cc_embeddings) { [embedding_value] * vector_rep.dimensions }
|
||||
|
||||
limit.times do |i|
|
||||
candidate_ids << Fabricate(
|
||||
:rag_document_fragment,
|
||||
fragment: "fragment-n#{i}",
|
||||
target_id: ai_persona.id,
|
||||
target_type: "AiPersona",
|
||||
upload: upload,
|
||||
).id
|
||||
def stub_fragments(fragment_count, persona: ai_persona)
|
||||
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep)
|
||||
|
||||
fragment_count.times do |i|
|
||||
fragment =
|
||||
Fabricate(
|
||||
:rag_document_fragment,
|
||||
fragment: "fragment-n#{i}",
|
||||
target_id: persona.id,
|
||||
target_type: "AiPersona",
|
||||
upload: upload,
|
||||
)
|
||||
|
||||
# Similarity is determined left-to-right.
|
||||
embeddings = [embedding_value + "0.000#{i}".to_f] * vector_rep.dimensions
|
||||
|
||||
schema.store(fragment, embeddings, "test")
|
||||
end
|
||||
|
||||
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn
|
||||
.any_instance
|
||||
.expects(:asymmetric_rag_fragment_similarity_search)
|
||||
.with { |args, kwargs| kwargs[:limit] == (expected_limit || limit) }
|
||||
.returns(candidate_ids)
|
||||
end
|
||||
|
||||
before do
|
||||
stored_ai_persona = AiPersona.find(ai_persona.id)
|
||||
UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id])
|
||||
|
||||
context_embedding = [0.049382, 0.9999]
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
SiteSetting.ai_embeddings_model,
|
||||
with_cc.dig(:conversation_context, 0, :content),
|
||||
context_embedding,
|
||||
prompt_cc_embeddings,
|
||||
)
|
||||
end
|
||||
|
||||
context "when persona allows for less fragments" do
|
||||
before { stub_fragments(3) }
|
||||
|
||||
it "will only pick 3 fragments" do
|
||||
custom_ai_persona =
|
||||
Fabricate(
|
||||
|
@ -419,6 +420,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
|
||||
)
|
||||
|
||||
stub_fragments(3, persona: custom_ai_persona)
|
||||
|
||||
UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id])
|
||||
|
||||
custom_persona =
|
||||
|
@ -438,14 +441,13 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||
context "when the reranker is available" do
|
||||
before do
|
||||
SiteSetting.ai_hugging_face_tei_reranker_endpoint = "https://test.reranker.com"
|
||||
|
||||
# hard coded internal implementation, reranker takes x5 number of chunks
|
||||
stub_fragments(15, expected_limit: 50) # Mimic limit being more than 10 results
|
||||
stub_fragments(15)
|
||||
end
|
||||
|
||||
it "uses the re-ranker to reorder the fragments and pick the top 10 candidates" do
|
||||
skip "This test is flaky needs to be investigated ordering does not come back as expected"
|
||||
expected_reranked = (0..14).to_a.reverse.map { |idx| { index: idx } }
|
||||
# The re-ranker reverses the similarity search, but return less results
|
||||
# to act as a limit for test-purposes.
|
||||
expected_reranked = (4..14).to_a.reverse.map { |idx| { index: idx } }
|
||||
|
||||
WebMock.stub_request(:post, "https://test.reranker.com/rerank").to_return(
|
||||
status: 200,
|
||||
|
|
|
@ -110,8 +110,9 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
|||
assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model)
|
||||
SiteSetting.ai_embeddings_semantic_search_enabled = true
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
hyde_embedding = [0.049382] * vector_rep.dimensions
|
||||
|
||||
hyde_embedding = [0.049382, 0.9999]
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
SiteSetting.ai_embeddings_model,
|
||||
query,
|
||||
|
@ -126,10 +127,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
|
|||
bot_user: bot_user,
|
||||
)
|
||||
|
||||
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn
|
||||
.any_instance
|
||||
.expects(:asymmetric_topics_similarity_search)
|
||||
.returns([post1.topic_id])
|
||||
DiscourseAi::Embeddings::Schema.for(Topic).store(post1.topic, hyde_embedding, "digest")
|
||||
|
||||
results =
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(["<ai>#{query}</ai>"]) do
|
||||
|
|
|
@ -14,10 +14,7 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
|
|||
fab!(:category)
|
||||
fab!(:topic) { Fabricate(:topic, category: category) }
|
||||
|
||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||
let(:vector_rep) do
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
|
||||
end
|
||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) }
|
||||
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
|
||||
|
||||
|
|
|
@ -13,9 +13,9 @@ RSpec.describe Jobs::GenerateEmbeddings do
|
|||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||
|
||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||
let(:vector_rep) do
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
|
||||
end
|
||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) }
|
||||
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) }
|
||||
|
||||
it "works for topics" do
|
||||
expected_embedding = [0.0038493] * vector_rep.dimensions
|
||||
|
@ -30,7 +30,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
|
|||
|
||||
job.execute(target_id: topic.id, target_type: "Topic")
|
||||
|
||||
expect(vector_rep.topic_id_from_representation(expected_embedding)).to eq(topic.id)
|
||||
expect(topics_schema.find_by_embedding(expected_embedding).topic_id).to eq(topic.id)
|
||||
end
|
||||
|
||||
it "works for posts" do
|
||||
|
@ -42,7 +42,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
|
|||
|
||||
job.execute(target_id: post.id, target_type: "Post")
|
||||
|
||||
expect(vector_rep.post_id_from_representation(expected_embedding)).to eq(post.id)
|
||||
expect(posts_schema.find_by_embedding(expected_embedding).post_id).to eq(post.id)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Embeddings::Schema do
|
||||
subject(:posts_schema) { described_class.for(Post, vector: vector) }
|
||||
|
||||
let(:embeddings) { [0.0038490295] * vector.dimensions }
|
||||
fab!(:post) { Fabricate(:post, post_number: 1) }
|
||||
let(:digest) { OpenSSL::Digest.hexdigest("SHA1", "test") }
|
||||
let(:vector) do
|
||||
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new(
|
||||
DiscourseAi::Embeddings::Strategies::Truncation.new,
|
||||
)
|
||||
end
|
||||
|
||||
before { posts_schema.store(post, embeddings, digest) }
|
||||
|
||||
describe "#find_by_target" do
|
||||
it "gets you the post_id of the record that matches the post" do
|
||||
embeddings_record = posts_schema.find_by_target(post)
|
||||
|
||||
expect(embeddings_record.digest).to eq(digest)
|
||||
expect(JSON.parse(embeddings_record.embeddings)).to eq(embeddings)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#find_by_embedding" do
|
||||
it "gets you the record that matches the embedding" do
|
||||
embeddings_record = posts_schema.find_by_embedding(embeddings)
|
||||
|
||||
expect(embeddings_record.digest).to eq(digest)
|
||||
expect(embeddings_record.post_id).to eq(post.id)
|
||||
end
|
||||
end
|
||||
|
||||
describe "similarity searches" do
|
||||
fab!(:post_2) { Fabricate(:post) }
|
||||
let(:similar_embeddings) { [0.0038490294] * vector.dimensions }
|
||||
|
||||
describe "#symmetric_similarity_search" do
|
||||
before { posts_schema.store(post_2, similar_embeddings, digest) }
|
||||
|
||||
it "returns target_id with similar embeddings" do
|
||||
similar_records = posts_schema.symmetric_similarity_search(post)
|
||||
|
||||
expect(similar_records.map(&:post_id)).to contain_exactly(post.id, post_2.id)
|
||||
end
|
||||
|
||||
it "let's you apply additional scopes to filter results further" do
|
||||
similar_records =
|
||||
posts_schema.symmetric_similarity_search(post) do |builder|
|
||||
builder.where("post_id = ?", post_2.id)
|
||||
end
|
||||
|
||||
expect(similar_records.map(&:post_id)).to contain_exactly(post_2.id)
|
||||
end
|
||||
|
||||
it "let's you join on additional tables and combine with additional scopes" do
|
||||
similar_records =
|
||||
posts_schema.symmetric_similarity_search(post) do |builder|
|
||||
builder.join("posts p on p.id = post_id")
|
||||
builder.join("topics t on t.id = p.topic_id")
|
||||
builder.where("t.id = ?", post_2.topic_id)
|
||||
end
|
||||
|
||||
expect(similar_records.map(&:post_id)).to contain_exactly(post_2.id)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#asymmetric_similarity_search" do
|
||||
it "returns target_id with similar embeddings" do
|
||||
similar_records =
|
||||
posts_schema.asymmetric_similarity_search(similar_embeddings, limit: 1, offset: 0)
|
||||
|
||||
expect(similar_records.map(&:post_id)).to contain_exactly(post.id)
|
||||
end
|
||||
|
||||
it "let's you apply additional scopes to filter results further" do
|
||||
similar_records =
|
||||
posts_schema.asymmetric_similarity_search(
|
||||
similar_embeddings,
|
||||
limit: 1,
|
||||
offset: 0,
|
||||
) { |builder| builder.where("post_id <> ?", post.id) }
|
||||
|
||||
expect(similar_records.map(&:post_id)).to be_empty
|
||||
end
|
||||
|
||||
it "let's you join on additional tables and combine with additional scopes" do
|
||||
similar_records =
|
||||
posts_schema.asymmetric_similarity_search(
|
||||
similar_embeddings,
|
||||
limit: 1,
|
||||
offset: 0,
|
||||
) do |builder|
|
||||
builder.join("posts p on p.id = post_id")
|
||||
builder.join("topics t on t.id = p.topic_id")
|
||||
builder.where("t.id <> ?", post.topic_id)
|
||||
end
|
||||
|
||||
expect(similar_records.map(&:post_id)).to be_empty
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -25,9 +25,7 @@ describe DiscourseAi::Embeddings::SemanticRelated do
|
|||
end
|
||||
|
||||
let(:vector_rep) do
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
end
|
||||
|
||||
it "properly generates embeddings if missing" do
|
||||
|
|
|
@ -11,11 +11,12 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
|||
|
||||
describe "#search_for_topics" do
|
||||
let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" }
|
||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
let(:hyde_embedding) { [0.049382] * vector_rep.dimensions }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
|
||||
hyde_embedding = [0.049382, 0.9999]
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
SiteSetting.ai_embeddings_model,
|
||||
hypothetical_post,
|
||||
|
@ -25,11 +26,8 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
|||
|
||||
after { described_class.clear_cache_for(query) }
|
||||
|
||||
def stub_candidate_ids(candidate_ids)
|
||||
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn
|
||||
.any_instance
|
||||
.expects(:asymmetric_topics_similarity_search)
|
||||
.returns(candidate_ids)
|
||||
def insert_candidate(candidate)
|
||||
DiscourseAi::Embeddings::Schema.for(Topic).store(candidate, hyde_embedding, "digest")
|
||||
end
|
||||
|
||||
def trigger_search(query)
|
||||
|
@ -39,7 +37,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
|||
end
|
||||
|
||||
it "returns the first post of a topic included in the asymmetric search results" do
|
||||
stub_candidate_ids([post.topic_id])
|
||||
insert_candidate(post.topic)
|
||||
|
||||
posts = trigger_search(query)
|
||||
|
||||
|
@ -50,7 +48,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
|||
context "when the topic is not visible" do
|
||||
it "returns an empty list" do
|
||||
post.topic.update!(visible: false)
|
||||
stub_candidate_ids([post.topic_id])
|
||||
insert_candidate(post.topic)
|
||||
|
||||
posts = trigger_search(query)
|
||||
|
||||
|
@ -61,7 +59,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
|||
context "when the post is not public" do
|
||||
it "returns an empty list" do
|
||||
pm_post = Fabricate(:private_message_post)
|
||||
stub_candidate_ids([pm_post.topic_id])
|
||||
insert_candidate(pm_post.topic)
|
||||
|
||||
posts = trigger_search(query)
|
||||
|
||||
|
@ -72,7 +70,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
|||
context "when the post type is not visible" do
|
||||
it "returns an empty list" do
|
||||
post.update!(post_type: Post.types[:whisper])
|
||||
stub_candidate_ids([post.topic_id])
|
||||
insert_candidate(post.topic)
|
||||
|
||||
posts = trigger_search(query)
|
||||
|
||||
|
@ -84,7 +82,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
|||
it "returns an empty list" do
|
||||
reply = Fabricate(:reply)
|
||||
reply.topic.first_post.trash!
|
||||
stub_candidate_ids([reply.topic_id])
|
||||
insert_candidate(reply.topic)
|
||||
|
||||
posts = trigger_search(query)
|
||||
|
||||
|
@ -95,7 +93,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
|||
context "when the post is not a candidate" do
|
||||
it "doesn't include it in the results" do
|
||||
post_2 = Fabricate(:post)
|
||||
stub_candidate_ids([post.topic_id])
|
||||
insert_candidate(post.topic)
|
||||
|
||||
posts = trigger_search(query)
|
||||
|
||||
|
@ -109,7 +107,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
|
|||
|
||||
before do
|
||||
post.topic.update!(category: private_category)
|
||||
stub_candidate_ids([post.topic_id])
|
||||
insert_candidate(post.topic)
|
||||
end
|
||||
|
||||
it "returns an empty list" do
|
||||
|
|
|
@ -9,11 +9,17 @@ describe DiscourseAi::Embeddings::EntryPoint do
|
|||
|
||||
fab!(:target) { Fabricate(:topic) }
|
||||
|
||||
def stub_semantic_search_with(results)
|
||||
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn
|
||||
.any_instance
|
||||
.expects(:symmetric_topics_similarity_search)
|
||||
.returns(results.concat([target.id]))
|
||||
# The Distance gap to target increases for each element of topics.
|
||||
def seed_embeddings(topics)
|
||||
schema = DiscourseAi::Embeddings::Schema.for(Topic)
|
||||
base_value = 1
|
||||
|
||||
schema.store(target, [base_value] * 1024, "disgest")
|
||||
|
||||
topics.each do |t|
|
||||
base_value -= 0.01
|
||||
schema.store(t, [base_value] * 1024, "digest")
|
||||
end
|
||||
end
|
||||
|
||||
after { DiscourseAi::Embeddings::SemanticRelated.clear_cache_for(target) }
|
||||
|
@ -21,7 +27,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
|
|||
context "when the semantic search returns an unlisted topic" do
|
||||
fab!(:unlisted_topic) { Fabricate(:topic, visible: false) }
|
||||
|
||||
before { stub_semantic_search_with([unlisted_topic.id]) }
|
||||
before { seed_embeddings([unlisted_topic]) }
|
||||
|
||||
it "filters it out" do
|
||||
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
|
||||
|
@ -31,7 +37,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
|
|||
context "when the semantic search returns a private topic" do
|
||||
fab!(:private_topic) { Fabricate(:private_message_topic) }
|
||||
|
||||
before { stub_semantic_search_with([private_topic.id]) }
|
||||
before { seed_embeddings([private_topic]) }
|
||||
|
||||
it "filters it out" do
|
||||
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
|
||||
|
@ -43,7 +49,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
|
|||
fab!(:category) { Fabricate(:private_category, group: group) }
|
||||
fab!(:secured_category_topic) { Fabricate(:topic, category: category) }
|
||||
|
||||
before { stub_semantic_search_with([secured_category_topic.id]) }
|
||||
before { seed_embeddings([secured_category_topic]) }
|
||||
|
||||
it "filters it out" do
|
||||
expect(topic_query.list_semantic_related_topics(target).topics).to be_empty
|
||||
|
@ -63,7 +69,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
|
|||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_semantic_related_include_closed_topics = false
|
||||
stub_semantic_search_with([closed_topic.id])
|
||||
seed_embeddings([closed_topic])
|
||||
end
|
||||
|
||||
it "filters it out" do
|
||||
|
@ -80,7 +86,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
|
|||
category_id: category.id,
|
||||
notification_level: CategoryUser.notification_levels[:muted],
|
||||
)
|
||||
stub_semantic_search_with([topic.id])
|
||||
seed_embeddings([topic])
|
||||
expect(topic_query.list_semantic_related_topics(target).topics).not_to include(topic)
|
||||
end
|
||||
end
|
||||
|
@ -91,11 +97,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
|
|||
fab!(:normal_topic_3) { Fabricate(:topic) }
|
||||
fab!(:closed_topic) { Fabricate(:topic, closed: true) }
|
||||
|
||||
before do
|
||||
stub_semantic_search_with(
|
||||
[closed_topic.id, normal_topic_1.id, normal_topic_2.id, normal_topic_3.id],
|
||||
)
|
||||
end
|
||||
before { seed_embeddings([closed_topic, normal_topic_1, normal_topic_2, normal_topic_3]) }
|
||||
|
||||
it "filters it out" do
|
||||
expect(topic_query.list_semantic_related_topics(target).topics).to eq(
|
||||
|
@ -117,7 +119,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
|
|||
fab!(:included_topic) { Fabricate(:topic) }
|
||||
fab!(:excluded_topic) { Fabricate(:topic) }
|
||||
|
||||
before { stub_semantic_search_with([included_topic.id, excluded_topic.id]) }
|
||||
before { seed_embeddings([included_topic, excluded_topic]) }
|
||||
|
||||
let(:modifier_block) { Proc.new { |query| query.where.not(id: excluded_topic.id) } }
|
||||
|
||||
|
|
|
@ -4,6 +4,9 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
|||
let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions }
|
||||
let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions }
|
||||
|
||||
let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep) }
|
||||
let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector: vector_rep) }
|
||||
|
||||
describe "#vector_from" do
|
||||
it "creates a vector from a given string" do
|
||||
text = "This is a piece of text"
|
||||
|
@ -29,7 +32,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
|||
|
||||
vector_rep.generate_representation_from(topic)
|
||||
|
||||
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
|
||||
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
||||
end
|
||||
|
||||
it "creates a vector from a post and stores it in the database" do
|
||||
|
@ -43,7 +46,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
|||
|
||||
vector_rep.generate_representation_from(post)
|
||||
|
||||
expect(vector_rep.post_id_from_representation(expected_embedding_1)).to eq(post.id)
|
||||
expect(posts_schema.find_by_embedding(expected_embedding_1).post_id).to eq(post.id)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -76,8 +79,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
|||
|
||||
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
|
||||
|
||||
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
|
||||
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
|
||||
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
||||
end
|
||||
|
||||
it "does nothing if passed record has no content" do
|
||||
|
@ -99,69 +101,15 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
|||
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
|
||||
end
|
||||
# check vector exists
|
||||
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
|
||||
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
||||
|
||||
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
|
||||
last_update =
|
||||
DB.query_single(
|
||||
"SELECT updated_at FROM #{vector_rep.topic_table_name} WHERE topic_id = #{topic.id} LIMIT 1",
|
||||
"SELECT updated_at FROM #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE topic_id = #{topic.id} LIMIT 1",
|
||||
).first
|
||||
|
||||
expect(last_update).to eq(original_vector_gen)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#asymmetric_topics_similarity_search" do
|
||||
fab!(:topic) { Fabricate(:topic) }
|
||||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||
|
||||
it "finds IDs of similar topics with a given embedding" do
|
||||
similar_vector = [0.0038494] * vector_rep.dimensions
|
||||
text =
|
||||
truncation.prepare_text_from(
|
||||
topic,
|
||||
vector_rep.tokenizer,
|
||||
vector_rep.max_sequence_length - 2,
|
||||
)
|
||||
stub_vector_mapping(text, expected_embedding_1)
|
||||
vector_rep.generate_representation_from(topic)
|
||||
|
||||
expect(
|
||||
vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0),
|
||||
).to contain_exactly(topic.id)
|
||||
end
|
||||
|
||||
it "can exclude categories" do
|
||||
similar_vector = [0.0038494] * vector_rep.dimensions
|
||||
text =
|
||||
truncation.prepare_text_from(
|
||||
topic,
|
||||
vector_rep.tokenizer,
|
||||
vector_rep.max_sequence_length - 2,
|
||||
)
|
||||
stub_vector_mapping(text, expected_embedding_1)
|
||||
vector_rep.generate_representation_from(topic)
|
||||
|
||||
expect(
|
||||
vector_rep.asymmetric_topics_similarity_search(
|
||||
similar_vector,
|
||||
limit: 1,
|
||||
offset: 0,
|
||||
exclude_category_ids: [topic.category_id],
|
||||
),
|
||||
).to be_empty
|
||||
|
||||
child_category = Fabricate(:category, parent_category_id: topic.category_id)
|
||||
topic.update!(category_id: child_category.id)
|
||||
|
||||
expect(
|
||||
vector_rep.asymmetric_topics_similarity_search(
|
||||
similar_vector,
|
||||
limit: 1,
|
||||
offset: 0,
|
||||
exclude_category_ids: [topic.category_id],
|
||||
),
|
||||
).to be_empty
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -74,10 +74,7 @@ RSpec.describe RagDocumentFragment do
|
|||
end
|
||||
|
||||
describe ".indexing_status" do
|
||||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||
let(:vector_rep) do
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
|
||||
end
|
||||
let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation }
|
||||
|
||||
fab!(:rag_document_fragment_1) do
|
||||
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
|
||||
|
|
|
@ -19,9 +19,7 @@ describe DiscourseAi::Embeddings::EmbeddingsController do
|
|||
fab!(:post_in_subcategory) { Fabricate(:post, topic: topic_in_subcategory) }
|
||||
|
||||
def index(topic)
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
||||
|
||||
stub_request(:post, "https://api.openai.com/v1/embeddings").to_return(
|
||||
status: 200,
|
||||
|
|
Loading…
Reference in New Issue