mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-03-06 09:20:14 +00:00
FIX: Sentiment classification results needs to be transformed before saving (#983)
This commit is contained in:
parent
120a20c5cd
commit
0abd4b1244
37
db/migrate/20241129190708_fix_classification_data.rb
Normal file
37
db/migrate/20241129190708_fix_classification_data.rb
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class FixClassificationData < ActiveRecord::Migration[7.2]
|
||||||
|
def up
|
||||||
|
classifications = DB.query(<<~SQL)
|
||||||
|
SELECT id, classification
|
||||||
|
FROM classification_results
|
||||||
|
WHERE classification_type = 'sentiment'
|
||||||
|
AND SUBSTRING(LTRIM(classification::text), 1, 1) = '['
|
||||||
|
SQL
|
||||||
|
|
||||||
|
transformed =
|
||||||
|
classifications.reduce([]) do |memo, c|
|
||||||
|
hash_result = {}
|
||||||
|
c.classification.each { |r| hash_result[r["label"]] = r["score"] }
|
||||||
|
|
||||||
|
memo << { id: c.id, fixed_classification: hash_result }
|
||||||
|
end
|
||||||
|
|
||||||
|
transformed_json = transformed.to_json
|
||||||
|
|
||||||
|
DB.exec(<<~SQL, values: transformed_json)
|
||||||
|
UPDATE classification_results
|
||||||
|
SET classification = N.fixed_classification
|
||||||
|
FROM (
|
||||||
|
SELECT (value::jsonb->'id')::integer AS id, (value::jsonb->'fixed_classification')::jsonb AS fixed_classification
|
||||||
|
FROM jsonb_array_elements(:values::jsonb)
|
||||||
|
) N
|
||||||
|
WHERE classification_results.id = N.id
|
||||||
|
AND classification_type = 'sentiment'
|
||||||
|
SQL
|
||||||
|
end
|
||||||
|
|
||||||
|
def down
|
||||||
|
raise ActiveRecord::IrreversibleMigration
|
||||||
|
end
|
||||||
|
end
|
@ -83,7 +83,15 @@ module DiscourseAi
|
|||||||
end
|
end
|
||||||
|
|
||||||
def request_with(content, config, base_url = Discourse.base_url)
|
def request_with(content, config, base_url = Discourse.base_url)
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url)
|
result =
|
||||||
|
DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url)
|
||||||
|
transform_result(result)
|
||||||
|
end
|
||||||
|
|
||||||
|
def transform_result(result)
|
||||||
|
hash_result = {}
|
||||||
|
result.each { |r| hash_result[r[:label]] = r[:score] }
|
||||||
|
hash_result
|
||||||
end
|
end
|
||||||
|
|
||||||
def store_classification(target, classification)
|
def store_classification(target, classification)
|
||||||
|
@ -11,6 +11,16 @@ RSpec.describe DiscourseAi::Sentiment::PostClassification do
|
|||||||
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
|
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def check_classification_for(post)
|
||||||
|
result =
|
||||||
|
ClassificationResult.find_by(
|
||||||
|
model_used: "cardiffnlp/twitter-roberta-base-sentiment-latest",
|
||||||
|
target: post,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(result.classification.keys).to contain_exactly("negative", "neutral", "positive")
|
||||||
|
end
|
||||||
|
|
||||||
describe "#classify!" do
|
describe "#classify!" do
|
||||||
it "does nothing if the post content is blank" do
|
it "does nothing if the post content is blank" do
|
||||||
post_1.update_columns(raw: "")
|
post_1.update_columns(raw: "")
|
||||||
@ -28,6 +38,13 @@ RSpec.describe DiscourseAi::Sentiment::PostClassification do
|
|||||||
|
|
||||||
expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
|
expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "classification results must be { emotion => score }" do
|
||||||
|
SentimentInferenceStubs.stub_classification(post_1)
|
||||||
|
|
||||||
|
subject.classify!(post_1)
|
||||||
|
check_classification_for(post_1)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "#classify_bulk!" do
|
describe "#classify_bulk!" do
|
||||||
@ -43,5 +60,15 @@ RSpec.describe DiscourseAi::Sentiment::PostClassification do
|
|||||||
expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
|
expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis)
|
||||||
expect(ClassificationResult.where(target: post_2).count).to eq(expected_analysis)
|
expect(ClassificationResult.where(target: post_2).count).to eq(expected_analysis)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "classification results must be { emotion => score }" do
|
||||||
|
SentimentInferenceStubs.stub_classification(post_1)
|
||||||
|
SentimentInferenceStubs.stub_classification(post_2)
|
||||||
|
|
||||||
|
subject.bulk_classify!(Post.where(id: [post_1.id, post_2.id]))
|
||||||
|
|
||||||
|
check_classification_for(post_1)
|
||||||
|
check_classification_for(post_2)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
x
Reference in New Issue
Block a user