FIX: Add a digest check to avoid repeteadly generating embeddings (bulk) (#1001)

This commit is contained in:
Roman Rizzi 2024-12-04 17:47:28 -03:00 committed by GitHub
parent d6beac48f8
commit b32b1cf241
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 29 deletions

View File

@ -66,13 +66,16 @@ module DiscourseAi
prepared_text = prepare_text(record) prepared_text = prepare_text(record)
next if prepared_text.blank? next if prepared_text.blank?
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
next if find_digest_of(record) == new_digest
Concurrent::Promises Concurrent::Promises
.fulfilled_future({ target: record, text: prepared_text }, pool) .fulfilled_future(
{ target: record, text: prepared_text, digest: new_digest },
pool,
)
.then_on(pool) do |w_prepared_text| .then_on(pool) do |w_prepared_text|
w_prepared_text.merge( w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
embedding: embedding_gen.perform!(w_prepared_text[:text]),
digest: OpenSSL::Digest::SHA1.hexdigest(w_prepared_text[:text]),
)
end end
end end
.compact .compact
@ -90,31 +93,8 @@ module DiscourseAi
text = prepare_text(target) text = prepare_text(target)
return if text.blank? return if text.blank?
target_column =
case target
when Topic
"topic_id"
when Post
"post_id"
when RagDocumentFragment
"rag_document_fragment_id"
else
raise ArgumentError, "Invalid target type"
end
new_digest = OpenSSL::Digest::SHA1.hexdigest(text) new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
current_digest = DB.query_single(<<~SQL, target_id: target.id).first return if find_digest_of(target) == new_digest
SELECT
digest
FROM
#{table_name(target)}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id} AND
#{target_column} = :target_id
LIMIT 1
SQL
return if current_digest == new_digest
vector = vector_from(text) vector = vector_from(text)
@ -412,6 +392,32 @@ module DiscourseAi
protected protected
def find_digest_of(target)
target_column =
case target
when Topic
"topic_id"
when Post
"post_id"
when RagDocumentFragment
"rag_document_fragment_id"
else
raise ArgumentError, "Invalid target type"
end
DB.query_single(<<~SQL, target_id: target.id).first
SELECT
digest
FROM
#{table_name(target)}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id} AND
#{target_column} = :target_id
LIMIT 1
SQL
end
def save_to_db(target, vector, digest) def save_to_db(target, vector, digest)
if target.is_a?(Topic) if target.is_a?(Topic)
DB.exec( DB.exec(

View File

@ -83,6 +83,32 @@ RSpec.shared_examples "generates and store embedding using with vector represent
it "does nothing if passed record has no content" do it "does nothing if passed record has no content" do
expect { vector_rep.gen_bulk_reprensentations([Topic.new]) }.not_to raise_error expect { vector_rep.gen_bulk_reprensentations([Topic.new]) }.not_to raise_error
end end
it "doesn't ask for a new embedding if digest is the same" do
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
original_vector_gen = Time.zone.parse("2021-06-04 10:00")
freeze_time(original_vector_gen) do
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
end
# check vector exists
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
last_update =
DB.query_single(
"SELECT updated_at FROM #{vector_rep.topic_table_name} WHERE topic_id = #{topic.id} LIMIT 1",
).first
expect(last_update).to eq(original_vector_gen)
end
end end
describe "#asymmetric_topics_similarity_search" do describe "#asymmetric_topics_similarity_search" do