FIX: Add a digest check to avoid repeteadly generating embeddings (bulk) (#1001)
This commit is contained in:
parent
d6beac48f8
commit
b32b1cf241
|
@ -66,13 +66,16 @@ module DiscourseAi
|
|||
prepared_text = prepare_text(record)
|
||||
next if prepared_text.blank?
|
||||
|
||||
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
|
||||
next if find_digest_of(record) == new_digest
|
||||
|
||||
Concurrent::Promises
|
||||
.fulfilled_future({ target: record, text: prepared_text }, pool)
|
||||
.fulfilled_future(
|
||||
{ target: record, text: prepared_text, digest: new_digest },
|
||||
pool,
|
||||
)
|
||||
.then_on(pool) do |w_prepared_text|
|
||||
w_prepared_text.merge(
|
||||
embedding: embedding_gen.perform!(w_prepared_text[:text]),
|
||||
digest: OpenSSL::Digest::SHA1.hexdigest(w_prepared_text[:text]),
|
||||
)
|
||||
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
|
||||
end
|
||||
end
|
||||
.compact
|
||||
|
@ -90,31 +93,8 @@ module DiscourseAi
|
|||
text = prepare_text(target)
|
||||
return if text.blank?
|
||||
|
||||
target_column =
|
||||
case target
|
||||
when Topic
|
||||
"topic_id"
|
||||
when Post
|
||||
"post_id"
|
||||
when RagDocumentFragment
|
||||
"rag_document_fragment_id"
|
||||
else
|
||||
raise ArgumentError, "Invalid target type"
|
||||
end
|
||||
|
||||
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
||||
current_digest = DB.query_single(<<~SQL, target_id: target.id).first
|
||||
SELECT
|
||||
digest
|
||||
FROM
|
||||
#{table_name(target)}
|
||||
WHERE
|
||||
model_id = #{id} AND
|
||||
strategy_id = #{@strategy.id} AND
|
||||
#{target_column} = :target_id
|
||||
LIMIT 1
|
||||
SQL
|
||||
return if current_digest == new_digest
|
||||
return if find_digest_of(target) == new_digest
|
||||
|
||||
vector = vector_from(text)
|
||||
|
||||
|
@ -412,6 +392,32 @@ module DiscourseAi
|
|||
|
||||
protected
|
||||
|
||||
def find_digest_of(target)
|
||||
target_column =
|
||||
case target
|
||||
when Topic
|
||||
"topic_id"
|
||||
when Post
|
||||
"post_id"
|
||||
when RagDocumentFragment
|
||||
"rag_document_fragment_id"
|
||||
else
|
||||
raise ArgumentError, "Invalid target type"
|
||||
end
|
||||
|
||||
DB.query_single(<<~SQL, target_id: target.id).first
|
||||
SELECT
|
||||
digest
|
||||
FROM
|
||||
#{table_name(target)}
|
||||
WHERE
|
||||
model_id = #{id} AND
|
||||
strategy_id = #{@strategy.id} AND
|
||||
#{target_column} = :target_id
|
||||
LIMIT 1
|
||||
SQL
|
||||
end
|
||||
|
||||
def save_to_db(target, vector, digest)
|
||||
if target.is_a?(Topic)
|
||||
DB.exec(
|
||||
|
|
|
@ -83,6 +83,32 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
|||
it "does nothing if passed record has no content" do
|
||||
expect { vector_rep.gen_bulk_reprensentations([Topic.new]) }.not_to raise_error
|
||||
end
|
||||
|
||||
it "doesn't ask for a new embedding if digest is the same" do
|
||||
text =
|
||||
truncation.prepare_text_from(
|
||||
topic,
|
||||
vector_rep.tokenizer,
|
||||
vector_rep.max_sequence_length - 2,
|
||||
)
|
||||
stub_vector_mapping(text, expected_embedding_1)
|
||||
|
||||
original_vector_gen = Time.zone.parse("2021-06-04 10:00")
|
||||
|
||||
freeze_time(original_vector_gen) do
|
||||
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
|
||||
end
|
||||
# check vector exists
|
||||
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
|
||||
|
||||
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
|
||||
last_update =
|
||||
DB.query_single(
|
||||
"SELECT updated_at FROM #{vector_rep.topic_table_name} WHERE topic_id = #{topic.id} LIMIT 1",
|
||||
).first
|
||||
|
||||
expect(last_update).to eq(original_vector_gen)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#asymmetric_topics_similarity_search" do
|
||||
|
|
Loading…
Reference in New Issue