2024-11-28 15:38:23 -03:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
|
|
|
module DiscourseAi
|
|
|
|
module Sentiment
|
|
|
|
class PostClassification
|
2024-12-03 10:27:03 -03:00
|
|
|
def self.backfill_query(from_post_id: nil, max_age_days: nil)
|
|
|
|
available_classifier_names =
|
2024-12-04 12:10:31 -03:00
|
|
|
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.map { _1.model_name }
|
2024-12-03 10:27:03 -03:00
|
|
|
|
2024-12-04 12:10:31 -03:00
|
|
|
queries =
|
|
|
|
available_classifier_names.map do |classifier_name|
|
|
|
|
base_query =
|
|
|
|
Post
|
|
|
|
.includes(:sentiment_classifications)
|
|
|
|
.joins("INNER JOIN topics ON topics.id = posts.topic_id")
|
|
|
|
.where(post_type: Post.types[:regular])
|
|
|
|
.where.not(topics: { archetype: Archetype.private_message })
|
|
|
|
.where(posts: { deleted_at: nil })
|
|
|
|
.where(topics: { deleted_at: nil })
|
|
|
|
.joins(<<~SQL)
|
|
|
|
LEFT JOIN classification_results crs
|
|
|
|
ON crs.target_id = posts.id
|
|
|
|
AND crs.target_type = 'Post'
|
|
|
|
AND crs.classification_type = 'sentiment'
|
|
|
|
AND crs.model_used = '#{classifier_name}'
|
|
|
|
SQL
|
|
|
|
.where("crs.id IS NULL")
|
2024-12-03 10:27:03 -03:00
|
|
|
|
2024-12-04 12:10:31 -03:00
|
|
|
base_query =
|
|
|
|
base_query.where("posts.id >= ?", from_post_id.to_i) if from_post_id.present?
|
2024-12-03 10:27:03 -03:00
|
|
|
|
2024-12-04 12:10:31 -03:00
|
|
|
if max_age_days.present?
|
|
|
|
base_query =
|
|
|
|
base_query.where(
|
|
|
|
"posts.created_at > current_date - INTERVAL '#{max_age_days.to_i} DAY'",
|
|
|
|
)
|
|
|
|
end
|
2024-12-03 10:27:03 -03:00
|
|
|
|
2024-12-04 12:10:31 -03:00
|
|
|
base_query
|
|
|
|
end
|
|
|
|
|
|
|
|
unioned_queries = queries.map(&:to_sql).join(" UNION ")
|
|
|
|
|
|
|
|
Post.from(Arel.sql("(#{unioned_queries}) as posts"))
|
2024-12-03 10:27:03 -03:00
|
|
|
end
|
|
|
|
|
2025-06-20 16:06:03 +10:00
|
|
|
CONCURRENT_CLASSFICATIONS = 40
|
|
|
|
|
2024-11-28 15:38:23 -03:00
|
|
|
def bulk_classify!(relation)
|
|
|
|
pool =
|
2025-06-20 16:06:03 +10:00
|
|
|
Scheduler::ThreadPool.new(
|
2024-11-28 15:38:23 -03:00
|
|
|
min_threads: 0,
|
2025-06-20 16:06:03 +10:00
|
|
|
max_threads: CONCURRENT_CLASSFICATIONS,
|
|
|
|
idle_time: 30,
|
2024-11-28 15:38:23 -03:00
|
|
|
)
|
|
|
|
|
|
|
|
available_classifiers = classifiers
|
2024-12-03 10:27:03 -03:00
|
|
|
return if available_classifiers.blank?
|
2024-11-28 15:38:23 -03:00
|
|
|
|
2025-06-20 16:06:03 +10:00
|
|
|
results = Queue.new
|
|
|
|
queued = 0
|
|
|
|
|
|
|
|
relation.each do |record|
|
|
|
|
text = prepare_text(record)
|
|
|
|
next if text.blank?
|
|
|
|
|
|
|
|
already_classified = record.sentiment_classifications.pluck(&:model_used)
|
|
|
|
missing_classifiers =
|
|
|
|
available_classifiers.reject { |ac| already_classified.include?(ac[:model_name]) }
|
|
|
|
|
|
|
|
missing_classifiers.each do |classifier|
|
|
|
|
pool.post do
|
|
|
|
result = { target: record, classifier: classifier, text: text }
|
|
|
|
begin
|
|
|
|
result[:classification] = request_with(classifier[:client], text)
|
|
|
|
rescue StandardError => e
|
|
|
|
result[:error] = e
|
|
|
|
end
|
|
|
|
results << result
|
2024-11-28 15:38:23 -03:00
|
|
|
end
|
2025-06-20 16:06:03 +10:00
|
|
|
queued += 1
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
errors = []
|
|
|
|
|
|
|
|
while queued > 0
|
|
|
|
result = results.pop
|
|
|
|
if result[:error]
|
|
|
|
errors << result
|
|
|
|
else
|
|
|
|
store_classification(
|
|
|
|
result[:target],
|
|
|
|
[[result[:classifier][:model_name], result[:classification]]],
|
|
|
|
)
|
|
|
|
end
|
|
|
|
queued -= 1
|
|
|
|
end
|
|
|
|
|
|
|
|
if errors.any?
|
|
|
|
example_posts = errors.map { |e| e[:target].id }.take(5).join(", ")
|
|
|
|
Discourse.warn_exception(
|
|
|
|
errors[0][:error],
|
|
|
|
"Discourse AI: Errors during bulk classification: Failed to classify #{errors.count} posts (example ids: #{example_posts})",
|
|
|
|
)
|
|
|
|
end
|
2025-03-21 11:08:36 +08:00
|
|
|
ensure
|
2024-11-28 15:38:23 -03:00
|
|
|
pool.shutdown
|
2025-06-20 16:06:03 +10:00
|
|
|
pool.wait_for_termination(timeout: 30)
|
2024-11-28 15:38:23 -03:00
|
|
|
end
|
|
|
|
|
|
|
|
def classify!(target)
|
|
|
|
return if target.blank?
|
2025-02-06 13:11:10 -03:00
|
|
|
available_classifiers = classifiers
|
|
|
|
return if available_classifiers.blank?
|
2024-11-28 15:38:23 -03:00
|
|
|
|
|
|
|
to_classify = prepare_text(target)
|
|
|
|
return if to_classify.blank?
|
|
|
|
|
2024-12-03 10:27:03 -03:00
|
|
|
already_classified = target.sentiment_classifications.map(&:model_used)
|
|
|
|
classifiers_for_target =
|
2025-02-06 13:11:10 -03:00
|
|
|
available_classifiers.reject { |ac| already_classified.include?(ac[:model_name]) }
|
2024-12-03 10:27:03 -03:00
|
|
|
|
2024-11-28 15:38:23 -03:00
|
|
|
results =
|
2025-02-06 13:11:10 -03:00
|
|
|
classifiers_for_target.reduce({}) do |memo, cft|
|
|
|
|
memo[cft[:model_name]] = request_with(cft[:client], to_classify)
|
2024-11-28 15:38:23 -03:00
|
|
|
memo
|
|
|
|
end
|
|
|
|
|
|
|
|
store_classification(target, results)
|
|
|
|
end
|
|
|
|
|
2024-12-02 14:18:03 -03:00
|
|
|
def classifiers
|
2025-02-06 13:11:10 -03:00
|
|
|
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.map do |config|
|
|
|
|
api_endpoint = config.endpoint
|
|
|
|
|
|
|
|
if api_endpoint.present? && api_endpoint.start_with?("srv://")
|
|
|
|
service = DiscourseAi::Utils::DnsSrv.lookup(api_endpoint.delete_prefix("srv://"))
|
|
|
|
api_endpoint = "https://#{service.target}:#{service.port}"
|
|
|
|
end
|
|
|
|
|
|
|
|
{
|
|
|
|
model_name: config.model_name,
|
|
|
|
client:
|
|
|
|
DiscourseAi::Inference::HuggingFaceTextEmbeddings.new(api_endpoint, config.api_key),
|
|
|
|
}
|
|
|
|
end
|
2024-12-02 14:18:03 -03:00
|
|
|
end
|
|
|
|
|
2024-12-03 10:27:03 -03:00
|
|
|
def has_classifiers?
|
|
|
|
classifiers.present?
|
|
|
|
end
|
|
|
|
|
2024-11-28 15:38:23 -03:00
|
|
|
private
|
|
|
|
|
|
|
|
def prepare_text(target)
|
|
|
|
content =
|
|
|
|
if target.post_number == 1
|
|
|
|
"#{target.topic.title}\n#{target.raw}"
|
|
|
|
else
|
|
|
|
target.raw
|
|
|
|
end
|
|
|
|
|
|
|
|
Tokenizer::BertTokenizer.truncate(content, 512)
|
|
|
|
end
|
|
|
|
|
2025-02-06 13:11:10 -03:00
|
|
|
def request_with(client, content)
|
|
|
|
result = client.classify_by_sentiment!(content)
|
|
|
|
|
2024-11-29 17:31:56 -03:00
|
|
|
transform_result(result)
|
|
|
|
end
|
|
|
|
|
|
|
|
def transform_result(result)
|
|
|
|
hash_result = {}
|
|
|
|
result.each { |r| hash_result[r[:label]] = r[:score] }
|
|
|
|
hash_result
|
2024-11-28 15:38:23 -03:00
|
|
|
end
|
|
|
|
|
|
|
|
def store_classification(target, classification)
|
|
|
|
attrs =
|
|
|
|
classification.map do |model_name, classifications|
|
|
|
|
{
|
|
|
|
model_used: model_name,
|
|
|
|
target_id: target.id,
|
|
|
|
target_type: target.class.sti_name,
|
|
|
|
classification_type: :sentiment,
|
|
|
|
classification: classifications,
|
|
|
|
updated_at: DateTime.now,
|
|
|
|
created_at: DateTime.now,
|
|
|
|
}
|
|
|
|
end
|
|
|
|
|
|
|
|
ClassificationResult.upsert_all(
|
|
|
|
attrs,
|
|
|
|
unique_by: %i[target_id target_type model_used],
|
|
|
|
update_only: %i[classification],
|
|
|
|
)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|