diff --git a/app/jobs/scheduled/embeddings_backfill.rb b/app/jobs/scheduled/embeddings_backfill.rb index 3300b479..db0828f2 100644 --- a/app/jobs/scheduled/embeddings_backfill.rb +++ b/app/jobs/scheduled/embeddings_backfill.rb @@ -75,9 +75,9 @@ module Jobs # First, we'll try to backfill embeddings for posts that have none posts .where("#{table_name}.post_id IS NULL") - .find_each do |t| - vector_rep.generate_representation_from(t) - rebaked += 1 + .find_in_batches do |batch| + vector_rep.gen_bulk_reprensentations(batch) + rebaked += batch.size end return if rebaked >= limit @@ -90,24 +90,28 @@ module Jobs OR #{table_name}.strategy_version < #{strategy.version} SQL - .find_each do |t| - vector_rep.generate_representation_from(t) - rebaked += 1 + .find_in_batches do |batch| + vector_rep.gen_bulk_reprensentations(batch) + rebaked += batch.size end return if rebaked >= limit # Finally, we'll try to backfill embeddings for posts that have outdated # embeddings due to edits. Here we only do 10% of the limit - posts - .where("#{table_name}.updated_at < ?", 7.days.ago) - .order("random()") - .limit((limit - rebaked) / 10) - .pluck(:id) - .each do |id| - vector_rep.generate_representation_from(Post.find_by(id: id)) - rebaked += 1 - end + posts_batch_size = 1000 + + outdated_post_ids = + posts + .where("#{table_name}.updated_at < ?", 7.days.ago) + .order("random()") + .limit((limit - rebaked) / 10) + .pluck(:id) + + outdated_post_ids.each_slice(posts_batch_size) do |batch| + vector_rep.gen_bulk_reprensentations(Post.where(id: batch).order("topics.bumped_at DESC")) + rebaked += batch.length + end rebaked end @@ -120,14 +124,13 @@ module Jobs topics = topics.where("#{vector_rep.topic_table_name}.topic_id IS NULL") if !force ids = topics.pluck("topics.id") + batch_size = 1000 - ids.each do |id| - topic = Topic.find_by(id: id) - if topic - vector_rep.generate_representation_from(topic) - done += 1 - end + ids.each_slice(batch_size) do |batch| + vector_rep.gen_bulk_reprensentations(Topic.where(id: batch).order("topics.bumped_at DESC")) + done += batch.length end + done end end diff --git a/lib/embeddings/vector_representations/base.rb b/lib/embeddings/vector_representations/base.rb index be6b46b5..e1f3ff49 100644 --- a/lib/embeddings/vector_representations/base.rb +++ b/lib/embeddings/vector_representations/base.rb @@ -50,8 +50,38 @@ module DiscourseAi raise NotImplementedError end + def gen_bulk_reprensentations(relation) + http_pool_size = 100 + pool = + Concurrent::CachedThreadPool.new( + min_threads: 0, + max_threads: http_pool_size, + idletime: 30, + ) + + embedding_gen = inference_client + promised_embeddings = + relation.map do |record| + materials = { target: record, text: prepare_text(record) } + + Concurrent::Promises + .fulfilled_future(materials, pool) + .then_on(pool) do |w_prepared_text| + w_prepared_text.merge( + embedding: embedding_gen.perform!(w_prepared_text[:text]), + digest: OpenSSL::Digest::SHA1.hexdigest(w_prepared_text[:text]), + ) + end + end + + Concurrent::Promises + .zip(*promised_embeddings) + .value! + .each { |e| save_to_db(e[:target], e[:embedding], e[:digest]) } + end + def generate_representation_from(target, persist: true) - text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2) + text = prepare_text(target) return if text.blank? target_column = @@ -429,6 +459,10 @@ module DiscourseAi def inference_client raise NotImplementedError end + + def prepare_text(record) + @strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2) + end end end end diff --git a/lib/embeddings/vector_representations/multilingual_e5_large.rb b/lib/embeddings/vector_representations/multilingual_e5_large.rb index c7ef3c0f..605ec8b5 100644 --- a/lib/embeddings/vector_representations/multilingual_e5_large.rb +++ b/lib/embeddings/vector_representations/multilingual_e5_large.rb @@ -34,7 +34,7 @@ module DiscourseAi needs_truncation = client.class.name.include?("HuggingFaceTextEmbeddings") if needs_truncation text = tokenizer.truncate(text, max_sequence_length - 2) - else + elsif !text.starts_with?("query:") text = "query: #{text}" end @@ -79,6 +79,14 @@ module DiscourseAi raise "No inference endpoint configured" end end + + def prepare_text(record) + if inference_client.class.name.include?("DiscourseClassifier") + return "query: #{super(record)}" + end + + super(record) + end end end end diff --git a/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb index 9689a3c6..fce9f612 100644 --- a/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb +++ b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb @@ -1,14 +1,15 @@ # frozen_string_literal: true RSpec.shared_examples "generates and store embedding using with vector representation" do - before { @expected_embedding = [0.0038493] * vector_rep.dimensions } + let(:expected_embedding_1) { [0.0038493] * vector_rep.dimensions } + let(:expected_embedding_2) { [0.0037684] * vector_rep.dimensions } describe "#vector_from" do it "creates a vector from a given string" do text = "This is a piece of text" - stub_vector_mapping(text, @expected_embedding) + stub_vector_mapping(text, expected_embedding_1) - expect(vector_rep.vector_from(text)).to eq(@expected_embedding) + expect(vector_rep.vector_from(text)).to eq(expected_embedding_1) end end @@ -24,11 +25,11 @@ RSpec.shared_examples "generates and store embedding using with vector represent vector_rep.tokenizer, vector_rep.max_sequence_length - 2, ) - stub_vector_mapping(text, @expected_embedding) + stub_vector_mapping(text, expected_embedding_1) vector_rep.generate_representation_from(topic) - expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id) + expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id) end it "creates a vector from a post and stores it in the database" do @@ -38,11 +39,45 @@ RSpec.shared_examples "generates and store embedding using with vector represent vector_rep.tokenizer, vector_rep.max_sequence_length - 2, ) - stub_vector_mapping(text, @expected_embedding) + stub_vector_mapping(text, expected_embedding_1) vector_rep.generate_representation_from(post) - expect(vector_rep.post_id_from_representation(@expected_embedding)).to eq(post.id) + expect(vector_rep.post_id_from_representation(expected_embedding_1)).to eq(post.id) + end + end + + describe "#gen_bulk_reprensentations" do + fab!(:topic) { Fabricate(:topic) } + fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } + fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) } + + fab!(:topic_2) { Fabricate(:topic) } + fab!(:post_2_1) { Fabricate(:post, post_number: 1, topic: topic_2) } + fab!(:post_2_2) { Fabricate(:post, post_number: 2, topic: topic_2) } + + it "creates a vector for each object in the relation" do + text = + truncation.prepare_text_from( + topic, + vector_rep.tokenizer, + vector_rep.max_sequence_length - 2, + ) + + text2 = + truncation.prepare_text_from( + topic_2, + vector_rep.tokenizer, + vector_rep.max_sequence_length - 2, + ) + + stub_vector_mapping(text, expected_embedding_1) + stub_vector_mapping(text2, expected_embedding_2) + + vector_rep.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id])) + + expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id) + expect(vector_rep.topic_id_from_representation(expected_embedding_1)).to eq(topic.id) end end @@ -58,7 +93,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent vector_rep.tokenizer, vector_rep.max_sequence_length - 2, ) - stub_vector_mapping(text, @expected_embedding) + stub_vector_mapping(text, expected_embedding_1) vector_rep.generate_representation_from(topic) expect(