diff --git a/db/migrate/20241129190708_fix_classification_data.rb b/db/migrate/20241129190708_fix_classification_data.rb new file mode 100644 index 00000000..5e91b6e2 --- /dev/null +++ b/db/migrate/20241129190708_fix_classification_data.rb @@ -0,0 +1,37 @@ +# frozen_string_literal: true + +class FixClassificationData < ActiveRecord::Migration[7.2] + def up + classifications = DB.query(<<~SQL) + SELECT id, classification + FROM classification_results + WHERE classification_type = 'sentiment' + AND SUBSTRING(LTRIM(classification::text), 1, 1) = '[' + SQL + + transformed = + classifications.reduce([]) do |memo, c| + hash_result = {} + c.classification.each { |r| hash_result[r["label"]] = r["score"] } + + memo << { id: c.id, fixed_classification: hash_result } + end + + transformed_json = transformed.to_json + + DB.exec(<<~SQL, values: transformed_json) + UPDATE classification_results + SET classification = N.fixed_classification + FROM ( + SELECT (value::jsonb->'id')::integer AS id, (value::jsonb->'fixed_classification')::jsonb AS fixed_classification + FROM jsonb_array_elements(:values::jsonb) + ) N + WHERE classification_results.id = N.id + AND classification_type = 'sentiment' + SQL + end + + def down + raise ActiveRecord::IrreversibleMigration + end +end diff --git a/lib/sentiment/post_classification.rb b/lib/sentiment/post_classification.rb index eea10ce9..18c9c186 100644 --- a/lib/sentiment/post_classification.rb +++ b/lib/sentiment/post_classification.rb @@ -83,7 +83,15 @@ module DiscourseAi end def request_with(content, config, base_url = Discourse.base_url) - DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url) + result = + DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, config, base_url) + transform_result(result) + end + + def transform_result(result) + hash_result = {} + result.each { |r| hash_result[r[:label]] = r[:score] } + hash_result end def store_classification(target, classification) diff --git a/spec/lib/modules/sentiment/post_classification_spec.rb b/spec/lib/modules/sentiment/post_classification_spec.rb index 99ce44de..fc852151 100644 --- a/spec/lib/modules/sentiment/post_classification_spec.rb +++ b/spec/lib/modules/sentiment/post_classification_spec.rb @@ -11,6 +11,16 @@ RSpec.describe DiscourseAi::Sentiment::PostClassification do "[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]" end + def check_classification_for(post) + result = + ClassificationResult.find_by( + model_used: "cardiffnlp/twitter-roberta-base-sentiment-latest", + target: post, + ) + + expect(result.classification.keys).to contain_exactly("negative", "neutral", "positive") + end + describe "#classify!" do it "does nothing if the post content is blank" do post_1.update_columns(raw: "") @@ -28,6 +38,13 @@ RSpec.describe DiscourseAi::Sentiment::PostClassification do expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis) end + + it "classification results must be { emotion => score }" do + SentimentInferenceStubs.stub_classification(post_1) + + subject.classify!(post_1) + check_classification_for(post_1) + end end describe "#classify_bulk!" do @@ -43,5 +60,15 @@ RSpec.describe DiscourseAi::Sentiment::PostClassification do expect(ClassificationResult.where(target: post_1).count).to eq(expected_analysis) expect(ClassificationResult.where(target: post_2).count).to eq(expected_analysis) end + + it "classification results must be { emotion => score }" do + SentimentInferenceStubs.stub_classification(post_1) + SentimentInferenceStubs.stub_classification(post_2) + + subject.bulk_classify!(Post.where(id: [post_1.id, post_2.id])) + + check_classification_for(post_1) + check_classification_for(post_2) + end end end