mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-06-29 19:12:15 +00:00
REFACTOR: Simplify sentiment classification (#977)
This change adds a simpler class for sentiment classification, replacing the soon-to-be removed `Classificator` hierarchy. Additionally, it adds a method for classifying concurrently, speeding up the backfill rake task.
This commit is contained in:
parent
6456a4f44a
commit
c980c34d77
@ -9,9 +9,7 @@ module ::Jobs
|
||||
post = Post.find_by(id: post_id, post_type: Post.types[:regular])
|
||||
return if post&.raw.blank?
|
||||
|
||||
DiscourseAi::PostClassificator.new(
|
||||
DiscourseAi::Sentiment::SentimentClassification.new,
|
||||
).classify!(post)
|
||||
DiscourseAi::Sentiment::PostClassification.new.classify!(post)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -64,8 +64,8 @@ module ::DiscourseAi
|
||||
JSON.parse(response.body, symbolize_names: true)
|
||||
end
|
||||
|
||||
def classify(content, model_config)
|
||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||
def classify(content, model_config, base_url = Discourse.base_url)
|
||||
headers = { "Referer" => base_url, "Content-Type" => "application/json" }
|
||||
headers["X-API-KEY"] = model_config.api_key
|
||||
headers["Authorization"] = "Bearer #{model_config.api_key}"
|
||||
|
||||
|
111
lib/sentiment/post_classification.rb
Normal file
111
lib/sentiment/post_classification.rb
Normal file
@ -0,0 +1,111 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Sentiment
|
||||
class PostClassification
|
||||
def bulk_classify!(relation)
|
||||
http_pool_size = 100
|
||||
pool =
|
||||
Concurrent::CachedThreadPool.new(
|
||||
min_threads: 0,
|
||||
max_threads: http_pool_size,
|
||||
idletime: 30,
|
||||
)
|
||||
|
||||
available_classifiers = classifiers
|
||||
base_url = Discourse.base_url
|
||||
|
||||
promised_classifications =
|
||||
relation
|
||||
.map do |record|
|
||||
text = prepare_text(record)
|
||||
next if text.blank?
|
||||
|
||||
Concurrent::Promises
|
||||
.fulfilled_future({ target: record, text: text }, pool)
|
||||
.then_on(pool) do |w_text|
|
||||
results = Concurrent::Hash.new
|
||||
|
||||
promised_target_results =
|
||||
available_classifiers.map do |c|
|
||||
Concurrent::Promises.future_on(pool) do
|
||||
results[c.model_name] = request_with(w_text[:text], c, base_url)
|
||||
end
|
||||
end
|
||||
|
||||
Concurrent::Promises
|
||||
.zip(*promised_target_results)
|
||||
.then_on(pool) { |_| w_text.merge(classification: results) }
|
||||
end
|
||||
.flat(1)
|
||||
end
|
||||
.compact
|
||||
|
||||
Concurrent::Promises
|
||||
.zip(*promised_classifications)
|
||||
.value!
|
||||
.each { |r| store_classification(r[:target], r[:classification]) }
|
||||
|
||||
pool.shutdown
|
||||
pool.wait_for_termination
|
||||
end
|
||||
|
||||
def classify!(target)
|
||||
return if target.blank?
|
||||
|
||||
to_classify = prepare_text(target)
|
||||
return if to_classify.blank?
|
||||
|
||||
results =
|
||||
classifiers.reduce({}) do |memo, model|
|
||||
memo[model.model_name] = request_with(to_classify, model)
|
||||
memo
|
||||
end
|
||||
|
||||
store_classification(target, results)
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def prepare_text(target)
|
||||
content =
|
||||
if target.post_number == 1
|
||||
"#{target.topic.title}\n#{target.raw}"
|
||||
else
|
||||
target.raw
|
||||
end
|
||||
|
||||
Tokenizer::BertTokenizer.truncate(content, 512)
|
||||
end
|
||||
|
||||
def classifiers
|
||||
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values
|
||||
end
|
||||
|
||||
def request_with(content, config, base_url = Discourse.base_url)
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url)
|
||||
end
|
||||
|
||||
def store_classification(target, classification)
|
||||
attrs =
|
||||
classification.map do |model_name, classifications|
|
||||
{
|
||||
model_used: model_name,
|
||||
target_id: target.id,
|
||||
target_type: target.class.sti_name,
|
||||
classification_type: :sentiment,
|
||||
classification: classifications,
|
||||
updated_at: DateTime.now,
|
||||
created_at: DateTime.now,
|
||||
}
|
||||
end
|
||||
|
||||
ClassificationResult.upsert_all(
|
||||
attrs,
|
||||
unique_by: %i[target_id target_type model_used],
|
||||
update_only: %i[classification],
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
@ -14,11 +14,8 @@ task "ai:sentiment:backfill", [:start_post] => [:environment] do |_, args|
|
||||
.where("category_id IN (?)", public_categories)
|
||||
.where(posts: { deleted_at: nil })
|
||||
.where(topics: { deleted_at: nil })
|
||||
.order("posts.id ASC")
|
||||
.find_each do |post|
|
||||
.find_in_batches do |batch|
|
||||
print "."
|
||||
DiscourseAi::PostClassificator.new(
|
||||
DiscourseAi::Sentiment::SentimentClassification.new,
|
||||
).classify!(post)
|
||||
DiscourseAi::Sentiment::PostClassification.new.bulk_classify!(batch)
|
||||
end
|
||||
end
|
||||
|
47
spec/lib/modules/sentiment/post_classification_spec.rb
Normal file
47
spec/lib/modules/sentiment/post_classification_spec.rb
Normal file
@ -0,0 +1,47 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
require_relative "../../../support/sentiment_inference_stubs"
|
||||
|
||||
RSpec.describe DiscourseAi::Sentiment::PostClassification do
|
||||
fab!(:post_1) { Fabricate(:post, post_number: 2) }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_sentiment_enabled = true
|
||||
SiteSetting.ai_sentiment_model_configs =
|
||||
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
|
||||
end
|
||||
|
||||
describe "#classify!" do
|
||||
it "does nothing if the post content is blank" do
|
||||
post_1.update_columns(raw: "")
|
||||
|
||||
subject.classify!(post_1)
|
||||
|
||||
expect(ClassificationResult.where(target: post_1).count).to be_zero
|
||||
end
|
||||
|
||||
it "successfully classifies the post" do
|
||||
expected_analysis = DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.length
|
||||
SentimentInferenceStubs.stub_classification(post_1)
|
||||
|
||||
subject.classify!(post_1)
|
||||
|
||||
expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#classify_bulk!" do
|
||||
fab!(:post_2) { Fabricate(:post, post_number: 2) }
|
||||
|
||||
it "classifies all given posts" do
|
||||
expected_analysis = DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.length
|
||||
SentimentInferenceStubs.stub_classification(post_1)
|
||||
SentimentInferenceStubs.stub_classification(post_2)
|
||||
|
||||
subject.bulk_classify!(Post.where(id: [post_1.id, post_2.id]))
|
||||
|
||||
expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
|
||||
expect(ClassificationResult.where(target: post_2).count).to eq(expected_analysis)
|
||||
end
|
||||
end
|
||||
end
|
Loading…
x
Reference in New Issue
Block a user