DEV: Backfill embeddings concurrently. (#941)

We are adding a new method for generating and storing embeddings in bulk, which relies on `Concurrent::Promises::Future`. Generating an embedding consists of three steps:

Prepare text
HTTP call to retrieve the vector
Save to DB.
Each one is independently executed on whatever thread the pool gives us.

We are bringing a custom thread pool instead of the global executor since we want control over how many threads we spawn to limit concurrency. We also avoid firing thousands of HTTP requests when working with large batches.
This commit is contained in:
Roman Rizzi 2024-11-26 14:12:32 -03:00 committed by GitHub
parent 23193ee6f2
commit ddf2bf7034
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 111 additions and 31 deletions

View File

@ -75,9 +75,9 @@ module Jobs
# First, we'll try to backfill embeddings for posts that have none
posts
.where("#{table_name}.post_id IS NULL")
.find_each do |t|
vector_rep.generate_representation_from(t)
rebaked += 1
.find_in_batches do |batch|
vector_rep.gen_bulk_reprensentations(batch)
rebaked += batch.size
end
return if rebaked >= limit
@ -90,24 +90,28 @@ module Jobs
OR
#{table_name}.strategy_version < #{strategy.version}
SQL
.find_each do |t|
vector_rep.generate_representation_from(t)
rebaked += 1
.find_in_batches do |batch|
vector_rep.gen_bulk_reprensentations(batch)
rebaked += batch.size
end
return if rebaked >= limit
# Finally, we'll try to backfill embeddings for posts that have outdated
# embeddings due to edits. Here we only do 10% of the limit
posts
.where("#{table_name}.updated_at < ?", 7.days.ago)
.order("random()")
.limit((limit - rebaked) / 10)
.pluck(:id)
.each do |id|
vector_rep.generate_representation_from(Post.find_by(id: id))
rebaked += 1
end
posts_batch_size = 1000
outdated_post_ids =
posts
.where("#{table_name}.updated_at < ?", 7.days.ago)
.order("random()")
.limit((limit - rebaked) / 10)
.pluck(:id)
outdated_post_ids.each_slice(posts_batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Post.where(id: batch).order("topics.bumped_at DESC"))
rebaked += batch.length
end
rebaked
end
@ -120,14 +124,13 @@ module Jobs
topics = topics.where("#{vector_rep.topic_table_name}.topic_id IS NULL") if !force
ids = topics.pluck("topics.id")
batch_size = 1000
ids.each do |id|
topic = Topic.find_by(id: id)
if topic
vector_rep.generate_representation_from(topic)
done += 1
end
ids.each_slice(batch_size) do |batch|
vector_rep.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC"))
done += batch.length
end
done
end
end

View File

@ -50,8 +50,38 @@ module DiscourseAi
raise NotImplementedError
end
def gen_bulk_reprensentations(relation)
http_pool_size = 100
pool =
Concurrent::CachedThreadPool.new(
min_threads: 0,
max_threads: http_pool_size,
idletime: 30,
)
embedding_gen = inference_client
promised_embeddings =
relation.map do |record|
materials = { target: record, text: prepare_text(record) }
Concurrent::Promises
.fulfilled_future(materials, pool)
.then_on(pool) do |w_prepared_text|
w_prepared_text.merge(
embedding: embedding_gen.perform!(w_prepared_text[:text]),
digest: OpenSSL::Digest::SHA1.hexdigest(w_prepared_text[:text]),
)
end
end
Concurrent::Promises
.zip(*promised_embeddings)
.value!
.each { |e| save_to_db(e[:target], e[:embedding], e[:digest]) }
end
def generate_representation_from(target, persist: true)
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
text = prepare_text(target)
return if text.blank?
target_column =
@ -429,6 +459,10 @@ module DiscourseAi
def inference_client
raise NotImplementedError
end
def prepare_text(record)
@strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2)
end
end
end
end

View File

@ -34,7 +34,7 @@ module DiscourseAi
needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings")
if needs_truncation
text = tokenizer.truncate(text, max_sequence_length - 2)
else
elsif !text.starts_with?("query:")
text = "query: #{text}"
end
@ -79,6 +79,14 @@ module DiscourseAi
raise "No inference endpoint configured"
end
end
def prepare_text(record)
if inference_client.class.name.include?("DiscourseClassifier")
return "query: #{super(record)}"
end
super(record)
end
end
end
end

View File

@ -1,14 +1,15 @@
# frozen_string_literal: true
RSpec.shared_examples "generates and store embedding using with vector representation" do
before { @expected_embedding = [0.0038493] * vector_rep.dimensions }
let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions }
let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions }
describe "#vector_from" do
it "creates a vector from a given string" do
text = "This is a piece of text"
stub_vector_mapping(text, @expected_embedding)
stub_vector_mapping(text, expected_embedding_1)
expect(vector_rep.vector_from(text)).to eq(@expected_embedding)
expect(vector_rep.vector_from(text)).to eq(expected_embedding_1)
end
end
@ -24,11 +25,11 @@ RSpec.shared_examples "generates and store embedding using with vector represent
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, @expected_embedding)
stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(topic)
expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id)
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
end
it "creates a vector from a post and stores it in the database" do
@ -38,11 +39,45 @@ RSpec.shared_examples "generates and store embedding using with vector represent
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, @expected_embedding)
stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(post)
expect(vector_rep.post_id_from_representation(@expected_embedding)).to eq(post.id)
expect(vector_rep.post_id_from_representation(expected_embedding_1)).to eq(post.id)
end
end
describe "#gen_bulk_reprensentations" do
fab!(:topic) { Fabricate(:topic) }
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
fab!(:topic_2) { Fabricate(:topic) }
fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) }
fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) }
it "creates a vector for each object in the relation" do
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
text2 =
truncation.prepare_text_from(
topic_2,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
stub_vector_mapping(text2, expected_embedding_2)
vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id)
end
end
@ -58,7 +93,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, @expected_embedding)
stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(topic)
expect(