Migrate sentiment to a TEI backend (#886)

This commit is contained in:
Rafael dos Santos Silva 2024-11-04 09:14:34 -03:00 committed by GitHub
parent bffe9dfa07
commit 772ee934ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 287 additions and 92 deletions

View File

@ -56,22 +56,9 @@ discourse_ai:
ai_sentiment_enabled:
default: false
client: true
ai_sentiment_inference_service_api_endpoint:
default: "https://sentiment-testing.demo-by-discourse.com"
ai_sentiment_inference_service_api_endpoint_srv:
ai_sentiment_model_configs:
default: ""
hidden: true
ai_sentiment_inference_service_api_key:
default: ""
secret: true
ai_sentiment_models:
type: list
list_type: compact
default: "emotion|sentiment"
allow_any: false
choices:
- sentiment
- emotion
json_schema: DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema
ai_nsfw_detection_enabled:
default: false

View File

@ -0,0 +1,30 @@
# frozen_string_literal: true
class MigrateSentimentClassificationResultFormat < ActiveRecord::Migration[7.1]
def up
DB.exec(<<~SQL)
UPDATE classification_results
SET
model_used = 'cardiffnlp/twitter-roberta-base-sentiment-latest',
classification = jsonb_build_object(
'neutral', (classification->>'neutral')::float / 100,
'negative', (classification->>'negative')::float / 100,
'positive', (classification->>'positive')::float / 100
)
WHERE model_used = 'sentiment';
UPDATE classification_results
SET
model_used = 'j-hartmann/emotion-english-distilroberta-base',
classification = jsonb_build_object(
'sadness', (classification->>'sadness')::float / 100,
'surprise', (classification->>'surprise')::float / 100,
'fear', (classification->>'fear')::float / 100,
'anger', (classification->>'anger')::float / 100,
'joy', (classification->>'joy')::float / 100,
'disgust', (classification->>'disgust')::float / 100,
'neutral', (classification->>'neutral')::float / 100
)
WHERE model_used = 'emotion';
SQL
end
end

View File

@ -58,6 +58,29 @@ module ::DiscourseAi
JSON.parse(response.body, symbolize_names: true)
end
def classify(content, model_config)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
headers["X-API-KEY"] = model_config.api_key
headers["Authorization"] = "Bearer #{model_config.api_key}"
body = { inputs: content, truncate: true }.to_json
api_endpoint = model_config.endpoint
if api_endpoint.present? && api_endpoint.start_with?("srv://")
service = DiscourseAi::Utils::DnsSrv.lookup(api_endpoint.delete_prefix("srv://"))
api_endpoint = "https://#{service.target}:#{service.port}"
end
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post(api_endpoint, body, headers)
if response.status != 200
raise Net::HTTPBadResponse.new("Status: #{response.status}\n\n#{response.body}")
end
JSON.parse(response.body, symbolize_names: true)
end
def reranker_configured?
SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?

View File

@ -16,11 +16,11 @@ module DiscourseAi
plugin.add_report("overall_sentiment") do |report|
report.modes = [:stacked_chart]
threshold = 60
threshold = 0.6
sentiment_count_sql = Proc.new { |sentiment| <<~SQL }
COUNT(
CASE WHEN (cr.classification::jsonb->'#{sentiment}')::integer > :threshold THEN 1 ELSE NULL END
CASE WHEN (cr.classification::jsonb->'#{sentiment}')::float > :threshold THEN 1 ELSE NULL END
) AS #{sentiment}_count
SQL
@ -39,7 +39,7 @@ module DiscourseAi
WHERE
t.archetype = 'regular' AND
p.user_id > 0 AND
cr.model_used = 'sentiment' AND
cr.model_used = 'cardiffnlp/twitter-roberta-base-sentiment-latest' AND
(p.created_at > :report_start AND p.created_at < :report_end)
GROUP BY DATE_TRUNC('day', p.created_at)
SQL
@ -68,11 +68,11 @@ module DiscourseAi
plugin.add_report("post_emotion") do |report|
report.modes = [:stacked_line_chart]
threshold = 30
threshold = 0.3
emotion_count_clause = Proc.new { |emotion| <<~SQL }
COUNT(
CASE WHEN (cr.classification::jsonb->'#{emotion}')::integer > :threshold THEN 1 ELSE NULL END
CASE WHEN (cr.classification::jsonb->'#{emotion}')::float > :threshold THEN 1 ELSE NULL END
) AS #{emotion}_count
SQL
@ -96,7 +96,7 @@ module DiscourseAi
WHERE
t.archetype = 'regular' AND
p.user_id > 0 AND
cr.model_used = 'emotion' AND
cr.model_used = 'j-hartmann/emotion-english-distilroberta-base' AND
(p.created_at > :report_start AND p.created_at < :report_end)
GROUP BY DATE_TRUNC('day', p.created_at)
SQL

View File

@ -7,8 +7,8 @@ module DiscourseAi
:sentiment
end
def available_models
SiteSetting.ai_sentiment_models.split("|")
def available_classifiers
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values
end
def can_classify?(target)
@ -16,8 +16,8 @@ module DiscourseAi
end
def get_verdicts(_)
available_models.reduce({}) do |memo, model|
memo[model] = false
available_classifiers.reduce({}) do |memo, model|
memo[model.model_name] = false
memo
end
end
@ -30,21 +30,23 @@ module DiscourseAi
def request(target_to_classify)
target_content = content_of(target_to_classify)
available_models.reduce({}) do |memo, model|
memo[model] = request_with(model, target_content)
available_classifiers.reduce({}) do |memo, model|
memo[model.model_name] = request_with(target_content, model)
memo
end
end
def transform_result(result)
hash_result = {}
result.each { |r| hash_result[r[:label]] = r[:score] }
hash_result
end
private
def request_with(model, content)
::DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{endpoint}/api/v1/classify",
model,
content,
SiteSetting.ai_sentiment_inference_service_api_key,
)
def request_with(content, model_config)
result = ::DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, model_config)
transform_result(result)
end
def content_of(target_to_classify)
@ -57,18 +59,6 @@ module DiscourseAi
Tokenizer::BertTokenizer.truncate(content, 512)
end
def endpoint
if SiteSetting.ai_sentiment_inference_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_sentiment_inference_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_sentiment_inference_service_api_endpoint
end
end
end
end
end

