diff --git a/lib/ai_helper/semantic_categorizer.rb b/lib/ai_helper/semantic_categorizer.rb index 7fa46f29..612030c8 100644 --- a/lib/ai_helper/semantic_categorizer.rb +++ b/lib/ai_helper/semantic_categorizer.rb @@ -12,6 +12,8 @@ module DiscourseAi return [] unless SiteSetting.ai_embeddings_enabled candidates = nearest_neighbors(limit: 100) + return [] if candidates.empty? + candidate_ids = candidates.map(&:first) ::Topic @@ -52,6 +54,8 @@ module DiscourseAi return [] unless SiteSetting.ai_embeddings_enabled candidates = nearest_neighbors(limit: 100) + return [] if candidates.empty? + candidate_ids = candidates.map(&:first) count_column = Tag.topic_count_column(@user.guardian) # Determine the count column @@ -94,11 +98,21 @@ module DiscourseAi raw_vector = vector_rep.vector_from(@text) + muted_category_ids = nil + if @user.present? + muted_category_ids = + CategoryUser.where( + user: @user, + notification_level: CategoryUser.notification_levels[:muted], + ).pluck(:category_id) + end + vector_rep.asymmetric_topics_similarity_search( raw_vector, limit: limit, offset: 0, return_distance: true, + exclude_category_ids: muted_category_ids, ) end end diff --git a/lib/embeddings/vector_representations/base.rb b/lib/embeddings/vector_representations/base.rb index 5f113002..831a9870 100644 --- a/lib/embeddings/vector_representations/base.rb +++ b/lib/embeddings/vector_representations/base.rb @@ -151,16 +151,22 @@ module DiscourseAi SQL end - def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false) - results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset) + def asymmetric_topics_similarity_search( + raw_vector, + limit:, + offset:, + return_distance: false, + exclude_category_ids: nil + ) + builder = DB.build(<<~SQL) WITH candidates AS ( SELECT topic_id, embeddings::halfvec(#{dimensions}) AS embeddings FROM #{topic_table_name} - WHERE - model_id = #{id} AND strategy_id = #{@strategy.id} + /*join*/ + /*where*/ ORDER BY binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions})) LIMIT :limit * 2 @@ -176,6 +182,22 @@ module DiscourseAi OFFSET :offset SQL + builder.where( + "model_id = :model_id AND strategy_id = :strategy_id", + model_id: id, + strategy_id: @strategy.id, + ) + + if exclude_category_ids.present? + builder.join("topics t on t.id = topic_id") + builder.where(<<~SQL, exclude_category_ids: exclude_category_ids.map(&:to_i)) + t.category_id NOT IN (:exclude_category_ids) AND + t.category_id NOT IN (SELECT categories.id FROM categories WHERE categories.parent_category_id IN (:exclude_category_ids)) + SQL + end + + results = builder.query(query_embedding: raw_vector, limit: limit, offset: offset) + if return_distance results.map { |r| [r.topic_id, r.distance] } else diff --git a/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb new file mode 100644 index 00000000..8d40b572 --- /dev/null +++ b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb @@ -0,0 +1,43 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do + fab!(:user) + fab!(:muted_category) { Fabricate(:category) } + fab!(:category_mute) do + CategoryUser.create!( + user: user, + category: muted_category, + notification_level: CategoryUser.notification_levels[:muted], + ) + end + fab!(:muted_topic) { Fabricate(:topic, category: muted_category) } + fab!(:category) + fab!(:topic) { Fabricate(:topic, category: category) } + + let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } + let(:vector_rep) do + DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) + end + let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) } + let(:expected_embedding) { [0.0038493] * vector_rep.dimensions } + + before do + SiteSetting.ai_embeddings_enabled = true + SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" + SiteSetting.ai_embeddings_model = "bge-large-en" + + WebMock.stub_request( + :post, + "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", + ).to_return(status: 200, body: JSON.dump(expected_embedding)) + + vector_rep.generate_representation_from(topic) + vector_rep.generate_representation_from(muted_topic) + end + + it "respects user muted categories when making suggestions" do + category_ids = categorizer.categories.map { |c| c[:id] } + expect(category_ids).not_to include(muted_category.id) + expect(category_ids).to include(category.id) + end +end diff --git a/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb index 075e6930..a9134403 100644 --- a/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb +++ b/spec/lib/modules/embeddings/vector_representations/vector_rep_shared_examples.rb @@ -104,5 +104,38 @@ RSpec.shared_examples "generates and store embedding using with vector represent vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0), ).to contain_exactly(topic.id) end + + it "can exclude categories" do + similar_vector = [0.0038494] * vector_rep.dimensions + text = + truncation.prepare_text_from( + topic, + vector_rep.tokenizer, + vector_rep.max_sequence_length - 2, + ) + stub_vector_mapping(text, expected_embedding_1) + vector_rep.generate_representation_from(topic) + + expect( + vector_rep.asymmetric_topics_similarity_search( + similar_vector, + limit: 1, + offset: 0, + exclude_category_ids: [topic.category_id], + ), + ).to be_empty + + child_category = Fabricate(:category, parent_category_id: topic.category_id) + topic.update!(category_id: child_category.id) + + expect( + vector_rep.asymmetric_topics_similarity_search( + similar_vector, + limit: 1, + offset: 0, + exclude_category_ids: [topic.category_id], + ), + ).to be_empty + end end end