2023-09-01 20:10:58 -04:00
|
|
|
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
|
|
module AiHelper
|
|
|
|
class SemanticCategorizer
|
2023-10-02 15:36:56 -04:00
|
|
|
def initialize(input, user)
|
2023-09-04 13:30:33 -04:00
|
|
|
@user = user
|
2023-10-02 15:36:56 -04:00
|
|
|
@text = input[:text]
|
2023-09-01 20:10:58 -04:00
|
|
|
end
|
|
|
|
|
|
|
|
def categories
|
|
|
|
return [] if @text.blank?
|
|
|
|
return [] unless SiteSetting.ai_embeddings_enabled
|
|
|
|
|
2023-09-05 13:15:01 -04:00
|
|
|
candidates = nearest_neighbors(limit: 100)
|
2024-11-28 20:17:28 -05:00
|
|
|
return [] if candidates.empty?
|
|
|
|
|
2023-09-01 20:10:58 -04:00
|
|
|
candidate_ids = candidates.map(&:first)
|
|
|
|
|
|
|
|
::Topic
|
|
|
|
.joins(:category)
|
|
|
|
.where(id: candidate_ids)
|
2023-09-04 13:30:33 -04:00
|
|
|
.where("categories.id IN (?)", Category.topic_create_allowed(@user.guardian).pluck(:id))
|
2023-09-01 20:10:58 -04:00
|
|
|
.order("array_position(ARRAY#{candidate_ids}, topics.id)")
|
2024-11-27 15:21:03 -05:00
|
|
|
.pluck(
|
|
|
|
"categories.id",
|
|
|
|
"categories.name",
|
|
|
|
"categories.slug",
|
|
|
|
"categories.color",
|
|
|
|
"categories.topic_count",
|
|
|
|
)
|
2023-09-01 20:10:58 -04:00
|
|
|
.map
|
2024-11-27 15:21:03 -05:00
|
|
|
.with_index do |(id, name, slug, color, topic_count), index|
|
|
|
|
{
|
|
|
|
id: id,
|
|
|
|
name: name,
|
|
|
|
slug: slug,
|
|
|
|
color: color,
|
|
|
|
topicCount: topic_count,
|
|
|
|
score: candidates[index].last,
|
|
|
|
}
|
|
|
|
end
|
2023-09-01 20:10:58 -04:00
|
|
|
.map do |c|
|
|
|
|
c[:score] = 1 / (c[:score] + 1) # inverse of the distance
|
|
|
|
c
|
|
|
|
end
|
|
|
|
.group_by { |c| c[:name] }
|
2024-11-27 15:21:03 -05:00
|
|
|
.map { |name, scores| scores.first.merge(score: scores.sum { |s| s[:score] }) }
|
2023-09-01 20:10:58 -04:00
|
|
|
.sort_by { |c| -c[:score] }
|
|
|
|
.take(5)
|
|
|
|
end
|
|
|
|
|
|
|
|
def tags
|
|
|
|
return [] if @text.blank?
|
|
|
|
return [] unless SiteSetting.ai_embeddings_enabled
|
|
|
|
|
2023-09-05 13:15:01 -04:00
|
|
|
candidates = nearest_neighbors(limit: 100)
|
2024-11-28 20:17:28 -05:00
|
|
|
return [] if candidates.empty?
|
|
|
|
|
2023-09-01 20:10:58 -04:00
|
|
|
candidate_ids = candidates.map(&:first)
|
|
|
|
|
2024-11-27 15:21:03 -05:00
|
|
|
count_column = Tag.topic_count_column(@user.guardian) # Determine the count column
|
|
|
|
|
2023-09-01 20:10:58 -04:00
|
|
|
::Topic
|
|
|
|
.joins(:topic_tags, :tags)
|
|
|
|
.where(id: candidate_ids)
|
2023-09-04 13:30:33 -04:00
|
|
|
.where("tags.id IN (?)", DiscourseTagging.visible_tags(@user.guardian).pluck(:id))
|
2024-11-28 14:27:38 -05:00
|
|
|
.group("topics.id")
|
2023-09-01 20:10:58 -04:00
|
|
|
.order("array_position(ARRAY#{candidate_ids}, topics.id)")
|
2024-11-28 14:27:38 -05:00
|
|
|
.pluck("array_agg(tags.name)")
|
|
|
|
.map(&:uniq)
|
2023-09-01 20:10:58 -04:00
|
|
|
.map
|
2024-11-28 14:27:38 -05:00
|
|
|
.with_index { |tag_list, index| { tags: tag_list, score: candidates[index].last } }
|
|
|
|
.flat_map { |c| c[:tags].map { |t| { name: t, score: c[:score] } } }
|
|
|
|
.map do |c|
|
|
|
|
c[:score] = 1 / (c[:score] + 1) # inverse of the distance
|
|
|
|
c
|
2023-09-01 20:10:58 -04:00
|
|
|
end
|
2024-11-28 14:27:38 -05:00
|
|
|
.group_by { |c| c[:name] }
|
|
|
|
.map { |name, scores| { name: name, score: scores.sum { |s| s[:score] } } }
|
|
|
|
.sort_by { |c| -c[:score] }
|
2024-12-02 15:25:04 -05:00
|
|
|
.take(7)
|
2024-11-28 14:27:38 -05:00
|
|
|
.then do |tags|
|
|
|
|
models = Tag.where(name: tags.map { _1[:name] }).index_by(&:name)
|
|
|
|
tags.map do |tag|
|
|
|
|
tag[:id] = models.dig(tag[:name])&.id
|
|
|
|
tag[:count] = models.dig(tag[:name])&.public_send(count_column) || 0
|
|
|
|
tag
|
|
|
|
end
|
|
|
|
end
|
2023-09-01 20:10:58 -04:00
|
|
|
end
|
2023-09-05 13:15:01 -04:00
|
|
|
|
|
|
|
private
|
|
|
|
|
|
|
|
def nearest_neighbors(limit: 100)
|
2024-12-13 08:15:21 -05:00
|
|
|
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
|
|
|
|
schema = DiscourseAi::Embeddings::Schema.for(Topic, vector: vector_rep)
|
2023-09-05 13:15:01 -04:00
|
|
|
|
|
|
|
raw_vector = vector_rep.vector_from(@text)
|
|
|
|
|
2024-11-28 20:17:28 -05:00
|
|
|
muted_category_ids = nil
|
|
|
|
if @user.present?
|
|
|
|
muted_category_ids =
|
|
|
|
CategoryUser.where(
|
|
|
|
user: @user,
|
|
|
|
notification_level: CategoryUser.notification_levels[:muted],
|
|
|
|
).pluck(:category_id)
|
|
|
|
end
|
|
|
|
|
2024-12-13 08:15:21 -05:00
|
|
|
schema
|
|
|
|
.asymmetric_similarity_search(raw_vector, limit: limit, offset: 0) do |builder|
|
|
|
|
builder.join("topics t on t.id = topic_id")
|
|
|
|
builder.where(<<~SQL, exclude_category_ids: muted_category_ids.map(&:to_i))
|
|
|
|
t.category_id NOT IN (:exclude_category_ids) AND
|
|
|
|
t.category_id NOT IN (SELECT categories.id FROM categories WHERE categories.parent_category_id IN (:exclude_category_ids))
|
|
|
|
SQL
|
|
|
|
end
|
|
|
|
.map { |r| [r.topic_id, r.distance] }
|
2023-09-05 13:15:01 -04:00
|
|
|
end
|
2023-09-01 20:10:58 -04:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|