DEV: Add tests for the sentiment module
This commit is contained in:
parent
ef6c785aca
commit
e8bffcdd64
|
@ -67,7 +67,7 @@ plugins:
|
|||
|
||||
ai_nsfw_live_detection_enabled: false
|
||||
ai_nsfw_inference_service_api_endpoint:
|
||||
default: "https://nsfw-testing.demo-by-discourse.com"
|
||||
default: ""
|
||||
ai_nsfw_inference_service_api_key:
|
||||
default: ""
|
||||
ai_nsfw_flag_automatically: true
|
||||
|
|
|
@ -3,10 +3,12 @@
|
|||
module DiscourseAI
|
||||
module NSFW
|
||||
class EntryPoint
|
||||
def inject_into(plugin)
|
||||
def load_files
|
||||
require_relative "evaluation.rb"
|
||||
require_relative "jobs/regular/evaluate_content.rb"
|
||||
end
|
||||
|
||||
def inject_into(plugin)
|
||||
plugin.add_model_callback(Upload, :after_create) do
|
||||
Jobs.enqueue(:evaluate_content, upload_id: self.id)
|
||||
end
|
||||
|
|
|
@ -2,19 +2,22 @@
|
|||
module DiscourseAI
|
||||
module Sentiment
|
||||
class EntryPoint
|
||||
def inject_into(plugin)
|
||||
require_relative "event_handler.rb"
|
||||
def load_files
|
||||
require_relative "post_classifier.rb"
|
||||
require_relative "jobs/regular/sentiment_classify_post.rb"
|
||||
|
||||
plugin.on(:post_created) do |post|
|
||||
DiscourseAI::Sentiment::EventHandler.handle_post_async(post)
|
||||
require_relative "jobs/regular/post_sentiment_analysis.rb"
|
||||
end
|
||||
|
||||
plugin.on(:post_edited) do |post|
|
||||
DiscourseAI::Sentiment::EventHandler.handle_post_async(post)
|
||||
def inject_into(plugin)
|
||||
sentiment_analysis_cb =
|
||||
Proc.new do |post|
|
||||
if SiteSetting.ai_sentiment_enabled
|
||||
Jobs.enqueue(:post_sentiment_analysis, post_id: post.id)
|
||||
end
|
||||
end
|
||||
|
||||
plugin.on(:post_created, &sentiment_analysis_cb)
|
||||
plugin.on(:post_edited, &sentiment_analysis_cb)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAI
|
||||
module Sentiment
|
||||
class EventHandler
|
||||
class << self
|
||||
def handle_post_async(post)
|
||||
return unless SiteSetting.ai_sentiment_enabled
|
||||
Jobs.enqueue(:sentiment_classify_post, post_id: post.id)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,17 +1,15 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::Jobs
|
||||
class SentimentClassifyPost < ::Jobs::Base
|
||||
class PostSentimentAnalysis < ::Jobs::Base
|
||||
def execute(args)
|
||||
return unless SiteSetting.ai_sentiment_enabled
|
||||
|
||||
post_id = args[:post_id]
|
||||
return if post_id.blank?
|
||||
return if (post_id = args[:post_id]).blank?
|
||||
|
||||
post = Post.find_by(id: post_id, post_type: Post.types[:regular])
|
||||
return if post&.raw.blank?
|
||||
|
||||
::DiscourseAI::Sentiment::PostClassifier.new(post).classify!
|
||||
::DiscourseAI::Sentiment::PostClassifier.new.classify!(post)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -7,34 +7,36 @@ module ::DiscourseAI
|
|||
|
||||
SENTIMENT_LABELS = %w[negative neutral positive]
|
||||
|
||||
def initialize(object)
|
||||
@object = object
|
||||
def classify!(post)
|
||||
available_models.each do |model|
|
||||
classification = request_classification(post, model)
|
||||
|
||||
store_classification(post, model, classification)
|
||||
end
|
||||
end
|
||||
|
||||
def content
|
||||
@object.post_number == 1 ? "#{@object.topic.title}\n#{@object.raw}" : @object.raw
|
||||
def available_models
|
||||
SiteSetting.ai_sentiment_models.split("|")
|
||||
end
|
||||
|
||||
def classify!
|
||||
SiteSetting
|
||||
.ai_sentiment_models
|
||||
.split("|")
|
||||
.each do |model|
|
||||
classification =
|
||||
private
|
||||
|
||||
def request_classification(post, model)
|
||||
::DiscourseAI::InferenceManager.perform!(
|
||||
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify",
|
||||
model,
|
||||
content,
|
||||
content(post),
|
||||
SiteSetting.ai_sentiment_inference_service_api_key,
|
||||
)
|
||||
|
||||
store_classification(model, classification)
|
||||
end
|
||||
end
|
||||
|
||||
def store_classification(model, classification)
|
||||
def content(post)
|
||||
post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw
|
||||
end
|
||||
|
||||
def store_classification(post, model, classification)
|
||||
PostCustomField.create!(
|
||||
post_id: @object.id,
|
||||
post_id: post.id,
|
||||
name: "ai-sentiment-#{model}",
|
||||
value: { classification: classification }.to_json,
|
||||
)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
module DiscourseAI
|
||||
module Toxicity
|
||||
class EntryPoint
|
||||
def inject_into(plugin)
|
||||
def load_files
|
||||
require_relative "event_handler.rb"
|
||||
require_relative "classifier.rb"
|
||||
require_relative "post_classifier.rb"
|
||||
|
@ -10,7 +10,9 @@ module DiscourseAI
|
|||
|
||||
require_relative "jobs/regular/toxicity_classify_post.rb"
|
||||
require_relative "jobs/regular/toxicity_classify_chat_message.rb"
|
||||
end
|
||||
|
||||
def inject_into(plugin)
|
||||
plugin.on(:post_created) do |post|
|
||||
DiscourseAI::Toxicity::EventHandler.handle_post_async(post)
|
||||
end
|
||||
|
|
31
plugin.rb
31
plugin.rb
|
@ -9,22 +9,25 @@
|
|||
|
||||
enabled_site_setting :discourse_ai_enabled
|
||||
|
||||
require_relative "lib/shared/inference_manager"
|
||||
|
||||
require_relative "lib/modules/nsfw/entry_point"
|
||||
require_relative "lib/modules/toxicity/entry_point"
|
||||
require_relative "lib/modules/sentiment/entry_point"
|
||||
|
||||
after_initialize do
|
||||
modules = [
|
||||
DiscourseAI::NSFW::EntryPoint.new,
|
||||
DiscourseAI::Toxicity::EntryPoint.new,
|
||||
DiscourseAI::Sentiment::EntryPoint.new,
|
||||
]
|
||||
|
||||
modules.each do |a_module|
|
||||
a_module.load_files
|
||||
a_module.inject_into(self)
|
||||
end
|
||||
|
||||
module ::DiscourseAI
|
||||
PLUGIN_NAME = "discourse-ai"
|
||||
end
|
||||
|
||||
require_relative "lib/shared/inference_manager.rb"
|
||||
|
||||
require_relative "lib/modules/nsfw/entry_point.rb"
|
||||
require_relative "lib/modules/toxicity/entry_point.rb"
|
||||
require_relative "lib/modules/sentiment/entry_point.rb"
|
||||
|
||||
modules = [
|
||||
DiscourseAI::NSFW::EntryPoint,
|
||||
DiscourseAI::Toxicity::EntryPoint,
|
||||
DiscourseAI::Sentiment::EntryPoint,
|
||||
]
|
||||
|
||||
modules.each { |a_module| a_module.new.inject_into(self) }
|
||||
end
|
||||
|
|
|
@ -4,7 +4,10 @@ require "rails_helper"
|
|||
require_relative "../../../support/nsfw_inference_stubs"
|
||||
|
||||
describe DiscourseAI::NSFW::Evaluation do
|
||||
before { SiteSetting.ai_nsfw_live_detection_enabled = true }
|
||||
before do
|
||||
SiteSetting.ai_nsfw_inference_service_api_endpoint = "http://test.com"
|
||||
SiteSetting.ai_nsfw_live_detection_enabled = true
|
||||
end
|
||||
|
||||
fab!(:image) { Fabricate(:s3_image_upload) }
|
||||
|
||||
|
|
|
@ -7,6 +7,8 @@ describe Jobs::EvaluateContent do
|
|||
fab!(:image) { Fabricate(:s3_image_upload) }
|
||||
|
||||
describe "#execute" do
|
||||
before { SiteSetting.ai_nsfw_inference_service_api_endpoint = "http://test.com" }
|
||||
|
||||
context "when we conclude content is NSFW" do
|
||||
before { NSFWInferenceStubs.positive(image) }
|
||||
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "rails_helper"
|
||||
|
||||
describe DiscourseAI::Sentiment::EntryPoint do
|
||||
fab!(:user) { Fabricate(:user) }
|
||||
|
||||
describe "registering event callbacks" do
|
||||
context "when creating a post" do
|
||||
let(:creator) do
|
||||
PostCreator.new(
|
||||
user,
|
||||
raw: "this is the new content for my topic",
|
||||
title: "this is my new topic title",
|
||||
)
|
||||
end
|
||||
|
||||
it "queues a job on create if sentiment analysis is enabled" do
|
||||
SiteSetting.ai_sentiment_enabled = true
|
||||
|
||||
expect { creator.create }.to change(Jobs::PostSentimentAnalysis.jobs, :size).by(1)
|
||||
end
|
||||
|
||||
it "does nothing if sentiment analysis is disabled" do
|
||||
SiteSetting.ai_sentiment_enabled = false
|
||||
|
||||
expect { creator.create }.not_to change(Jobs::PostSentimentAnalysis.jobs, :size)
|
||||
end
|
||||
end
|
||||
|
||||
context "when editing a post" do
|
||||
fab!(:post) { Fabricate(:post, user: user) }
|
||||
let(:revisor) { PostRevisor.new(post) }
|
||||
|
||||
it "queues a job on update if sentiment analysis is enabled" do
|
||||
SiteSetting.ai_sentiment_enabled = true
|
||||
|
||||
expect { revisor.revise!(user, raw: "This is my new test") }.to change(
|
||||
Jobs::PostSentimentAnalysis.jobs,
|
||||
:size,
|
||||
).by(1)
|
||||
end
|
||||
|
||||
it "does nothing if sentiment analysis is disabled" do
|
||||
SiteSetting.ai_sentiment_enabled = false
|
||||
|
||||
expect { revisor.revise!(user, raw: "This is my new test") }.not_to change(
|
||||
Jobs::PostSentimentAnalysis.jobs,
|
||||
:size,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,54 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "rails_helper"
|
||||
require_relative "../../../../../support/sentiment_inference_stubs"
|
||||
|
||||
describe Jobs::PostSentimentAnalysis do
|
||||
describe "#execute" do
|
||||
let(:post) { Fabricate(:post) }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_sentiment_enabled = true
|
||||
SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com"
|
||||
end
|
||||
|
||||
describe "scenarios where we return early without doing anything" do
|
||||
it "does nothing when ai_sentiment_enabled is disabled" do
|
||||
SiteSetting.ai_sentiment_enabled = false
|
||||
|
||||
subject.execute({ post_id: post.id })
|
||||
|
||||
expect(PostCustomField.where(post: post).count).to be_zero
|
||||
end
|
||||
|
||||
it "does nothing if there's no arg called post_id" do
|
||||
subject.execute({})
|
||||
|
||||
expect(PostCustomField.where(post: post).count).to be_zero
|
||||
end
|
||||
|
||||
it "does nothing if no post match the given id" do
|
||||
subject.execute({ post_id: nil })
|
||||
|
||||
expect(PostCustomField.where(post: post).count).to be_zero
|
||||
end
|
||||
|
||||
it "does nothing if the post content is blank" do
|
||||
post.update_columns(raw: "")
|
||||
|
||||
subject.execute({ post_id: post.id })
|
||||
|
||||
expect(PostCustomField.where(post: post).count).to be_zero
|
||||
end
|
||||
end
|
||||
|
||||
it "succesfully classifies the post" do
|
||||
expected_analysis = SiteSetting.ai_sentiment_models.split("|").length
|
||||
SentimentInferenceStubs.stub_classification(post)
|
||||
|
||||
subject.execute({ post_id: post.id })
|
||||
|
||||
expect(PostCustomField.where(post: post).count).to eq(expected_analysis)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,26 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "rails_helper"
|
||||
require_relative "../../../support/sentiment_inference_stubs"
|
||||
|
||||
describe DiscourseAI::Sentiment::PostClassifier do
|
||||
fab!(:post) { Fabricate(:post) }
|
||||
|
||||
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
|
||||
|
||||
describe "#classify!" do
|
||||
it "stores each model classification in a post custom field" do
|
||||
SentimentInferenceStubs.stub_classification(post)
|
||||
|
||||
subject.classify!(post)
|
||||
|
||||
subject.available_models.each do |model|
|
||||
stored_classification = PostCustomField.find_by(post: post, name: "ai-sentiment-#{model}")
|
||||
expect(stored_classification).to be_present
|
||||
expect(stored_classification.value).to eq(
|
||||
{ classification: SentimentInferenceStubs.model_response(model) }.to_json,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,26 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class SentimentInferenceStubs
|
||||
class << self
|
||||
def endpoint
|
||||
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify"
|
||||
end
|
||||
|
||||
def model_response(model)
|
||||
{ negative: 72, neutral: 23, positive: 4 } if model == "sentiment"
|
||||
|
||||
{ sadness: 99, surprise: 0, neutral: 0, fear: 0, anger: 0, joy: 0, disgust: 0 }
|
||||
end
|
||||
|
||||
def stub_classification(post)
|
||||
content = post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw
|
||||
|
||||
DiscourseAI::Sentiment::PostClassifier.new.available_models.each do |model|
|
||||
WebMock
|
||||
.stub_request(:post, endpoint)
|
||||
.with(body: JSON.dump(model: model, content: content))
|
||||
.to_return(status: 200, body: JSON.dump(model_response(model)))
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue