diff --git a/app/jobs/scheduled/embeddings_backfill.rb b/app/jobs/scheduled/embeddings_backfill.rb index 4ae9f70c..163534cc 100644 --- a/app/jobs/scheduled/embeddings_backfill.rb +++ b/app/jobs/scheduled/embeddings_backfill.rb @@ -18,105 +18,115 @@ module Jobs ) end - rebaked = 0 + production_vector = DiscourseAi::Embeddings::Vector.instance - vector = DiscourseAi::Embeddings::Vector.instance - vector_def = vector.vdef - table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE - - topics = - Topic - .joins( - "LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id AND #{table_name}.model_id = #{vector_def.id}", + if SiteSetting.ai_embeddings_backfill_model.present? && + SiteSetting.ai_embeddings_backfill_model != SiteSetting.ai_embeddings_selected_model + backfill_vector = + DiscourseAi::Embeddings::Vector.new( + EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_backfill_model), ) - .where(archetype: Archetype.default) - .where(deleted_at: nil) - .order("topics.bumped_at DESC") + end - rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked)) + topic_work_list = [] + topic_work_list << production_vector + topic_work_list << backfill_vector if backfill_vector - return if rebaked >= limit + topic_work_list.each do |vector| + rebaked = 0 + table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE + vector_def = vector.vdef - # Then, we'll try to backfill embeddings for topics that have outdated - # embeddings, be it model or strategy version - relation = topics.where(<<~SQL).limit(limit - rebaked) - #{table_name}.model_version < #{vector_def.version} - OR - #{table_name}.strategy_version < #{vector_def.strategy_version} - SQL + topics = + Topic + .joins( + "LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id AND #{table_name}.model_id = #{vector_def.id}", + ) + .where(archetype: Archetype.default) + .where(deleted_at: nil) + .order("topics.bumped_at DESC") - rebaked += populate_topic_embeddings(vector, relation, force: true) + rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked)) - return if rebaked >= limit + next if rebaked >= limit - # Finally, we'll try to backfill embeddings for topics that have outdated - # embeddings due to edits or new replies. Here we only do 10% of the limit - relation = - topics - .where("#{table_name}.updated_at < ?", 6.hours.ago) - .where("#{table_name}.updated_at < topics.updated_at") + # Then, we'll try to backfill embeddings for topics that have outdated + # embeddings, be it model or strategy version + relation = topics.where(<<~SQL).limit(limit - rebaked) + #{table_name}.model_version < #{vector_def.version} + OR + #{table_name}.strategy_version < #{vector_def.strategy_version} + SQL + + rebaked += populate_topic_embeddings(vector, relation, force: true) + + next if rebaked >= limit + + # Finally, we'll try to backfill embeddings for topics that have outdated + # embeddings due to edits or new replies. Here we only do 10% of the limit + relation = + topics + .where("#{table_name}.updated_at < ?", 6.hours.ago) + .where("#{table_name}.updated_at < topics.updated_at") + .limit((limit - rebaked) / 10) + + populate_topic_embeddings(vector, relation, force: true) + + next unless SiteSetting.ai_embeddings_per_post_enabled + + # Now for posts + table_name = DiscourseAi::Embeddings::Schema::POSTS_TABLE + posts_batch_size = 1000 + + posts = + Post + .joins( + "LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id AND #{table_name}.model_id = #{vector_def.id}", + ) + .where(deleted_at: nil) + .where(post_type: Post.types[:regular]) + + # First, we'll try to backfill embeddings for posts that have none + posts + .where("#{table_name}.post_id IS NULL") + .limit(limit - rebaked) + .pluck(:id) + .each_slice(posts_batch_size) do |batch| + vector.gen_bulk_reprensentations(Post.where(id: batch)) + rebaked += batch.length + end + + next if rebaked >= limit + + # Then, we'll try to backfill embeddings for posts that have outdated + # embeddings, be it model or strategy version + posts + .where(<<~SQL) + #{table_name}.model_version < #{vector_def.version} + OR + #{table_name}.strategy_version < #{vector_def.strategy_version} + SQL + .limit(limit - rebaked) + .pluck(:id) + .each_slice(posts_batch_size) do |batch| + vector.gen_bulk_reprensentations(Post.where(id: batch)) + rebaked += batch.length + end + + next if rebaked >= limit + + # Finally, we'll try to backfill embeddings for posts that have outdated + # embeddings due to edits. Here we only do 10% of the limit + posts + .where("#{table_name}.updated_at < ?", 7.days.ago) + .order("random()") .limit((limit - rebaked) / 10) - - populate_topic_embeddings(vector, relation, force: true) - - return if rebaked >= limit - - return unless SiteSetting.ai_embeddings_per_post_enabled - - # Now for posts - table_name = DiscourseAi::Embeddings::Schema::POSTS_TABLE - posts_batch_size = 1000 - - posts = - Post - .joins( - "LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id AND #{table_name}.model_id = #{vector_def.id}", - ) - .where(deleted_at: nil) - .where(post_type: Post.types[:regular]) - - # First, we'll try to backfill embeddings for posts that have none - posts - .where("#{table_name}.post_id IS NULL") - .limit(limit - rebaked) - .pluck(:id) - .each_slice(posts_batch_size) do |batch| - vector.gen_bulk_reprensentations(Post.where(id: batch)) - rebaked += batch.length - end - - return if rebaked >= limit - - # Then, we'll try to backfill embeddings for posts that have outdated - # embeddings, be it model or strategy version - posts - .where(<<~SQL) - #{table_name}.model_version < #{vector_def.version} - OR - #{table_name}.strategy_version < #{vector_def.strategy_version} - SQL - .limit(limit - rebaked) - .pluck(:id) - .each_slice(posts_batch_size) do |batch| - vector.gen_bulk_reprensentations(Post.where(id: batch)) - rebaked += batch.length - end - - return if rebaked >= limit - - # Finally, we'll try to backfill embeddings for posts that have outdated - # embeddings due to edits. Here we only do 10% of the limit - posts - .where("#{table_name}.updated_at < ?", 7.days.ago) - .order("random()") - .limit((limit - rebaked) / 10) - .pluck(:id) - .each_slice(posts_batch_size) do |batch| - vector.gen_bulk_reprensentations(Post.where(id: batch)) - rebaked += batch.length - end - - rebaked + .pluck(:id) + .each_slice(posts_batch_size) do |batch| + vector.gen_bulk_reprensentations(Post.where(id: batch)) + rebaked += batch.length + end + end end private diff --git a/config/settings.yml b/config/settings.yml index 69c56080..117b4115 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -230,20 +230,26 @@ discourse_ai: enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator" validator: "DiscourseAi::Configuration::EmbeddingDefsValidator" area: "ai-features/embeddings" + ai_embeddings_backfill_model: + type: enum + default: "" + allow_any: false + enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator" + hidden: true ai_embeddings_per_post_enabled: default: false hidden: true - ai_embeddings_generate_for_pms: + ai_embeddings_generate_for_pms: default: false area: "ai-features/embeddings" ai_embeddings_semantic_related_topics_enabled: default: false client: true area: "ai-features/embeddings" - ai_embeddings_semantic_related_topics: + ai_embeddings_semantic_related_topics: default: 5 area: "ai-features/embeddings" - ai_embeddings_semantic_related_include_closed_topics: + ai_embeddings_semantic_related_include_closed_topics: default: true area: "ai-features/embeddings" ai_embeddings_backfill_batch_size: diff --git a/lib/embeddings/schema.rb b/lib/embeddings/schema.rb index 4d540fa7..a43dbbf4 100644 --- a/lib/embeddings/schema.rb +++ b/lib/embeddings/schema.rb @@ -20,8 +20,11 @@ module DiscourseAi MissingEmbeddingError = Class.new(StandardError) class << self - def for(target_klass) - vector_def = EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_selected_model) + def for(target_klass, vector_def: nil) + vector_def = + EmbeddingDefinition.find_by( + id: SiteSetting.ai_embeddings_selected_model, + ) if vector_def.nil? raise "Invalid embeddings selected model" if vector_def.nil? case target_klass&.name diff --git a/lib/embeddings/semantic_related.rb b/lib/embeddings/semantic_related.rb index c3bf728b..8c0376cc 100644 --- a/lib/embeddings/semantic_related.rb +++ b/lib/embeddings/semantic_related.rb @@ -3,6 +3,8 @@ module DiscourseAi module Embeddings class SemanticRelated + CACHE_PREFIX = "semantic-suggested-topic-" + def self.clear_cache_for(topic) Discourse.cache.delete("semantic-suggested-topic-#{topic.id}") Discourse.redis.del("build-semantic-suggested-topic-#{topic.id}") @@ -79,14 +81,21 @@ module DiscourseAi ) end + def self.clear_cache! + Discourse + .cache + .keys("#{CACHE_PREFIX}*") + .each { |key| Discourse.cache.delete(key.split(":").last) } + end + private def semantic_suggested_key(topic_id) - "semantic-suggested-topic-#{topic_id}" + "#{CACHE_PREFIX}#{topic_id}" end def build_semantic_suggested_key(topic_id) - "build-semantic-suggested-topic-#{topic_id}" + "build-#{CACHE_PREFIX}#{topic_id}" end end end diff --git a/lib/embeddings/vector.rb b/lib/embeddings/vector.rb index c4edd78a..4e847f17 100644 --- a/lib/embeddings/vector.rb +++ b/lib/embeddings/vector.rb @@ -25,7 +25,7 @@ module DiscourseAi idletime: 30, ) - schema = DiscourseAi::Embeddings::Schema.for(relation.first.class) + schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector_def: @vdef) embedding_gen = vdef.inference_client promised_embeddings = @@ -58,7 +58,7 @@ module DiscourseAi text = vdef.prepare_target_text(target) return if text.blank? - schema = DiscourseAi::Embeddings::Schema.for(target.class) + schema = DiscourseAi::Embeddings::Schema.for(target.class, vector_def: @vdef) new_digest = OpenSSL::Digest::SHA1.hexdigest(text) return if schema.find_by_target(target)&.digest == new_digest diff --git a/spec/jobs/scheduled/embeddings_backfill_spec.rb b/spec/jobs/scheduled/embeddings_backfill_spec.rb index bd74e1f3..dfb9b64a 100644 --- a/spec/jobs/scheduled/embeddings_backfill_spec.rb +++ b/spec/jobs/scheduled/embeddings_backfill_spec.rb @@ -20,6 +20,8 @@ RSpec.describe Jobs::EmbeddingsBackfill do end fab!(:vector_def) { Fabricate(:embedding_definition) } + fab!(:vector_def2) { Fabricate(:embedding_definition) } + fab!(:embedding_array) { Array.new(1024) { 1 } } before do SiteSetting.ai_embeddings_selected_model = vector_def.id @@ -27,16 +29,14 @@ RSpec.describe Jobs::EmbeddingsBackfill do SiteSetting.ai_embeddings_backfill_batch_size = 1 SiteSetting.ai_embeddings_per_post_enabled = true Jobs.run_immediately! - end - - it "backfills topics based on bumped_at date" do - embedding = Array.new(1024) { 1 } WebMock.stub_request(:post, "https://test.com/embeddings").to_return( status: 200, - body: JSON.dump(embedding), + body: JSON.dump(embedding_array), ) + end + it "backfills topics based on bumped_at date" do Jobs::EmbeddingsBackfill.new.execute({}) topic_ids = @@ -68,4 +68,19 @@ RSpec.describe Jobs::EmbeddingsBackfill do expect(index_date).to be_within_one_second_of(Time.zone.now) end + + it "backfills embeddings for the ai_embeddings_backfill_model" do + SiteSetting.ai_embeddings_backfill_model = vector_def2.id + SiteSetting.ai_embeddings_backfill_batch_size = 100 + + Jobs::EmbeddingsBackfill.new.execute({}) + + topic_ids = + DB.query_single( + "SELECT topic_id from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE model_id = ?", + vector_def2.id, + ) + + expect(topic_ids).to contain_exactly(first_topic.id, second_topic.id, third_topic.id) + end end