REFACTOR: Simplify sentiment classification (#977)

This change adds a simpler class for sentiment classification, replacing the soon-to-be removed `Classificator` hierarchy. Additionally, it adds a method for classifying concurrently, speeding up the backfill rake task.
This commit is contained in:
Roman Rizzi 2024-11-28 15:38:23 -03:00 committed by GitHub
parent 6456a4f44a
commit c980c34d77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 163 additions and 10 deletions

View File

@ -9,9 +9,7 @@ module ::Jobs
post = Post.find_by(id: post_id, post_type: Post.types[:regular])
return if post&.raw.blank?
DiscourseAi::PostClassificator.new(
DiscourseAi::Sentiment::SentimentClassification.new,
).classify!(post)
DiscourseAi::Sentiment::PostClassification.new.classify!(post)
end
end
end

View File

@ -64,8 +64,8 @@ module ::DiscourseAi
JSON.parse(response.body, symbolize_names: true)
end
def classify(content, model_config)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
def classify(content, model_config, base_url = Discourse.base_url)
headers = { "Referer" => base_url, "Content-Type" => "application/json" }
headers["X-API-KEY"] = model_config.api_key
headers["Authorization"] = "Bearer #{model_config.api_key}"

View File

@ -0,0 +1,111 @@
# frozen_string_literal: true
module DiscourseAi
module Sentiment
class PostClassification
def bulk_classify!(relation)
http_pool_size = 100
pool =
Concurrent::CachedThreadPool.new(
min_threads: 0,
max_threads: http_pool_size,
idletime: 30,
)
available_classifiers = classifiers
base_url = Discourse.base_url
promised_classifications =
relation
.map do |record|
text = prepare_text(record)
next if text.blank?
Concurrent::Promises
.fulfilled_future({ target: record, text: text }, pool)
.then_on(pool) do |w_text|
results = Concurrent::Hash.new
promised_target_results =
available_classifiers.map do |c|
Concurrent::Promises.future_on(pool) do
results[c.model_name] = request_with(w_text[:text], c, base_url)
end
end
Concurrent::Promises
.zip(*promised_target_results)
.then_on(pool) { |_| w_text.merge(classification: results) }
end
.flat(1)
end
.compact
Concurrent::Promises
.zip(*promised_classifications)
.value!
.each { |r| store_classification(r[:target], r[:classification]) }
pool.shutdown
pool.wait_for_termination
end
def classify!(target)
return if target.blank?
to_classify = prepare_text(target)
return if to_classify.blank?
results =
classifiers.reduce({}) do |memo, model|
memo[model.model_name] = request_with(to_classify, model)
memo
end
store_classification(target, results)
end
private
def prepare_text(target)
content =
if target.post_number == 1
"#{target.topic.title}\n#{target.raw}"
else
target.raw
end
Tokenizer::BertTokenizer.truncate(content, 512)
end
def classifiers
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values
end
def request_with(content, config, base_url = Discourse.base_url)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url)
end
def store_classification(target, classification)
attrs =
classification.map do |model_name, classifications|
{
model_used: model_name,
target_id: target.id,
target_type: target.class.sti_name,
classification_type: :sentiment,
classification: classifications,
updated_at: DateTime.now,
created_at: DateTime.now,
}
end
ClassificationResult.upsert_all(
attrs,
unique_by: %i[target_id target_type model_used],
update_only: %i[classification],
)
end
end
end
end

View File

@ -14,11 +14,8 @@ task "ai:sentiment:backfill", [:start_post] => [:environment] do |_, args|
.where("category_id IN (?)", public_categories)
.where(posts: { deleted_at: nil })
.where(topics: { deleted_at: nil })
.order("posts.id ASC")
.find_each do |post|
.find_in_batches do |batch|
print "."
DiscourseAi::PostClassificator.new(
DiscourseAi::Sentiment::SentimentClassification.new,
).classify!(post)
DiscourseAi::Sentiment::PostClassification.new.bulk_classify!(batch)
end
end

View File

@ -0,0 +1,47 @@
# frozen_string_literal: true
require_relative "../../../support/sentiment_inference_stubs"
RSpec.describe DiscourseAi::Sentiment::PostClassification do
fab!(:post_1) { Fabricate(:post, post_number: 2) }
before do
SiteSetting.ai_sentiment_enabled = true
SiteSetting.ai_sentiment_model_configs =
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
end
describe "#classify!" do
it "does nothing if the post content is blank" do
post_1.update_columns(raw: "")
subject.classify!(post_1)
expect(ClassificationResult.where(target: post_1).count).to be_zero
end
it "successfully classifies the post" do
expected_analysis = DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.length
SentimentInferenceStubs.stub_classification(post_1)
subject.classify!(post_1)
expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
end
end
describe "#classify_bulk!" do
fab!(:post_2) { Fabricate(:post, post_number: 2) }
it "classifies all given posts" do
expected_analysis = DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.length
SentimentInferenceStubs.stub_classification(post_1)
SentimentInferenceStubs.stub_classification(post_2)
subject.bulk_classify!(Post.where(id: [post_1.id, post_2.id]))
expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
expect(ClassificationResult.where(target: post_2).count).to eq(expected_analysis)
end
end
end