FEATURE: Seamless embedding model upgrades (#1486)

This commit is contained in:
Rafael dos Santos Silva 2025-07-04 16:44:03 -03:00 committed by GitHub
parent ab5edae121
commit 6247906c13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 147 additions and 104 deletions

View File

@ -18,11 +18,24 @@ module Jobs
) )
end end
rebaked = 0 production_vector = DiscourseAi::Embeddings::Vector.instance
vector = DiscourseAi::Embeddings::Vector.instance if SiteSetting.ai_embeddings_backfill_model.present? &&
vector_def = vector.vdef SiteSetting.ai_embeddings_backfill_model != SiteSetting.ai_embeddings_selected_model
backfill_vector =
DiscourseAi::Embeddings::Vector.new(
EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_backfill_model),
)
end
topic_work_list = []
topic_work_list << production_vector
topic_work_list << backfill_vector if backfill_vector
topic_work_list.each do |vector|
rebaked = 0
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
vector_def = vector.vdef
topics = topics =
Topic Topic
@ -35,7 +48,7 @@ module Jobs
rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked)) rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked))
return if rebaked >= limit next if rebaked >= limit
# Then, we'll try to backfill embeddings for topics that have outdated # Then, we'll try to backfill embeddings for topics that have outdated
# embeddings, be it model or strategy version # embeddings, be it model or strategy version
@ -47,7 +60,7 @@ module Jobs
rebaked += populate_topic_embeddings(vector, relation, force: true) rebaked += populate_topic_embeddings(vector, relation, force: true)
return if rebaked >= limit next if rebaked >= limit
# Finally, we'll try to backfill embeddings for topics that have outdated # Finally, we'll try to backfill embeddings for topics that have outdated
# embeddings due to edits or new replies. Here we only do 10% of the limit # embeddings due to edits or new replies. Here we only do 10% of the limit
@ -59,9 +72,7 @@ module Jobs
populate_topic_embeddings(vector, relation, force: true) populate_topic_embeddings(vector, relation, force: true)
return if rebaked >= limit next unless SiteSetting.ai_embeddings_per_post_enabled
return unless SiteSetting.ai_embeddings_per_post_enabled
# Now for posts # Now for posts
table_name = DiscourseAi::Embeddings::Schema::POSTS_TABLE table_name = DiscourseAi::Embeddings::Schema::POSTS_TABLE
@ -85,7 +96,7 @@ module Jobs
rebaked += batch.length rebaked += batch.length
end end
return if rebaked >= limit next if rebaked >= limit
# Then, we'll try to backfill embeddings for posts that have outdated # Then, we'll try to backfill embeddings for posts that have outdated
# embeddings, be it model or strategy version # embeddings, be it model or strategy version
@ -102,7 +113,7 @@ module Jobs
rebaked += batch.length rebaked += batch.length
end end
return if rebaked >= limit next if rebaked >= limit
# Finally, we'll try to backfill embeddings for posts that have outdated # Finally, we'll try to backfill embeddings for posts that have outdated
# embeddings due to edits. Here we only do 10% of the limit # embeddings due to edits. Here we only do 10% of the limit
@ -115,8 +126,7 @@ module Jobs
vector.gen_bulk_reprensentations(Post.where(id: batch)) vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length rebaked += batch.length
end end
end
rebaked
end end
private private

View File

@ -230,6 +230,12 @@ discourse_ai:
enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator" enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator"
validator: "DiscourseAi::Configuration::EmbeddingDefsValidator" validator: "DiscourseAi::Configuration::EmbeddingDefsValidator"
area: "ai-features/embeddings" area: "ai-features/embeddings"
ai_embeddings_backfill_model:
type: enum
default: ""
allow_any: false
enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator"
hidden: true
ai_embeddings_per_post_enabled: ai_embeddings_per_post_enabled:
default: false default: false
hidden: true hidden: true

View File

@ -20,8 +20,11 @@ module DiscourseAi
MissingEmbeddingError = Class.new(StandardError) MissingEmbeddingError = Class.new(StandardError)
class << self class << self
def for(target_klass) def for(target_klass, vector_def: nil)
vector_def = EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_selected_model) vector_def =
EmbeddingDefinition.find_by(
id: SiteSetting.ai_embeddings_selected_model,
) if vector_def.nil?
raise "Invalid embeddings selected model" if vector_def.nil? raise "Invalid embeddings selected model" if vector_def.nil?
case target_klass&.name case target_klass&.name

