FEATURE: exclude muted categories from category suggester (#979)

The logic here is that users do not particularly care about
topics in the category so we can exclude them from tag
and category suggestions
This commit is contained in:
Sam 2024-11-29 12:17:28 +11:00 committed by GitHub
parent 80adefa1c1
commit 0cb2c413ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 116 additions and 4 deletions

View File

@ -12,6 +12,8 @@ module DiscourseAi
return [] unless SiteSetting.ai_embeddings_enabled
candidates = nearest_neighbors(limit: 100)
return [] if candidates.empty?
candidate_ids = candidates.map(&:first)
::Topic
@ -52,6 +54,8 @@ module DiscourseAi
return [] unless SiteSetting.ai_embeddings_enabled
candidates = nearest_neighbors(limit: 100)
return [] if candidates.empty?
candidate_ids = candidates.map(&:first)
count_column = Tag.topic_count_column(@user.guardian) # Determine the count column
@ -94,11 +98,21 @@ module DiscourseAi
raw_vector = vector_rep.vector_from(@text)
muted_category_ids = nil
if @user.present?
muted_category_ids =
CategoryUser.where(
user: @user,
notification_level: CategoryUser.notification_levels[:muted],
).pluck(:category_id)
end
vector_rep.asymmetric_topics_similarity_search(
raw_vector,
limit: limit,
offset: 0,
return_distance: true,
exclude_category_ids: muted_category_ids,
)
end
end

View File

@ -151,16 +151,22 @@ module DiscourseAi
SQL
end
def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false)
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
def asymmetric_topics_similarity_search(
raw_vector,
limit:,
offset:,
return_distance: false,
exclude_category_ids: nil
)
builder = DB.build(<<~SQL)
WITH candidates AS (
SELECT
topic_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND strategy_id = #{@strategy.id}
/*join*/
/*where*/
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :limit * 2
@ -176,6 +182,22 @@ module DiscourseAi
OFFSET :offset
SQL
builder.where(
"model_id = :model_id AND strategy_id = :strategy_id",
model_id: id,
strategy_id: @strategy.id,
)
if exclude_category_ids.present?
builder.join("topics t on t.id = topic_id")
builder.where(<<~SQL, exclude_category_ids: exclude_category_ids.map(&:to_i))
t.category_id NOT IN (:exclude_category_ids) AND
t.category_id NOT IN (SELECT categories.id FROM categories WHERE categories.parent_category_id IN (:exclude_category_ids))
SQL
end
results = builder.query(query_embedding: raw_vector, limit: limit, offset: offset)
if return_distance
results.map { |r| [r.topic_id, r.distance] }
else

View File

@ -0,0 +1,43 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
fab!(:user)
fab!(:muted_category) { Fabricate(:category) }
fab!(:category_mute) do
CategoryUser.create!(
user: user,
category: muted_category,
notification_level: CategoryUser.notification_levels[:muted],
)
end
fab!(:muted_topic) { Fabricate(:topic, category: muted_category) }
fab!(:category)
fab!(:topic) { Fabricate(:topic, category: category) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
let(:vector_rep) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
end
let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) }
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
before do
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_model = "bge-large-en"
WebMock.stub_request(
:post,
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
).to_return(status: 200, body: JSON.dump(expected_embedding))
vector_rep.generate_representation_from(topic)
vector_rep.generate_representation_from(muted_topic)
end
it "respects user muted categories when making suggestions" do
category_ids = categorizer.categories.map { |c| c[:id] }
expect(category_ids).not_to include(muted_category.id)
expect(category_ids).to include(category.id)
end
end

View File

@ -104,5 +104,38 @@ RSpec.shared_examples "generates and store embedding using with vector represent
vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0),
).to contain_exactly(topic.id)
end
it "can exclude categories" do
similar_vector = [0.0038494] * vector_rep.dimensions
text =
truncation.prepare_text_from(
topic,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, expected_embedding_1)
vector_rep.generate_representation_from(topic)
expect(
vector_rep.asymmetric_topics_similarity_search(
similar_vector,
limit: 1,
offset: 0,
exclude_category_ids: [topic.category_id],
),
).to be_empty
child_category = Fabricate(:category, parent_category_id: topic.category_id)
topic.update!(category_id: child_category.id)
expect(
vector_rep.asymmetric_topics_similarity_search(
similar_vector,
limit: 1,
offset: 0,
exclude_category_ids: [topic.category_id],
),
).to be_empty
end
end
end