diff --git a/config/settings.yml b/config/settings.yml index 01705470..9a1fa7ea 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -56,22 +56,9 @@ discourse_ai: ai_sentiment_enabled: default: false client: true - ai_sentiment_inference_service_api_endpoint: - default: "https://sentiment-testing.demo-by-discourse.com" - ai_sentiment_inference_service_api_endpoint_srv: + ai_sentiment_model_configs: default: "" - hidden: true - ai_sentiment_inference_service_api_key: - default: "" - secret: true - ai_sentiment_models: - type: list - list_type: compact - default: "emotion|sentiment" - allow_any: false - choices: - - sentiment - - emotion + json_schema: DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema ai_nsfw_detection_enabled: default: false diff --git a/db/post_migrate/20241031041242_migrate_sentiment_classification_result_format.rb b/db/post_migrate/20241031041242_migrate_sentiment_classification_result_format.rb new file mode 100644 index 00000000..59821da0 --- /dev/null +++ b/db/post_migrate/20241031041242_migrate_sentiment_classification_result_format.rb @@ -0,0 +1,30 @@ +# frozen_string_literal: true +class MigrateSentimentClassificationResultFormat < ActiveRecord::Migration[7.1] + def up + DB.exec(<<~SQL) + UPDATE classification_results + SET + model_used = 'cardiffnlp/twitter-roberta-base-sentiment-latest', + classification = jsonb_build_object( + 'neutral', (classification->>'neutral')::float / 100, + 'negative', (classification->>'negative')::float / 100, + 'positive', (classification->>'positive')::float / 100 + ) + WHERE model_used = 'sentiment'; + + UPDATE classification_results + SET + model_used = 'j-hartmann/emotion-english-distilroberta-base', + classification = jsonb_build_object( + 'sadness', (classification->>'sadness')::float / 100, + 'surprise', (classification->>'surprise')::float / 100, + 'fear', (classification->>'fear')::float / 100, + 'anger', (classification->>'anger')::float / 100, + 'joy', (classification->>'joy')::float / 100, + 'disgust', (classification->>'disgust')::float / 100, + 'neutral', (classification->>'neutral')::float / 100 + ) + WHERE model_used = 'emotion'; + SQL + end +end diff --git a/lib/inference/hugging_face_text_embeddings.rb b/lib/inference/hugging_face_text_embeddings.rb index e30ce57f..0e904a94 100644 --- a/lib/inference/hugging_face_text_embeddings.rb +++ b/lib/inference/hugging_face_text_embeddings.rb @@ -58,6 +58,29 @@ module ::DiscourseAi JSON.parse(response.body, symbolize_names: true) end + def classify(content, model_config) + headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } + headers["X-API-KEY"] = model_config.api_key + headers["Authorization"] = "Bearer #{model_config.api_key}" + + body = { inputs: content, truncate: true }.to_json + + api_endpoint = model_config.endpoint + if api_endpoint.present? && api_endpoint.start_with?("srv://") + service = DiscourseAi::Utils::DnsSrv.lookup(api_endpoint.delete_prefix("srv://")) + api_endpoint = "https://#{service.target}:#{service.port}" + end + + conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } + response = conn.post(api_endpoint, body, headers) + + if response.status != 200 + raise Net::HTTPBadResponse.new("Status: #{response.status}\n\n#{response.body}") + end + + JSON.parse(response.body, symbolize_names: true) + end + def reranker_configured? SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? || SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present? diff --git a/lib/sentiment/entry_point.rb b/lib/sentiment/entry_point.rb index e25eea4e..6cf688e9 100644 --- a/lib/sentiment/entry_point.rb +++ b/lib/sentiment/entry_point.rb @@ -16,11 +16,11 @@ module DiscourseAi plugin.add_report("overall_sentiment") do |report| report.modes = [:stacked_chart] - threshold = 60 + threshold = 0.6 sentiment_count_sql = Proc.new { |sentiment| <<~SQL } COUNT( - CASE WHEN (cr.classification::jsonb->'#{sentiment}')::integer > :threshold THEN 1 ELSE NULL END + CASE WHEN (cr.classification::jsonb->'#{sentiment}')::float > :threshold THEN 1 ELSE NULL END ) AS #{sentiment}_count SQL @@ -39,7 +39,7 @@ module DiscourseAi WHERE t.archetype = 'regular' AND p.user_id > 0 AND - cr.model_used = 'sentiment' AND + cr.model_used = 'cardiffnlp/twitter-roberta-base-sentiment-latest' AND (p.created_at > :report_start AND p.created_at < :report_end) GROUP BY DATE_TRUNC('day', p.created_at) SQL @@ -68,11 +68,11 @@ module DiscourseAi plugin.add_report("post_emotion") do |report| report.modes = [:stacked_line_chart] - threshold = 30 + threshold = 0.3 emotion_count_clause = Proc.new { |emotion| <<~SQL } COUNT( - CASE WHEN (cr.classification::jsonb->'#{emotion}')::integer > :threshold THEN 1 ELSE NULL END + CASE WHEN (cr.classification::jsonb->'#{emotion}')::float > :threshold THEN 1 ELSE NULL END ) AS #{emotion}_count SQL @@ -96,7 +96,7 @@ module DiscourseAi WHERE t.archetype = 'regular' AND p.user_id > 0 AND - cr.model_used = 'emotion' AND + cr.model_used = 'j-hartmann/emotion-english-distilroberta-base' AND (p.created_at > :report_start AND p.created_at < :report_end) GROUP BY DATE_TRUNC('day', p.created_at) SQL diff --git a/lib/sentiment/sentiment_classification.rb b/lib/sentiment/sentiment_classification.rb index 1d36e6a1..f73447ca 100644 --- a/lib/sentiment/sentiment_classification.rb +++ b/lib/sentiment/sentiment_classification.rb @@ -7,8 +7,8 @@ module DiscourseAi :sentiment end - def available_models - SiteSetting.ai_sentiment_models.split("|") + def available_classifiers + DiscourseAi::Sentiment::SentimentSiteSettingJsonSchema.values end def can_classify?(target) @@ -16,8 +16,8 @@ module DiscourseAi end def get_verdicts(_) - available_models.reduce({}) do |memo, model| - memo[model] = false + available_classifiers.reduce({}) do |memo, model| + memo[model.model_name] = false memo end end @@ -30,21 +30,23 @@ module DiscourseAi def request(target_to_classify) target_content = content_of(target_to_classify) - available_models.reduce({}) do |memo, model| - memo[model] = request_with(model, target_content) + available_classifiers.reduce({}) do |memo, model| + memo[model.model_name] = request_with(target_content, model) memo end end + def transform_result(result) + hash_result = {} + result.each { |r| hash_result[r[:label]] = r[:score] } + hash_result + end + private - def request_with(model, content) - ::DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{endpoint}/api/v1/classify", - model, - content, - SiteSetting.ai_sentiment_inference_service_api_key, - ) + def request_with(content, model_config) + result = ::DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, model_config) + transform_result(result) end def content_of(target_to_classify) @@ -57,18 +59,6 @@ module DiscourseAi Tokenizer::BertTokenizer.truncate(content, 512) end - - def endpoint - if SiteSetting.ai_sentiment_inference_service_api_endpoint_srv.present? - service = - DiscourseAi::Utils::DnsSrv.lookup( - SiteSetting.ai_sentiment_inference_service_api_endpoint_srv, - ) - "https://#{service.target}:#{service.port}" - else - SiteSetting.ai_sentiment_inference_service_api_endpoint - end - end end end end diff --git a/lib/sentiment/sentiment_site_setting_json_schema.rb b/lib/sentiment/sentiment_site_setting_json_schema.rb new file mode 100644 index 00000000..520a0494 --- /dev/null +++ b/lib/sentiment/sentiment_site_setting_json_schema.rb @@ -0,0 +1,34 @@ +# frozen_string_literal: true + +module DiscourseAi + module Sentiment + class SentimentSiteSettingJsonSchema + def self.schema + @schema ||= { + type: "array", + items: { + type: "object", + format: "table", + title: "model", + properties: { + model_name: { + type: "string", + }, + endpoint: { + type: "string", + }, + api_key: { + type: "string", + }, + }, + required: %w[model_name endpoint api_key], + }, + } + end + + def self.values + JSON.parse(SiteSetting.ai_sentiment_model_configs, object_class: OpenStruct) + end + end + end +end diff --git a/spec/db/migrate/20241031041242_migrate_sentiment_classification_result_format_spec.rb b/spec/db/migrate/20241031041242_migrate_sentiment_classification_result_format_spec.rb new file mode 100644 index 00000000..ab152d5f --- /dev/null +++ b/spec/db/migrate/20241031041242_migrate_sentiment_classification_result_format_spec.rb @@ -0,0 +1,52 @@ +# frozen_string_literal: true + +require "rails_helper" +require Rails.root.join( + "plugins/discourse-ai/db/post_migrate/20241031041242_migrate_sentiment_classification_result_format", + ) + +RSpec.describe MigrateSentimentClassificationResultFormat do + let(:connection) { ActiveRecord::Base.connection } + + before { connection.execute(<<~SQL) } + INSERT INTO classification_results (model_used, classification, created_at, updated_at) VALUES + ('sentiment', '{"neutral": 65, "negative": 20, "positive": 14}', NOW(), NOW()), + ('emotion', '{"sadness": 10, "surprise": 15, "fear": 5, "anger": 20, "joy": 30, "disgust": 8, "neutral": 10}', NOW(), NOW()); + SQL + + after { connection.execute("DELETE FROM classification_results") } + + describe "#up" do + before { described_class.new.up } + + it "migrates sentiment classifications correctly" do + sentiment_result = connection.execute(<<~SQL).first + SELECT * FROM classification_results + WHERE model_used = 'cardiffnlp/twitter-roberta-base-sentiment-latest'; + SQL + + expected_sentiment = { "neutral" => 0.65, "negative" => 0.20, "positive" => 0.14 } + + expect(JSON.parse(sentiment_result["classification"])).to eq(expected_sentiment) + end + + it "migrates emotion classifications correctly" do + emotion_result = connection.execute(<<~SQL).first + SELECT * FROM classification_results + WHERE model_used = 'j-hartmann/emotion-english-distilroberta-base'; + SQL + + expected_emotion = { + "sadness" => 0.10, + "surprise" => 0.15, + "fear" => 0.05, + "anger" => 0.20, + "joy" => 0.30, + "disgust" => 0.08, + "neutral" => 0.10, + } + + expect(JSON.parse(emotion_result["classification"])).to eq(expected_emotion) + end + end +end diff --git a/spec/fabricators/classification_result_fabricator.rb b/spec/fabricators/classification_result_fabricator.rb index 525c59cd..23e63c28 100644 --- a/spec/fabricators/classification_result_fabricator.rb +++ b/spec/fabricators/classification_result_fabricator.rb @@ -6,11 +6,13 @@ Fabricator(:classification_result) do end Fabricator(:sentiment_classification, from: :classification_result) do - model_used "sentiment" - classification { { negative: 72, neutral: 23, positive: 4 } } + model_used "cardiffnlp/twitter-roberta-base-sentiment-latest" + classification { { negative: 0.72, neutral: 0.23, positive: 0.4 } } end Fabricator(:emotion_classification, from: :classification_result) do - model_used "emotion" - classification { { negative: 72, neutral: 23, positive: 4 } } + model_used "j-hartmann/emotion-english-distilroberta-base" + classification do + { sadness: 0.72, surprise: 0.23, fear: 0.4, anger: 0.87, joy: 0.22, disgust: 0.70 } + end end diff --git a/spec/lib/modules/sentiment/entry_point_spec.rb b/spec/lib/modules/sentiment/entry_point_spec.rb index 02cf1c18..3d892d17 100644 --- a/spec/lib/modules/sentiment/entry_point_spec.rb +++ b/spec/lib/modules/sentiment/entry_point_spec.rb @@ -53,7 +53,10 @@ RSpec.describe DiscourseAi::Sentiment::EntryPoint do end describe "custom reports" do - before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" } + before do + SiteSetting.ai_sentiment_model_configs = + "[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]" + end fab!(:pm) { Fabricate(:private_message_post) } @@ -61,8 +64,8 @@ RSpec.describe DiscourseAi::Sentiment::EntryPoint do fab!(:post_2) { Fabricate(:post) } describe "overall_sentiment report" do - let(:positive_classification) { { negative: 2, neutral: 30, positive: 70 } } - let(:negative_classification) { { negative: 65, neutral: 2, positive: 10 } } + let(:positive_classification) { { negative: 0.2, neutral: 0.3, positive: 0.7 } } + let(:negative_classification) { { negative: 0.65, neutral: 0.2, positive: 0.1 } } def sentiment_classification(post, classification) Fabricate(:sentiment_classification, target: post, classification: classification) @@ -84,12 +87,28 @@ RSpec.describe DiscourseAi::Sentiment::EntryPoint do describe "post_emotion report" do let(:emotion_1) do - { sadness: 49, surprise: 23, neutral: 6, fear: 34, anger: 87, joy: 22, disgust: 70 } + { + sadness: 0.49, + surprise: 0.23, + neutral: 0.6, + fear: 0.34, + anger: 0.87, + joy: 0.22, + disgust: 0.70, + } end let(:emotion_2) do - { sadness: 19, surprise: 63, neutral: 45, fear: 44, anger: 27, joy: 62, disgust: 30 } + { + sadness: 0.19, + surprise: 0.63, + neutral: 0.45, + fear: 0.44, + anger: 0.27, + joy: 0.62, + disgust: 0.30, + } end - let(:model_used) { "emotion" } + let(:model_used) { "j-hartmann/emotion-english-distilroberta-base" } def emotion_classification(post, classification) Fabricate( @@ -106,7 +125,7 @@ RSpec.describe DiscourseAi::Sentiment::EntryPoint do end it "calculate averages using only public posts" do - threshold = 30 + threshold = 0.30 emotion_classification(post_1, emotion_1) emotion_classification(post_2, emotion_2) diff --git a/spec/lib/modules/sentiment/jobs/regular/post_sentiment_analysis_spec.rb b/spec/lib/modules/sentiment/jobs/regular/post_sentiment_analysis_spec.rb index 4a4ffff9..e23b9e5f 100644 --- a/spec/lib/modules/sentiment/jobs/regular/post_sentiment_analysis_spec.rb +++ b/spec/lib/modules/sentiment/jobs/regular/post_sentiment_analysis_spec.rb @@ -8,7 +8,8 @@ describe Jobs::PostSentimentAnalysis do before do SiteSetting.ai_sentiment_enabled = true - SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" + SiteSetting.ai_sentiment_model_configs = + "[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]" end describe "scenarios where we return early without doing anything" do @@ -42,7 +43,8 @@ describe Jobs::PostSentimentAnalysis do end it "successfully classifies the post" do - expected_analysis = SiteSetting.ai_sentiment_models.split("|").length + expected_analysis = + DiscourseAi::Sentiment::SentimentClassification.new.available_classifiers.length SentimentInferenceStubs.stub_classification(post) subject.execute({ post_id: post.id }) diff --git a/spec/lib/modules/sentiment/sentiment_classification_spec.rb b/spec/lib/modules/sentiment/sentiment_classification_spec.rb index fdc32999..372a4284 100644 --- a/spec/lib/modules/sentiment/sentiment_classification_spec.rb +++ b/spec/lib/modules/sentiment/sentiment_classification_spec.rb @@ -6,15 +6,20 @@ describe DiscourseAi::Sentiment::SentimentClassification do fab!(:target) { Fabricate(:post) } describe "#request" do - before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" } + before do + SiteSetting.ai_sentiment_model_configs = + "[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]" + end it "returns the classification and the model used for it" do SentimentInferenceStubs.stub_classification(target) result = subject.request(target) - subject.available_models.each do |model| - expect(result[model]).to eq(SentimentInferenceStubs.model_response(model)) + subject.available_classifiers.each do |model_config| + expect(result[model_config.model_name]).to eq( + subject.transform_result(SentimentInferenceStubs.model_response(model_config.model_name)), + ) end end end diff --git a/spec/shared/classificator_spec.rb b/spec/shared/classificator_spec.rb index 41be592e..c598759e 100644 --- a/spec/shared/classificator_spec.rb +++ b/spec/shared/classificator_spec.rb @@ -6,21 +6,25 @@ require_relative "../support/sentiment_inference_stubs" describe DiscourseAi::Classificator do describe "#classify!" do describe "saving the classification result" do + let(:model) { DiscourseAi::Sentiment::SentimentClassification.new } + let(:classification_raw_result) do model - .available_models - .reduce({}) do |memo, model_name| - memo[model_name] = SentimentInferenceStubs.model_response(model_name) + .available_classifiers + .reduce({}) do |memo, model_config| + memo[model_config.model_name] = model.transform_result( + SentimentInferenceStubs.model_response(model_config.model_name), + ) memo end end - let(:model) { DiscourseAi::Sentiment::SentimentClassification.new } let(:classification) { DiscourseAi::PostClassificator.new(model) } fab!(:target) { Fabricate(:post) } before do - SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" + SiteSetting.ai_sentiment_model_configs = + "[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]" SentimentInferenceStubs.stub_classification(target) end @@ -28,28 +32,29 @@ describe DiscourseAi::Classificator do classification.classify!(target) stored_results = ClassificationResult.where(target: target) - expect(stored_results.length).to eq(model.available_models.length) + expect(stored_results.length).to eq(model.available_classifiers.length) - model.available_models.each do |model_name| - result = stored_results.detect { |c| c.model_used == model_name } + model.available_classifiers.each do |model_config| + result = stored_results.detect { |c| c.model_used == model_config.model_name } expect(result.classification_type).to eq(model.type.to_s) expect(result.created_at).to be_present expect(result.updated_at).to be_present - expected_classification = SentimentInferenceStubs.model_response(model) + expected_classification = SentimentInferenceStubs.model_response(model_config.model_name) + transformed_classification = model.transform_result(expected_classification) - expect(result.classification.deep_symbolize_keys).to eq(expected_classification) + expect(result.classification).to eq(transformed_classification) end end it "updates an existing classification result" do original_creation = 3.days.ago - model.available_models.each do |model_name| + model.available_classifiers.each do |model_config| ClassificationResult.create!( target: target, - model_used: model_name, + model_used: model_config.model_name, classification_type: model.type, created_at: original_creation, updated_at: original_creation, @@ -61,18 +66,16 @@ describe DiscourseAi::Classificator do classification.classify!(target) stored_results = ClassificationResult.where(target: target) - expect(stored_results.length).to eq(model.available_models.length) + expect(stored_results.length).to eq(model.available_classifiers.length) - model.available_models.each do |model_name| - result = stored_results.detect { |c| c.model_used == model_name } + model.available_classifiers.each do |model_config| + result = stored_results.detect { |c| c.model_used == model_config.model_name } expect(result.classification_type).to eq(model.type.to_s) expect(result.updated_at).to be > original_creation expect(result.created_at).to eq_time(original_creation) - expect(result.classification.deep_symbolize_keys).to eq( - classification_raw_result[model_name], - ) + expect(result.classification).to eq(classification_raw_result[model_config.model_name]) end end end diff --git a/spec/support/sentiment_inference_stubs.rb b/spec/support/sentiment_inference_stubs.rb index 209fcf8e..95c21d0e 100644 --- a/spec/support/sentiment_inference_stubs.rb +++ b/spec/support/sentiment_inference_stubs.rb @@ -2,24 +2,69 @@ class SentimentInferenceStubs class << self - def endpoint - "#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify" - end - def model_response(model) - { negative: 72, neutral: 23, positive: 4 } if model == "sentiment" - - { sadness: 99, surprise: 0, neutral: 0, fear: 0, anger: 0, joy: 0, disgust: 0 } + case model + when "SamLowe/roberta-base-go_emotions" + [ + { score: 0.90261286, label: "anger" }, + { score: 0.04127813, label: "annoyance" }, + { score: 0.03183503, label: "neutral" }, + { score: 0.005037033, label: "disgust" }, + { score: 0.0031153716, label: "disapproval" }, + { score: 0.0019118421, label: "disappointment" }, + { score: 0.0015849728, label: "sadness" }, + { score: 0.0012343781, label: "curiosity" }, + { score: 0.0010682651, label: "amusement" }, + { score: 0.00100747, label: "confusion" }, + { score: 0.0010035422, label: "admiration" }, + { score: 0.0009957326, label: "approval" }, + { score: 0.0009726665, label: "surprise" }, + { score: 0.0007754773, label: "realization" }, + { score: 0.0006978541, label: "love" }, + { score: 0.00064793555, label: "fear" }, + { score: 0.0006454095, label: "optimism" }, + { score: 0.0005969062, label: "joy" }, + { score: 0.0005498958, label: "embarrassment" }, + { score: 0.00050068577, label: "excitement" }, + { score: 0.00047403979, label: "caring" }, + { score: 0.00038841428, label: "gratitude" }, + { score: 0.00034546282, label: "desire" }, + { score: 0.00023012784, label: "grief" }, + { score: 0.00018133638, label: "remorse" }, + { score: 0.00012511834, label: "nervousness" }, + { score: 0.00012079607, label: "pride" }, + { score: 0.000063159685, label: "relief" }, + ] + when "cardiffnlp/twitter-roberta-base-sentiment-latest" + [ + { score: 0.627579, label: "negative" }, + { score: 0.29482335, label: "neutral" }, + { score: 0.07759768, label: "positive" }, + ] + when "j-hartmann/emotion-english-distilroberta-base" + [ + { score: 0.7033674, label: "anger" }, + { score: 0.2701151, label: "disgust" }, + { score: 0.009492096, label: "sadness" }, + { score: 0.0080775, label: "neutral" }, + { score: 0.0049473303, label: "fear" }, + { score: 0.0023369535, label: "surprise" }, + { score: 0.001663634, label: "joy" }, + ] + end end def stub_classification(post) content = post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw - DiscourseAi::Sentiment::SentimentClassification.new.available_models.each do |model| + DiscourseAi::Sentiment::SentimentClassification + .new + .available_classifiers + .each do |model_config| WebMock - .stub_request(:post, endpoint) - .with(body: JSON.dump(model: model, content: content)) - .to_return(status: 200, body: JSON.dump(model_response(model))) + .stub_request(:post, model_config.endpoint) + .with(body: JSON.dump(inputs: content, truncate: true)) + .to_return(status: 200, body: JSON.dump(model_response(model_config.model_name))) end end end diff --git a/spec/tasks/backfill_spec.rb b/spec/tasks/backfill_spec.rb index 7c0d51ce..dc33d309 100644 --- a/spec/tasks/backfill_spec.rb +++ b/spec/tasks/backfill_spec.rb @@ -9,7 +9,10 @@ RSpec.describe "assets:precompile" do end describe "ai:sentiment:backfill" do - before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" } + before do + SiteSetting.ai_sentiment_model_configs = + "[{\"model_name\":\"SamLowe/roberta-base-go_emotions\",\"endpoint\":\"http://samlowe-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"j-hartmann/emotion-english-distilroberta-base\",\"endpoint\":\"http://jhartmann-emotion.com\",\"api_key\":\"123\"},{\"model_name\":\"cardiffnlp/twitter-roberta-base-sentiment-latest\",\"endpoint\":\"http://cardiffnlp-sentiment.com\",\"api_key\":\"123\"}]" + end it "does nothing if the topic is soft-deleted" do target = Fabricate(:post)