From dcafc8032f62218317681c631b989708269303d2 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 1 Feb 2024 00:38:47 +1100 Subject: [PATCH] FIX: improve embedding generation (#452) 1. on failure we were queuing a job to generate embeddings, it had the wrong params. This is both fixed and covered in a test. 2. backfill embedding in the order of bumped_at, so newest content is embedded first, cover with a test 3. add a safeguard for hidden site setting that only allows batches of 50k in an embedding job run Previously old embeddings were updated in a random order, this changes it so we update in a consistent order --- app/jobs/scheduled/embeddings_backfill.rb | 55 +++++++++++------- lib/embeddings/semantic_related.rb | 2 +- .../scheduled/embeddings_backfill_spec.rb | 55 ++++++++++++++++++ .../embeddings/semantic_related_spec.rb | 58 ++++++++++++++++--- 4 files changed, 140 insertions(+), 30 deletions(-) create mode 100644 spec/jobs/scheduled/embeddings_backfill_spec.rb diff --git a/app/jobs/scheduled/embeddings_backfill.rb b/app/jobs/scheduled/embeddings_backfill.rb index bba792e5..a5f10f3e 100644 --- a/app/jobs/scheduled/embeddings_backfill.rb +++ b/app/jobs/scheduled/embeddings_backfill.rb @@ -10,6 +10,14 @@ module Jobs return unless SiteSetting.ai_embeddings_enabled limit = SiteSetting.ai_embeddings_backfill_batch_size + + if limit > 50_000 + limit = 50_000 + Rails.logger.warn( + "Limiting backfill batch size to 50,000 to avoid OOM errors, reduce ai_embeddings_backfill_batch_size to avoid this warning", + ) + end + rebaked = 0 strategy = DiscourseAi::Embeddings::Strategies::Truncation.new @@ -22,15 +30,10 @@ module Jobs .joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id") .where(archetype: Archetype.default) .where(deleted_at: nil) + .order("topics.bumped_at DESC") .limit(limit - rebaked) - # First, we'll try to backfill embeddings for topics that have none - topics - .where("#{table_name}.topic_id IS NULL") - .find_each do |t| - vector_rep.generate_representation_from(t) - rebaked += 1 - end + rebaked += populate_topic_embeddings(vector_rep, topics) vector_rep.consider_indexing @@ -38,30 +41,22 @@ module Jobs # Then, we'll try to backfill embeddings for topics that have outdated # embeddings, be it model or strategy version - topics - .where(<<~SQL) + relation = topics.where(<<~SQL) #{table_name}.model_version < #{vector_rep.version} OR #{table_name}.strategy_version < #{strategy.version} SQL - .find_each do |t| - vector_rep.generate_representation_from(t) - rebaked += 1 - end + + rebaked += populate_topic_embeddings(vector_rep, relation) return if rebaked >= limit # Finally, we'll try to backfill embeddings for topics that have outdated # embeddings due to edits or new replies. Here we only do 10% of the limit - topics - .where("#{table_name}.updated_at < ?", 7.days.ago) - .order("random()") - .limit((limit - rebaked) / 10) - .pluck(:id) - .each do |id| - vector_rep.generate_representation_from(Topic.find_by(id: id)) - rebaked += 1 - end + relation = + topics.where("#{table_name}.updated_at < ?", 7.days.ago).limit((limit - rebaked) / 10) + + populate_topic_embeddings(vector_rep, relation) return if rebaked >= limit @@ -117,5 +112,21 @@ module Jobs rebaked end + + private + + def populate_topic_embeddings(vector_rep, topics) + done = 0 + ids = topics.where("#{vector_rep.topic_table_name}.topic_id IS NULL").pluck("topics.id") + + ids.each do |id| + topic = Topic.find_by(id: id) + if topic + vector_rep.generate_representation_from(topic) + done += 1 + end + end + done + end end end diff --git a/lib/embeddings/semantic_related.rb b/lib/embeddings/semantic_related.rb index e9f5ed75..e8b8c517 100644 --- a/lib/embeddings/semantic_related.rb +++ b/lib/embeddings/semantic_related.rb @@ -39,7 +39,7 @@ module DiscourseAi ex: 15.minutes.to_i, nx: true, ) - Jobs.enqueue(:generate_embeddings, topic_id: topic.id) + Jobs.enqueue(:generate_embeddings, target_type: "Topic", target_id: topic.id) end [] end diff --git a/spec/jobs/scheduled/embeddings_backfill_spec.rb b/spec/jobs/scheduled/embeddings_backfill_spec.rb new file mode 100644 index 00000000..74cec6dc --- /dev/null +++ b/spec/jobs/scheduled/embeddings_backfill_spec.rb @@ -0,0 +1,55 @@ +# frozen_string_literal: true + +RSpec.describe Jobs::EmbeddingsBackfill do + fab!(:second_topic) do + topic = Fabricate(:topic, created_at: 1.year.ago, bumped_at: 2.day.ago) + Fabricate(:post, topic: topic) + topic + end + + fab!(:first_topic) do + topic = Fabricate(:topic, created_at: 1.year.ago, bumped_at: 1.day.ago) + Fabricate(:post, topic: topic) + topic + end + + fab!(:third_topic) do + topic = Fabricate(:topic, created_at: 1.year.ago, bumped_at: 3.day.ago) + Fabricate(:post, topic: topic) + topic + end + + let(:vector_rep) do + strategy = DiscourseAi::Embeddings::Strategies::Truncation.new + DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + end + + it "backfills topics based on bumped_at date" do + SiteSetting.ai_embeddings_enabled = true + SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" + SiteSetting.ai_embeddings_backfill_batch_size = 1 + + Jobs.run_immediately! + + embedding = Array.new(1024) { 1 } + + WebMock.stub_request( + :post, + "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", + ).to_return(status: 200, body: JSON.dump(embedding)) + + Jobs::EmbeddingsBackfill.new.execute({}) + + topic_ids = DB.query_single("SELECT topic_id from #{vector_rep.topic_table_name}") + + expect(topic_ids).to eq([first_topic.id]) + + # pulse again for the rest (and cover code) + SiteSetting.ai_embeddings_backfill_batch_size = 100 + Jobs::EmbeddingsBackfill.new.execute({}) + + topic_ids = DB.query_single("SELECT topic_id from #{vector_rep.topic_table_name}") + + expect(topic_ids).to contain_exactly(first_topic.id, second_topic.id, third_topic.id) + end +end diff --git a/spec/lib/modules/embeddings/semantic_related_spec.rb b/spec/lib/modules/embeddings/semantic_related_spec.rb index 7a82ab02..37965aa1 100644 --- a/spec/lib/modules/embeddings/semantic_related_spec.rb +++ b/spec/lib/modules/embeddings/semantic_related_spec.rb @@ -17,20 +17,64 @@ describe DiscourseAi::Embeddings::SemanticRelated do describe "#related_topic_ids_for" do context "when embeddings do not exist" do - let(:topic) { Fabricate(:topic).tap { described_class.clear_cache_for(target) } } + let(:topic) do + post = Fabricate(:post) + topic = post.topic + described_class.clear_cache_for(target) + topic + end + + let(:vector_rep) do + strategy = DiscourseAi::Embeddings::Strategies::Truncation.new + + DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) + end + + it "properly generates embeddings if missing" do + SiteSetting.ai_embeddings_enabled = true + SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" + Jobs.run_immediately! + + embedding = Array.new(1024) { 1 } + + WebMock.stub_request( + :post, + "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", + ).to_return(status: 200, body: JSON.dump(embedding)) + + # miss first + ids = semantic_related.related_topic_ids_for(topic) + + # clear cache so we lookup + described_class.clear_cache_for(topic) + + # hit cause we queued generation + ids = semantic_related.related_topic_ids_for(topic) + + # at this point though the only embedding is ourselves + expect(ids).to eq([topic.id]) + end it "queues job only once per 15 minutes" do results = nil - expect_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do - results = semantic_related.related_topic_ids_for(topic) - end + expect_enqueued_with( + job: :generate_embeddings, + args: { + target_id: topic.id, + target_type: "Topic", + }, + ) { results = semantic_related.related_topic_ids_for(topic) } expect(results).to eq([]) - expect_not_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do - results = semantic_related.related_topic_ids_for(topic) - end + expect_not_enqueued_with( + job: :generate_embeddings, + args: { + target_id: topic.id, + target_type: "Topic", + }, + ) { results = semantic_related.related_topic_ids_for(topic) } expect(results).to eq([]) end