mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-07-01 12:02:16 +00:00
REFACTOR: Simplify sentiment classification (#977)
This change adds a simpler class for sentiment classification, replacing the soon-to-be removed `Classificator` hierarchy. Additionally, it adds a method for classifying concurrently, speeding up the backfill rake task.
This commit is contained in:
parent
6456a4f44a
commit
c980c34d77
@ -9,9 +9,7 @@ module ::Jobs
|
|||||||
post = Post.find_by(id: post_id, post_type: Post.types[:regular])
|
post = Post.find_by(id: post_id, post_type: Post.types[:regular])
|
||||||
return if post&.raw.blank?
|
return if post&.raw.blank?
|
||||||
|
|
||||||
DiscourseAi::PostClassificator.new(
|
DiscourseAi::Sentiment::PostClassification.new.classify!(post)
|
||||||
DiscourseAi::Sentiment::SentimentClassification.new,
|
|
||||||
).classify!(post)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -64,8 +64,8 @@ module ::DiscourseAi
|
|||||||
JSON.parse(response.body, symbolize_names: true)
|
JSON.parse(response.body, symbolize_names: true)
|
||||||
end
|
end
|
||||||
|
|
||||||
def classify(content, model_config)
|
def classify(content, model_config, base_url = Discourse.base_url)
|
||||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
headers = { "Referer" => base_url, "Content-Type" => "application/json" }
|
||||||
headers["X-API-KEY"] = model_config.api_key
|
headers["X-API-KEY"] = model_config.api_key
|
||||||
headers["Authorization"] = "Bearer #{model_config.api_key}"
|
headers["Authorization"] = "Bearer #{model_config.api_key}"
|
||||||
|
|
||||||
|
111
lib/sentiment/post_classification.rb
Normal file
111
lib/sentiment/post_classification.rb
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Sentiment
|
||||||
|
class PostClassification
|
||||||
|
def bulk_classify!(relation)
|
||||||
|
http_pool_size = 100
|
||||||
|
pool =
|
||||||
|
Concurrent::CachedThreadPool.new(
|
||||||
|
min_threads: 0,
|
||||||
|
max_threads: http_pool_size,
|
||||||
|
idletime: 30,
|
||||||
|
)
|
||||||
|
|
||||||
|
available_classifiers = classifiers
|
||||||
|
base_url = Discourse.base_url
|
||||||
|
|
||||||
|
promised_classifications =
|
||||||
|
relation
|
||||||
|
.map do |record|
|
||||||
|
text = prepare_text(record)
|
||||||
|
next if text.blank?
|
||||||
|
|
||||||
|
Concurrent::Promises
|
||||||
|
.fulfilled_future({ target: record, text: text }, pool)
|
||||||
|
.then_on(pool) do |w_text|
|
||||||
|
results = Concurrent::Hash.new
|
||||||
|
|
||||||
|
promised_target_results =
|
||||||
|
available_classifiers.map do |c|
|
||||||
|
Concurrent::Promises.future_on(pool) do
|
||||||
|
results[c.model_name] = request_with(w_text[:text], c, base_url)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
Concurrent::Promises
|
||||||
|
.zip(*promised_target_results)
|
||||||
|
.then_on(pool) { |_| w_text.merge(classification: results) }
|
||||||
|
end
|
||||||
|
.flat(1)
|
||||||
|
end
|
||||||
|
.compact
|
||||||
|
|
||||||
|
Concurrent::Promises
|
||||||
|
.zip(*promised_classifications)
|
||||||
|
.value!
|
||||||
|
.each { |r| store_classification(r[:target], r[:classification]) }
|
||||||
|
|
||||||
|
pool.shutdown
|
||||||
|
pool.wait_for_termination
|
||||||
|
end
|
||||||
|
|
||||||
|
def classify!(target)
|
||||||
|
return if target.blank?
|
||||||
|
|
||||||
|
to_classify = prepare_text(target)
|
||||||
|
return if to_classify.blank?
|
||||||
|
|
||||||
|
results =
|
||||||
|
classifiers.reduce({}) do |memo, model|
|
||||||
|
memo[model.model_name] = request_with(to_classify, model)
|
||||||
|
memo
|
||||||
|
end
|
||||||
|
|
||||||
|
store_classification(target, results)
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def prepare_text(target)
|
||||||
|
content =
|
||||||
|
if target.post_number == 1
|
||||||
|
"#{target.topic.title}\n#{target.raw}"
|
||||||
|
else
|
||||||
|
target.raw
|
||||||
|
end
|
||||||
|
|
||||||
|
Tokenizer::BertTokenizer.truncate(content, 512)
|
||||||
|
end
|
||||||
|
|
||||||
|
def classifiers
|
||||||
|
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values
|
||||||
|
end
|
||||||
|
|
||||||
|
def request_with(content, config, base_url = Discourse.base_url)
|
||||||
|
DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url)
|
||||||
|
end
|
||||||
|
|
||||||
|
def store_classification(target, classification)
|
||||||
|
attrs =
|
||||||
|
classification.map do |model_name, classifications|
|
||||||
|
{
|
||||||
|
model_used: model_name,
|
||||||
|
target_id: target.id,
|
||||||
|
target_type: target.class.sti_name,
|
||||||
|
classification_type: :sentiment,
|
||||||
|
classification: classifications,
|
||||||
|
updated_at: DateTime.now,
|
||||||
|
created_at: DateTime.now,
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
ClassificationResult.upsert_all(
|
||||||
|
attrs,
|
||||||
|
unique_by: %i[target_id target_type model_used],
|
||||||
|
update_only: %i[classification],
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -14,11 +14,8 @@ task "ai:sentiment:backfill", [:start_post] => [:environment] do |_, args|
|
|||||||
.where("category_id IN (?)", public_categories)
|
.where("category_id IN (?)", public_categories)
|
||||||
.where(posts: { deleted_at: nil })
|
.where(posts: { deleted_at: nil })
|
||||||
.where(topics: { deleted_at: nil })
|
.where(topics: { deleted_at: nil })
|
||||||
.order("posts.id ASC")
|
.find_in_batches do |batch|
|
||||||
.find_each do |post|
|
|
||||||
print "."
|
print "."
|
||||||
DiscourseAi::PostClassificator.new(
|
DiscourseAi::Sentiment::PostClassification.new.bulk_classify!(batch)
|
||||||
DiscourseAi::Sentiment::SentimentClassification.new,
|
|
||||||
).classify!(post)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
47
spec/lib/modules/sentiment/post_classification_spec.rb
Normal file
47
spec/lib/modules/sentiment/post_classification_spec.rb
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "../../../support/sentiment_inference_stubs"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Sentiment::PostClassification do
|
||||||
|
fab!(:post_1) { Fabricate(:post, post_number: 2) }
|
||||||
|
|
||||||
|
before do
|
||||||
|
SiteSetting.ai_sentiment_enabled = true
|
||||||
|
SiteSetting.ai_sentiment_model_configs =
|
||||||
|
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#classify!" do
|
||||||
|
it "does nothing if the post content is blank" do
|
||||||
|
post_1.update_columns(raw: "")
|
||||||
|
|
||||||
|
subject.classify!(post_1)
|
||||||
|
|
||||||
|
expect(ClassificationResult.where(target: post_1).count).to be_zero
|
||||||
|
end
|
||||||
|
|
||||||
|
it "successfully classifies the post" do
|
||||||
|
expected_analysis = DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.length
|
||||||
|
SentimentInferenceStubs.stub_classification(post_1)
|
||||||
|
|
||||||
|
subject.classify!(post_1)
|
||||||
|
|
||||||
|
expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#classify_bulk!" do
|
||||||
|
fab!(:post_2) { Fabricate(:post, post_number: 2) }
|
||||||
|
|
||||||
|
it "classifies all given posts" do
|
||||||
|
expected_analysis = DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values.length
|
||||||
|
SentimentInferenceStubs.stub_classification(post_1)
|
||||||
|
SentimentInferenceStubs.stub_classification(post_2)
|
||||||
|
|
||||||
|
subject.bulk_classify!(Post.where(id: [post_1.id, post_2.id]))
|
||||||
|
|
||||||
|
expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
|
||||||
|
expect(ClassificationResult.where(target: post_2).count).to eq(expected_analysis)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
Loading…
x
Reference in New Issue
Block a user