FIX: Add a digest check to avoid repeteadly generating embeddings (bulk) (#1001)
This commit is contained in:
parent
d6beac48f8
commit
b32b1cf241
|
@ -66,13 +66,16 @@ module DiscourseAi
|
||||||
prepared_text = prepare_text(record)
|
prepared_text = prepare_text(record)
|
||||||
next if prepared_text.blank?
|
next if prepared_text.blank?
|
||||||
|
|
||||||
|
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
|
||||||
|
next if find_digest_of(record) == new_digest
|
||||||
|
|
||||||
Concurrent::Promises
|
Concurrent::Promises
|
||||||
.fulfilled_future({ target: record, text: prepared_text }, pool)
|
.fulfilled_future(
|
||||||
.then_on(pool) do |w_prepared_text|
|
{ target: record, text: prepared_text, digest: new_digest },
|
||||||
w_prepared_text.merge(
|
pool,
|
||||||
embedding: embedding_gen.perform!(w_prepared_text[:text]),
|
|
||||||
digest: OpenSSL::Digest::SHA1.hexdigest(w_prepared_text[:text]),
|
|
||||||
)
|
)
|
||||||
|
.then_on(pool) do |w_prepared_text|
|
||||||
|
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
.compact
|
.compact
|
||||||
|
@ -90,31 +93,8 @@ module DiscourseAi
|
||||||
text = prepare_text(target)
|
text = prepare_text(target)
|
||||||
return if text.blank?
|
return if text.blank?
|
||||||
|
|
||||||
target_column =
|
|
||||||
case target
|
|
||||||
when Topic
|
|
||||||
"topic_id"
|
|
||||||
when Post
|
|
||||||
"post_id"
|
|
||||||
when RagDocumentFragment
|
|
||||||
"rag_document_fragment_id"
|
|
||||||
else
|
|
||||||
raise ArgumentError, "Invalid target type"
|
|
||||||
end
|
|
||||||
|
|
||||||
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
||||||
current_digest = DB.query_single(<<~SQL, target_id: target.id).first
|
return if find_digest_of(target) == new_digest
|
||||||
SELECT
|
|
||||||
digest
|
|
||||||
FROM
|
|
||||||
#{table_name(target)}
|
|
||||||
WHERE
|
|
||||||
model_id = #{id} AND
|
|
||||||
strategy_id = #{@strategy.id} AND
|
|
||||||
#{target_column} = :target_id
|
|
||||||
LIMIT 1
|
|
||||||
SQL
|
|
||||||
return if current_digest == new_digest
|
|
||||||
|
|
||||||
vector = vector_from(text)
|
vector = vector_from(text)
|
||||||
|
|
||||||
|
@ -412,6 +392,32 @@ module DiscourseAi
|
||||||
|
|
||||||
protected
|
protected
|
||||||
|
|
||||||
|
def find_digest_of(target)
|
||||||
|
target_column =
|
||||||
|
case target
|
||||||
|
when Topic
|
||||||
|
"topic_id"
|
||||||
|
when Post
|
||||||
|
"post_id"
|
||||||
|
when RagDocumentFragment
|
||||||
|
"rag_document_fragment_id"
|
||||||
|
else
|
||||||
|
raise ArgumentError, "Invalid target type"
|
||||||
|
end
|
||||||
|
|
||||||
|
DB.query_single(<<~SQL, target_id: target.id).first
|
||||||
|
SELECT
|
||||||
|
digest
|
||||||
|
FROM
|
||||||
|
#{table_name(target)}
|
||||||
|
WHERE
|
||||||
|
model_id = #{id} AND
|
||||||
|
strategy_id = #{@strategy.id} AND
|
||||||
|
#{target_column} = :target_id
|
||||||
|
LIMIT 1
|
||||||
|
SQL
|
||||||
|
end
|
||||||
|
|
||||||
def save_to_db(target, vector, digest)
|
def save_to_db(target, vector, digest)
|
||||||
if target.is_a?(Topic)
|
if target.is_a?(Topic)
|
||||||
DB.exec(
|
DB.exec(
|
||||||
|
|
|
@ -83,6 +83,32 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
||||||
it "does nothing if passed record has no content" do
|
it "does nothing if passed record has no content" do
|
||||||
expect { vector_rep.gen_bulk_reprensentations([Topic.new]) }.not_to raise_error
|
expect { vector_rep.gen_bulk_reprensentations([Topic.new]) }.not_to raise_error
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "doesn't ask for a new embedding if digest is the same" do
|
||||||
|
text =
|
||||||
|
truncation.prepare_text_from(
|
||||||
|
topic,
|
||||||
|
vector_rep.tokenizer,
|
||||||
|
vector_rep.max_sequence_length - 2,
|
||||||
|
)
|
||||||
|
stub_vector_mapping(text, expected_embedding_1)
|
||||||
|
|
||||||
|
original_vector_gen = Time.zone.parse("2021-06-04 10:00")
|
||||||
|
|
||||||
|
freeze_time(original_vector_gen) do
|
||||||
|
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
|
||||||
|
end
|
||||||
|
# check vector exists
|
||||||
|
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
|
||||||
|
|
||||||
|
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id]))
|
||||||
|
last_update =
|
||||||
|
DB.query_single(
|
||||||
|
"SELECT updated_at FROM #{vector_rep.topic_table_name} WHERE topic_id = #{topic.id} LIMIT 1",
|
||||||
|
).first
|
||||||
|
|
||||||
|
expect(last_update).to eq(original_vector_gen)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "#asymmetric_topics_similarity_search" do
|
describe "#asymmetric_topics_similarity_search" do
|
||||||
|
|
Loading…
Reference in New Issue