mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-06-26 09:32:40 +00:00
DEV: Add tests for the sentiment module
This commit is contained in:
parent
ef6c785aca
commit
e8bffcdd64
@ -67,7 +67,7 @@ plugins:
|
|||||||
|
|
||||||
ai_nsfw_live_detection_enabled: false
|
ai_nsfw_live_detection_enabled: false
|
||||||
ai_nsfw_inference_service_api_endpoint:
|
ai_nsfw_inference_service_api_endpoint:
|
||||||
default: "https://nsfw-testing.demo-by-discourse.com"
|
default: ""
|
||||||
ai_nsfw_inference_service_api_key:
|
ai_nsfw_inference_service_api_key:
|
||||||
default: ""
|
default: ""
|
||||||
ai_nsfw_flag_automatically: true
|
ai_nsfw_flag_automatically: true
|
||||||
|
@ -3,10 +3,12 @@
|
|||||||
module DiscourseAI
|
module DiscourseAI
|
||||||
module NSFW
|
module NSFW
|
||||||
class EntryPoint
|
class EntryPoint
|
||||||
def inject_into(plugin)
|
def load_files
|
||||||
require_relative "evaluation.rb"
|
require_relative "evaluation.rb"
|
||||||
require_relative "jobs/regular/evaluate_content.rb"
|
require_relative "jobs/regular/evaluate_content.rb"
|
||||||
|
end
|
||||||
|
|
||||||
|
def inject_into(plugin)
|
||||||
plugin.add_model_callback(Upload, :after_create) do
|
plugin.add_model_callback(Upload, :after_create) do
|
||||||
Jobs.enqueue(:evaluate_content, upload_id: self.id)
|
Jobs.enqueue(:evaluate_content, upload_id: self.id)
|
||||||
end
|
end
|
||||||
|
@ -2,18 +2,21 @@
|
|||||||
module DiscourseAI
|
module DiscourseAI
|
||||||
module Sentiment
|
module Sentiment
|
||||||
class EntryPoint
|
class EntryPoint
|
||||||
def inject_into(plugin)
|
def load_files
|
||||||
require_relative "event_handler.rb"
|
|
||||||
require_relative "post_classifier.rb"
|
require_relative "post_classifier.rb"
|
||||||
require_relative "jobs/regular/sentiment_classify_post.rb"
|
require_relative "jobs/regular/post_sentiment_analysis.rb"
|
||||||
|
end
|
||||||
|
|
||||||
plugin.on(:post_created) do |post|
|
def inject_into(plugin)
|
||||||
DiscourseAI::Sentiment::EventHandler.handle_post_async(post)
|
sentiment_analysis_cb =
|
||||||
end
|
Proc.new do |post|
|
||||||
|
if SiteSetting.ai_sentiment_enabled
|
||||||
|
Jobs.enqueue(:post_sentiment_analysis, post_id: post.id)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
plugin.on(:post_edited) do |post|
|
plugin.on(:post_created, &sentiment_analysis_cb)
|
||||||
DiscourseAI::Sentiment::EventHandler.handle_post_async(post)
|
plugin.on(:post_edited, &sentiment_analysis_cb)
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -1,14 +0,0 @@
|
|||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module ::DiscourseAI
|
|
||||||
module Sentiment
|
|
||||||
class EventHandler
|
|
||||||
class << self
|
|
||||||
def handle_post_async(post)
|
|
||||||
return unless SiteSetting.ai_sentiment_enabled
|
|
||||||
Jobs.enqueue(:sentiment_classify_post, post_id: post.id)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
@ -1,17 +1,15 @@
|
|||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
module ::Jobs
|
module ::Jobs
|
||||||
class SentimentClassifyPost < ::Jobs::Base
|
class PostSentimentAnalysis < ::Jobs::Base
|
||||||
def execute(args)
|
def execute(args)
|
||||||
return unless SiteSetting.ai_sentiment_enabled
|
return unless SiteSetting.ai_sentiment_enabled
|
||||||
|
return if (post_id = args[:post_id]).blank?
|
||||||
post_id = args[:post_id]
|
|
||||||
return if post_id.blank?
|
|
||||||
|
|
||||||
post = Post.find_by(id: post_id, post_type: Post.types[:regular])
|
post = Post.find_by(id: post_id, post_type: Post.types[:regular])
|
||||||
return if post&.raw.blank?
|
return if post&.raw.blank?
|
||||||
|
|
||||||
::DiscourseAI::Sentiment::PostClassifier.new(post).classify!
|
::DiscourseAI::Sentiment::PostClassifier.new.classify!(post)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
@ -7,34 +7,36 @@ module ::DiscourseAI
|
|||||||
|
|
||||||
SENTIMENT_LABELS = %w[negative neutral positive]
|
SENTIMENT_LABELS = %w[negative neutral positive]
|
||||||
|
|
||||||
def initialize(object)
|
def classify!(post)
|
||||||
@object = object
|
available_models.each do |model|
|
||||||
|
classification = request_classification(post, model)
|
||||||
|
|
||||||
|
store_classification(post, model, classification)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def content
|
def available_models
|
||||||
@object.post_number == 1 ? "#{@object.topic.title}\n#{@object.raw}" : @object.raw
|
SiteSetting.ai_sentiment_models.split("|")
|
||||||
end
|
end
|
||||||
|
|
||||||
def classify!
|
private
|
||||||
SiteSetting
|
|
||||||
.ai_sentiment_models
|
|
||||||
.split("|")
|
|
||||||
.each do |model|
|
|
||||||
classification =
|
|
||||||
::DiscourseAI::InferenceManager.perform!(
|
|
||||||
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify",
|
|
||||||
model,
|
|
||||||
content,
|
|
||||||
SiteSetting.ai_sentiment_inference_service_api_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
store_classification(model, classification)
|
def request_classification(post, model)
|
||||||
end
|
::DiscourseAI::InferenceManager.perform!(
|
||||||
|
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify",
|
||||||
|
model,
|
||||||
|
content(post),
|
||||||
|
SiteSetting.ai_sentiment_inference_service_api_key,
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
def store_classification(model, classification)
|
def content(post)
|
||||||
|
post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw
|
||||||
|
end
|
||||||
|
|
||||||
|
def store_classification(post, model, classification)
|
||||||
PostCustomField.create!(
|
PostCustomField.create!(
|
||||||
post_id: @object.id,
|
post_id: post.id,
|
||||||
name: "ai-sentiment-#{model}",
|
name: "ai-sentiment-#{model}",
|
||||||
value: { classification: classification }.to_json,
|
value: { classification: classification }.to_json,
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
module DiscourseAI
|
module DiscourseAI
|
||||||
module Toxicity
|
module Toxicity
|
||||||
class EntryPoint
|
class EntryPoint
|
||||||
def inject_into(plugin)
|
def load_files
|
||||||
require_relative "event_handler.rb"
|
require_relative "event_handler.rb"
|
||||||
require_relative "classifier.rb"
|
require_relative "classifier.rb"
|
||||||
require_relative "post_classifier.rb"
|
require_relative "post_classifier.rb"
|
||||||
@ -10,7 +10,9 @@ module DiscourseAI
|
|||||||
|
|
||||||
require_relative "jobs/regular/toxicity_classify_post.rb"
|
require_relative "jobs/regular/toxicity_classify_post.rb"
|
||||||
require_relative "jobs/regular/toxicity_classify_chat_message.rb"
|
require_relative "jobs/regular/toxicity_classify_chat_message.rb"
|
||||||
|
end
|
||||||
|
|
||||||
|
def inject_into(plugin)
|
||||||
plugin.on(:post_created) do |post|
|
plugin.on(:post_created) do |post|
|
||||||
DiscourseAI::Toxicity::EventHandler.handle_post_async(post)
|
DiscourseAI::Toxicity::EventHandler.handle_post_async(post)
|
||||||
end
|
end
|
||||||
|
31
plugin.rb
31
plugin.rb
@ -9,22 +9,25 @@
|
|||||||
|
|
||||||
enabled_site_setting :discourse_ai_enabled
|
enabled_site_setting :discourse_ai_enabled
|
||||||
|
|
||||||
|
require_relative "lib/shared/inference_manager"
|
||||||
|
|
||||||
|
require_relative "lib/modules/nsfw/entry_point"
|
||||||
|
require_relative "lib/modules/toxicity/entry_point"
|
||||||
|
require_relative "lib/modules/sentiment/entry_point"
|
||||||
|
|
||||||
after_initialize do
|
after_initialize do
|
||||||
|
modules = [
|
||||||
|
DiscourseAI::NSFW::EntryPoint.new,
|
||||||
|
DiscourseAI::Toxicity::EntryPoint.new,
|
||||||
|
DiscourseAI::Sentiment::EntryPoint.new,
|
||||||
|
]
|
||||||
|
|
||||||
|
modules.each do |a_module|
|
||||||
|
a_module.load_files
|
||||||
|
a_module.inject_into(self)
|
||||||
|
end
|
||||||
|
|
||||||
module ::DiscourseAI
|
module ::DiscourseAI
|
||||||
PLUGIN_NAME = "discourse-ai"
|
PLUGIN_NAME = "discourse-ai"
|
||||||
end
|
end
|
||||||
|
|
||||||
require_relative "lib/shared/inference_manager.rb"
|
|
||||||
|
|
||||||
require_relative "lib/modules/nsfw/entry_point.rb"
|
|
||||||
require_relative "lib/modules/toxicity/entry_point.rb"
|
|
||||||
require_relative "lib/modules/sentiment/entry_point.rb"
|
|
||||||
|
|
||||||
modules = [
|
|
||||||
DiscourseAI::NSFW::EntryPoint,
|
|
||||||
DiscourseAI::Toxicity::EntryPoint,
|
|
||||||
DiscourseAI::Sentiment::EntryPoint,
|
|
||||||
]
|
|
||||||
|
|
||||||
modules.each { |a_module| a_module.new.inject_into(self) }
|
|
||||||
end
|
end
|
||||||
|
@ -4,7 +4,10 @@ require "rails_helper"
|
|||||||
require_relative "../../../support/nsfw_inference_stubs"
|
require_relative "../../../support/nsfw_inference_stubs"
|
||||||
|
|
||||||
describe DiscourseAI::NSFW::Evaluation do
|
describe DiscourseAI::NSFW::Evaluation do
|
||||||
before { SiteSetting.ai_nsfw_live_detection_enabled = true }
|
before do
|
||||||
|
SiteSetting.ai_nsfw_inference_service_api_endpoint = "http://test.com"
|
||||||
|
SiteSetting.ai_nsfw_live_detection_enabled = true
|
||||||
|
end
|
||||||
|
|
||||||
fab!(:image) { Fabricate(:s3_image_upload) }
|
fab!(:image) { Fabricate(:s3_image_upload) }
|
||||||
|
|
||||||
|
@ -7,6 +7,8 @@ describe Jobs::EvaluateContent do
|
|||||||
fab!(:image) { Fabricate(:s3_image_upload) }
|
fab!(:image) { Fabricate(:s3_image_upload) }
|
||||||
|
|
||||||
describe "#execute" do
|
describe "#execute" do
|
||||||
|
before { SiteSetting.ai_nsfw_inference_service_api_endpoint = "http://test.com" }
|
||||||
|
|
||||||
context "when we conclude content is NSFW" do
|
context "when we conclude content is NSFW" do
|
||||||
before { NSFWInferenceStubs.positive(image) }
|
before { NSFWInferenceStubs.positive(image) }
|
||||||
|
|
||||||
|
54
spec/lib/modules/sentiment/entry_point_spec.rb
Normal file
54
spec/lib/modules/sentiment/entry_point_spec.rb
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require "rails_helper"
|
||||||
|
|
||||||
|
describe DiscourseAI::Sentiment::EntryPoint do
|
||||||
|
fab!(:user) { Fabricate(:user) }
|
||||||
|
|
||||||
|
describe "registering event callbacks" do
|
||||||
|
context "when creating a post" do
|
||||||
|
let(:creator) do
|
||||||
|
PostCreator.new(
|
||||||
|
user,
|
||||||
|
raw: "this is the new content for my topic",
|
||||||
|
title: "this is my new topic title",
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "queues a job on create if sentiment analysis is enabled" do
|
||||||
|
SiteSetting.ai_sentiment_enabled = true
|
||||||
|
|
||||||
|
expect { creator.create }.to change(Jobs::PostSentimentAnalysis.jobs, :size).by(1)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "does nothing if sentiment analysis is disabled" do
|
||||||
|
SiteSetting.ai_sentiment_enabled = false
|
||||||
|
|
||||||
|
expect { creator.create }.not_to change(Jobs::PostSentimentAnalysis.jobs, :size)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
context "when editing a post" do
|
||||||
|
fab!(:post) { Fabricate(:post, user: user) }
|
||||||
|
let(:revisor) { PostRevisor.new(post) }
|
||||||
|
|
||||||
|
it "queues a job on update if sentiment analysis is enabled" do
|
||||||
|
SiteSetting.ai_sentiment_enabled = true
|
||||||
|
|
||||||
|
expect { revisor.revise!(user, raw: "This is my new test") }.to change(
|
||||||
|
Jobs::PostSentimentAnalysis.jobs,
|
||||||
|
:size,
|
||||||
|
).by(1)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "does nothing if sentiment analysis is disabled" do
|
||||||
|
SiteSetting.ai_sentiment_enabled = false
|
||||||
|
|
||||||
|
expect { revisor.revise!(user, raw: "This is my new test") }.not_to change(
|
||||||
|
Jobs::PostSentimentAnalysis.jobs,
|
||||||
|
:size,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -0,0 +1,54 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require "rails_helper"
|
||||||
|
require_relative "../../../../../support/sentiment_inference_stubs"
|
||||||
|
|
||||||
|
describe Jobs::PostSentimentAnalysis do
|
||||||
|
describe "#execute" do
|
||||||
|
let(:post) { Fabricate(:post) }
|
||||||
|
|
||||||
|
before do
|
||||||
|
SiteSetting.ai_sentiment_enabled = true
|
||||||
|
SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com"
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "scenarios where we return early without doing anything" do
|
||||||
|
it "does nothing when ai_sentiment_enabled is disabled" do
|
||||||
|
SiteSetting.ai_sentiment_enabled = false
|
||||||
|
|
||||||
|
subject.execute({ post_id: post.id })
|
||||||
|
|
||||||
|
expect(PostCustomField.where(post: post).count).to be_zero
|
||||||
|
end
|
||||||
|
|
||||||
|
it "does nothing if there's no arg called post_id" do
|
||||||
|
subject.execute({})
|
||||||
|
|
||||||
|
expect(PostCustomField.where(post: post).count).to be_zero
|
||||||
|
end
|
||||||
|
|
||||||
|
it "does nothing if no post match the given id" do
|
||||||
|
subject.execute({ post_id: nil })
|
||||||
|
|
||||||
|
expect(PostCustomField.where(post: post).count).to be_zero
|
||||||
|
end
|
||||||
|
|
||||||
|
it "does nothing if the post content is blank" do
|
||||||
|
post.update_columns(raw: "")
|
||||||
|
|
||||||
|
subject.execute({ post_id: post.id })
|
||||||
|
|
||||||
|
expect(PostCustomField.where(post: post).count).to be_zero
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
it "succesfully classifies the post" do
|
||||||
|
expected_analysis = SiteSetting.ai_sentiment_models.split("|").length
|
||||||
|
SentimentInferenceStubs.stub_classification(post)
|
||||||
|
|
||||||
|
subject.execute({ post_id: post.id })
|
||||||
|
|
||||||
|
expect(PostCustomField.where(post: post).count).to eq(expected_analysis)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
26
spec/lib/modules/sentiment/post_classifier_spec.rb
Normal file
26
spec/lib/modules/sentiment/post_classifier_spec.rb
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require "rails_helper"
|
||||||
|
require_relative "../../../support/sentiment_inference_stubs"
|
||||||
|
|
||||||
|
describe DiscourseAI::Sentiment::PostClassifier do
|
||||||
|
fab!(:post) { Fabricate(:post) }
|
||||||
|
|
||||||
|
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
|
||||||
|
|
||||||
|
describe "#classify!" do
|
||||||
|
it "stores each model classification in a post custom field" do
|
||||||
|
SentimentInferenceStubs.stub_classification(post)
|
||||||
|
|
||||||
|
subject.classify!(post)
|
||||||
|
|
||||||
|
subject.available_models.each do |model|
|
||||||
|
stored_classification = PostCustomField.find_by(post: post, name: "ai-sentiment-#{model}")
|
||||||
|
expect(stored_classification).to be_present
|
||||||
|
expect(stored_classification.value).to eq(
|
||||||
|
{ classification: SentimentInferenceStubs.model_response(model) }.to_json,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
26
spec/support/sentiment_inference_stubs.rb
Normal file
26
spec/support/sentiment_inference_stubs.rb
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class SentimentInferenceStubs
|
||||||
|
class << self
|
||||||
|
def endpoint
|
||||||
|
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify"
|
||||||
|
end
|
||||||
|
|
||||||
|
def model_response(model)
|
||||||
|
{ negative: 72, neutral: 23, positive: 4 } if model == "sentiment"
|
||||||
|
|
||||||
|
{ sadness: 99, surprise: 0, neutral: 0, fear: 0, anger: 0, joy: 0, disgust: 0 }
|
||||||
|
end
|
||||||
|
|
||||||
|
def stub_classification(post)
|
||||||
|
content = post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw
|
||||||
|
|
||||||
|
DiscourseAI::Sentiment::PostClassifier.new.available_models.each do |model|
|
||||||
|
WebMock
|
||||||
|
.stub_request(:post, endpoint)
|
||||||
|
.with(body: JSON.dump(model: model, content: content))
|
||||||
|
.to_return(status: 200, body: JSON.dump(model_response(model)))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
Loading…
x
Reference in New Issue
Block a user