FEATURE: exclude muted categories from category suggester (#979)
The logic here is that users do not particularly care about topics in the category so we can exclude them from tag and category suggestions
This commit is contained in:
parent
80adefa1c1
commit
0cb2c413ba
|
@ -12,6 +12,8 @@ module DiscourseAi
|
||||||
return [] unless SiteSetting.ai_embeddings_enabled
|
return [] unless SiteSetting.ai_embeddings_enabled
|
||||||
|
|
||||||
candidates = nearest_neighbors(limit: 100)
|
candidates = nearest_neighbors(limit: 100)
|
||||||
|
return [] if candidates.empty?
|
||||||
|
|
||||||
candidate_ids = candidates.map(&:first)
|
candidate_ids = candidates.map(&:first)
|
||||||
|
|
||||||
::Topic
|
::Topic
|
||||||
|
@ -52,6 +54,8 @@ module DiscourseAi
|
||||||
return [] unless SiteSetting.ai_embeddings_enabled
|
return [] unless SiteSetting.ai_embeddings_enabled
|
||||||
|
|
||||||
candidates = nearest_neighbors(limit: 100)
|
candidates = nearest_neighbors(limit: 100)
|
||||||
|
return [] if candidates.empty?
|
||||||
|
|
||||||
candidate_ids = candidates.map(&:first)
|
candidate_ids = candidates.map(&:first)
|
||||||
|
|
||||||
count_column = Tag.topic_count_column(@user.guardian) # Determine the count column
|
count_column = Tag.topic_count_column(@user.guardian) # Determine the count column
|
||||||
|
@ -94,11 +98,21 @@ module DiscourseAi
|
||||||
|
|
||||||
raw_vector = vector_rep.vector_from(@text)
|
raw_vector = vector_rep.vector_from(@text)
|
||||||
|
|
||||||
|
muted_category_ids = nil
|
||||||
|
if @user.present?
|
||||||
|
muted_category_ids =
|
||||||
|
CategoryUser.where(
|
||||||
|
user: @user,
|
||||||
|
notification_level: CategoryUser.notification_levels[:muted],
|
||||||
|
).pluck(:category_id)
|
||||||
|
end
|
||||||
|
|
||||||
vector_rep.asymmetric_topics_similarity_search(
|
vector_rep.asymmetric_topics_similarity_search(
|
||||||
raw_vector,
|
raw_vector,
|
||||||
limit: limit,
|
limit: limit,
|
||||||
offset: 0,
|
offset: 0,
|
||||||
return_distance: true,
|
return_distance: true,
|
||||||
|
exclude_category_ids: muted_category_ids,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -151,16 +151,22 @@ module DiscourseAi
|
||||||
SQL
|
SQL
|
||||||
end
|
end
|
||||||
|
|
||||||
def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false)
|
def asymmetric_topics_similarity_search(
|
||||||
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
|
raw_vector,
|
||||||
|
limit:,
|
||||||
|
offset:,
|
||||||
|
return_distance: false,
|
||||||
|
exclude_category_ids: nil
|
||||||
|
)
|
||||||
|
builder = DB.build(<<~SQL)
|
||||||
WITH candidates AS (
|
WITH candidates AS (
|
||||||
SELECT
|
SELECT
|
||||||
topic_id,
|
topic_id,
|
||||||
embeddings::halfvec(#{dimensions}) AS embeddings
|
embeddings::halfvec(#{dimensions}) AS embeddings
|
||||||
FROM
|
FROM
|
||||||
#{topic_table_name}
|
#{topic_table_name}
|
||||||
WHERE
|
/*join*/
|
||||||
model_id = #{id} AND strategy_id = #{@strategy.id}
|
/*where*/
|
||||||
ORDER BY
|
ORDER BY
|
||||||
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
|
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
|
||||||
LIMIT :limit * 2
|
LIMIT :limit * 2
|
||||||
|
@ -176,6 +182,22 @@ module DiscourseAi
|
||||||
OFFSET :offset
|
OFFSET :offset
|
||||||
SQL
|
SQL
|
||||||
|
|
||||||
|
builder.where(
|
||||||
|
"model_id = :model_id AND strategy_id = :strategy_id",
|
||||||
|
model_id: id,
|
||||||
|
strategy_id: @strategy.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if exclude_category_ids.present?
|
||||||
|
builder.join("topics t on t.id = topic_id")
|
||||||
|
builder.where(<<~SQL, exclude_category_ids: exclude_category_ids.map(&:to_i))
|
||||||
|
t.category_id NOT IN (:exclude_category_ids) AND
|
||||||
|
t.category_id NOT IN (SELECT categories.id FROM categories WHERE categories.parent_category_id IN (:exclude_category_ids))
|
||||||
|
SQL
|
||||||
|
end
|
||||||
|
|
||||||
|
results = builder.query(query_embedding: raw_vector, limit: limit, offset: offset)
|
||||||
|
|
||||||
if return_distance
|
if return_distance
|
||||||
results.map { |r| [r.topic_id, r.distance] }
|
results.map { |r| [r.topic_id, r.distance] }
|
||||||
else
|
else
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do
|
||||||
|
fab!(:user)
|
||||||
|
fab!(:muted_category) { Fabricate(:category) }
|
||||||
|
fab!(:category_mute) do
|
||||||
|
CategoryUser.create!(
|
||||||
|
user: user,
|
||||||
|
category: muted_category,
|
||||||
|
notification_level: CategoryUser.notification_levels[:muted],
|
||||||
|
)
|
||||||
|
end
|
||||||
|
fab!(:muted_topic) { Fabricate(:topic, category: muted_category) }
|
||||||
|
fab!(:category)
|
||||||
|
fab!(:topic) { Fabricate(:topic, category: category) }
|
||||||
|
|
||||||
|
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||||
|
let(:vector_rep) do
|
||||||
|
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
|
||||||
|
end
|
||||||
|
let(:categorizer) { DiscourseAi::AiHelper::SemanticCategorizer.new({ text: "hello" }, user) }
|
||||||
|
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
|
||||||
|
|
||||||
|
before do
|
||||||
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||||
|
SiteSetting.ai_embeddings_model = "bge-large-en"
|
||||||
|
|
||||||
|
WebMock.stub_request(
|
||||||
|
:post,
|
||||||
|
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
|
||||||
|
).to_return(status: 200, body: JSON.dump(expected_embedding))
|
||||||
|
|
||||||
|
vector_rep.generate_representation_from(topic)
|
||||||
|
vector_rep.generate_representation_from(muted_topic)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "respects user muted categories when making suggestions" do
|
||||||
|
category_ids = categorizer.categories.map { |c| c[:id] }
|
||||||
|
expect(category_ids).not_to include(muted_category.id)
|
||||||
|
expect(category_ids).to include(category.id)
|
||||||
|
end
|
||||||
|
end
|
|
@ -104,5 +104,38 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
||||||
vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0),
|
vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0),
|
||||||
).to contain_exactly(topic.id)
|
).to contain_exactly(topic.id)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "can exclude categories" do
|
||||||
|
similar_vector = [0.0038494] * vector_rep.dimensions
|
||||||
|
text =
|
||||||
|
truncation.prepare_text_from(
|
||||||
|
topic,
|
||||||
|
vector_rep.tokenizer,
|
||||||
|
vector_rep.max_sequence_length - 2,
|
||||||
|
)
|
||||||
|
stub_vector_mapping(text, expected_embedding_1)
|
||||||
|
vector_rep.generate_representation_from(topic)
|
||||||
|
|
||||||
|
expect(
|
||||||
|
vector_rep.asymmetric_topics_similarity_search(
|
||||||
|
similar_vector,
|
||||||
|
limit: 1,
|
||||||
|
offset: 0,
|
||||||
|
exclude_category_ids: [topic.category_id],
|
||||||
|
),
|
||||||
|
).to be_empty
|
||||||
|
|
||||||
|
child_category = Fabricate(:category, parent_category_id: topic.category_id)
|
||||||
|
topic.update!(category_id: child_category.id)
|
||||||
|
|
||||||
|
expect(
|
||||||
|
vector_rep.asymmetric_topics_similarity_search(
|
||||||
|
similar_vector,
|
||||||
|
limit: 1,
|
||||||
|
offset: 0,
|
||||||
|
exclude_category_ids: [topic.category_id],
|
||||||
|
),
|
||||||
|
).to be_empty
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue