Migrate sentiment to a TEI backend (#886)
This commit is contained in:
parent
bffe9dfa07
commit
772ee934ab
|
@ -56,22 +56,9 @@ discourse_ai:
|
|||
ai_sentiment_enabled:
|
||||
default: false
|
||||
client: true
|
||||
ai_sentiment_inference_service_api_endpoint:
|
||||
default: "https://sentiment-testing.demo-by-discourse.com"
|
||||
ai_sentiment_inference_service_api_endpoint_srv:
|
||||
ai_sentiment_model_configs:
|
||||
default: ""
|
||||
hidden: true
|
||||
ai_sentiment_inference_service_api_key:
|
||||
default: ""
|
||||
secret: true
|
||||
ai_sentiment_models:
|
||||
type: list
|
||||
list_type: compact
|
||||
default: "emotion|sentiment"
|
||||
allow_any: false
|
||||
choices:
|
||||
- sentiment
|
||||
- emotion
|
||||
json_schema: DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema
|
||||
|
||||
ai_nsfw_detection_enabled:
|
||||
default: false
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# frozen_string_literal: true
|
||||
class MigrateSentimentClassificationResultFormat < ActiveRecord::Migration[7.1]
|
||||
def up
|
||||
DB.exec(<<~SQL)
|
||||
UPDATE classification_results
|
||||
SET
|
||||
model_used = 'cardiffnlp/twitter-roberta-base-sentiment-latest',
|
||||
classification = jsonb_build_object(
|
||||
'neutral', (classification->>'neutral')::float / 100,
|
||||
'negative', (classification->>'negative')::float / 100,
|
||||
'positive', (classification->>'positive')::float / 100
|
||||
)
|
||||
WHERE model_used = 'sentiment';
|
||||
|
||||
UPDATE classification_results
|
||||
SET
|
||||
model_used = 'j-hartmann/emotion-english-distilroberta-base',
|
||||
classification = jsonb_build_object(
|
||||
'sadness', (classification->>'sadness')::float / 100,
|
||||
'surprise', (classification->>'surprise')::float / 100,
|
||||
'fear', (classification->>'fear')::float / 100,
|
||||
'anger', (classification->>'anger')::float / 100,
|
||||
'joy', (classification->>'joy')::float / 100,
|
||||
'disgust', (classification->>'disgust')::float / 100,
|
||||
'neutral', (classification->>'neutral')::float / 100
|
||||
)
|
||||
WHERE model_used = 'emotion';
|
||||
SQL
|
||||
end
|
||||
end
|
|
@ -58,6 +58,29 @@ module ::DiscourseAi
|
|||
JSON.parse(response.body, symbolize_names: true)
|
||||
end
|
||||
|
||||
def classify(content, model_config)
|
||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||
headers["X-API-KEY"] = model_config.api_key
|
||||
headers["Authorization"] = "Bearer #{model_config.api_key}"
|
||||
|
||||
body = { inputs: content, truncate: true }.to_json
|
||||
|
||||
api_endpoint = model_config.endpoint
|
||||
if api_endpoint.present? && api_endpoint.start_with?("srv://")
|
||||
service = DiscourseAi::Utils::DnsSrv.lookup(api_endpoint.delete_prefix("srv://"))
|
||||
api_endpoint = "https://#{service.target}:#{service.port}"
|
||||
end
|
||||
|
||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||
response = conn.post(api_endpoint, body, headers)
|
||||
|
||||
if response.status != 200
|
||||
raise Net::HTTPBadResponse.new("Status: #{response.status}\n\n#{response.body}")
|
||||
end
|
||||
|
||||
JSON.parse(response.body, symbolize_names: true)
|
||||
end
|
||||
|
||||
def reranker_configured?
|
||||
SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? ||
|
||||
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
|
||||
|
|
|
@ -16,11 +16,11 @@ module DiscourseAi
|
|||
|
||||
plugin.add_report("overall_sentiment") do |report|
|
||||
report.modes = [:stacked_chart]
|
||||
threshold = 60
|
||||
threshold = 0.6
|
||||
|
||||
sentiment_count_sql = Proc.new { |sentiment| <<~SQL }
|
||||
COUNT(
|
||||
CASE WHEN (cr.classification::jsonb->'#{sentiment}')::integer > :threshold THEN 1 ELSE NULL END
|
||||
CASE WHEN (cr.classification::jsonb->'#{sentiment}')::float > :threshold THEN 1 ELSE NULL END
|
||||
) AS #{sentiment}_count
|
||||
SQL
|
||||
|
||||
|
@ -39,7 +39,7 @@ module DiscourseAi
|
|||
WHERE
|
||||
t.archetype = 'regular' AND
|
||||
p.user_id > 0 AND
|
||||
cr.model_used = 'sentiment' AND
|
||||
cr.model_used = 'cardiffnlp/twitter-roberta-base-sentiment-latest' AND
|
||||
(p.created_at > :report_start AND p.created_at < :report_end)
|
||||
GROUP BY DATE_TRUNC('day', p.created_at)
|
||||
SQL
|
||||
|
@ -68,11 +68,11 @@ module DiscourseAi
|
|||
|
||||
plugin.add_report("post_emotion") do |report|
|
||||
report.modes = [:stacked_line_chart]
|
||||
threshold = 30
|
||||
threshold = 0.3
|
||||
|
||||
emotion_count_clause = Proc.new { |emotion| <<~SQL }
|
||||
COUNT(
|
||||
CASE WHEN (cr.classification::jsonb->'#{emotion}')::integer > :threshold THEN 1 ELSE NULL END
|
||||
CASE WHEN (cr.classification::jsonb->'#{emotion}')::float > :threshold THEN 1 ELSE NULL END
|
||||
) AS #{emotion}_count
|
||||
SQL
|
||||
|
||||
|
@ -96,7 +96,7 @@ module DiscourseAi
|
|||
WHERE
|
||||
t.archetype = 'regular' AND
|
||||
p.user_id > 0 AND
|
||||
cr.model_used = 'emotion' AND
|
||||
cr.model_used = 'j-hartmann/emotion-english-distilroberta-base' AND
|
||||
(p.created_at > :report_start AND p.created_at < :report_end)
|
||||
GROUP BY DATE_TRUNC('day', p.created_at)
|
||||
SQL
|
||||
|
|
|
@ -7,8 +7,8 @@ module DiscourseAi
|
|||
:sentiment
|
||||
end
|
||||
|
||||
def available_models
|
||||
SiteSetting.ai_sentiment_models.split("|")
|
||||
def available_classifiers
|
||||
DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values
|
||||
end
|
||||
|
||||
def can_classify?(target)
|
||||
|
@ -16,8 +16,8 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def get_verdicts(_)
|
||||
available_models.reduce({}) do |memo, model|
|
||||
memo[model] = false
|
||||
available_classifiers.reduce({}) do |memo, model|
|
||||
memo[model.model_name] = false
|
||||
memo
|
||||
end
|
||||
end
|
||||
|
@ -30,21 +30,23 @@ module DiscourseAi
|
|||
def request(target_to_classify)
|
||||
target_content = content_of(target_to_classify)
|
||||
|
||||
available_models.reduce({}) do |memo, model|
|
||||
memo[model] = request_with(model, target_content)
|
||||
available_classifiers.reduce({}) do |memo, model|
|
||||
memo[model.model_name] = request_with(target_content, model)
|
||||
memo
|
||||
end
|
||||
end
|
||||
|
||||
def transform_result(result)
|
||||
hash_result = {}
|
||||
result.each { |r| hash_result[r[:label]] = r[:score] }
|
||||
hash_result
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def request_with(model, content)
|
||||
::DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||
"#{endpoint}/api/v1/classify",
|
||||
model,
|
||||
content,
|
||||
SiteSetting.ai_sentiment_inference_service_api_key,
|
||||
)
|
||||
def request_with(content, model_config)
|
||||
result = ::DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, model_config)
|
||||
transform_result(result)
|
||||
end
|
||||
|
||||
def content_of(target_to_classify)
|
||||
|
@ -57,18 +59,6 @@ module DiscourseAi
|
|||
|
||||
Tokenizer::BertTokenizer.truncate(content, 512)
|
||||
end
|
||||
|
||||
def endpoint
|
||||
if SiteSetting.ai_sentiment_inference_service_api_endpoint_srv.present?
|
||||
service =
|
||||
DiscourseAi::Utils::DnsSrv.lookup(
|
||||
SiteSetting.ai_sentiment_inference_service_api_endpoint_srv,
|
||||
)
|
||||
"https://#{service.target}:#{service.port}"
|
||||
else
|
||||
SiteSetting.ai_sentiment_inference_service_api_endpoint
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Sentiment
|
||||
class SentimentSiteSettingJsonSchema
|
||||
def self.schema
|
||||
@schema ||= {
|
||||
type: "array",
|
||||
items: {
|
||||
type: "object",
|
||||
format: "table",
|
||||
title: "model",
|
||||
properties: {
|
||||
model_name: {
|
||||
type: "string",
|
||||
},
|
||||
endpoint: {
|
||||
type: "string",
|
||||
},
|
||||
api_key: {
|
||||
type: "string",
|
||||
},
|
||||
},
|
||||
required: %w[model_name endpoint api_key],
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
def self.values
|
||||
JSON.parse(SiteSetting.ai_sentiment_model_configs, object_class: OpenStruct)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,52 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "rails_helper"
|
||||
require Rails.root.join(
|
||||
"plugins/discourse-ai/db/post_migrate/20241031041242_migrate_sentiment_classification_result_format",
|
||||
)
|
||||
|
||||
RSpec.describe MigrateSentimentClassificationResultFormat do
|
||||
let(:connection) { ActiveRecord::Base.connection }
|
||||
|
||||
before { connection.execute(<<~SQL) }
|
||||
INSERT INTO classification_results (model_used, classification, created_at, updated_at) VALUES
|
||||
('sentiment', '{"neutral": 65, "negative": 20, "positive": 14}', NOW(), NOW()),
|
||||
('emotion', '{"sadness": 10, "surprise": 15, "fear": 5, "anger": 20, "joy": 30, "disgust": 8, "neutral": 10}', NOW(), NOW());
|
||||
SQL
|
||||
|
||||
after { connection.execute("DELETE FROM classification_results") }
|
||||
|
||||
describe "#up" do
|
||||
before { described_class.new.up }
|
||||
|
||||
it "migrates sentiment classifications correctly" do
|
||||
sentiment_result = connection.execute(<<~SQL).first
|
||||
SELECT * FROM classification_results
|
||||
WHERE model_used = 'cardiffnlp/twitter-roberta-base-sentiment-latest';
|
||||
SQL
|
||||
|
||||
expected_sentiment = { "neutral" => 0.65, "negative" => 0.20, "positive" => 0.14 }
|
||||
|
||||
expect(JSON.parse(sentiment_result["classification"])).to eq(expected_sentiment)
|
||||
end
|
||||
|
||||
it "migrates emotion classifications correctly" do
|
||||
emotion_result = connection.execute(<<~SQL).first
|
||||
SELECT * FROM classification_results
|
||||
WHERE model_used = 'j-hartmann/emotion-english-distilroberta-base';
|
||||
SQL
|
||||
|
||||
expected_emotion = {
|
||||
"sadness" => 0.10,
|
||||
"surprise" => 0.15,
|
||||
"fear" => 0.05,
|
||||
"anger" => 0.20,
|
||||
"joy" => 0.30,
|
||||
"disgust" => 0.08,
|
||||
"neutral" => 0.10,
|
||||
}
|
||||
|
||||
expect(JSON.parse(emotion_result["classification"])).to eq(expected_emotion)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -6,11 +6,13 @@ Fabricator(:classification_result) do
|
|||
end
|
||||
|
||||
Fabricator(:sentiment_classification, from: :classification_result) do
|
||||
model_used "sentiment"
|
||||
classification { { negative: 72, neutral: 23, positive: 4 } }
|
||||
model_used "cardiffnlp/twitter-roberta-base-sentiment-latest"
|
||||
classification { { negative: 0.72, neutral: 0.23, positive: 0.4 } }
|
||||
end
|
||||
|
||||
Fabricator(:emotion_classification, from: :classification_result) do
|
||||
model_used "emotion"
|
||||
classification { { negative: 72, neutral: 23, positive: 4 } }
|
||||
model_used "j-hartmann/emotion-english-distilroberta-base"
|
||||
classification do
|
||||
{ sadness: 0.72, surprise: 0.23, fear: 0.4, anger: 0.87, joy: 0.22, disgust: 0.70 }
|
||||
end
|
||||
end
|
||||
|
|
|
@ -53,7 +53,10 @@ RSpec.describe DiscourseAi::Sentiment::EntryPoint do
|
|||
end
|
||||
|
||||
describe "custom reports" do
|
||||
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
|
||||
before do
|
||||
SiteSetting.ai_sentiment_model_configs =
|
||||
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
|
||||
end
|
||||
|
||||
fab!(:pm) { Fabricate(:private_message_post) }
|
||||
|
||||
|
@ -61,8 +64,8 @@ RSpec.describe DiscourseAi::Sentiment::EntryPoint do
|
|||
fab!(:post_2) { Fabricate(:post) }
|
||||
|
||||
describe "overall_sentiment report" do
|
||||
let(:positive_classification) { { negative: 2, neutral: 30, positive: 70 } }
|
||||
let(:negative_classification) { { negative: 65, neutral: 2, positive: 10 } }
|
||||
let(:positive_classification) { { negative: 0.2, neutral: 0.3, positive: 0.7 } }
|
||||
let(:negative_classification) { { negative: 0.65, neutral: 0.2, positive: 0.1 } }
|
||||
|
||||
def sentiment_classification(post, classification)
|
||||
Fabricate(:sentiment_classification, target: post, classification: classification)
|
||||
|
@ -84,12 +87,28 @@ RSpec.describe DiscourseAi::Sentiment::EntryPoint do
|
|||
|
||||
describe "post_emotion report" do
|
||||
let(:emotion_1) do
|
||||
{ sadness: 49, surprise: 23, neutral: 6, fear: 34, anger: 87, joy: 22, disgust: 70 }
|
||||
{
|
||||
sadness: 0.49,
|
||||
surprise: 0.23,
|
||||
neutral: 0.6,
|
||||
fear: 0.34,
|
||||
anger: 0.87,
|
||||
joy: 0.22,
|
||||
disgust: 0.70,
|
||||
}
|
||||
end
|
||||
let(:emotion_2) do
|
||||
{ sadness: 19, surprise: 63, neutral: 45, fear: 44, anger: 27, joy: 62, disgust: 30 }
|
||||
{
|
||||
sadness: 0.19,
|
||||
surprise: 0.63,
|
||||
neutral: 0.45,
|
||||
fear: 0.44,
|
||||
anger: 0.27,
|
||||
joy: 0.62,
|
||||
disgust: 0.30,
|
||||
}
|
||||
end
|
||||
let(:model_used) { "emotion" }
|
||||
let(:model_used) { "j-hartmann/emotion-english-distilroberta-base" }
|
||||
|
||||
def emotion_classification(post, classification)
|
||||
Fabricate(
|
||||
|
@ -106,7 +125,7 @@ RSpec.describe DiscourseAi::Sentiment::EntryPoint do
|
|||
end
|
||||
|
||||
it "calculate averages using only public posts" do
|
||||
threshold = 30
|
||||
threshold = 0.30
|
||||
|
||||
emotion_classification(post_1, emotion_1)
|
||||
emotion_classification(post_2, emotion_2)
|
||||
|
|
|
@ -8,7 +8,8 @@ describe Jobs::PostSentimentAnalysis do
|
|||
|
||||
before do
|
||||
SiteSetting.ai_sentiment_enabled = true
|
||||
SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com"
|
||||
SiteSetting.ai_sentiment_model_configs =
|
||||
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
|
||||
end
|
||||
|
||||
describe "scenarios where we return early without doing anything" do
|
||||
|
@ -42,7 +43,8 @@ describe Jobs::PostSentimentAnalysis do
|
|||
end
|
||||
|
||||
it "successfully classifies the post" do
|
||||
expected_analysis = SiteSetting.ai_sentiment_models.split("|").length
|
||||
expected_analysis =
|
||||
DiscourseAi::Sentiment::SentimentClassification.new.available_classifiers.length
|
||||
SentimentInferenceStubs.stub_classification(post)
|
||||
|
||||
subject.execute({ post_id: post.id })
|
||||
|
|
|
@ -6,15 +6,20 @@ describe DiscourseAi::Sentiment::SentimentClassification do
|
|||
fab!(:target) { Fabricate(:post) }
|
||||
|
||||
describe "#request" do
|
||||
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
|
||||
before do
|
||||
SiteSetting.ai_sentiment_model_configs =
|
||||
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
|
||||
end
|
||||
|
||||
it "returns the classification and the model used for it" do
|
||||
SentimentInferenceStubs.stub_classification(target)
|
||||
|
||||
result = subject.request(target)
|
||||
|
||||
subject.available_models.each do |model|
|
||||
expect(result[model]).to eq(SentimentInferenceStubs.model_response(model))
|
||||
subject.available_classifiers.each do |model_config|
|
||||
expect(result[model_config.model_name]).to eq(
|
||||
subject.transform_result(SentimentInferenceStubs.model_response(model_config.model_name)),
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -6,21 +6,25 @@ require_relative "../support/sentiment_inference_stubs"
|
|||
describe DiscourseAi::Classificator do
|
||||
describe "#classify!" do
|
||||
describe "saving the classification result" do
|
||||
let(:model) { DiscourseAi::Sentiment::SentimentClassification.new }
|
||||
|
||||
let(:classification_raw_result) do
|
||||
model
|
||||
.available_models
|
||||
.reduce({}) do |memo, model_name|
|
||||
memo[model_name] = SentimentInferenceStubs.model_response(model_name)
|
||||
.available_classifiers
|
||||
.reduce({}) do |memo, model_config|
|
||||
memo[model_config.model_name] = model.transform_result(
|
||||
SentimentInferenceStubs.model_response(model_config.model_name),
|
||||
)
|
||||
memo
|
||||
end
|
||||
end
|
||||
|
||||
let(:model) { DiscourseAi::Sentiment::SentimentClassification.new }
|
||||
let(:classification) { DiscourseAi::PostClassificator.new(model) }
|
||||
fab!(:target) { Fabricate(:post) }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com"
|
||||
SiteSetting.ai_sentiment_model_configs =
|
||||
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
|
||||
SentimentInferenceStubs.stub_classification(target)
|
||||
end
|
||||
|
||||
|
@ -28,28 +32,29 @@ describe DiscourseAi::Classificator do
|
|||
classification.classify!(target)
|
||||
|
||||
stored_results = ClassificationResult.where(target: target)
|
||||
expect(stored_results.length).to eq(model.available_models.length)
|
||||
expect(stored_results.length).to eq(model.available_classifiers.length)
|
||||
|
||||
model.available_models.each do |model_name|
|
||||
result = stored_results.detect { |c| c.model_used == model_name }
|
||||
model.available_classifiers.each do |model_config|
|
||||
result = stored_results.detect { |c| c.model_used == model_config.model_name }
|
||||
|
||||
expect(result.classification_type).to eq(model.type.to_s)
|
||||
expect(result.created_at).to be_present
|
||||
expect(result.updated_at).to be_present
|
||||
|
||||
expected_classification = SentimentInferenceStubs.model_response(model)
|
||||
expected_classification = SentimentInferenceStubs.model_response(model_config.model_name)
|
||||
transformed_classification = model.transform_result(expected_classification)
|
||||
|
||||
expect(result.classification.deep_symbolize_keys).to eq(expected_classification)
|
||||
expect(result.classification).to eq(transformed_classification)
|
||||
end
|
||||
end
|
||||
|
||||
it "updates an existing classification result" do
|
||||
original_creation = 3.days.ago
|
||||
|
||||
model.available_models.each do |model_name|
|
||||
model.available_classifiers.each do |model_config|
|
||||
ClassificationResult.create!(
|
||||
target: target,
|
||||
model_used: model_name,
|
||||
model_used: model_config.model_name,
|
||||
classification_type: model.type,
|
||||
created_at: original_creation,
|
||||
updated_at: original_creation,
|
||||
|
@ -61,18 +66,16 @@ describe DiscourseAi::Classificator do
|
|||
classification.classify!(target)
|
||||
|
||||
stored_results = ClassificationResult.where(target: target)
|
||||
expect(stored_results.length).to eq(model.available_models.length)
|
||||
expect(stored_results.length).to eq(model.available_classifiers.length)
|
||||
|
||||
model.available_models.each do |model_name|
|
||||
result = stored_results.detect { |c| c.model_used == model_name }
|
||||
model.available_classifiers.each do |model_config|
|
||||
result = stored_results.detect { |c| c.model_used == model_config.model_name }
|
||||
|
||||
expect(result.classification_type).to eq(model.type.to_s)
|
||||
expect(result.updated_at).to be > original_creation
|
||||
expect(result.created_at).to eq_time(original_creation)
|
||||
|
||||
expect(result.classification.deep_symbolize_keys).to eq(
|
||||
classification_raw_result[model_name],
|
||||
)
|
||||
expect(result.classification).to eq(classification_raw_result[model_config.model_name])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -2,24 +2,69 @@
|
|||
|
||||
class SentimentInferenceStubs
|
||||
class << self
|
||||
def endpoint
|
||||
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify"
|
||||
end
|
||||
|
||||
def model_response(model)
|
||||
{ negative: 72, neutral: 23, positive: 4 } if model == "sentiment"
|
||||
|
||||
{ sadness: 99, surprise: 0, neutral: 0, fear: 0, anger: 0, joy: 0, disgust: 0 }
|
||||
case model
|
||||
when "SamLowe/roberta-base-go_emotions"
|
||||
[
|
||||
{ score: 0.90261286, label: "anger" },
|
||||
{ score: 0.04127813, label: "annoyance" },
|
||||
{ score: 0.03183503, label: "neutral" },
|
||||
{ score: 0.005037033, label: "disgust" },
|
||||
{ score: 0.0031153716, label: "disapproval" },
|
||||
{ score: 0.0019118421, label: "disappointment" },
|
||||
{ score: 0.0015849728, label: "sadness" },
|
||||
{ score: 0.0012343781, label: "curiosity" },
|
||||
{ score: 0.0010682651, label: "amusement" },
|
||||
{ score: 0.00100747, label: "confusion" },
|
||||
{ score: 0.0010035422, label: "admiration" },
|
||||
{ score: 0.0009957326, label: "approval" },
|
||||
{ score: 0.0009726665, label: "surprise" },
|
||||
{ score: 0.0007754773, label: "realization" },
|
||||
{ score: 0.0006978541, label: "love" },
|
||||
{ score: 0.00064793555, label: "fear" },
|
||||
{ score: 0.0006454095, label: "optimism" },
|
||||
{ score: 0.0005969062, label: "joy" },
|
||||
{ score: 0.0005498958, label: "embarrassment" },
|
||||
{ score: 0.00050068577, label: "excitement" },
|
||||
{ score: 0.00047403979, label: "caring" },
|
||||
{ score: 0.00038841428, label: "gratitude" },
|
||||
{ score: 0.00034546282, label: "desire" },
|
||||
{ score: 0.00023012784, label: "grief" },
|
||||
{ score: 0.00018133638, label: "remorse" },
|
||||
{ score: 0.00012511834, label: "nervousness" },
|
||||
{ score: 0.00012079607, label: "pride" },
|
||||
{ score: 0.000063159685, label: "relief" },
|
||||
]
|
||||
when "cardiffnlp/twitter-roberta-base-sentiment-latest"
|
||||
[
|
||||
{ score: 0.627579, label: "negative" },
|
||||
{ score: 0.29482335, label: "neutral" },
|
||||
{ score: 0.07759768, label: "positive" },
|
||||
]
|
||||
when "j-hartmann/emotion-english-distilroberta-base"
|
||||
[
|
||||
{ score: 0.7033674, label: "anger" },
|
||||
{ score: 0.2701151, label: "disgust" },
|
||||
{ score: 0.009492096, label: "sadness" },
|
||||
{ score: 0.0080775, label: "neutral" },
|
||||
{ score: 0.0049473303, label: "fear" },
|
||||
{ score: 0.0023369535, label: "surprise" },
|
||||
{ score: 0.001663634, label: "joy" },
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def stub_classification(post)
|
||||
content = post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw
|
||||
|
||||
DiscourseAi::Sentiment::SentimentClassification.new.available_models.each do |model|
|
||||
DiscourseAi::Sentiment::SentimentClassification
|
||||
.new
|
||||
.available_classifiers
|
||||
.each do |model_config|
|
||||
WebMock
|
||||
.stub_request(:post, endpoint)
|
||||
.with(body: JSON.dump(model: model, content: content))
|
||||
.to_return(status: 200, body: JSON.dump(model_response(model)))
|
||||
.stub_request(:post, model_config.endpoint)
|
||||
.with(body: JSON.dump(inputs: content, truncate: true))
|
||||
.to_return(status: 200, body: JSON.dump(model_response(model_config.model_name)))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -9,7 +9,10 @@ RSpec.describe "assets:precompile" do
|
|||
end
|
||||
|
||||
describe "ai:sentiment:backfill" do
|
||||
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
|
||||
before do
|
||||
SiteSetting.ai_sentiment_model_configs =
|
||||
"[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]"
|
||||
end
|
||||
|
||||
it "does nothing if the topic is soft-deleted" do
|
||||
target = Fabricate(:post)
|
||||
|
|
Loading…
Reference in New Issue