DEV: Backfill embeddings concurrently. (#941)
We are adding a new method for generating and storing embeddings in bulk, which relies on `Concurrent::Promises::Future`. Generating an embedding consists of three steps: Prepare text HTTP call to retrieve the vector Save to DB. Each one is independently executed on whatever thread the pool gives us. We are bringing a custom thread pool instead of the global executor since we want control over how many threads we spawn to limit concurrency. We also avoid firing thousands of HTTP requests when working with large batches.
This commit is contained in:
parent
23193ee6f2
commit
ddf2bf7034
|
@ -75,9 +75,9 @@ module Jobs
|
||||||
# First, we'll try to backfill embeddings for posts that have none
|
# First, we'll try to backfill embeddings for posts that have none
|
||||||
posts
|
posts
|
||||||
.where("#{table_name}.post_id IS NULL")
|
.where("#{table_name}.post_id IS NULL")
|
||||||
.find_each do |t|
|
.find_in_batches do |batch|
|
||||||
vector_rep.generate_representation_from(t)
|
vector_rep.gen_bulk_reprensentations(batch)
|
||||||
rebaked += 1
|
rebaked += batch.size
|
||||||
end
|
end
|
||||||
|
|
||||||
return if rebaked >= limit
|
return if rebaked >= limit
|
||||||
|
@ -90,24 +90,28 @@ module Jobs
|
||||||
OR
|
OR
|
||||||
#{table_name}.strategy_version < #{strategy.version}
|
#{table_name}.strategy_version < #{strategy.version}
|
||||||
SQL
|
SQL
|
||||||
.find_each do |t|
|
.find_in_batches do |batch|
|
||||||
vector_rep.generate_representation_from(t)
|
vector_rep.gen_bulk_reprensentations(batch)
|
||||||
rebaked += 1
|
rebaked += batch.size
|
||||||
end
|
end
|
||||||
|
|
||||||
return if rebaked >= limit
|
return if rebaked >= limit
|
||||||
|
|
||||||
# Finally, we'll try to backfill embeddings for posts that have outdated
|
# Finally, we'll try to backfill embeddings for posts that have outdated
|
||||||
# embeddings due to edits. Here we only do 10% of the limit
|
# embeddings due to edits. Here we only do 10% of the limit
|
||||||
posts
|
posts_batch_size = 1000
|
||||||
.where("#{table_name}.updated_at < ?", 7.days.ago)
|
|
||||||
.order("random()")
|
outdated_post_ids =
|
||||||
.limit((limit - rebaked) / 10)
|
posts
|
||||||
.pluck(:id)
|
.where("#{table_name}.updated_at < ?", 7.days.ago)
|
||||||
.each do |id|
|
.order("random()")
|
||||||
vector_rep.generate_representation_from(Post.find_by(id: id))
|
.limit((limit - rebaked) / 10)
|
||||||
rebaked += 1
|
.pluck(:id)
|
||||||
end
|
|
||||||
|
outdated_post_ids.each_slice(posts_batch_size) do |batch|
|
||||||
|
vector_rep.gen_bulk_reprensentations(Post.where(id: batch).order("topics.bumped_at DESC"))
|
||||||
|
rebaked += batch.length
|
||||||
|
end
|
||||||
|
|
||||||
rebaked
|
rebaked
|
||||||
end
|
end
|
||||||
|
@ -120,14 +124,13 @@ module Jobs
|
||||||
topics = topics.where("#{vector_rep.topic_table_name}.topic_id IS NULL") if !force
|
topics = topics.where("#{vector_rep.topic_table_name}.topic_id IS NULL") if !force
|
||||||
|
|
||||||
ids = topics.pluck("topics.id")
|
ids = topics.pluck("topics.id")
|
||||||
|
batch_size = 1000
|
||||||
|
|
||||||
ids.each do |id|
|
ids.each_slice(batch_size) do |batch|
|
||||||
topic = Topic.find_by(id: id)
|
vector_rep.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
|
||||||
if topic
|
done += batch.length
|
||||||
vector_rep.generate_representation_from(topic)
|
|
||||||
done += 1
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
done
|
done
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -50,8 +50,38 @@ module DiscourseAi
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def gen_bulk_reprensentations(relation)
|
||||||
|
http_pool_size = 100
|
||||||
|
pool =
|
||||||
|
Concurrent::CachedThreadPool.new(
|
||||||
|
min_threads: 0,
|
||||||
|
max_threads: http_pool_size,
|
||||||
|
idletime: 30,
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_gen = inference_client
|
||||||
|
promised_embeddings =
|
||||||
|
relation.map do |record|
|
||||||
|
materials = { target: record, text: prepare_text(record) }
|
||||||
|
|
||||||
|
Concurrent::Promises
|
||||||
|
.fulfilled_future(materials, pool)
|
||||||
|
.then_on(pool) do |w_prepared_text|
|
||||||
|
w_prepared_text.merge(
|
||||||
|
embedding: embedding_gen.perform!(w_prepared_text[:text]),
|
||||||
|
digest: OpenSSL::Digest::SHA1.hexdigest(w_prepared_text[:text]),
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
Concurrent::Promises
|
||||||
|
.zip(*promised_embeddings)
|
||||||
|
.value!
|
||||||
|
.each { |e| save_to_db(e[:target], e[:embedding], e[:digest]) }
|
||||||
|
end
|
||||||
|
|
||||||
def generate_representation_from(target, persist: true)
|
def generate_representation_from(target, persist: true)
|
||||||
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
|
text = prepare_text(target)
|
||||||
return if text.blank?
|
return if text.blank?
|
||||||
|
|
||||||
target_column =
|
target_column =
|
||||||
|
@ -429,6 +459,10 @@ module DiscourseAi
|
||||||
def inference_client
|
def inference_client
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def prepare_text(record)
|
||||||
|
@strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -34,7 +34,7 @@ module DiscourseAi
|
||||||
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
|
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
|
||||||
if needs_truncation
|
if needs_truncation
|
||||||
text = tokenizer.truncate(text, max_sequence_length - 2)
|
text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||||
else
|
elsif !text.starts_with?("query:")
|
||||||
text = "query: #{text}"
|
text = "query: #{text}"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -79,6 +79,14 @@ module DiscourseAi
|
||||||
raise "No inference endpoint configured"
|
raise "No inference endpoint configured"
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def prepare_text(record)
|
||||||
|
if inference_client.class.name.include?("DiscourseClassifier")
|
||||||
|
return "query: #{super(record)}"
|
||||||
|
end
|
||||||
|
|
||||||
|
super(record)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.shared_examples "generates and store embedding using with vector representation" do
|
RSpec.shared_examples "generates and store embedding using with vector representation" do
|
||||||
before { @expected_embedding = [0.0038493] * vector_rep.dimensions }
|
let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions }
|
||||||
|
let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions }
|
||||||
|
|
||||||
describe "#vector_from" do
|
describe "#vector_from" do
|
||||||
it "creates a vector from a given string" do
|
it "creates a vector from a given string" do
|
||||||
text = "This is a piece of text"
|
text = "This is a piece of text"
|
||||||
stub_vector_mapping(text, @expected_embedding)
|
stub_vector_mapping(text, expected_embedding_1)
|
||||||
|
|
||||||
expect(vector_rep.vector_from(text)).to eq(@expected_embedding)
|
expect(vector_rep.vector_from(text)).to eq(expected_embedding_1)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -24,11 +25,11 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
||||||
vector_rep.tokenizer,
|
vector_rep.tokenizer,
|
||||||
vector_rep.max_sequence_length - 2,
|
vector_rep.max_sequence_length - 2,
|
||||||
)
|
)
|
||||||
stub_vector_mapping(text, @expected_embedding)
|
stub_vector_mapping(text, expected_embedding_1)
|
||||||
|
|
||||||
vector_rep.generate_representation_from(topic)
|
vector_rep.generate_representation_from(topic)
|
||||||
|
|
||||||
expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id)
|
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "creates a vector from a post and stores it in the database" do
|
it "creates a vector from a post and stores it in the database" do
|
||||||
|
@ -38,11 +39,45 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
||||||
vector_rep.tokenizer,
|
vector_rep.tokenizer,
|
||||||
vector_rep.max_sequence_length - 2,
|
vector_rep.max_sequence_length - 2,
|
||||||
)
|
)
|
||||||
stub_vector_mapping(text, @expected_embedding)
|
stub_vector_mapping(text, expected_embedding_1)
|
||||||
|
|
||||||
vector_rep.generate_representation_from(post)
|
vector_rep.generate_representation_from(post)
|
||||||
|
|
||||||
expect(vector_rep.post_id_from_representation(@expected_embedding)).to eq(post.id)
|
expect(vector_rep.post_id_from_representation(expected_embedding_1)).to eq(post.id)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#gen_bulk_reprensentations" do
|
||||||
|
fab!(:topic) { Fabricate(:topic) }
|
||||||
|
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||||
|
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
|
||||||
|
|
||||||
|
fab!(:topic_2) { Fabricate(:topic) }
|
||||||
|
fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) }
|
||||||
|
fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) }
|
||||||
|
|
||||||
|
it "creates a vector for each object in the relation" do
|
||||||
|
text =
|
||||||
|
truncation.prepare_text_from(
|
||||||
|
topic,
|
||||||
|
vector_rep.tokenizer,
|
||||||
|
vector_rep.max_sequence_length - 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
text2 =
|
||||||
|
truncation.prepare_text_from(
|
||||||
|
topic_2,
|
||||||
|
vector_rep.tokenizer,
|
||||||
|
vector_rep.max_sequence_length - 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
stub_vector_mapping(text, expected_embedding_1)
|
||||||
|
stub_vector_mapping(text2, expected_embedding_2)
|
||||||
|
|
||||||
|
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
|
||||||
|
|
||||||
|
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
|
||||||
|
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -58,7 +93,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
||||||
vector_rep.tokenizer,
|
vector_rep.tokenizer,
|
||||||
vector_rep.max_sequence_length - 2,
|
vector_rep.max_sequence_length - 2,
|
||||||
)
|
)
|
||||||
stub_vector_mapping(text, @expected_embedding)
|
stub_vector_mapping(text, expected_embedding_1)
|
||||||
vector_rep.generate_representation_from(topic)
|
vector_rep.generate_representation_from(topic)
|
||||||
|
|
||||||
expect(
|
expect(
|
||||||
|
|
Loading…
Reference in New Issue