DEV: Add tests for the sentiment module

This commit is contained in:
Roman Rizzi 2023-02-23 15:50:10 -03:00
parent ef6c785aca
commit e8bffcdd64
No known key found for this signature in database
GPG Key ID: 64024A71CE7330D3
14 changed files with 227 additions and 66 deletions

View File

@ -67,7 +67,7 @@ plugins:
ai_nsfw_live_detection_enabled: false
ai_nsfw_inference_service_api_endpoint:
default: "https://nsfw-testing.demo-by-discourse.com"
default: ""
ai_nsfw_inference_service_api_key:
default: ""
ai_nsfw_flag_automatically: true

View File

@ -3,10 +3,12 @@
module DiscourseAI
module NSFW
class EntryPoint
def inject_into(plugin)
def load_files
require_relative "evaluation.rb"
require_relative "jobs/regular/evaluate_content.rb"
end
def inject_into(plugin)
plugin.add_model_callback(Upload, :after_create) do
Jobs.enqueue(:evaluate_content, upload_id: self.id)
end

View File

@ -2,18 +2,21 @@
module DiscourseAI
module Sentiment
class EntryPoint
def inject_into(plugin)
require_relative "event_handler.rb"
def load_files
require_relative "post_classifier.rb"
require_relative "jobs/regular/sentiment_classify_post.rb"
require_relative "jobs/regular/post_sentiment_analysis.rb"
end
plugin.on(:post_created) do |post|
DiscourseAI::Sentiment::EventHandler.handle_post_async(post)
end
def inject_into(plugin)
sentiment_analysis_cb =
Proc.new do |post|
if SiteSetting.ai_sentiment_enabled
Jobs.enqueue(:post_sentiment_analysis, post_id: post.id)
end
end
plugin.on(:post_edited) do |post|
DiscourseAI::Sentiment::EventHandler.handle_post_async(post)
end
plugin.on(:post_created, &sentiment_analysis_cb)
plugin.on(:post_edited, &sentiment_analysis_cb)
end
end
end

View File

@ -1,14 +0,0 @@
# frozen_string_literal: true
module ::DiscourseAI
module Sentiment
class EventHandler
class << self
def handle_post_async(post)
return unless SiteSetting.ai_sentiment_enabled
Jobs.enqueue(:sentiment_classify_post, post_id: post.id)
end
end
end
end
end

View File

@ -1,17 +1,15 @@
# frozen_string_literal: true
module ::Jobs
class SentimentClassifyPost < ::Jobs::Base
class PostSentimentAnalysis < ::Jobs::Base
def execute(args)
return unless SiteSetting.ai_sentiment_enabled
post_id = args[:post_id]
return if post_id.blank?
return if (post_id = args[:post_id]).blank?
post = Post.find_by(id: post_id, post_type: Post.types[:regular])
return if post&.raw.blank?
::DiscourseAI::Sentiment::PostClassifier.new(post).classify!
::DiscourseAI::Sentiment::PostClassifier.new.classify!(post)
end
end
end

View File

@ -7,34 +7,36 @@ module ::DiscourseAI
SENTIMENT_LABELS = %w[negative neutral positive]
def initialize(object)
@object = object
def classify!(post)
available_models.each do |model|
classification = request_classification(post, model)
store_classification(post, model, classification)
end
end
def content
@object.post_number == 1 ? "#{@object.topic.title}\n#{@object.raw}" : @object.raw
def available_models
SiteSetting.ai_sentiment_models.split("|")
end
def classify!
SiteSetting
.ai_sentiment_models
.split("|")
.each do |model|
classification =
::DiscourseAI::InferenceManager.perform!(
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify",
model,
content,
SiteSetting.ai_sentiment_inference_service_api_key,
)
private
store_classification(model, classification)
end
def request_classification(post, model)
::DiscourseAI::InferenceManager.perform!(
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify",
model,
content(post),
SiteSetting.ai_sentiment_inference_service_api_key,
)
end
def store_classification(model, classification)
def content(post)
post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw
end
def store_classification(post, model, classification)
PostCustomField.create!(
post_id: @object.id,
post_id: post.id,
name: "ai-sentiment-#{model}",
value: { classification: classification }.to_json,
)

View File

@ -2,7 +2,7 @@
module DiscourseAI
module Toxicity
class EntryPoint
def inject_into(plugin)
def load_files
require_relative "event_handler.rb"
require_relative "classifier.rb"
require_relative "post_classifier.rb"
@ -10,7 +10,9 @@ module DiscourseAI
require_relative "jobs/regular/toxicity_classify_post.rb"
require_relative "jobs/regular/toxicity_classify_chat_message.rb"
end
def inject_into(plugin)
plugin.on(:post_created) do |post|
DiscourseAI::Toxicity::EventHandler.handle_post_async(post)
end

View File

@ -9,22 +9,25 @@
enabled_site_setting :discourse_ai_enabled
require_relative "lib/shared/inference_manager"
require_relative "lib/modules/nsfw/entry_point"
require_relative "lib/modules/toxicity/entry_point"
require_relative "lib/modules/sentiment/entry_point"
after_initialize do
modules = [
DiscourseAI::NSFW::EntryPoint.new,
DiscourseAI::Toxicity::EntryPoint.new,
DiscourseAI::Sentiment::EntryPoint.new,
]
modules.each do |a_module|
a_module.load_files
a_module.inject_into(self)
end
module ::DiscourseAI
PLUGIN_NAME = "discourse-ai"
end
require_relative "lib/shared/inference_manager.rb"
require_relative "lib/modules/nsfw/entry_point.rb"
require_relative "lib/modules/toxicity/entry_point.rb"
require_relative "lib/modules/sentiment/entry_point.rb"
modules = [
DiscourseAI::NSFW::EntryPoint,
DiscourseAI::Toxicity::EntryPoint,
DiscourseAI::Sentiment::EntryPoint,
]
modules.each { |a_module| a_module.new.inject_into(self) }
end