View File

@ -0,0 +1,34 @@
# frozen_string_literal: true
module DiscourseAi
module Sentiment
class SentimentSiteSettingJsonSchema
def self.schema
@schema ||= {
type: "array",
items: {
type: "object",
format: "table",
title: "model",
properties: {
model_name: {
type: "string",
},
endpoint: {
type: "string",
},
api_key: {
type: "string",
},
},
required: %w[model_name endpoint api_key],
},
}
end
def self.values
JSON.parse(SiteSetting.ai_sentiment_model_configs, object_class: OpenStruct)
end
end
end
end

View File

@ -0,0 +1,52 @@
# frozen_string_literal: true
require "rails_helper"
require Rails.root.join(
"plugins/discourse-ai/db/post_migrate/20241031041242_migrate_sentiment_classification_result_format",
)
RSpec.describe MigrateSentimentClassificationResultFormat do
let(:connection) { ActiveRecord::Base.connection }
before { connection.execute(<<~SQL) }
INSERT INTO classification_results (model_used, classification, created_at, updated_at) VALUES
('sentiment', '{"neutral": 65, "negative": 20, "positive": 14}', NOW(), NOW()),
('emotion', '{"sadness": 10, "surprise": 15, "fear": 5, "anger": 20, "joy": 30, "disgust": 8, "neutral": 10}', NOW(), NOW());
SQL
after { connection.execute("DELETE FROM classification_results") }
describe "#up" do
before { described_class.new.up }
it "migrates sentiment classifications correctly" do
sentiment_result = connection.execute(<<~SQL).first
SELECT * FROM classification_results
WHERE model_used = 'cardiffnlp/twitter-roberta-base-sentiment-latest';
SQL
expected_sentiment = { "neutral" => 0.65, "negative" => 0.20, "positive" => 0.14 }
expect(JSON.parse(sentiment_result["classification"])).to eq(expected_sentiment)
end
it "migrates emotion classifications correctly" do
emotion_result = connection.execute(<<~SQL).first
SELECT * FROM classification_results
WHERE model_used = 'j-hartmann/emotion-english-distilroberta-base';
SQL
expected_emotion = {
"sadness" => 0.10,
"surprise" => 0.15,
"fear" => 0.05,
"anger" => 0.20,
"joy" => 0.30,
"disgust" => 0.08,
"neutral" => 0.10,
}
expect(JSON.parse(emotion_result["classification"])).to eq(expected_emotion)
end
end
end

