Migrate sentiment to a TEI backend (#886)
This commit is contained in:
parent
bffe9dfa07
commit
772ee934ab
|
@ -56,22 +56,9 @@ discourse_ai:
|
||||||
ai_sentiment_enabled:
|
ai_sentiment_enabled:
|
||||||
default: false
|
default: false
|
||||||
client: true
|
client: true
|
||||||
ai_sentiment_inference_service_api_endpoint:
|
ai_sentiment_model_configs:
|
||||||
default: "https://sentiment-testing.demo-by-discourse.com"
|
|
||||||
ai_sentiment_inference_service_api_endpoint_srv:
|
|
||||||
default: ""
|
default: ""
|
||||||
hidden: true
|
json_schema: DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema
|
||||||
ai_sentiment_inference_service_api_key:
|
|
||||||
default: ""
|
|
||||||
secret: true
|
|
||||||
ai_sentiment_models:
|
|
||||||
type: list
|
|
||||||
list_type: compact
|
|
||||||
default: "emotion|sentiment"
|
|
||||||
allow_any: false
|
|
||||||
choices:
|
|
||||||
- sentiment
|
|
||||||
- emotion
|
|
||||||
|
|
||||||
ai_nsfw_detection_enabled:
|
ai_nsfw_detection_enabled:
|
||||||
default: false
|
default: false
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
class MigrateSentimentClassificationResultFormat < ActiveRecord::Migration[7.1]
|
||||||
|
def up
|
||||||
|
DB.exec(<<~SQL)
|
||||||
|
UPDATE classification_results
|
||||||
|
SET
|
||||||
|
model_used = 'cardiffnlp/twitter-roberta-base-sentiment-latest',
|
||||||
|
classification = jsonb_build_object(
|
||||||
|
'neutral', (classification->>'neutral')::float / 100,
|
||||||
|
'negative', (classification->>'negative')::float / 100,
|
||||||
|
'positive', (classification->>'positive')::float / 100
|
||||||
|
)
|
||||||
|
WHERE model_used = 'sentiment';
|
||||||
|
|
||||||
|
UPDATE classification_results
|
||||||
|
SET
|
||||||
|
model_used = 'j-hartmann/emotion-english-distilroberta-base',
|
||||||
|
classification = jsonb_build_object(
|
||||||
|
'sadness', (classification->>'sadness')::float / 100,
|
||||||
|
'surprise', (classification->>'surprise')::float / 100,
|
||||||
|
'fear', (classification->>'fear')::float / 100,
|
||||||
|
'anger', (classification->>'anger')::float / 100,
|
||||||
|
'joy', (classification->>'joy')::float / 100,
|
||||||
|
'disgust', (classification->>'disgust')::float / 100,
|
||||||
|
'neutral', (classification->>'neutral')::float / 100
|
||||||
|
)
|
||||||
|
WHERE model_used = 'emotion';
|
||||||
|
SQL
|
||||||
|
end
|
||||||
|
end
|
|
@ -58,6 +58,29 @@ module ::DiscourseAi
|
||||||
JSON.parse(response.body, symbolize_names: true)
|
JSON.parse(response.body, symbolize_names: true)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def classify(content, model_config)
|
||||||
|
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||||
|
headers["X-API-KEY"] = model_config.api_key
|
||||||
|
headers["Authorization"] = "Bearer #{model_config.api_key}"
|
||||||
|
|
||||||
|
body = { inputs: content, truncate: true }.to_json
|
||||||
|
|
||||||
|
api_endpoint = model_config.endpoint
|
||||||
|
if api_endpoint.present? && api_endpoint.start_with?("srv://")
|
||||||
|
service = DiscourseAi::Utils::DnsSrv.lookup(api_endpoint.delete_prefix("srv://"))
|
||||||
|
api_endpoint = "https://#{service.target}:#{service.port}"
|
||||||
|
end
|
||||||
|
|
||||||
|
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||||
|
response = conn.post(api_endpoint, body, headers)
|
||||||
|
|
||||||
|
if response.status != 200
|
||||||
|
raise Net::HTTPBadResponse.new("Status: #{response.status}\n\n#{response.body}")
|
||||||
|
end
|
||||||
|
|
||||||
|
JSON.parse(response.body, symbolize_names: true)
|
||||||
|
end
|
||||||
|
|
||||||
def reranker_configured?
|
def reranker_configured?
|
||||||
SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? ||
|
SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? ||
|
||||||
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
|
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
|
||||||
|
|
|
@ -16,11 +16,11 @@ module DiscourseAi
|
||||||
|
|
||||||
plugin.add_report("overall_sentiment") do |report|
|
plugin.add_report("overall_sentiment") do |report|
|
||||||
report.modes = [:stacked_chart]
|
report.modes = [:stacked_chart]
|
||||||
threshold = 60
|
threshold = 0.6
|
||||||
|
|
||||||
sentiment_count_sql = Proc.new { |sentiment| <<~SQL }
|
sentiment_count_sql = Proc.new { |sentiment| <<~SQL }
|
||||||
COUNT(
|
COUNT(
|
||||||
CASE WHEN (cr.classification::jsonb->'#{sentiment}')::integer > :threshold THEN 1 ELSE NULL END
|
CASE WHEN (cr.classification::jsonb->'#{sentiment}')::float > :threshold THEN 1 ELSE NULL END
|
||||||
) AS #{sentiment}_count
|
) AS #{sentiment}_count
|
||||||
SQL
|
SQL
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ module DiscourseAi
|
||||||
WHERE
|
WHERE
|
||||||
t.archetype = 'regular' AND
|
t.archetype = 'regular' AND
|
||||||
p.user_id > 0 AND
|
p.user_id > 0 AND
|
||||||
cr.model_used = 'sentiment' AND
|
cr.model_used = 'cardiffnlp/twitter-roberta-base-sentiment-latest' AND
|
||||||
(p.created_at > :report_start AND p.created_at < :report_end)
|
(p.created_at > :report_start AND p.created_at < :report_end)
|
||||||
GROUP BY DATE_TRUNC('day', p.created_at)
|
GROUP BY DATE_TRUNC('day', p.created_at)
|
||||||
SQL
|
SQL
|
||||||
|
@ -68,11 +68,11 @@ module DiscourseAi
|
||||||
|
|
||||||
plugin.add_report("post_emotion") do |report|
|
plugin.add_report("post_emotion") do |report|
|
||||||
report.modes = [:stacked_line_chart]
|
report.modes = [:stacked_line_chart]
|
||||||
threshold = 30
|
threshold = 0.3
|
||||||
|
|
||||||
emotion_count_clause = Proc.new { |emotion| <<~SQL }
|
emotion_count_clause = Proc.new { |emotion| <<~SQL }
|
||||||
COUNT(
|
COUNT(
|
||||||
CASE WHEN (cr.classification::jsonb->'#{emotion}')::integer > :threshold THEN 1 ELSE NULL END
|
CASE WHEN (cr.classification::jsonb->'#{emotion}')::float > :threshold THEN 1 ELSE NULL END
|
||||||
) AS #{emotion}_count
|
) AS #{emotion}_count
|
||||||
SQL
|
SQL
|
||||||
|
|
||||||
|
@ -96,7 +96,7 @@ module DiscourseAi
|
||||||
WHERE
|
WHERE
|
||||||
t.archetype = 'regular' AND
|
t.archetype = 'regular' AND
|
||||||
p.user_id > 0 AND
|
p.user_id > 0 AND
|
||||||
cr.model_used = 'emotion' AND
|
cr.model_used = 'j-hartmann/emotion-english-distilroberta-base' AND
|
||||||
(p.created_at > :report_start AND p.created_at < :report_end)
|
(p.created_at > :report_start AND p.created_at < :report_end)
|
||||||
GROUP BY DATE_TRUNC('day', p.created_at)
|
GROUP BY DATE_TRUNC('day', p.created_at)
|
||||||
SQL
|
SQL
|
||||||
|
|
|
@ -7,8 +7,8 @@ module DiscourseAi
|
||||||
:sentiment
|
:sentiment
|
||||||
end
|
end
|
||||||
|
|
||||||
def available_models
|
def available_classifiers
|
||||||
SiteSetting.ai_sentiment_models.split("|")
|
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values
|
||||||
end
|
end
|
||||||
|
|
||||||
def can_classify?(target)
|
def can_classify?(target)
|
||||||
|
@ -16,8 +16,8 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def get_verdicts(_)
|
def get_verdicts(_)
|
||||||
available_models.reduce({}) do |memo, model|
|
available_classifiers.reduce({}) do |memo, model|
|
||||||
memo[model] = false
|
memo[model.model_name] = false
|
||||||
memo
|
memo
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -30,21 +30,23 @@ module DiscourseAi
|
||||||
def request(target_to_classify)
|
def request(target_to_classify)
|
||||||
target_content = content_of(target_to_classify)
|
target_content = content_of(target_to_classify)
|
||||||
|
|
||||||
available_models.reduce({}) do |memo, model|
|
available_classifiers.reduce({}) do |memo, model|
|
||||||
memo[model] = request_with(model, target_content)
|
memo[model.model_name] = request_with(target_content, model)
|
||||||
memo
|
memo
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def transform_result(result)
|
||||||
|
hash_result = {}
|
||||||
|
result.each { |r| hash_result[r[:label]] = r[:score] }
|
||||||
|
hash_result
|
||||||
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
def request_with(model, content)
|
def request_with(content, model_config)
|
||||||
::DiscourseAi::Inference::DiscourseClassifier.perform!(
|
result = ::DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, model_config)
|
||||||
"#{endpoint}/api/v1/classify",
|
transform_result(result)
|
||||||
model,
|
|
||||||
content,
|
|
||||||
SiteSetting.ai_sentiment_inference_service_api_key,
|
|
||||||
)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def content_of(target_to_classify)
|
def content_of(target_to_classify)
|
||||||
|
@ -57,18 +59,6 @@ module DiscourseAi
|
||||||
|
|
||||||
Tokenizer::BertTokenizer.truncate(content, 512)
|
Tokenizer::BertTokenizer.truncate(content, 512)
|
||||||
end
|
end
|
||||||
|
|
||||||
def endpoint
|
|
||||||
if SiteSetting.ai_sentiment_inference_service_api_endpoint_srv.present?
|
|
||||||
service =
|
|
||||||
DiscourseAi::Utils::DnsSrv.lookup(
|
|
||||||
SiteSetting.ai_sentiment_inference_service_api_endpoint_srv,
|
|
||||||
)
|
|
||||||
"https://#{service.target}:#{service.port}"
|
|
||||||
else
|
|
||||||
SiteSetting.ai_sentiment_inference_service_api_endpoint
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Sentiment
|
||||||
|
class SentimentSiteSettingJsonSchema
|
||||||
|
def self.schema
|
||||||
|
@schema ||= {
|
||||||
|
type: "array",
|
||||||
|
items: {
|
||||||
|
type: "object",
|
||||||
|
format: "table",
|
||||||
|
title: "model",
|
||||||
|
properties: {
|
||||||
|
model_name: {
|
||||||
|
type: "string",
|
||||||
|
},
|
||||||
|
endpoint: {
|
||||||
|
type: "string",
|
||||||
|
},
|
||||||
|
api_key: {
|
||||||
|
type: "string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required: %w[model_name endpoint api_key],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.values
|
||||||
|
JSON.parse(SiteSetting.ai_sentiment_model_configs, object_class: OpenStruct)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,52 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require "rails_helper"
|
||||||
|
require Rails.root.join(
|
||||||
|
"plugins/discourse-ai/db/post_migrate/20241031041242_migrate_sentiment_classification_result_format",
|
||||||
|
)
|
||||||
|
|
||||||
|
RSpec.describe MigrateSentimentClassificationResultFormat do
|
||||||
|
let(:connection) { ActiveRecord::Base.connection }
|
||||||
|
|
||||||
|
before { connection.execute(<<~SQL) }
|
||||||
|
INSERT INTO classification_results (model_used, classification, created_at, updated_at) VALUES
|
||||||
|
('sentiment', '{"neutral": 65, "negative": 20, "positive": 14}', NOW(), NOW()),
|
||||||
|
('emotion', '{"sadness": 10, "surprise": 15, "fear": 5, "anger": 20, "joy": 30, "disgust": 8, "neutral": 10}', NOW(), NOW());
|
||||||
|
SQL
|
||||||
|
|
||||||
|
after { connection.execute("DELETE FROM classification_results") }
|
||||||
|
|
||||||
|
describe "#up" do
|
||||||
|
before { described_class.new.up }
|
||||||
|
|
||||||
|
it "migrates sentiment classifications correctly" do
|
||||||
|
sentiment_result = connection.execute(<<~SQL).first
|
||||||
|
SELECT * FROM classification_results
|
||||||
|
WHERE model_used = 'cardiffnlp/twitter-roberta-base-sentiment-latest';
|
||||||
|
SQL
|
||||||
|
|
||||||
|
expected_sentiment = { "neutral" => 0.65, "negative" => 0.20, "positive" => 0.14 }
|
||||||
|
|
||||||
|
expect(JSON.parse(sentiment_result["classification"])).to eq(expected_sentiment)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "migrates emotion classifications correctly" do
|
||||||
|
emotion_result = connection.execute(<<~SQL).first
|
||||||
|
SELECT * FROM classification_results
|
||||||
|
WHERE model_used = 'j-hartmann/emotion-english-distilroberta-base';
|
||||||
|
SQL
|
||||||
|
|
||||||
|
expected_emotion = {
|
||||||
|
"sadness" => 0.10,
|
||||||
|
"surprise" => 0.15,
|
||||||
|
"fear" => 0.05,
|
||||||
|
"anger" => 0.20,
|
||||||
|
"joy" => 0.30,
|
||||||
|
"disgust" => 0.08,
|
||||||
|
"neutral" => 0.10,
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(JSON.parse(emotion_result["classification"])).to eq(expected_emotion)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -6,11 +6,13 @@ Fabricator(:classification_result) do
|
||||||
end
|
end
|
||||||
|
|
||||||
Fabricator(:sentiment_classification, from: :classification_result) do
|
Fabricator(:sentiment_classification, from: :classification_result) do
|
||||||
model_used "sentiment"
|
model_used "cardiffnlp/twitter-roberta-base-sentiment-latest"
|
||||||
classification { { negative: 72, neutral: 23, positive: 4 } }
|
classification { { negative: 0.72, neutral: 0.23, positive: 0.4 } }
|
||||||
end
|
end
|
||||||
|
|
||||||
Fabricator(:emotion_classification, from: :classification_result) do
|
Fabricator(:emotion_classification, from: :classification_result) do
|
||||||
model_used "emotion"
|
model_used "j-hartmann/emotion-english-distilroberta-base"
|
||||||
classification { { negative: 72, neutral: 23, positive: 4 } }
|
classification do
|
||||||
|
{ sadness: 0.72, surprise: 0.23, fear: 0.4, anger: 0.87, joy: 0.22, disgust: 0.70 }
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -53,7 +53,10 @@ RSpec.describe DiscourseAi::Sentiment::EntryPoint do
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "custom reports" do
|
describe "custom reports" do
|
||||||
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
|
before do
|
||||||
|
SiteSetting.ai_sentiment_model_configs =
|
||||||
|
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
|
||||||
|
end
|
||||||
|
|
||||||
fab!(:pm) { Fabricate(:private_message_post) }
|
fab!(:pm) { Fabricate(:private_message_post) }
|
||||||
|
|
||||||
|
@ -61,8 +64,8 @@ RSpec.describe DiscourseAi::Sentiment::EntryPoint do
|
||||||
fab!(:post_2) { Fabricate(:post) }
|
fab!(:post_2) { Fabricate(:post) }
|
||||||
|
|
||||||
describe "overall_sentiment report" do
|
describe "overall_sentiment report" do
|
||||||
let(:positive_classification) { { negative: 2, neutral: 30, positive: 70 } }
|
let(:positive_classification) { { negative: 0.2, neutral: 0.3, positive: 0.7 } }
|
||||||
let(:negative_classification) { { negative: 65, neutral: 2, positive: 10 } }
|
let(:negative_classification) { { negative: 0.65, neutral: 0.2, positive: 0.1 } }
|
||||||
|
|
||||||
def sentiment_classification(post, classification)
|
def sentiment_classification(post, classification)
|
||||||
Fabricate(:sentiment_classification, target: post, classification: classification)
|
Fabricate(:sentiment_classification, target: post, classification: classification)
|
||||||
|
@ -84,12 +87,28 @@ RSpec.describe DiscourseAi::Sentiment::EntryPoint do
|
||||||
|
|
||||||
describe "post_emotion report" do
|
describe "post_emotion report" do
|
||||||
let(:emotion_1) do
|
let(:emotion_1) do
|
||||||
{ sadness: 49, surprise: 23, neutral: 6, fear: 34, anger: 87, joy: 22, disgust: 70 }
|
{
|
||||||
|
sadness: 0.49,
|
||||||
|
surprise: 0.23,
|
||||||
|
neutral: 0.6,
|
||||||
|
fear: 0.34,
|
||||||
|
anger: 0.87,
|
||||||
|
joy: 0.22,
|
||||||
|
disgust: 0.70,
|
||||||
|
}
|
||||||
end
|
end
|
||||||
let(:emotion_2) do
|
let(:emotion_2) do
|
||||||
{ sadness: 19, surprise: 63, neutral: 45, fear: 44, anger: 27, joy: 62, disgust: 30 }
|
{
|
||||||
|
sadness: 0.19,
|
||||||
|
surprise: 0.63,
|
||||||
|
neutral: 0.45,
|
||||||
|
fear: 0.44,
|
||||||
|
anger: 0.27,
|
||||||
|
joy: 0.62,
|
||||||
|
disgust: 0.30,
|
||||||
|
}
|
||||||
end
|
end
|
||||||
let(:model_used) { "emotion" }
|
let(:model_used) { "j-hartmann/emotion-english-distilroberta-base" }
|
||||||
|
|
||||||
def emotion_classification(post, classification)
|
def emotion_classification(post, classification)
|
||||||
Fabricate(
|
Fabricate(
|
||||||
|
@ -106,7 +125,7 @@ RSpec.describe DiscourseAi::Sentiment::EntryPoint do
|
||||||
end
|
end
|
||||||
|
|
||||||
it "calculate averages using only public posts" do
|
it "calculate averages using only public posts" do
|
||||||
threshold = 30
|
threshold = 0.30
|
||||||
|
|
||||||
emotion_classification(post_1, emotion_1)
|
emotion_classification(post_1, emotion_1)
|
||||||
emotion_classification(post_2, emotion_2)
|
emotion_classification(post_2, emotion_2)
|
||||||
|
|
|
@ -8,7 +8,8 @@ describe Jobs::PostSentimentAnalysis do
|
||||||
|
|
||||||
before do
|
before do
|
||||||
SiteSetting.ai_sentiment_enabled = true
|
SiteSetting.ai_sentiment_enabled = true
|
||||||
SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com"
|
SiteSetting.ai_sentiment_model_configs =
|
||||||
|
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "scenarios where we return early without doing anything" do
|
describe "scenarios where we return early without doing anything" do
|
||||||
|
@ -42,7 +43,8 @@ describe Jobs::PostSentimentAnalysis do
|
||||||
end
|
end
|
||||||
|
|
||||||
it "successfully classifies the post" do
|
it "successfully classifies the post" do
|
||||||
expected_analysis = SiteSetting.ai_sentiment_models.split("|").length
|
expected_analysis =
|
||||||
|
DiscourseAi::Sentiment::SentimentClassification.new.available_classifiers.length
|
||||||
SentimentInferenceStubs.stub_classification(post)
|
SentimentInferenceStubs.stub_classification(post)
|
||||||
|
|
||||||
subject.execute({ post_id: post.id })
|
subject.execute({ post_id: post.id })
|
||||||
|
|
|
@ -6,15 +6,20 @@ describe DiscourseAi::Sentiment::SentimentClassification do
|
||||||
fab!(:target) { Fabricate(:post) }
|
fab!(:target) { Fabricate(:post) }
|
||||||
|
|
||||||
describe "#request" do
|
describe "#request" do
|
||||||
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
|
before do
|
||||||
|
SiteSetting.ai_sentiment_model_configs =
|
||||||
|
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
|
||||||
|
end
|
||||||
|
|
||||||
it "returns the classification and the model used for it" do
|
it "returns the classification and the model used for it" do
|
||||||
SentimentInferenceStubs.stub_classification(target)
|
SentimentInferenceStubs.stub_classification(target)
|
||||||
|
|
||||||
result = subject.request(target)
|
result = subject.request(target)
|
||||||
|
|
||||||
subject.available_models.each do |model|
|
subject.available_classifiers.each do |model_config|
|
||||||
expect(result[model]).to eq(SentimentInferenceStubs.model_response(model))
|
expect(result[model_config.model_name]).to eq(
|
||||||
|
subject.transform_result(SentimentInferenceStubs.model_response(model_config.model_name)),
|
||||||
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -6,21 +6,25 @@ require_relative "../support/sentiment_inference_stubs"
|
||||||
describe DiscourseAi::Classificator do
|
describe DiscourseAi::Classificator do
|
||||||
describe "#classify!" do
|
describe "#classify!" do
|
||||||
describe "saving the classification result" do
|
describe "saving the classification result" do
|
||||||
|
let(:model) { DiscourseAi::Sentiment::SentimentClassification.new }
|
||||||
|
|
||||||
let(:classification_raw_result) do
|
let(:classification_raw_result) do
|
||||||
model
|
model
|
||||||
.available_models
|
.available_classifiers
|
||||||
.reduce({}) do |memo, model_name|
|
.reduce({}) do |memo, model_config|
|
||||||
memo[model_name] = SentimentInferenceStubs.model_response(model_name)
|
memo[model_config.model_name] = model.transform_result(
|
||||||
|
SentimentInferenceStubs.model_response(model_config.model_name),
|
||||||
|
)
|
||||||
memo
|
memo
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:model) { DiscourseAi::Sentiment::SentimentClassification.new }
|
|
||||||
let(:classification) { DiscourseAi::PostClassificator.new(model) }
|
let(:classification) { DiscourseAi::PostClassificator.new(model) }
|
||||||
fab!(:target) { Fabricate(:post) }
|
fab!(:target) { Fabricate(:post) }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com"
|
SiteSetting.ai_sentiment_model_configs =
|
||||||
|
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
|
||||||
SentimentInferenceStubs.stub_classification(target)
|
SentimentInferenceStubs.stub_classification(target)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -28,28 +32,29 @@ describe DiscourseAi::Classificator do
|
||||||
classification.classify!(target)
|
classification.classify!(target)
|
||||||
|
|
||||||
stored_results = ClassificationResult.where(target: target)
|
stored_results = ClassificationResult.where(target: target)
|
||||||
expect(stored_results.length).to eq(model.available_models.length)
|
expect(stored_results.length).to eq(model.available_classifiers.length)
|
||||||
|
|
||||||
model.available_models.each do |model_name|
|
model.available_classifiers.each do |model_config|
|
||||||
result = stored_results.detect { |c| c.model_used == model_name }
|
result = stored_results.detect { |c| c.model_used == model_config.model_name }
|
||||||
|
|
||||||
expect(result.classification_type).to eq(model.type.to_s)
|
expect(result.classification_type).to eq(model.type.to_s)
|
||||||
expect(result.created_at).to be_present
|
expect(result.created_at).to be_present
|
||||||
expect(result.updated_at).to be_present
|
expect(result.updated_at).to be_present
|
||||||
|
|
||||||
expected_classification = SentimentInferenceStubs.model_response(model)
|
expected_classification = SentimentInferenceStubs.model_response(model_config.model_name)
|
||||||
|
transformed_classification = model.transform_result(expected_classification)
|
||||||
|
|
||||||
expect(result.classification.deep_symbolize_keys).to eq(expected_classification)
|
expect(result.classification).to eq(transformed_classification)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
it "updates an existing classification result" do
|
it "updates an existing classification result" do
|
||||||
original_creation = 3.days.ago
|
original_creation = 3.days.ago
|
||||||
|
|
||||||
model.available_models.each do |model_name|
|
model.available_classifiers.each do |model_config|
|
||||||
ClassificationResult.create!(
|
ClassificationResult.create!(
|
||||||
target: target,
|
target: target,
|
||||||
model_used: model_name,
|
model_used: model_config.model_name,
|
||||||
classification_type: model.type,
|
classification_type: model.type,
|
||||||
created_at: original_creation,
|
created_at: original_creation,
|
||||||
updated_at: original_creation,
|
updated_at: original_creation,
|
||||||
|
@ -61,18 +66,16 @@ describe DiscourseAi::Classificator do
|
||||||
classification.classify!(target)
|
classification.classify!(target)
|
||||||
|
|
||||||
stored_results = ClassificationResult.where(target: target)
|
stored_results = ClassificationResult.where(target: target)
|
||||||
expect(stored_results.length).to eq(model.available_models.length)
|
expect(stored_results.length).to eq(model.available_classifiers.length)
|
||||||
|
|
||||||
model.available_models.each do |model_name|
|
model.available_classifiers.each do |model_config|
|
||||||
result = stored_results.detect { |c| c.model_used == model_name }
|
result = stored_results.detect { |c| c.model_used == model_config.model_name }
|
||||||
|
|
||||||
expect(result.classification_type).to eq(model.type.to_s)
|
expect(result.classification_type).to eq(model.type.to_s)
|
||||||
expect(result.updated_at).to be > original_creation
|
expect(result.updated_at).to be > original_creation
|
||||||
expect(result.created_at).to eq_time(original_creation)
|
expect(result.created_at).to eq_time(original_creation)
|
||||||
|
|
||||||
expect(result.classification.deep_symbolize_keys).to eq(
|
expect(result.classification).to eq(classification_raw_result[model_config.model_name])
|
||||||
classification_raw_result[model_name],
|
|
||||||
)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -2,24 +2,69 @@
|
||||||
|
|
||||||
class SentimentInferenceStubs
|
class SentimentInferenceStubs
|
||||||
class << self
|
class << self
|
||||||
def endpoint
|
|
||||||
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify"
|
|
||||||
end
|
|
||||||
|
|
||||||
def model_response(model)
|
def model_response(model)
|
||||||
{ negative: 72, neutral: 23, positive: 4 } if model == "sentiment"
|
case model
|
||||||
|
when "SamLowe/roberta-base-go_emotions"
|
||||||
{ sadness: 99, surprise: 0, neutral: 0, fear: 0, anger: 0, joy: 0, disgust: 0 }
|
[
|
||||||
|
{ score: 0.90261286, label: "anger" },
|
||||||
|
{ score: 0.04127813, label: "annoyance" },
|
||||||
|
{ score: 0.03183503, label: "neutral" },
|
||||||
|
{ score: 0.005037033, label: "disgust" },
|
||||||
|
{ score: 0.0031153716, label: "disapproval" },
|
||||||
|
{ score: 0.0019118421, label: "disappointment" },
|
||||||
|
{ score: 0.0015849728, label: "sadness" },
|
||||||
|
{ score: 0.0012343781, label: "curiosity" },
|
||||||
|
{ score: 0.0010682651, label: "amusement" },
|
||||||
|
{ score: 0.00100747, label: "confusion" },
|
||||||
|
{ score: 0.0010035422, label: "admiration" },
|
||||||
|
{ score: 0.0009957326, label: "approval" },
|
||||||
|
{ score: 0.0009726665, label: "surprise" },
|
||||||
|
{ score: 0.0007754773, label: "realization" },
|
||||||
|
{ score: 0.0006978541, label: "love" },
|
||||||
|
{ score: 0.00064793555, label: "fear" },
|
||||||
|
{ score: 0.0006454095, label: "optimism" },
|
||||||
|
{ score: 0.0005969062, label: "joy" },
|
||||||
|
{ score: 0.0005498958, label: "embarrassment" },
|
||||||
|
{ score: 0.00050068577, label: "excitement" },
|
||||||
|
{ score: 0.00047403979, label: "caring" },
|
||||||
|
{ score: 0.00038841428, label: "gratitude" },
|
||||||
|
{ score: 0.00034546282, label: "desire" },
|
||||||
|
{ score: 0.00023012784, label: "grief" },
|
||||||
|
{ score: 0.00018133638, label: "remorse" },
|
||||||
|
{ score: 0.00012511834, label: "nervousness" },
|
||||||
|
{ score: 0.00012079607, label: "pride" },
|
||||||
|
{ score: 0.000063159685, label: "relief" },
|
||||||
|
]
|
||||||
|
when "cardiffnlp/twitter-roberta-base-sentiment-latest"
|
||||||
|
[
|
||||||
|
{ score: 0.627579, label: "negative" },
|
||||||
|
{ score: 0.29482335, label: "neutral" },
|
||||||
|
{ score: 0.07759768, label: "positive" },
|
||||||
|
]
|
||||||
|
when "j-hartmann/emotion-english-distilroberta-base"
|
||||||
|
[
|
||||||
|
{ score: 0.7033674, label: "anger" },
|
||||||
|
{ score: 0.2701151, label: "disgust" },
|
||||||
|
{ score: 0.009492096, label: "sadness" },
|
||||||
|
{ score: 0.0080775, label: "neutral" },
|
||||||
|
{ score: 0.0049473303, label: "fear" },
|
||||||
|
{ score: 0.0023369535, label: "surprise" },
|
||||||
|
{ score: 0.001663634, label: "joy" },
|
||||||
|
]
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def stub_classification(post)
|
def stub_classification(post)
|
||||||
content = post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw
|
content = post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw
|
||||||
|
|
||||||
DiscourseAi::Sentiment::SentimentClassification.new.available_models.each do |model|
|
DiscourseAi::Sentiment::SentimentClassification
|
||||||
|
.new
|
||||||
|
.available_classifiers
|
||||||
|
.each do |model_config|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, endpoint)
|
.stub_request(:post, model_config.endpoint)
|
||||||
.with(body: JSON.dump(model: model, content: content))
|
.with(body: JSON.dump(inputs: content, truncate: true))
|
||||||
.to_return(status: 200, body: JSON.dump(model_response(model)))
|
.to_return(status: 200, body: JSON.dump(model_response(model_config.model_name)))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -9,7 +9,10 @@ RSpec.describe "assets:precompile" do
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "ai:sentiment:backfill" do
|
describe "ai:sentiment:backfill" do
|
||||||
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
|
before do
|
||||||
|
SiteSetting.ai_sentiment_model_configs =
|
||||||
|
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
|
||||||
|
end
|
||||||
|
|
||||||
it "does nothing if the topic is soft-deleted" do
|
it "does nothing if the topic is soft-deleted" do
|
||||||
target = Fabricate(:post)
|
target = Fabricate(:post)
|
||||||
|
|
Loading…
Reference in New Issue