View File

@ -4,7 +4,10 @@ require "rails_helper"
require_relative "../../../support/nsfw_inference_stubs"
describe DiscourseAI::NSFW::Evaluation do
before { SiteSetting.ai_nsfw_live_detection_enabled = true }
before do
SiteSetting.ai_nsfw_inference_service_api_endpoint = "http://test.com"
SiteSetting.ai_nsfw_live_detection_enabled = true
end
fab!(:image) { Fabricate(:s3_image_upload) }

View File

@ -7,6 +7,8 @@ describe Jobs::EvaluateContent do
fab!(:image) { Fabricate(:s3_image_upload) }
describe "#execute" do
before { SiteSetting.ai_nsfw_inference_service_api_endpoint = "http://test.com" }
context "when we conclude content is NSFW" do
before { NSFWInferenceStubs.positive(image) }

View File

@ -0,0 +1,54 @@
# frozen_string_literal: true
require "rails_helper"
describe DiscourseAI::Sentiment::EntryPoint do
fab!(:user) { Fabricate(:user) }
describe "registering event callbacks" do
context "when creating a post" do
let(:creator) do
PostCreator.new(
user,
raw: "this is the new content for my topic",
title: "this is my new topic title",
)
end
it "queues a job on create if sentiment analysis is enabled" do
SiteSetting.ai_sentiment_enabled = true
expect { creator.create }.to change(Jobs::PostSentimentAnalysis.jobs, :size).by(1)
end
it "does nothing if sentiment analysis is disabled" do
SiteSetting.ai_sentiment_enabled = false
expect { creator.create }.not_to change(Jobs::PostSentimentAnalysis.jobs, :size)
end
end
context "when editing a post" do
fab!(:post) { Fabricate(:post, user: user) }
let(:revisor) { PostRevisor.new(post) }
it "queues a job on update if sentiment analysis is enabled" do
SiteSetting.ai_sentiment_enabled = true
expect { revisor.revise!(user, raw: "This is my new test") }.to change(
Jobs::PostSentimentAnalysis.jobs,
:size,
).by(1)
end
it "does nothing if sentiment analysis is disabled" do
SiteSetting.ai_sentiment_enabled = false
expect { revisor.revise!(user, raw: "This is my new test") }.not_to change(
Jobs::PostSentimentAnalysis.jobs,
:size,
)
end
end
end
end

View File

@ -0,0 +1,54 @@
# frozen_string_literal: true
require "rails_helper"
require_relative "../../../../../support/sentiment_inference_stubs"
describe Jobs::PostSentimentAnalysis do
describe "#execute" do
let(:post) { Fabricate(:post) }
before do
SiteSetting.ai_sentiment_enabled = true
SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com"
end
describe "scenarios where we return early without doing anything" do
it "does nothing when ai_sentiment_enabled is disabled" do
SiteSetting.ai_sentiment_enabled = false
subject.execute({ post_id: post.id })
expect(PostCustomField.where(post: post).count).to be_zero
end
it "does nothing if there's no arg called post_id" do
subject.execute({})
expect(PostCustomField.where(post: post).count).to be_zero
end
it "does nothing if no post match the given id" do
subject.execute({ post_id: nil })
expect(PostCustomField.where(post: post).count).to be_zero
end
it "does nothing if the post content is blank" do
post.update_columns(raw: "")
subject.execute({ post_id: post.id })
expect(PostCustomField.where(post: post).count).to be_zero
end
end
it "succesfully classifies the post" do
expected_analysis = SiteSetting.ai_sentiment_models.split("|").length
SentimentInferenceStubs.stub_classification(post)
subject.execute({ post_id: post.id })
expect(PostCustomField.where(post: post).count).to eq(expected_analysis)
end
end
end

View File

@ -0,0 +1,26 @@
# frozen_string_literal: true
require "rails_helper"
require_relative "../../../support/sentiment_inference_stubs"
describe DiscourseAI::Sentiment::PostClassifier do
fab!(:post) { Fabricate(:post) }
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
describe "#classify!" do
it "stores each model classification in a post custom field" do
SentimentInferenceStubs.stub_classification(post)
subject.classify!(post)
subject.available_models.each do |model|
stored_classification = PostCustomField.find_by(post: post, name: "ai-sentiment-#{model}")
expect(stored_classification).to be_present
expect(stored_classification.value).to eq(
{ classification: SentimentInferenceStubs.model_response(model) }.to_json,
)
end
end
end
end

View File

@ -0,0 +1,26 @@
# frozen_string_literal: true
class SentimentInferenceStubs
class << self
def endpoint
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify"
end
def model_response(model)
{ negative: 72, neutral: 23, positive: 4 } if model == "sentiment"
{ sadness: 99, surprise: 0, neutral: 0, fear: 0, anger: 0, joy: 0, disgust: 0 }
end
def stub_classification(post)
content = post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw
DiscourseAI::Sentiment::PostClassifier.new.available_models.each do |model|
WebMock
.stub_request(:post, endpoint)
.with(body: JSON.dump(model: model, content: content))
.to_return(status: 200, body: JSON.dump(model_response(model)))
end
end
end
end