View File

@ -6,11 +6,13 @@ Fabricator(:classification_result) do
end
Fabricator(:sentiment_classification, from: :classification_result) do
model_used "sentiment"
classification { { negative: 72, neutral: 23, positive: 4 } }
model_used "cardiffnlp/twitter-roberta-base-sentiment-latest"
classification { { negative: 0.72, neutral: 0.23, positive: 0.4 } }
end
Fabricator(:emotion_classification, from: :classification_result) do
model_used "emotion"
classification { { negative: 72, neutral: 23, positive: 4 } }
model_used "j-hartmann/emotion-english-distilroberta-base"
classification do
{ sadness: 0.72, surprise: 0.23, fear: 0.4, anger: 0.87, joy: 0.22, disgust: 0.70 }
end
end

View File

@ -53,7 +53,10 @@ RSpec.describe DiscourseAi::Sentiment::EntryPoint do
end
describe "custom reports" do
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
before do
SiteSetting.ai_sentiment_model_configs =
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
end
fab!(:pm) { Fabricate(:private_message_post) }
@ -61,8 +64,8 @@ RSpec.describe DiscourseAi::Sentiment::EntryPoint do
fab!(:post_2) { Fabricate(:post) }
describe "overall_sentiment report" do
let(:positive_classification) { { negative: 2, neutral: 30, positive: 70 } }
let(:negative_classification) { { negative: 65, neutral: 2, positive: 10 } }
let(:positive_classification) { { negative: 0.2, neutral: 0.3, positive: 0.7 } }
let(:negative_classification) { { negative: 0.65, neutral: 0.2, positive: 0.1 } }
def sentiment_classification(post, classification)
Fabricate(:sentiment_classification, target: post, classification: classification)
@ -84,12 +87,28 @@ RSpec.describe DiscourseAi::Sentiment::EntryPoint do
describe "post_emotion report" do
let(:emotion_1) do
{ sadness: 49, surprise: 23, neutral: 6, fear: 34, anger: 87, joy: 22, disgust: 70 }
{
sadness: 0.49,
surprise: 0.23,
neutral: 0.6,
fear: 0.34,
anger: 0.87,
joy: 0.22,
disgust: 0.70,
}
end
let(:emotion_2) do
{ sadness: 19, surprise: 63, neutral: 45, fear: 44, anger: 27, joy: 62, disgust: 30 }
{
sadness: 0.19,
surprise: 0.63,
neutral: 0.45,
fear: 0.44,
anger: 0.27,
joy: 0.62,
disgust: 0.30,
}
end
let(:model_used) { "emotion" }
let(:model_used) { "j-hartmann/emotion-english-distilroberta-base" }
def emotion_classification(post, classification)
Fabricate(
@ -106,7 +125,7 @@ RSpec.describe DiscourseAi::Sentiment::EntryPoint do
end
it "calculate averages using only public posts" do
threshold = 30
threshold = 0.30
emotion_classification(post_1, emotion_1)
emotion_classification(post_2, emotion_2)

View File

@ -8,7 +8,8 @@ describe Jobs::PostSentimentAnalysis do
before do
SiteSetting.ai_sentiment_enabled = true
SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com"
SiteSetting.ai_sentiment_model_configs =
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
end
describe "scenarios where we return early without doing anything" do
@ -42,7 +43,8 @@ describe Jobs::PostSentimentAnalysis do
end
it "successfully classifies the post" do
expected_analysis = SiteSetting.ai_sentiment_models.split("|").length
expected_analysis =
DiscourseAi::Sentiment::SentimentClassification.new.available_classifiers.length
SentimentInferenceStubs.stub_classification(post)
subject.execute({ post_id: post.id })

View File

@ -6,15 +6,20 @@ describe DiscourseAi::Sentiment::SentimentClassification do
fab!(:target) { Fabricate(:post) }
describe "#request" do
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
before do
SiteSetting.ai_sentiment_model_configs =
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
end
it "returns the classification and the model used for it" do
SentimentInferenceStubs.stub_classification(target)
result = subject.request(target)
subject.available_models.each do |model|
expect(result[model]).to eq(SentimentInferenceStubs.model_response(model))
subject.available_classifiers.each do |model_config|
expect(result[model_config.model_name]).to eq(
subject.transform_result(SentimentInferenceStubs.model_response(model_config.model_name)),
)
end
end
end

View File

@ -6,21 +6,25 @@ require_relative "../support/sentiment_inference_stubs"
describe DiscourseAi::Classificator do
describe "#classify!" do
describe "saving the classification result" do
let(:model) { DiscourseAi::Sentiment::SentimentClassification.new }
let(:classification_raw_result) do
model
.available_models
.reduce({}) do |memo, model_name|
memo[model_name] = SentimentInferenceStubs.model_response(model_name)
.available_classifiers
.reduce({}) do |memo, model_config|
memo[model_config.model_name] = model.transform_result(
SentimentInferenceStubs.model_response(model_config.model_name),
)
memo
end
end
let(:model) { DiscourseAi::Sentiment::SentimentClassification.new }
let(:classification) { DiscourseAi::PostClassificator.new(model) }
fab!(:target) { Fabricate(:post) }
before do
SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com"
SiteSetting.ai_sentiment_model_configs =
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
SentimentInferenceStubs.stub_classification(target)
end
@ -28,28 +32,29 @@ describe DiscourseAi::Classificator do
classification.classify!(target)
stored_results = ClassificationResult.where(target: target)
expect(stored_results.length).to eq(model.available_models.length)
expect(stored_results.length).to eq(model.available_classifiers.length)
model.available_models.each do |model_name|
result = stored_results.detect { |c| c.model_used == model_name }
model.available_classifiers.each do |model_config|
result = stored_results.detect { |c| c.model_used == model_config.model_name }
expect(result.classification_type).to eq(model.type.to_s)
expect(result.created_at).to be_present
expect(result.updated_at).to be_present
expected_classification = SentimentInferenceStubs.model_response(model)
expected_classification = SentimentInferenceStubs.model_response(model_config.model_name)
transformed_classification = model.transform_result(expected_classification)
expect(result.classification.deep_symbolize_keys).to eq(expected_classification)
expect(result.classification).to eq(transformed_classification)
end
end
it "updates an existing classification result" do
original_creation = 3.days.ago
model.available_models.each do |model_name|
model.available_classifiers.each do |model_config|
ClassificationResult.create!(
target: target,
model_used: model_name,
model_used: model_config.model_name,
classification_type: model.type,
created_at: original_creation,
updated_at: original_creation,
@ -61,18 +66,16 @@ describe DiscourseAi::Classificator do
classification.classify!(target)
stored_results = ClassificationResult.where(target: target)
expect(stored_results.length).to eq(model.available_models.length)
expect(stored_results.length).to eq(model.available_classifiers.length)
model.available_models.each do |model_name|
result = stored_results.detect { |c| c.model_used == model_name }
model.available_classifiers.each do |model_config|
result = stored_results.detect { |c| c.model_used == model_config.model_name }
expect(result.classification_type).to eq(model.type.to_s)
expect(result.updated_at).to be > original_creation
expect(result.created_at).to eq_time(original_creation)
expect(result.classification.deep_symbolize_keys).to eq(
classification_raw_result[model_name],
)
expect(result.classification).to eq(classification_raw_result[model_config.model_name])
end
end
end

View File

@ -2,24 +2,69 @@
class SentimentInferenceStubs
class << self
def endpoint
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify"
end
def model_response(model)
{ negative: 72, neutral: 23, positive: 4 } if model == "sentiment"
{ sadness: 99, surprise: 0, neutral: 0, fear: 0, anger: 0, joy: 0, disgust: 0 }
case model
when "SamLowe/roberta-base-go_emotions"
[
{ score: 0.90261286, label: "anger" },
{ score: 0.04127813, label: "annoyance" },
{ score: 0.03183503, label: "neutral" },
{ score: 0.005037033, label: "disgust" },
{ score: 0.0031153716, label: "disapproval" },
{ score: 0.0019118421, label: "disappointment" },
{ score: 0.0015849728, label: "sadness" },
{ score: 0.0012343781, label: "curiosity" },
{ score: 0.0010682651, label: "amusement" },
{ score: 0.00100747, label: "confusion" },
{ score: 0.0010035422, label: "admiration" },
{ score: 0.0009957326, label: "approval" },
{ score: 0.0009726665, label: "surprise" },
{ score: 0.0007754773, label: "realization" },
{ score: 0.0006978541, label: "love" },
{ score: 0.00064793555, label: "fear" },
{ score: 0.0006454095, label: "optimism" },
{ score: 0.0005969062, label: "joy" },
{ score: 0.0005498958, label: "embarrassment" },
{ score: 0.00050068577, label: "excitement" },
{ score: 0.00047403979, label: "caring" },
{ score: 0.00038841428, label: "gratitude" },
{ score: 0.00034546282, label: "desire" },
{ score: 0.00023012784, label: "grief" },
{ score: 0.00018133638, label: "remorse" },
{ score: 0.00012511834, label: "nervousness" },
{ score: 0.00012079607, label: "pride" },
{ score: 0.000063159685, label: "relief" },
]
when "cardiffnlp/twitter-roberta-base-sentiment-latest"
[
{ score: 0.627579, label: "negative" },
{ score: 0.29482335, label: "neutral" },
{ score: 0.07759768, label: "positive" },
]
when "j-hartmann/emotion-english-distilroberta-base"
[
{ score: 0.7033674, label: "anger" },
{ score: 0.2701151, label: "disgust" },
{ score: 0.009492096, label: "sadness" },
{ score: 0.0080775, label: "neutral" },
{ score: 0.0049473303, label: "fear" },
{ score: 0.0023369535, label: "surprise" },
{ score: 0.001663634, label: "joy" },
]
end
end
def stub_classification(post)
content = post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw
DiscourseAi::Sentiment::SentimentClassification.new.available_models.each do |model|
DiscourseAi::Sentiment::SentimentClassification
.new
.available_classifiers
.each do |model_config|
WebMock
.stub_request(:post, endpoint)
.with(body: JSON.dump(model: model, content: content))
.to_return(status: 200, body: JSON.dump(model_response(model)))
.stub_request(:post, model_config.endpoint)
.with(body: JSON.dump(inputs: content, truncate: true))
.to_return(status: 200, body: JSON.dump(model_response(model_config.model_name)))
end
end
end

View File

@ -9,7 +9,10 @@ RSpec.describe "assets:precompile" do
end
describe "ai:sentiment:backfill" do
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
before do
SiteSetting.ai_sentiment_model_configs =
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
end
it "does nothing if the topic is soft-deleted" do
target = Fabricate(:post)