REFACTOR: Streamline flag and classification process
This commit is contained in:
parent
85768cfb1c
commit
5f9597474c
|
@ -4,14 +4,15 @@ module DiscourseAI
|
|||
module NSFW
|
||||
class EntryPoint
|
||||
def load_files
|
||||
require_relative "evaluation"
|
||||
require_relative "nsfw_classification"
|
||||
require_relative "jobs/regular/evaluate_post_uploads"
|
||||
end
|
||||
|
||||
def inject_into(plugin)
|
||||
nsfw_detection_cb =
|
||||
Proc.new do |post|
|
||||
if SiteSetting.ai_nsfw_detection_enabled && post.uploads.present?
|
||||
if SiteSetting.ai_nsfw_detection_enabled &&
|
||||
DiscourseAI::NSFW::NSFWClassification.new.can_classify?(post)
|
||||
Jobs.enqueue(:evaluate_post_uploads, post_id: post.id)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,50 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAI
|
||||
module NSFW
|
||||
class Evaluation
|
||||
def perform(upload)
|
||||
result = { verdict: false, evaluation: {} }
|
||||
|
||||
SiteSetting
|
||||
.ai_nsfw_models
|
||||
.split("|")
|
||||
.each do |model|
|
||||
model_result = evaluate_with_model(model, upload).symbolize_keys!
|
||||
|
||||
result[:evaluation][model.to_sym] = model_result
|
||||
|
||||
result[:verdict] = send("#{model}_verdict?", model_result)
|
||||
end
|
||||
|
||||
result
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def evaluate_with_model(model, upload)
|
||||
upload_url = Discourse.store.cdn_url(upload.url)
|
||||
upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/")
|
||||
|
||||
DiscourseAI::InferenceManager.perform!(
|
||||
"#{SiteSetting.ai_nsfw_inference_service_api_endpoint}/api/v1/classify",
|
||||
model,
|
||||
upload_url,
|
||||
SiteSetting.ai_nsfw_inference_service_api_key,
|
||||
)
|
||||
end
|
||||
|
||||
def opennsfw2_verdict?(clasification)
|
||||
clasification.values.first.to_i >= SiteSetting.ai_nsfw_flag_threshold_general
|
||||
end
|
||||
|
||||
def nsfw_detector_verdict?(classification)
|
||||
classification.each do |key, value|
|
||||
next if key == :neutral
|
||||
return true if value.to_i >= SiteSetting.send("ai_nsfw_flag_threshold_#{key}")
|
||||
end
|
||||
false
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -9,13 +9,9 @@ module Jobs
|
|||
post = Post.includes(:uploads).find_by_id(post_id)
|
||||
return if post.nil? || post.uploads.empty?
|
||||
|
||||
nsfw_evaluation = DiscourseAI::NSFW::Evaluation.new
|
||||
return if post.uploads.none? { |u| FileHelper.is_supported_image?(u.url) }
|
||||
|
||||
image_uploads = post.uploads.select { |upload| FileHelper.is_supported_image?(upload.url) }
|
||||
|
||||
results = image_uploads.map { |upload| nsfw_evaluation.perform(upload) }
|
||||
|
||||
DiscourseAI::FlagManager.new(post).flag! if results.any? { |r| r[:verdict] }
|
||||
DiscourseAI::PostClassification.new(DiscourseAI::NSFW::NSFWClassification.new).classify!(post)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAI
|
||||
module NSFW
|
||||
class NSFWClassification
|
||||
def type
|
||||
:nsfw
|
||||
end
|
||||
|
||||
def can_classify?(target)
|
||||
content_of(target).present?
|
||||
end
|
||||
|
||||
def should_flag_based_on?(classification_data)
|
||||
return false if !SiteSetting.ai_nsfw_flag_automatically
|
||||
|
||||
# Flat representation of each model classification of each upload.
|
||||
# Each element looks like [model_name, data]
|
||||
all_classifications = classification_data.values.flatten.map { |x| x.to_a.flatten }
|
||||
|
||||
all_classifications.any? { |(model_name, data)| send("#{model_name}_verdict?", data) }
|
||||
end
|
||||
|
||||
def request(target_to_classify)
|
||||
uploads_to_classify = content_of(target_to_classify)
|
||||
|
||||
uploads_to_classify.reduce({}) do |memo, upload|
|
||||
memo[upload.id] = available_models.reduce({}) do |per_model, model|
|
||||
per_model[model] = evaluate_with_model(model, upload)
|
||||
per_model
|
||||
end
|
||||
|
||||
memo
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def evaluate_with_model(model, upload)
|
||||
upload_url = Discourse.store.cdn_url(upload.url)
|
||||
upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/")
|
||||
|
||||
DiscourseAI::InferenceManager.perform!(
|
||||
"#{SiteSetting.ai_nsfw_inference_service_api_endpoint}/api/v1/classify",
|
||||
model,
|
||||
upload_url,
|
||||
SiteSetting.ai_nsfw_inference_service_api_key,
|
||||
)
|
||||
end
|
||||
|
||||
def available_models
|
||||
SiteSetting.ai_nsfw_models.split("|")
|
||||
end
|
||||
|
||||
def content_of(target_to_classify)
|
||||
target_to_classify.uploads.to_a.select { |u| FileHelper.is_supported_image?(u.url) }
|
||||
end
|
||||
|
||||
def opennsfw2_verdict?(clasification)
|
||||
clasification.values.first.to_i >= SiteSetting.ai_nsfw_flag_threshold_general
|
||||
end
|
||||
|
||||
def nsfw_detector_verdict?(classification)
|
||||
classification.each do |key, value|
|
||||
next if key == :neutral
|
||||
return true if value.to_i >= SiteSetting.send("ai_nsfw_flag_threshold_#{key}")
|
||||
end
|
||||
false
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -3,7 +3,7 @@ module DiscourseAI
|
|||
module Sentiment
|
||||
class EntryPoint
|
||||
def load_files
|
||||
require_relative "post_classifier"
|
||||
require_relative "sentiment_classification"
|
||||
require_relative "jobs/regular/post_sentiment_analysis"
|
||||
end
|
||||
|
||||
|
|
|
@ -9,7 +9,9 @@ module ::Jobs
|
|||
post = Post.find_by(id: post_id, post_type: Post.types[:regular])
|
||||
return if post&.raw.blank?
|
||||
|
||||
::DiscourseAI::Sentiment::PostClassifier.new.classify!(post)
|
||||
DiscourseAI::PostClassification.new(
|
||||
DiscourseAI::Sentiment::SentimentClassification.new,
|
||||
).classify!(post)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAI
|
||||
module Sentiment
|
||||
class PostClassifier
|
||||
def classify!(post)
|
||||
available_models.each do |model|
|
||||
classification = request_classification(post, model)
|
||||
|
||||
store_classification(post, model, classification)
|
||||
end
|
||||
end
|
||||
|
||||
def available_models
|
||||
SiteSetting.ai_sentiment_models.split("|")
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def request_classification(post, model)
|
||||
::DiscourseAI::InferenceManager.perform!(
|
||||
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify",
|
||||
model,
|
||||
content(post),
|
||||
SiteSetting.ai_sentiment_inference_service_api_key,
|
||||
)
|
||||
end
|
||||
|
||||
def content(post)
|
||||
post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw
|
||||
end
|
||||
|
||||
def store_classification(post, model, classification)
|
||||
PostCustomField.create!(
|
||||
post_id: post.id,
|
||||
name: "ai-sentiment-#{model}",
|
||||
value: { classification: classification }.to_json,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,52 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAI
|
||||
module Sentiment
|
||||
class SentimentClassification
|
||||
def type
|
||||
:sentiment
|
||||
end
|
||||
|
||||
def available_models
|
||||
SiteSetting.ai_sentiment_models.split("|")
|
||||
end
|
||||
|
||||
def can_classify?(target)
|
||||
content_of(target).present?
|
||||
end
|
||||
|
||||
def should_flag_based_on?(classification_data)
|
||||
# We don't flag based on sentiment classification.
|
||||
false
|
||||
end
|
||||
|
||||
def request(target_to_classify)
|
||||
target_content = content_of(target_to_classify)
|
||||
|
||||
available_models.reduce({}) do |memo, model|
|
||||
memo[model] = request_with(model, target_content)
|
||||
memo
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def request_with(model, content)
|
||||
::DiscourseAI::InferenceManager.perform!(
|
||||
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify",
|
||||
model,
|
||||
content,
|
||||
SiteSetting.ai_sentiment_inference_service_api_key,
|
||||
)
|
||||
end
|
||||
|
||||
def content_of(target_to_classify)
|
||||
if target_to_classify.post_number == 1
|
||||
"#{target_to_classify.topic.title}\n#{target_to_classify.raw}"
|
||||
else
|
||||
target_to_classify.raw
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,33 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAI
|
||||
module Toxicity
|
||||
class ChatMessageClassifier < Classifier
|
||||
private
|
||||
|
||||
def content(chat_message)
|
||||
chat_message.message
|
||||
end
|
||||
|
||||
def store_classification(chat_message, classification)
|
||||
PluginStore.set(
|
||||
"toxicity",
|
||||
"chat_message_#{chat_message.id}",
|
||||
{
|
||||
classification: classification,
|
||||
model: SiteSetting.ai_toxicity_inference_service_api_model,
|
||||
date: Time.now.utc,
|
||||
},
|
||||
)
|
||||
end
|
||||
|
||||
def flag!(chat_message, _toxic_labels)
|
||||
Chat::ChatReviewQueue.new.flag_message(
|
||||
chat_message,
|
||||
Guardian.new(flagger),
|
||||
ReviewableScore.types[:inappropriate],
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,66 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAI
|
||||
module Toxicity
|
||||
class Classifier
|
||||
CLASSIFICATION_LABELS = %w[
|
||||
toxicity
|
||||
severe_toxicity
|
||||
obscene
|
||||
identity_attack
|
||||
insult
|
||||
threat
|
||||
sexual_explicit
|
||||
]
|
||||
|
||||
def classify!(target)
|
||||
classification = request_classification(target)
|
||||
|
||||
store_classification(target, classification)
|
||||
|
||||
toxic_labels = filter_toxic_labels(classification)
|
||||
|
||||
flag!(target, toxic_labels) if should_flag_based_on?(toxic_labels)
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def flag!(_target, _toxic_labels)
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def store_classification(_target, _classification)
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def content(_target)
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def flagger
|
||||
Discourse.system_user
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def request_classification(target)
|
||||
::DiscourseAI::InferenceManager.perform!(
|
||||
"#{SiteSetting.ai_toxicity_inference_service_api_endpoint}/api/v1/classify",
|
||||
SiteSetting.ai_toxicity_inference_service_api_model,
|
||||
content(target),
|
||||
SiteSetting.ai_toxicity_inference_service_api_key,
|
||||
)
|
||||
end
|
||||
|
||||
def filter_toxic_labels(classification)
|
||||
CLASSIFICATION_LABELS.filter do |label|
|
||||
classification[label] >= SiteSetting.send("ai_toxicity_flag_threshold_#{label}")
|
||||
end
|
||||
end
|
||||
|
||||
def should_flag_based_on?(toxic_labels)
|
||||
SiteSetting.ai_toxicity_flag_automatically && toxic_labels.present?
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -4,9 +4,7 @@ module DiscourseAI
|
|||
class EntryPoint
|
||||
def load_files
|
||||
require_relative "scan_queue"
|
||||
require_relative "classifier"
|
||||
require_relative "post_classifier"
|
||||
require_relative "chat_message_classifier"
|
||||
require_relative "toxicity_classification"
|
||||
|
||||
require_relative "jobs/regular/toxicity_classify_post"
|
||||
require_relative "jobs/regular/toxicity_classify_chat_message"
|
||||
|
|
|
@ -10,7 +10,9 @@ module ::Jobs
|
|||
chat_message = ChatMessage.find_by(id: chat_message_id)
|
||||
return if chat_message&.message.blank?
|
||||
|
||||
::DiscourseAI::Toxicity::ChatMessageClassifier.new.classify!(chat_message)
|
||||
DiscourseAI::ChatMessageClassification.new(
|
||||
DiscourseAI::Toxicity::ToxicityClassification.new,
|
||||
).classify!(chat_message)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -11,7 +11,9 @@ module ::Jobs
|
|||
post = Post.find_by(id: post_id, post_type: Post.types[:regular])
|
||||
return if post&.raw.blank?
|
||||
|
||||
::DiscourseAI::Toxicity::PostClassifier.new.classify!(post)
|
||||
DiscourseAI::PostClassification.new(
|
||||
DiscourseAI::Toxicity::ToxicityClassification.new,
|
||||
).classify!(post)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAI
|
||||
module Toxicity
|
||||
class PostClassifier < Classifier
|
||||
private
|
||||
|
||||
def content(post)
|
||||
post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw
|
||||
end
|
||||
|
||||
def store_classification(post, classification)
|
||||
PostCustomField.create!(
|
||||
post_id: post.id,
|
||||
name: "toxicity",
|
||||
value: {
|
||||
classification: classification,
|
||||
model: SiteSetting.ai_toxicity_inference_service_api_model,
|
||||
}.to_json,
|
||||
)
|
||||
end
|
||||
|
||||
def flag!(target, toxic_labels)
|
||||
::DiscourseAI::FlagManager.new(target, reasons: toxic_labels).flag!
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,61 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAI
|
||||
module Toxicity
|
||||
class ToxicityClassification
|
||||
CLASSIFICATION_LABELS = %i[
|
||||
toxicity
|
||||
severe_toxicity
|
||||
obscene
|
||||
identity_attack
|
||||
insult
|
||||
threat
|
||||
sexual_explicit
|
||||
]
|
||||
|
||||
def type
|
||||
:toxicity
|
||||
end
|
||||
|
||||
def can_classify?(target)
|
||||
content_of(target).present?
|
||||
end
|
||||
|
||||
def should_flag_based_on?(classification_data)
|
||||
return false if !SiteSetting.ai_toxicity_flag_automatically
|
||||
|
||||
# We only use one model for this classification.
|
||||
# Classification_data looks like { model_name => classification }
|
||||
_model_used, data = classification_data.to_a.first
|
||||
|
||||
CLASSIFICATION_LABELS.any? do |label|
|
||||
data[label] >= SiteSetting.send("ai_toxicity_flag_threshold_#{label}")
|
||||
end
|
||||
end
|
||||
|
||||
def request(target_to_classify)
|
||||
data =
|
||||
::DiscourseAI::InferenceManager.perform!(
|
||||
"#{SiteSetting.ai_toxicity_inference_service_api_endpoint}/api/v1/classify",
|
||||
SiteSetting.ai_toxicity_inference_service_api_model,
|
||||
content_of(target_to_classify),
|
||||
SiteSetting.ai_toxicity_inference_service_api_key,
|
||||
)
|
||||
|
||||
{ SiteSetting.ai_toxicity_inference_service_api_model => data }
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def content_of(target_to_classify)
|
||||
return target_to_classify.message if target_to_classify.is_a?(ChatMessage)
|
||||
|
||||
if target_to_classify.post_number == 1
|
||||
"#{target_to_classify.topic.title}\n#{target_to_classify.raw}"
|
||||
else
|
||||
target_to_classify.raw
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,24 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAI
|
||||
class ChatMessageClassification < Classification
|
||||
private
|
||||
|
||||
def store_classification(chat_message, type, classification_data)
|
||||
PluginStore.set(
|
||||
type,
|
||||
"chat_message_#{chat_message.id}",
|
||||
classification_data.merge(date: Time.now.utc),
|
||||
)
|
||||
end
|
||||
|
||||
def flag!(chat_message, _toxic_labels)
|
||||
Chat::ChatReviewQueue.new.flag_message(
|
||||
chat_message,
|
||||
Guardian.new(flagger),
|
||||
ReviewableScore.types[:inappropriate],
|
||||
queue_for_review: true,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,39 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAI
|
||||
class Classification
|
||||
def initialize(classification_model)
|
||||
@classification_model = classification_model
|
||||
end
|
||||
|
||||
def classify!(target)
|
||||
return :cannot_classify unless classification_model.can_classify?(target)
|
||||
|
||||
classification_model
|
||||
.request(target)
|
||||
.tap do |classification|
|
||||
store_classification(target, classification_model.type, classification)
|
||||
|
||||
if classification_model.should_flag_based_on?(classification)
|
||||
flag!(target, classification)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
attr_reader :classification_model
|
||||
|
||||
def flag!(_target, _classification)
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def store_classification(_target, _classification)
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def flagger
|
||||
Discourse.system_user
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,27 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAI
|
||||
class FlagManager
|
||||
DEFAULT_FLAGGER = Discourse.system_user
|
||||
DEFAULT_REASON = "discourse-ai"
|
||||
|
||||
def initialize(object, flagger: DEFAULT_FLAGGER, type: :inappropriate, reasons: DEFAULT_REASON)
|
||||
@flagger = flagger
|
||||
@object = object
|
||||
@type = type
|
||||
@reasons = reasons
|
||||
end
|
||||
|
||||
def flag!
|
||||
PostActionCreator.new(
|
||||
@flagger,
|
||||
@object,
|
||||
PostActionType.types[:inappropriate],
|
||||
reason: @reasons,
|
||||
queue_for_review: true,
|
||||
).perform
|
||||
|
||||
@object.publish_change_to_clients! :acted
|
||||
end
|
||||
end
|
||||
end
|
|
@ -11,7 +11,7 @@ module ::DiscourseAI
|
|||
|
||||
raise Net::HTTPBadResponse unless response.status == 200
|
||||
|
||||
JSON.parse(response.body)
|
||||
JSON.parse(response.body, symbolize_names: true)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAI
|
||||
class PostClassification < Classification
|
||||
private
|
||||
|
||||
def store_classification(post, type, classification_data)
|
||||
PostCustomField.create!(post_id: post.id, name: type, value: classification_data.to_json)
|
||||
end
|
||||
|
||||
def flag!(post, classification_type)
|
||||
PostActionCreator.new(
|
||||
flagger,
|
||||
post,
|
||||
PostActionType.types[:inappropriate],
|
||||
reason: classification_type,
|
||||
queue_for_review: true,
|
||||
).perform
|
||||
|
||||
post.publish_change_to_clients! :acted
|
||||
end
|
||||
end
|
||||
end
|
|
@ -15,7 +15,9 @@ after_initialize do
|
|||
end
|
||||
|
||||
require_relative "lib/shared/inference_manager"
|
||||
require_relative "lib/shared/flag_manager"
|
||||
require_relative "lib/shared/classification"
|
||||
require_relative "lib/shared/post_classification"
|
||||
require_relative "lib/shared/chat_message_classification"
|
||||
|
||||
require_relative "lib/modules/nsfw/entry_point"
|
||||
require_relative "lib/modules/toxicity/entry_point"
|
||||
|
|
|
@ -1,49 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "rails_helper"
|
||||
require_relative "../../../support/nsfw_inference_stubs"
|
||||
|
||||
describe DiscourseAI::NSFW::Evaluation do
|
||||
before do
|
||||
SiteSetting.ai_nsfw_inference_service_api_endpoint = "http://test.com"
|
||||
SiteSetting.ai_nsfw_detection_enabled = true
|
||||
end
|
||||
|
||||
fab!(:image) { Fabricate(:s3_image_upload) }
|
||||
|
||||
let(:available_models) { SiteSetting.ai_nsfw_models.split("|") }
|
||||
|
||||
describe "perform" do
|
||||
context "when we determine content is NSFW" do
|
||||
before { NSFWInferenceStubs.positive(image) }
|
||||
|
||||
it "returns true alongside the evaluation" do
|
||||
result = subject.perform(image)
|
||||
|
||||
expect(result[:verdict]).to eq(true)
|
||||
|
||||
available_models.each do |model|
|
||||
expect(result.dig(:evaluation, model.to_sym)).to eq(
|
||||
NSFWInferenceStubs.positive_result(model),
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
context "when we determine content is safe" do
|
||||
before { NSFWInferenceStubs.negative(image) }
|
||||
|
||||
it "returns false alongside the evaluation" do
|
||||
result = subject.perform(image)
|
||||
|
||||
expect(result[:verdict]).to eq(false)
|
||||
|
||||
available_models.each do |model|
|
||||
expect(result.dig(:evaluation, model.to_sym)).to eq(
|
||||
NSFWInferenceStubs.negative_result(model),
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -76,25 +76,5 @@ describe Jobs::EvaluatePostUploads do
|
|||
end
|
||||
end
|
||||
end
|
||||
|
||||
context "when the post has multiple uploads" do
|
||||
fab!(:upload_2) { Fabricate(:upload) }
|
||||
|
||||
before { post.uploads << upload_2 }
|
||||
|
||||
context "when we conclude content is NSFW" do
|
||||
before do
|
||||
NSFWInferenceStubs.negative(upload_1)
|
||||
NSFWInferenceStubs.positive(upload_2)
|
||||
end
|
||||
|
||||
it "flags and hides the post if at least one upload is considered NSFW" do
|
||||
subject.execute({ post_id: post.id })
|
||||
|
||||
expect(ReviewableFlaggedPost.where(target: post).count).to eq(1)
|
||||
expect(post.reload.hidden?).to eq(true)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "rails_helper"
|
||||
require_relative "../../../support/nsfw_inference_stubs"
|
||||
|
||||
describe DiscourseAI::NSFW::NSFWClassification do
|
||||
before { SiteSetting.ai_nsfw_inference_service_api_endpoint = "http://test.com" }
|
||||
|
||||
let(:available_models) { SiteSetting.ai_nsfw_models.split("|") }
|
||||
|
||||
describe "#request" do
|
||||
fab!(:upload_1) { Fabricate(:s3_image_upload) }
|
||||
fab!(:post) { Fabricate(:post, uploads: [upload_1]) }
|
||||
|
||||
def assert_correctly_classified(upload, results, expected)
|
||||
available_models.each do |model|
|
||||
model_result = results.dig(upload.id, model)
|
||||
|
||||
expect(model_result).to eq(expected[model])
|
||||
end
|
||||
end
|
||||
|
||||
def build_expected_classification(positive: true)
|
||||
available_models.reduce({}) do |memo, model|
|
||||
model_expected =
|
||||
if positive
|
||||
NSFWInferenceStubs.positive_result(model)
|
||||
else
|
||||
NSFWInferenceStubs.negative_result(model)
|
||||
end
|
||||
|
||||
memo[model] = model_expected
|
||||
memo
|
||||
end
|
||||
end
|
||||
|
||||
context "when the target has one upload" do
|
||||
it "returns the classification and the model used for it" do
|
||||
NSFWInferenceStubs.positive(upload_1)
|
||||
expected = build_expected_classification
|
||||
|
||||
classification = subject.request(post)
|
||||
|
||||
assert_correctly_classified(upload_1, classification, expected)
|
||||
end
|
||||
|
||||
context "when the target has multiple uploads" do
|
||||
fab!(:upload_2) { Fabricate(:upload) }
|
||||
|
||||
before { post.uploads << upload_2 }
|
||||
|
||||
it "returns a classification for each one" do
|
||||
NSFWInferenceStubs.positive(upload_1)
|
||||
NSFWInferenceStubs.negative(upload_2)
|
||||
expected_upload_1 = build_expected_classification
|
||||
expected_upload_2 = build_expected_classification(positive: false)
|
||||
|
||||
classification = subject.request(post)
|
||||
|
||||
assert_correctly_classified(upload_1, classification, expected_upload_1)
|
||||
assert_correctly_classified(upload_2, classification, expected_upload_2)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "#should_flag_based_on?" do
|
||||
before { SiteSetting.ai_nsfw_flag_automatically = true }
|
||||
|
||||
let(:positive_classification) do
|
||||
{
|
||||
1 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } },
|
||||
2 => available_models.map { |m| { m => NSFWInferenceStubs.positive_result(m) } },
|
||||
}
|
||||
end
|
||||
|
||||
let(:negative_classification) do
|
||||
{
|
||||
1 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } },
|
||||
2 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } },
|
||||
}
|
||||
end
|
||||
|
||||
it "returns false when NSFW flaggin is disabled" do
|
||||
SiteSetting.ai_nsfw_flag_automatically = false
|
||||
|
||||
should_flag = subject.should_flag_based_on?(positive_classification)
|
||||
|
||||
expect(should_flag).to eq(false)
|
||||
end
|
||||
|
||||
it "returns true if the response is NSFW based on our thresholds" do
|
||||
should_flag = subject.should_flag_based_on?(positive_classification)
|
||||
|
||||
expect(should_flag).to eq(true)
|
||||
end
|
||||
|
||||
it "returns false if the response is safe based on our thresholds" do
|
||||
should_flag = subject.should_flag_based_on?(negative_classification)
|
||||
|
||||
expect(should_flag).to eq(false)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,26 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "rails_helper"
|
||||
require_relative "../../../support/sentiment_inference_stubs"
|
||||
|
||||
describe DiscourseAI::Sentiment::PostClassifier do
|
||||
fab!(:post) { Fabricate(:post) }
|
||||
|
||||
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
|
||||
|
||||
describe "#classify!" do
|
||||
it "stores each model classification in a post custom field" do
|
||||
SentimentInferenceStubs.stub_classification(post)
|
||||
|
||||
subject.classify!(post)
|
||||
|
||||
subject.available_models.each do |model|
|
||||
stored_classification = PostCustomField.find_by(post: post, name: "ai-sentiment-#{model}")
|
||||
expect(stored_classification).to be_present
|
||||
expect(stored_classification.value).to eq(
|
||||
{ classification: SentimentInferenceStubs.model_response(model) }.to_json,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,22 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "rails_helper"
|
||||
require_relative "../../../support/sentiment_inference_stubs"
|
||||
|
||||
describe DiscourseAI::Sentiment::SentimentClassification do
|
||||
describe "#request" do
|
||||
fab!(:target) { Fabricate(:post) }
|
||||
|
||||
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
|
||||
|
||||
it "returns the classification and the model used for it" do
|
||||
SentimentInferenceStubs.stub_classification(target)
|
||||
|
||||
result = subject.request(target)
|
||||
|
||||
subject.available_models.each do |model|
|
||||
expect(result[model]).to eq(SentimentInferenceStubs.model_response(model))
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,48 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "rails_helper"
|
||||
require_relative "../../../support/toxicity_inference_stubs"
|
||||
|
||||
describe DiscourseAI::Toxicity::ChatMessageClassifier do
|
||||
before { SiteSetting.ai_toxicity_flag_automatically = true }
|
||||
|
||||
fab!(:chat_message) { Fabricate(:chat_message) }
|
||||
|
||||
describe "#classify!" do
|
||||
it "creates a reviewable when the post is classified as toxic" do
|
||||
ToxicityInferenceStubs.stub_chat_message_classification(chat_message, toxic: true)
|
||||
|
||||
subject.classify!(chat_message)
|
||||
|
||||
expect(ReviewableChatMessage.where(target: chat_message).count).to eq(1)
|
||||
end
|
||||
|
||||
it "doesn't create a reviewable if the post is not classified as toxic" do
|
||||
ToxicityInferenceStubs.stub_chat_message_classification(chat_message, toxic: false)
|
||||
|
||||
subject.classify!(chat_message)
|
||||
|
||||
expect(ReviewableChatMessage.where(target: chat_message).count).to be_zero
|
||||
end
|
||||
|
||||
it "doesn't create a reviewable if flagging is disabled" do
|
||||
SiteSetting.ai_toxicity_flag_automatically = false
|
||||
ToxicityInferenceStubs.stub_chat_message_classification(chat_message, toxic: true)
|
||||
|
||||
subject.classify!(chat_message)
|
||||
|
||||
expect(ReviewableChatMessage.where(target: chat_message).count).to be_zero
|
||||
end
|
||||
|
||||
it "stores the classification in a custom field" do
|
||||
ToxicityInferenceStubs.stub_chat_message_classification(chat_message, toxic: false)
|
||||
|
||||
subject.classify!(chat_message)
|
||||
store_row = PluginStore.get("toxicity", "chat_message_#{chat_message.id}").deep_symbolize_keys
|
||||
|
||||
expect(store_row[:classification]).to eq(ToxicityInferenceStubs.civilized_response)
|
||||
expect(store_row[:model]).to eq(SiteSetting.ai_toxicity_inference_service_api_model)
|
||||
expect(store_row[:date]).to be_present
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,51 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "rails_helper"
|
||||
require_relative "../../../support/toxicity_inference_stubs"
|
||||
|
||||
describe DiscourseAI::Toxicity::PostClassifier do
|
||||
before { SiteSetting.ai_toxicity_flag_automatically = true }
|
||||
|
||||
fab!(:post) { Fabricate(:post) }
|
||||
|
||||
describe "#classify!" do
|
||||
it "creates a reviewable when the post is classified as toxic" do
|
||||
ToxicityInferenceStubs.stub_post_classification(post, toxic: true)
|
||||
|
||||
subject.classify!(post)
|
||||
|
||||
expect(ReviewableFlaggedPost.where(target: post).count).to eq(1)
|
||||
end
|
||||
|
||||
it "doesn't create a reviewable if the post is not classified as toxic" do
|
||||
ToxicityInferenceStubs.stub_post_classification(post, toxic: false)
|
||||
|
||||
subject.classify!(post)
|
||||
|
||||
expect(ReviewableFlaggedPost.where(target: post).count).to be_zero
|
||||
end
|
||||
|
||||
it "doesn't create a reviewable if flagging is disabled" do
|
||||
SiteSetting.ai_toxicity_flag_automatically = false
|
||||
ToxicityInferenceStubs.stub_post_classification(post, toxic: true)
|
||||
|
||||
subject.classify!(post)
|
||||
|
||||
expect(ReviewableFlaggedPost.where(target: post).count).to be_zero
|
||||
end
|
||||
|
||||
it "stores the classification in a custom field" do
|
||||
ToxicityInferenceStubs.stub_post_classification(post, toxic: false)
|
||||
|
||||
subject.classify!(post)
|
||||
custom_field = PostCustomField.find_by(post: post, name: "toxicity")
|
||||
|
||||
expect(custom_field.value).to eq(
|
||||
{
|
||||
classification: ToxicityInferenceStubs.civilized_response,
|
||||
model: SiteSetting.ai_toxicity_inference_service_api_model,
|
||||
}.to_json,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,56 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "rails_helper"
|
||||
require_relative "../../../support/toxicity_inference_stubs"
|
||||
|
||||
describe DiscourseAI::Toxicity::ToxicityClassification do
|
||||
describe "#request" do
|
||||
fab!(:target) { Fabricate(:post) }
|
||||
|
||||
it "returns the classification and the model used for it" do
|
||||
ToxicityInferenceStubs.stub_post_classification(target, toxic: false)
|
||||
|
||||
result = subject.request(target)
|
||||
|
||||
expect(result[SiteSetting.ai_toxicity_inference_service_api_model]).to eq(
|
||||
ToxicityInferenceStubs.civilized_response,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#should_flag_based_on?" do
|
||||
before { SiteSetting.ai_toxicity_flag_automatically = true }
|
||||
|
||||
let(:toxic_response) do
|
||||
{
|
||||
SiteSetting.ai_toxicity_inference_service_api_model =>
|
||||
ToxicityInferenceStubs.toxic_response,
|
||||
}
|
||||
end
|
||||
|
||||
it "returns false when toxicity flaggin is disabled" do
|
||||
SiteSetting.ai_toxicity_flag_automatically = false
|
||||
|
||||
should_flag = subject.should_flag_based_on?(toxic_response)
|
||||
|
||||
expect(should_flag).to eq(false)
|
||||
end
|
||||
|
||||
it "returns true if the response is toxic based on our thresholds" do
|
||||
should_flag = subject.should_flag_based_on?(toxic_response)
|
||||
|
||||
expect(should_flag).to eq(true)
|
||||
end
|
||||
|
||||
it "returns false if the response is civilized based on our thresholds" do
|
||||
civilized_response = {
|
||||
SiteSetting.ai_toxicity_inference_service_api_model =>
|
||||
ToxicityInferenceStubs.civilized_response,
|
||||
}
|
||||
|
||||
should_flag = subject.should_flag_based_on?(civilized_response)
|
||||
|
||||
expect(should_flag).to eq(false)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,42 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "rails_helper"
|
||||
require_relative "../support/toxicity_inference_stubs"
|
||||
|
||||
describe DiscourseAI::ChatMessageClassification do
|
||||
fab!(:chat_message) { Fabricate(:chat_message) }
|
||||
|
||||
let(:model) { DiscourseAI::Toxicity::ToxicityClassification.new }
|
||||
let(:classification) { described_class.new(model) }
|
||||
|
||||
describe "#classify!" do
|
||||
before { ToxicityInferenceStubs.stub_chat_message_classification(chat_message, toxic: true) }
|
||||
|
||||
it "stores the model classification data in a custom field" do
|
||||
classification.classify!(chat_message)
|
||||
store_row = PluginStore.get("toxicity", "chat_message_#{chat_message.id}")
|
||||
|
||||
classified_data =
|
||||
store_row[SiteSetting.ai_toxicity_inference_service_api_model].symbolize_keys
|
||||
|
||||
expect(classified_data).to eq(ToxicityInferenceStubs.toxic_response)
|
||||
expect(store_row[:date]).to be_present
|
||||
end
|
||||
|
||||
it "flags the message when the model decides we should" do
|
||||
SiteSetting.ai_toxicity_flag_automatically = true
|
||||
|
||||
classification.classify!(chat_message)
|
||||
|
||||
expect(ReviewableChatMessage.where(target: chat_message).count).to eq(1)
|
||||
end
|
||||
|
||||
it "doesn't flags the message if the model decides we shouldn't" do
|
||||
SiteSetting.ai_toxicity_flag_automatically = false
|
||||
|
||||
classification.classify!(chat_message)
|
||||
|
||||
expect(ReviewableChatMessage.where(target: chat_message).count).to be_zero
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,44 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "rails_helper"
|
||||
require_relative "../support/toxicity_inference_stubs"
|
||||
|
||||
describe DiscourseAI::PostClassification do
|
||||
fab!(:post) { Fabricate(:post) }
|
||||
|
||||
let(:model) { DiscourseAI::Toxicity::ToxicityClassification.new }
|
||||
let(:classification) { described_class.new(model) }
|
||||
|
||||
describe "#classify!" do
|
||||
before { ToxicityInferenceStubs.stub_post_classification(post, toxic: true) }
|
||||
|
||||
it "stores the model classification data in a custom field" do
|
||||
classification.classify!(post)
|
||||
custom_field = PostCustomField.find_by(post: post, name: model.type)
|
||||
|
||||
expect(custom_field.value).to eq(
|
||||
{
|
||||
SiteSetting.ai_toxicity_inference_service_api_model =>
|
||||
ToxicityInferenceStubs.toxic_response,
|
||||
}.to_json,
|
||||
)
|
||||
end
|
||||
|
||||
it "flags the message and hides the post when the model decides we should" do
|
||||
SiteSetting.ai_toxicity_flag_automatically = true
|
||||
|
||||
classification.classify!(post)
|
||||
|
||||
expect(ReviewableFlaggedPost.where(target: post).count).to eq(1)
|
||||
expect(post.reload.hidden?).to eq(true)
|
||||
end
|
||||
|
||||
it "doesn't flags the message if the model decides we shouldn't" do
|
||||
SiteSetting.ai_toxicity_flag_automatically = false
|
||||
|
||||
classification.classify!(post)
|
||||
|
||||
expect(ReviewableFlaggedPost.where(target: post).count).to be_zero
|
||||
end
|
||||
end
|
||||
end
|
|
@ -15,7 +15,7 @@ class SentimentInferenceStubs
|
|||
def stub_classification(post)
|
||||
content = post.post_number == 1 ? "#{post.topic.title}\n#{post.raw}" : post.raw
|
||||
|
||||
DiscourseAI::Sentiment::PostClassifier.new.available_models.each do |model|
|
||||
DiscourseAI::Sentiment::SentimentClassification.new.available_models.each do |model|
|
||||
WebMock
|
||||
.stub_request(:post, endpoint)
|
||||
.with(body: JSON.dump(model: model, content: content))
|
||||
|
|
Loading…
Reference in New Issue