FIX: improve embedding generation (#452)
1. on failure we were queuing a job to generate embeddings, it had the wrong params. This is both fixed and covered in a test. 2. backfill embedding in the order of bumped_at, so newest content is embedded first, cover with a test 3. add a safeguard for hidden site setting that only allows batches of 50k in an embedding job run Previously old embeddings were updated in a random order, this changes it so we update in a consistent order
This commit is contained in:
parent
abcf5ea94a
commit
dcafc8032f
|
@ -10,6 +10,14 @@ module Jobs
|
||||||
return unless SiteSetting.ai_embeddings_enabled
|
return unless SiteSetting.ai_embeddings_enabled
|
||||||
|
|
||||||
limit = SiteSetting.ai_embeddings_backfill_batch_size
|
limit = SiteSetting.ai_embeddings_backfill_batch_size
|
||||||
|
|
||||||
|
if limit > 50_000
|
||||||
|
limit = 50_000
|
||||||
|
Rails.logger.warn(
|
||||||
|
"Limiting backfill batch size to 50,000 to avoid OOM errors, reduce ai_embeddings_backfill_batch_size to avoid this warning",
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
rebaked = 0
|
rebaked = 0
|
||||||
|
|
||||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
|
@ -22,15 +30,10 @@ module Jobs
|
||||||
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
|
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
|
||||||
.where(archetype: Archetype.default)
|
.where(archetype: Archetype.default)
|
||||||
.where(deleted_at: nil)
|
.where(deleted_at: nil)
|
||||||
|
.order("topics.bumped_at DESC")
|
||||||
.limit(limit - rebaked)
|
.limit(limit - rebaked)
|
||||||
|
|
||||||
# First, we'll try to backfill embeddings for topics that have none
|
rebaked += populate_topic_embeddings(vector_rep, topics)
|
||||||
topics
|
|
||||||
.where("#{table_name}.topic_id IS NULL")
|
|
||||||
.find_each do |t|
|
|
||||||
vector_rep.generate_representation_from(t)
|
|
||||||
rebaked += 1
|
|
||||||
end
|
|
||||||
|
|
||||||
vector_rep.consider_indexing
|
vector_rep.consider_indexing
|
||||||
|
|
||||||
|
@ -38,30 +41,22 @@ module Jobs
|
||||||
|
|
||||||
# Then, we'll try to backfill embeddings for topics that have outdated
|
# Then, we'll try to backfill embeddings for topics that have outdated
|
||||||
# embeddings, be it model or strategy version
|
# embeddings, be it model or strategy version
|
||||||
topics
|
relation = topics.where(<<~SQL)
|
||||||
.where(<<~SQL)
|
|
||||||
#{table_name}.model_version < #{vector_rep.version}
|
#{table_name}.model_version < #{vector_rep.version}
|
||||||
OR
|
OR
|
||||||
#{table_name}.strategy_version < #{strategy.version}
|
#{table_name}.strategy_version < #{strategy.version}
|
||||||
SQL
|
SQL
|
||||||
.find_each do |t|
|
|
||||||
vector_rep.generate_representation_from(t)
|
rebaked += populate_topic_embeddings(vector_rep, relation)
|
||||||
rebaked += 1
|
|
||||||
end
|
|
||||||
|
|
||||||
return if rebaked >= limit
|
return if rebaked >= limit
|
||||||
|
|
||||||
# Finally, we'll try to backfill embeddings for topics that have outdated
|
# Finally, we'll try to backfill embeddings for topics that have outdated
|
||||||
# embeddings due to edits or new replies. Here we only do 10% of the limit
|
# embeddings due to edits or new replies. Here we only do 10% of the limit
|
||||||
topics
|
relation =
|
||||||
.where("#{table_name}.updated_at < ?", 7.days.ago)
|
topics.where("#{table_name}.updated_at < ?", 7.days.ago).limit((limit - rebaked) / 10)
|
||||||
.order("random()")
|
|
||||||
.limit((limit - rebaked) / 10)
|
populate_topic_embeddings(vector_rep, relation)
|
||||||
.pluck(:id)
|
|
||||||
.each do |id|
|
|
||||||
vector_rep.generate_representation_from(Topic.find_by(id: id))
|
|
||||||
rebaked += 1
|
|
||||||
end
|
|
||||||
|
|
||||||
return if rebaked >= limit
|
return if rebaked >= limit
|
||||||
|
|
||||||
|
@ -117,5 +112,21 @@ module Jobs
|
||||||
|
|
||||||
rebaked
|
rebaked
|
||||||
end
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def populate_topic_embeddings(vector_rep, topics)
|
||||||
|
done = 0
|
||||||
|
ids = topics.where("#{vector_rep.topic_table_name}.topic_id IS NULL").pluck("topics.id")
|
||||||
|
|
||||||
|
ids.each do |id|
|
||||||
|
topic = Topic.find_by(id: id)
|
||||||
|
if topic
|
||||||
|
vector_rep.generate_representation_from(topic)
|
||||||
|
done += 1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
done
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -39,7 +39,7 @@ module DiscourseAi
|
||||||
ex: 15.minutes.to_i,
|
ex: 15.minutes.to_i,
|
||||||
nx: true,
|
nx: true,
|
||||||
)
|
)
|
||||||
Jobs.enqueue(:generate_embeddings, topic_id: topic.id)
|
Jobs.enqueue(:generate_embeddings, target_type: "Topic", target_id: topic.id)
|
||||||
end
|
end
|
||||||
[]
|
[]
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe Jobs::EmbeddingsBackfill do
|
||||||
|
fab!(:second_topic) do
|
||||||
|
topic = Fabricate(:topic, created_at: 1.year.ago, bumped_at: 2.day.ago)
|
||||||
|
Fabricate(:post, topic: topic)
|
||||||
|
topic
|
||||||
|
end
|
||||||
|
|
||||||
|
fab!(:first_topic) do
|
||||||
|
topic = Fabricate(:topic, created_at: 1.year.ago, bumped_at: 1.day.ago)
|
||||||
|
Fabricate(:post, topic: topic)
|
||||||
|
topic
|
||||||
|
end
|
||||||
|
|
||||||
|
fab!(:third_topic) do
|
||||||
|
topic = Fabricate(:topic, created_at: 1.year.ago, bumped_at: 3.day.ago)
|
||||||
|
Fabricate(:post, topic: topic)
|
||||||
|
topic
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:vector_rep) do
|
||||||
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
|
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "backfills topics based on bumped_at date" do
|
||||||
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||||
|
SiteSetting.ai_embeddings_backfill_batch_size = 1
|
||||||
|
|
||||||
|
Jobs.run_immediately!
|
||||||
|
|
||||||
|
embedding = Array.new(1024) { 1 }
|
||||||
|
|
||||||
|
WebMock.stub_request(
|
||||||
|
:post,
|
||||||
|
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||||
|
).to_return(status: 200, body: JSON.dump(embedding))
|
||||||
|
|
||||||
|
Jobs::EmbeddingsBackfill.new.execute({})
|
||||||
|
|
||||||
|
topic_ids = DB.query_single("SELECT topic_id from #{vector_rep.topic_table_name}")
|
||||||
|
|
||||||
|
expect(topic_ids).to eq([first_topic.id])
|
||||||
|
|
||||||
|
# pulse again for the rest (and cover code)
|
||||||
|
SiteSetting.ai_embeddings_backfill_batch_size = 100
|
||||||
|
Jobs::EmbeddingsBackfill.new.execute({})
|
||||||
|
|
||||||
|
topic_ids = DB.query_single("SELECT topic_id from #{vector_rep.topic_table_name}")
|
||||||
|
|
||||||
|
expect(topic_ids).to contain_exactly(first_topic.id, second_topic.id, third_topic.id)
|
||||||
|
end
|
||||||
|
end
|
|
@ -17,20 +17,64 @@ describe DiscourseAi::Embeddings::SemanticRelated do
|
||||||
|
|
||||||
describe "#related_topic_ids_for" do
|
describe "#related_topic_ids_for" do
|
||||||
context "when embeddings do not exist" do
|
context "when embeddings do not exist" do
|
||||||
let(:topic) { Fabricate(:topic).tap { described_class.clear_cache_for(target) } }
|
let(:topic) do
|
||||||
|
post = Fabricate(:post)
|
||||||
|
topic = post.topic
|
||||||
|
described_class.clear_cache_for(target)
|
||||||
|
topic
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:vector_rep) do
|
||||||
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
|
|
||||||
|
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "properly generates embeddings if missing" do
|
||||||
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||||
|
Jobs.run_immediately!
|
||||||
|
|
||||||
|
embedding = Array.new(1024) { 1 }
|
||||||
|
|
||||||
|
WebMock.stub_request(
|
||||||
|
:post,
|
||||||
|
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||||
|
).to_return(status: 200, body: JSON.dump(embedding))
|
||||||
|
|
||||||
|
# miss first
|
||||||
|
ids = semantic_related.related_topic_ids_for(topic)
|
||||||
|
|
||||||
|
# clear cache so we lookup
|
||||||
|
described_class.clear_cache_for(topic)
|
||||||
|
|
||||||
|
# hit cause we queued generation
|
||||||
|
ids = semantic_related.related_topic_ids_for(topic)
|
||||||
|
|
||||||
|
# at this point though the only embedding is ourselves
|
||||||
|
expect(ids).to eq([topic.id])
|
||||||
|
end
|
||||||
|
|
||||||
it "queues job only once per 15 minutes" do
|
it "queues job only once per 15 minutes" do
|
||||||
results = nil
|
results = nil
|
||||||
|
|
||||||
expect_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do
|
expect_enqueued_with(
|
||||||
results = semantic_related.related_topic_ids_for(topic)
|
job: :generate_embeddings,
|
||||||
end
|
args: {
|
||||||
|
target_id: topic.id,
|
||||||
|
target_type: "Topic",
|
||||||
|
},
|
||||||
|
) { results = semantic_related.related_topic_ids_for(topic) }
|
||||||
|
|
||||||
expect(results).to eq([])
|
expect(results).to eq([])
|
||||||
|
|
||||||
expect_not_enqueued_with(job: :generate_embeddings, args: { topic_id: topic.id }) do
|
expect_not_enqueued_with(
|
||||||
results = semantic_related.related_topic_ids_for(topic)
|
job: :generate_embeddings,
|
||||||
end
|
args: {
|
||||||
|
target_id: topic.id,
|
||||||
|
target_type: "Topic",
|
||||||
|
},
|
||||||
|
) { results = semantic_related.related_topic_ids_for(topic) }
|
||||||
|
|
||||||
expect(results).to eq([])
|
expect(results).to eq([])
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue