FEATURE: Seamless embedding model upgrades (#1486)

This commit is contained in:
Rafael dos Santos Silva 2025-07-04 16:44:03 -03:00 committed by GitHub
parent ab5edae121
commit 6247906c13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 147 additions and 104 deletions

View File

@ -18,105 +18,115 @@ module Jobs
) )
end end
rebaked = 0 production_vector = DiscourseAi::Embeddings::Vector.instance
vector = DiscourseAi::Embeddings::Vector.instance if SiteSetting.ai_embeddings_backfill_model.present? &&
vector_def = vector.vdef SiteSetting.ai_embeddings_backfill_model != SiteSetting.ai_embeddings_selected_model
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE backfill_vector =
DiscourseAi::Embeddings::Vector.new(
topics = EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_backfill_model),
Topic
.joins(
"LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id AND #{table_name}.model_id = #{vector_def.id}",
) )
.where(archetype: Archetype.default) end
.where(deleted_at: nil)
.order("topics.bumped_at DESC")
rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked)) topic_work_list = []
topic_work_list << production_vector
topic_work_list << backfill_vector if backfill_vector
return if rebaked >= limit topic_work_list.each do |vector|
rebaked = 0
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
vector_def = vector.vdef
# Then, we'll try to backfill embeddings for topics that have outdated topics =
# embeddings, be it model or strategy version Topic
relation = topics.where(<<~SQL).limit(limit - rebaked) .joins(
#{table_name}.model_version < #{vector_def.version} "LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id AND #{table_name}.model_id = #{vector_def.id}",
OR )
#{table_name}.strategy_version < #{vector_def.strategy_version} .where(archetype: Archetype.default)
SQL .where(deleted_at: nil)
.order("topics.bumped_at DESC")
rebaked += populate_topic_embeddings(vector, relation, force: true) rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked))
return if rebaked >= limit next if rebaked >= limit
# Finally, we'll try to backfill embeddings for topics that have outdated # Then, we'll try to backfill embeddings for topics that have outdated
# embeddings due to edits or new replies. Here we only do 10% of the limit # embeddings, be it model or strategy version
relation = relation = topics.where(<<~SQL).limit(limit - rebaked)
topics #{table_name}.model_version < #{vector_def.version}
.where("#{table_name}.updated_at < ?", 6.hours.ago) OR
.where("#{table_name}.updated_at < topics.updated_at") #{table_name}.strategy_version < #{vector_def.strategy_version}
SQL
rebaked += populate_topic_embeddings(vector, relation, force: true)
next if rebaked >= limit
# Finally, we'll try to backfill embeddings for topics that have outdated
# embeddings due to edits or new replies. Here we only do 10% of the limit
relation =
topics
.where("#{table_name}.updated_at < ?", 6.hours.ago)
.where("#{table_name}.updated_at < topics.updated_at")
.limit((limit - rebaked) / 10)
populate_topic_embeddings(vector, relation, force: true)
next unless SiteSetting.ai_embeddings_per_post_enabled
# Now for posts
table_name = DiscourseAi::Embeddings::Schema::POSTS_TABLE
posts_batch_size = 1000
posts =
Post
.joins(
"LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id AND #{table_name}.model_id = #{vector_def.id}",
)
.where(deleted_at: nil)
.where(post_type: Post.types[:regular])
# First, we'll try to backfill embeddings for posts that have none
posts
.where("#{table_name}.post_id IS NULL")
.limit(limit - rebaked)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end
next if rebaked >= limit
# Then, we'll try to backfill embeddings for posts that have outdated
# embeddings, be it model or strategy version
posts
.where(<<~SQL)
#{table_name}.model_version < #{vector_def.version}
OR
#{table_name}.strategy_version < #{vector_def.strategy_version}
SQL
.limit(limit - rebaked)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end
next if rebaked >= limit
# Finally, we'll try to backfill embeddings for posts that have outdated
# embeddings due to edits. Here we only do 10% of the limit
posts
.where("#{table_name}.updated_at < ?", 7.days.ago)
.order("random()")
.limit((limit - rebaked) / 10) .limit((limit - rebaked) / 10)
.pluck(:id)
populate_topic_embeddings(vector, relation, force: true) .each_slice(posts_batch_size) do |batch|
vector.gen_bulk_reprensentations(Post.where(id: batch))
return if rebaked >= limit rebaked += batch.length
end
return unless SiteSetting.ai_embeddings_per_post_enabled end
# Now for posts
table_name = DiscourseAi::Embeddings::Schema::POSTS_TABLE
posts_batch_size = 1000
posts =
Post
.joins(
"LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id AND #{table_name}.model_id = #{vector_def.id}",
)
.where(deleted_at: nil)
.where(post_type: Post.types[:regular])
# First, we'll try to backfill embeddings for posts that have none
posts
.where("#{table_name}.post_id IS NULL")
.limit(limit - rebaked)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end
return if rebaked >= limit
# Then, we'll try to backfill embeddings for posts that have outdated
# embeddings, be it model or strategy version
posts
.where(<<~SQL)
#{table_name}.model_version < #{vector_def.version}
OR
#{table_name}.strategy_version < #{vector_def.strategy_version}
SQL
.limit(limit - rebaked)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end
return if rebaked >= limit
# Finally, we'll try to backfill embeddings for posts that have outdated
# embeddings due to edits. Here we only do 10% of the limit
posts
.where("#{table_name}.updated_at < ?", 7.days.ago)
.order("random()")
.limit((limit - rebaked) / 10)
.pluck(:id)
.each_slice(posts_batch_size) do |batch|
vector.gen_bulk_reprensentations(Post.where(id: batch))
rebaked += batch.length
end
rebaked
end end
private private

View File

@ -230,20 +230,26 @@ discourse_ai:
enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator" enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator"
validator: "DiscourseAi::Configuration::EmbeddingDefsValidator" validator: "DiscourseAi::Configuration::EmbeddingDefsValidator"
area: "ai-features/embeddings" area: "ai-features/embeddings"
ai_embeddings_backfill_model:
type: enum
default: ""
allow_any: false
enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator"
hidden: true
ai_embeddings_per_post_enabled: ai_embeddings_per_post_enabled:
default: false default: false
hidden: true hidden: true
ai_embeddings_generate_for_pms: ai_embeddings_generate_for_pms:
default: false default: false
area: "ai-features/embeddings" area: "ai-features/embeddings"
ai_embeddings_semantic_related_topics_enabled: ai_embeddings_semantic_related_topics_enabled:
default: false default: false
client: true client: true
area: "ai-features/embeddings" area: "ai-features/embeddings"
ai_embeddings_semantic_related_topics: ai_embeddings_semantic_related_topics:
default: 5 default: 5
area: "ai-features/embeddings" area: "ai-features/embeddings"
ai_embeddings_semantic_related_include_closed_topics: ai_embeddings_semantic_related_include_closed_topics:
default: true default: true
area: "ai-features/embeddings" area: "ai-features/embeddings"
ai_embeddings_backfill_batch_size: ai_embeddings_backfill_batch_size:

View File

@ -20,8 +20,11 @@ module DiscourseAi
MissingEmbeddingError = Class.new(StandardError) MissingEmbeddingError = Class.new(StandardError)
class << self class << self
def for(target_klass) def for(target_klass, vector_def: nil)
vector_def = EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_selected_model) vector_def =
EmbeddingDefinition.find_by(
id: SiteSetting.ai_embeddings_selected_model,
) if vector_def.nil?
raise "Invalid embeddings selected model" if vector_def.nil? raise "Invalid embeddings selected model" if vector_def.nil?
case target_klass&.name case target_klass&.name

View File