View File

@ -3,6 +3,8 @@
module DiscourseAi module DiscourseAi
module Embeddings module Embeddings
class SemanticRelated class SemanticRelated
CACHE_PREFIX = "semantic-suggested-topic-"
def self.clear_cache_for(topic) def self.clear_cache_for(topic)
Discourse.cache.delete("semantic-suggested-topic-#{topic.id}") Discourse.cache.delete("semantic-suggested-topic-#{topic.id}")
Discourse.redis.del("build-semantic-suggested-topic-#{topic.id}") Discourse.redis.del("build-semantic-suggested-topic-#{topic.id}")
@ -79,14 +81,21 @@ module DiscourseAi
) )
end end
def self.clear_cache!
Discourse
.cache
.keys("#{CACHE_PREFIX}*")
.each { |key| Discourse.cache.delete(key.split(":").last) }
end
private private
def semantic_suggested_key(topic_id) def semantic_suggested_key(topic_id)
"semantic-suggested-topic-#{topic_id}" "#{CACHE_PREFIX}#{topic_id}"
end end
def build_semantic_suggested_key(topic_id) def build_semantic_suggested_key(topic_id)
"build-semantic-suggested-topic-#{topic_id}" "build-#{CACHE_PREFIX}#{topic_id}"
end end
end end
end end

View File

@ -25,7 +25,7 @@ module DiscourseAi
idletime: 30, idletime: 30,
) )
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class) schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector_def: @vdef)
embedding_gen = vdef.inference_client embedding_gen = vdef.inference_client
promised_embeddings = promised_embeddings =
@ -58,7 +58,7 @@ module DiscourseAi
text = vdef.prepare_target_text(target) text = vdef.prepare_target_text(target)
return if text.blank? return if text.blank?
schema = DiscourseAi::Embeddings::Schema.for(target.class) schema = DiscourseAi::Embeddings::Schema.for(target.class, vector_def: @vdef)
new_digest = OpenSSL::Digest::SHA1.hexdigest(text) new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
return if schema.find_by_target(target)&.digest == new_digest return if schema.find_by_target(target)&.digest == new_digest

View File

@ -20,6 +20,8 @@ RSpec.describe Jobs::EmbeddingsBackfill do
end end
fab!(:vector_def) { Fabricate(:embedding_definition) } fab!(:vector_def) { Fabricate(:embedding_definition) }
fab!(:vector_def2) { Fabricate(:embedding_definition) }
fab!(:embedding_array) { Array.new(1024) { 1 } }
before do before do
SiteSetting.ai_embeddings_selected_model = vector_def.id SiteSetting.ai_embeddings_selected_model = vector_def.id
@ -27,16 +29,14 @@ RSpec.describe Jobs::EmbeddingsBackfill do
SiteSetting.ai_embeddings_backfill_batch_size = 1 SiteSetting.ai_embeddings_backfill_batch_size = 1
SiteSetting.ai_embeddings_per_post_enabled = true SiteSetting.ai_embeddings_per_post_enabled = true
Jobs.run_immediately! Jobs.run_immediately!
end
it "backfills topics based on bumped_at date" do
embedding = Array.new(1024) { 1 }
WebMock.stub_request(:post, "https://test.com/embeddings").to_return( WebMock.stub_request(:post, "https://test.com/embeddings").to_return(
status: 200, status: 200,
body: JSON.dump(embedding), body: JSON.dump(embedding_array),
) )
end
it "backfills topics based on bumped_at date" do
Jobs::EmbeddingsBackfill.new.execute({}) Jobs::EmbeddingsBackfill.new.execute({})
topic_ids = topic_ids =
@ -68,4 +68,19 @@ RSpec.describe Jobs::EmbeddingsBackfill do
expect(index_date).to be_within_one_second_of(Time.zone.now) expect(index_date).to be_within_one_second_of(Time.zone.now)
end end
it "backfills embeddings for the ai_embeddings_backfill_model" do
SiteSetting.ai_embeddings_backfill_model = vector_def2.id
SiteSetting.ai_embeddings_backfill_batch_size = 100
Jobs::EmbeddingsBackfill.new.execute({})
topic_ids =
DB.query_single(
"SELECT topic_id from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE model_id = ?",
vector_def2.id,
)
expect(topic_ids).to contain_exactly(first_topic.id, second_topic.id, third_topic.id)
end
end end