DEV: Backfill embeddings concurrently. (#941)

We are adding a new method for generating and storing embeddings in bulk, which relies on `Concurrent::Promises::Future`. Generating an embedding consists of three steps:

Prepare text
HTTP call to retrieve the vector
Save to DB.
Each one is independently executed on whatever thread the pool gives us.

We are bringing a custom thread pool instead of the global executor since we want control over how many threads we spawn to limit concurrency. We also avoid firing thousands of HTTP requests when working with large batches.
This commit is contained in:
Roman Rizzi 2024-11-26 14:12:32 -03:00 committed by GitHub
parent 23193ee6f2
commit ddf2bf7034
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 111 additions and 31 deletions

View File

@ -75,9 +75,9 @@ module Jobs
# First, we'll try to backfill embeddings for posts that have none # First, we'll try to backfill embeddings for posts that have none
posts posts
.where("#{table_name}.post_id IS NULL") .where("#{table_name}.post_id IS NULL")
.find_each do |t| .find_in_batches do |batch|
vector_rep.generate_representation_from(t) vector_rep.gen_bulk_reprensentations(batch)
rebaked += 1 rebaked += batch.size
end end
return if rebaked >= limit return if rebaked >= limit
@ -90,24 +90,28 @@ module Jobs
OR OR
#{table_name}.strategy_version < #{strategy.version} #{table_name}.strategy_version < #{strategy.version}
SQL SQL
.find_each do |t| .find_in_batches do |batch|
vector_rep.generate_representation_from(t) vector_rep.gen_bulk_reprensentations(batch)
rebaked += 1 rebaked += batch.size
end end
return if rebaked >= limit return if rebaked >= limit
# Finally, we'll try to backfill embeddings for posts that have outdated # Finally, we'll try to backfill embeddings for posts that have outdated
# embeddings due to edits. Here we only do 10% of the limit # embeddings due to edits. Here we only do 10% of the limit
posts posts_batch_size = 1000
.where("#{table_name}.updated_at < ?", 7.days.ago)
.order("random()") outdated_post_ids =
.limit((limit - rebaked) / 10) posts
.pluck(:id) .where("#{table_name}.updated_at < ?", 7.days.ago)
.each do |id| .order("random()")
vector_rep.generate_representation_from(Post.find_by(id: id)) .limit((limit - rebaked) / 10)
rebaked += 1 .pluck(:id)
end
outdated_post_ids.each_slice(posts_batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Post.where(id: batch).order("topics.bumped_at DESC"))
rebaked += batch.length
end
rebaked rebaked
end end
@ -120,14 +124,13 @@ module Jobs
topics = topics.where("#{vector_rep.topic_table_name}.topic_id IS NULL") if !force topics = topics.where("#{vector_rep.topic_table_name}.topic_id IS NULL") if !force
ids = topics.pluck("topics.id") ids = topics.pluck("topics.id")
batch_size = 1000
ids.each do |id| ids.each_slice(batch_size) do |batch|
topic = Topic.find_by(id: id) vector_rep.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
if topic done += batch.length
vector_rep.generate_representation_from(topic)
done += 1
end
end end
done done
end end
end end

View File

@ -50,8 +50,38 @@ module DiscourseAi
raise NotImplementedError raise NotImplementedError
end end
def gen_bulk_reprensentations(relation)
http_pool_size = 100
pool =
Concurrent::CachedThreadPool.new(
min_threads: 0,
max_threads: http_pool_size,
idletime: 30,
)
embedding_gen = inference_client
promised_embeddings =
relation.map do |record|
materials = { target: record, text: prepare_text(record) }
Concurrent::Promises
.fulfilled_future(materials, pool)
.then_on(pool) do |w_prepared_text|
w_prepared_text.merge(
embedding: embedding_gen.perform!(w_prepared_text[:text]),
digest: OpenSSL::Digest::SHA1.hexdigest(w_prepared_text[:text]),
)
end
end
Concurrent::Promises
.zip(*promised_embeddings)
.value!
.each { |e| save_to_db(e[:target], e[:embedding], e[:digest]) }
end
def generate_representation_from(target, persist: true) def generate_representation_from(target, persist: true)
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2) text = prepare_text(target)
return if text.blank? return if text.blank?
target_column = target_column =
@ -429,6 +459,10 @@ module DiscourseAi
def inference_client def inference_client
raise NotImplementedError raise NotImplementedError
end end
def prepare_text(record)
@strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2)
end
end end
end end
end end

View File

@ -34,7 +34,7 @@ module DiscourseAi
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings") needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
if needs_truncation if needs_truncation
text = tokenizer.truncate(text, max_sequence_length - 2) text = tokenizer.truncate(text, max_sequence_length - 2)
else elsif !text.starts_with?("query:")
text = "query: #{text}" text = "query: #{text}"
end end
@ -79,6 +79,14 @@ module DiscourseAi
raise "No inference endpoint configured" raise "No inference endpoint configured"
end end
end end
def prepare_text(record)
if inference_client.class.name.include?("DiscourseClassifier")
return "query: #{super(record)}"
end
super(record)
end
end end
end end
end end

View File

@ -1,14 +1,15 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.shared_examples "generates and store embedding using with vector representation" do RSpec.shared_examples "generates and store embedding using with vector representation" do
before { @expected_embedding = [0.0038493] * vector_rep.dimensions } let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions }
let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions }
describe "#vector_from" do describe "#vector_from" do
it "creates a vector from a given string" do it "creates a vector from a given string" do
text = "This is a piece of text" text = "This is a piece of text"
stub_vector_mapping(text, @expected_embedding) stub_vector_mapping(text, expected_embedding_1)
expect(vector_rep.vector_from(text)).to eq(@expected_embedding) expect(vector_rep.vector_from(text)).to eq(expected_embedding_1)
end end
end end
@ -24,11 +25,11 @@ RSpec.shared_examples "generates and store embedding using with vector represent
vector_rep.tokenizer, vector_rep.tokenizer,
vector_rep.max_sequence_length - 2, vector_rep.max_sequence_length - 2,
) )
stub_vector_mapping(text, @expected_embedding) stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(topic) vector_rep.generate_representation_from(topic)
expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id) expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
end end
it "creates a vector from a post and stores it in the database" do it "creates a vector from a post and stores it in the database" do
@ -38,11 +39,45 @@ RSpec.shared_examples "generates and store embedding using with vector represent
vector_rep.tokenizer, vector_rep.tokenizer,
vector_rep.max_sequence_length - 2, vector_rep.max_sequence_length - 2,
) )
stub_vector_mapping(text, @expected_embedding) stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(post) vector_rep.generate_representation_from(post)
expect(vector_rep.post_id_from_representation(@expected_embedding)).to eq(post.id) expect(vector_rep.post_id_from_representation(expected_embedding_1)).to eq(post.id)
end
end
describe "#gen_bulk_reprensentations" do
fab!(:topic) { Fabricate(:topic) }
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
fab!(:topic_2) { Fabricate(:topic) }
fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) }
fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) }
it "creates a vector for each object in the relation" do
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
text2 =
truncation.prepare_text_from(
topic_2,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
stub_vector_mapping(text2, expected_embedding_2)
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
end end
end end
@ -58,7 +93,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent
vector_rep.tokenizer, vector_rep.tokenizer,
vector_rep.max_sequence_length - 2, vector_rep.max_sequence_length - 2,
) )
stub_vector_mapping(text, @expected_embedding) stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(topic) vector_rep.generate_representation_from(topic)
expect( expect(