@ -3,6 +3,8 @@
module DiscourseAi module DiscourseAi
module Embeddings module Embeddings
class SemanticRelated class SemanticRelated
CACHE_PREFIX = "semantic-suggested-topic-"
def self.clear_cache_for(topic) def self.clear_cache_for(topic)
Discourse.cache.delete("semantic-suggested-topic-#{topic.id}") Discourse.cache.delete("semantic-suggested-topic-#{topic.id}")
Discourse.redis.del("build-semantic-suggested-topic-#{topic.id}") Discourse.redis.del("build-semantic-suggested-topic-#{topic.id}")
@ -79,14 +81,21 @@ module DiscourseAi
) )
end end
def self.clear_cache!
Discourse
.cache
.keys("#{CACHE_PREFIX}*")
.each { |key| Discourse.cache.delete(key.split(":").last) }
end
private private
def semantic_suggested_key(topic_id) def semantic_suggested_key(topic_id)
"semantic-suggested-topic-#{topic_id}" "#{CACHE_PREFIX}#{topic_id}"
end end
def build_semantic_suggested_key(topic_id) def build_semantic_suggested_key(topic_id)
"build-semantic-suggested-topic-#{topic_id}" "build-#{CACHE_PREFIX}#{topic_id}"
end end
end end
end end

View File

@ -25,7 +25,7 @@ module DiscourseAi
idletime: 30, idletime: 30,
) )
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class) schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector_def: @vdef)
embedding_gen = vdef.inference_client embedding_gen = vdef.inference_client
promised_embeddings = promised_embeddings =
@ -58,7 +58,7 @@ module DiscourseAi
text = vdef.prepare_target_text(target) text = vdef.prepare_target_text(target)
return if text.blank? return if text.blank?
schema = DiscourseAi::Embeddings::Schema.for(target.class) schema = DiscourseAi::Embeddings::Schema.for(target.class, vector_def: @vdef)
new_digest = OpenSSL::Digest::SHA1.hexdigest(text) new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
return if schema.find_by_target(target)&.digest == new_digest return if schema.find_by_target(target)&.digest == new_digest

View File

@ -20,6 +20,8 @@ RSpec.describe Jobs::EmbeddingsBackfill do
end end
fab!(:vector_def) { Fabricate(:embedding_definition) } fab!(:vector_def) { Fabricate(:embedding_definition) }
fab!(:vector_def2) { Fabricate(:embedding_definition) }
fab!(:embedding_array) { Array.new(1024) { 1 } }
before do before do
SiteSetting.ai_embeddings_selected_model = vector_def.id SiteSetting.ai_embeddings_selected_model = vector_def.id
@ -27,16 +29,14 @@ RSpec.describe Jobs::EmbeddingsBackfill do
SiteSetting.ai_embeddings_backfill_batch_size = 1 SiteSetting.ai_embeddings_backfill_batch_size = 1
SiteSetting.ai_embeddings_per_post_enabled = true SiteSetting.ai_embeddings_per_post_enabled = true
Jobs.run_immediately! Jobs.run_immediately!
end
it "backfills topics based on bumped_at date" do
embedding = Array.new(1024) { 1 }
WebMock.stub_request(:post, "https://test.com/embeddings").to_return( WebMock.stub_request(:post, "https://test.com/embeddings").to_return(
status: 200, status: 200,
body: JSON.dump(embedding), body: JSON.dump(embedding_array),
) )
end
it "backfills topics based on bumped_at date" do
Jobs::EmbeddingsBackfill.new.execute({}) Jobs::EmbeddingsBackfill.new.execute({})
topic_ids = topic_ids =
@ -68,4 +68,19 @@ RSpec.describe Jobs::EmbeddingsBackfill do
expect(index_date).to be_within_one_second_of(Time.zone.now) expect(index_date).to be_within_one_second_of(Time.zone.now)
end end
it "backfills embeddings for the ai_embeddings_backfill_model" do
SiteSetting.ai_embeddings_backfill_model = vector_def2.id
SiteSetting.ai_embeddings_backfill_batch_size = 100
Jobs::EmbeddingsBackfill.new.execute({})
topic_ids =
DB.query_single(
"SELECT topic_id from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE model_id = ?",
vector_def2.id,
)
expect(topic_ids).to contain_exactly(first_topic.id, second_topic.id, third_topic.id)
end
end end