DEV: Dedicated table for saving classification results (#1)
This commit is contained in:
parent
5f9597474c
commit
b9a650fde4
|
@ -0,0 +1,23 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class ClassificationResult < ActiveRecord::Base
|
||||
belongs_to :target, polymorphic: true
|
||||
end
|
||||
|
||||
# == Schema Information
|
||||
#
|
||||
# Table name: classification_results
|
||||
#
|
||||
# id :bigint not null, primary key
|
||||
# model_used :string
|
||||
# classification_type :string
|
||||
# target_id :integer
|
||||
# target_type :string
|
||||
# classification :jsonb
|
||||
# created_at :datetime not null
|
||||
# updated_at :datetime not null
|
||||
#
|
||||
# Indexes
|
||||
#
|
||||
# unique_classification_target_per_type (target_id,target_type,model_used) UNIQUE
|
||||
#
|
|
@ -0,0 +1,19 @@
|
|||
# frozen_string_literal: true
|
||||
class CreateClassificationResultsTable < ActiveRecord::Migration[7.0]
|
||||
def change
|
||||
create_table :classification_results do |t|
|
||||
t.string :model_used, null: true
|
||||
t.string :classification_type, null: true
|
||||
t.integer :target_id, null: true
|
||||
t.string :target_type, null: true
|
||||
|
||||
t.jsonb :classification, null: true
|
||||
t.timestamps
|
||||
end
|
||||
|
||||
add_index :classification_results,
|
||||
%i[target_id target_type model_used],
|
||||
unique: true,
|
||||
name: "unique_classification_target_per_type"
|
||||
end
|
||||
end
|
|
@ -14,20 +14,23 @@ module DiscourseAI
|
|||
def should_flag_based_on?(classification_data)
|
||||
return false if !SiteSetting.ai_nsfw_flag_automatically
|
||||
|
||||
# Flat representation of each model classification of each upload.
|
||||
# Each element looks like [model_name, data]
|
||||
all_classifications = classification_data.values.flatten.map { |x| x.to_a.flatten }
|
||||
|
||||
all_classifications.any? { |(model_name, data)| send("#{model_name}_verdict?", data) }
|
||||
classification_data.any? do |model_name, classifications|
|
||||
classifications.values.any? do |data|
|
||||
send("#{model_name}_verdict?", data.except(:neutral, :target_classified_type))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def request(target_to_classify)
|
||||
uploads_to_classify = content_of(target_to_classify)
|
||||
|
||||
uploads_to_classify.reduce({}) do |memo, upload|
|
||||
memo[upload.id] = available_models.reduce({}) do |per_model, model|
|
||||
per_model[model] = evaluate_with_model(model, upload)
|
||||
per_model
|
||||
available_models.reduce({}) do |memo, model|
|
||||
memo[model] = uploads_to_classify.reduce({}) do |upl_memo, upload|
|
||||
upl_memo[upload.id] = evaluate_with_model(model, upload).merge(
|
||||
target_classified_type: upload.class.name,
|
||||
)
|
||||
|
||||
upl_memo
|
||||
end
|
||||
|
||||
memo
|
||||
|
@ -61,11 +64,9 @@ module DiscourseAI
|
|||
end
|
||||
|
||||
def nsfw_detector_verdict?(classification)
|
||||
classification.each do |key, value|
|
||||
next if key == :neutral
|
||||
return true if value.to_i >= SiteSetting.send("ai_nsfw_flag_threshold_#{key}")
|
||||
classification.any? do |key, value|
|
||||
value.to_i >= SiteSetting.send("ai_nsfw_flag_threshold_#{key}")
|
||||
end
|
||||
false
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -42,11 +42,15 @@ module DiscourseAI
|
|||
SiteSetting.ai_toxicity_inference_service_api_key,
|
||||
)
|
||||
|
||||
{ SiteSetting.ai_toxicity_inference_service_api_model => data }
|
||||
{ available_model => data }
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def available_model
|
||||
SiteSetting.ai_toxicity_inference_service_api_model
|
||||
end
|
||||
|
||||
def content_of(target_to_classify)
|
||||
return target_to_classify.message if target_to_classify.is_a?(ChatMessage)
|
||||
|
||||
|
|
|
@ -4,14 +4,6 @@ module ::DiscourseAI
|
|||
class ChatMessageClassification < Classification
|
||||
private
|
||||
|
||||
def store_classification(chat_message, type, classification_data)
|
||||
PluginStore.set(
|
||||
type,
|
||||
"chat_message_#{chat_message.id}",
|
||||
classification_data.merge(date: Time.now.utc),
|
||||
)
|
||||
end
|
||||
|
||||
def flag!(chat_message, _toxic_labels)
|
||||
Chat::ChatReviewQueue.new.flag_message(
|
||||
chat_message,
|
||||
|
|
|
@ -12,7 +12,7 @@ module ::DiscourseAI
|
|||
classification_model
|
||||
.request(target)
|
||||
.tap do |classification|
|
||||
store_classification(target, classification_model.type, classification)
|
||||
store_classification(target, classification)
|
||||
|
||||
if classification_model.should_flag_based_on?(classification)
|
||||
flag!(target, classification)
|
||||
|
@ -28,8 +28,25 @@ module ::DiscourseAI
|
|||
raise NotImplemented
|
||||
end
|
||||
|
||||
def store_classification(_target, _classification)
|
||||
raise NotImplemented
|
||||
def store_classification(target, classification)
|
||||
attrs =
|
||||
classification.map do |model_name, classifications|
|
||||
{
|
||||
model_used: model_name,
|
||||
target_id: target.id,
|
||||
target_type: target.class.name,
|
||||
classification_type: classification_model.type,
|
||||
classification: classifications,
|
||||
updated_at: DateTime.now,
|
||||
created_at: DateTime.now,
|
||||
}
|
||||
end
|
||||
|
||||
ClassificationResult.upsert_all(
|
||||
attrs,
|
||||
unique_by: %i[target_id target_type model_used],
|
||||
update_only: %i[classification],
|
||||
)
|
||||
end
|
||||
|
||||
def flagger
|
||||
|
|
|
@ -4,10 +4,6 @@ module ::DiscourseAI
|
|||
class PostClassification < Classification
|
||||
private
|
||||
|
||||
def store_classification(post, type, classification_data)
|
||||
PostCustomField.create!(post_id: post.id, name: type, value: classification_data.to_json)
|
||||
end
|
||||
|
||||
def flag!(post, classification_type)
|
||||
PostActionCreator.new(
|
||||
flagger,
|
||||
|
|
|
@ -14,6 +14,8 @@ after_initialize do
|
|||
PLUGIN_NAME = "discourse-ai"
|
||||
end
|
||||
|
||||
require_relative "app/models/classification_result"
|
||||
|
||||
require_relative "lib/shared/inference_manager"
|
||||
require_relative "lib/shared/classification"
|
||||
require_relative "lib/shared/post_classification"
|
||||
|
|
|
@ -8,19 +8,15 @@ describe DiscourseAI::NSFW::NSFWClassification do
|
|||
|
||||
let(:available_models) { SiteSetting.ai_nsfw_models.split("|") }
|
||||
|
||||
describe "#request" do
|
||||
fab!(:upload_1) { Fabricate(:s3_image_upload) }
|
||||
fab!(:post) { Fabricate(:post, uploads: [upload_1]) }
|
||||
|
||||
def assert_correctly_classified(upload, results, expected)
|
||||
available_models.each do |model|
|
||||
model_result = results.dig(upload.id, model)
|
||||
|
||||
expect(model_result).to eq(expected[model])
|
||||
end
|
||||
describe "#request" do
|
||||
def assert_correctly_classified(results, expected)
|
||||
available_models.each { |model| expect(results[model]).to eq(expected[model]) }
|
||||
end
|
||||
|
||||
def build_expected_classification(positive: true)
|
||||
def build_expected_classification(target, positive: true)
|
||||
available_models.reduce({}) do |memo, model|
|
||||
model_expected =
|
||||
if positive
|
||||
|
@ -29,7 +25,9 @@ describe DiscourseAI::NSFW::NSFWClassification do
|
|||
NSFWInferenceStubs.negative_result(model)
|
||||
end
|
||||
|
||||
memo[model] = model_expected
|
||||
memo[model] = {
|
||||
target.id => model_expected.merge(target_classified_type: target.class.name),
|
||||
}
|
||||
memo
|
||||
end
|
||||
end
|
||||
|
@ -37,11 +35,11 @@ describe DiscourseAI::NSFW::NSFWClassification do
|
|||
context "when the target has one upload" do
|
||||
it "returns the classification and the model used for it" do
|
||||
NSFWInferenceStubs.positive(upload_1)
|
||||
expected = build_expected_classification
|
||||
expected = build_expected_classification(upload_1)
|
||||
|
||||
classification = subject.request(post)
|
||||
|
||||
assert_correctly_classified(upload_1, classification, expected)
|
||||
assert_correctly_classified(classification, expected)
|
||||
end
|
||||
|
||||
context "when the target has multiple uploads" do
|
||||
|
@ -52,13 +50,14 @@ describe DiscourseAI::NSFW::NSFWClassification do
|
|||
it "returns a classification for each one" do
|
||||
NSFWInferenceStubs.positive(upload_1)
|
||||
NSFWInferenceStubs.negative(upload_2)
|
||||
expected_upload_1 = build_expected_classification
|
||||
expected_upload_2 = build_expected_classification(positive: false)
|
||||
expected_classification = build_expected_classification(upload_1)
|
||||
expected_classification.deep_merge!(
|
||||
build_expected_classification(upload_2, positive: false),
|
||||
)
|
||||
|
||||
classification = subject.request(post)
|
||||
|
||||
assert_correctly_classified(upload_1, classification, expected_upload_1)
|
||||
assert_correctly_classified(upload_2, classification, expected_upload_2)
|
||||
assert_correctly_classified(classification, expected_classification)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -69,15 +68,23 @@ describe DiscourseAI::NSFW::NSFWClassification do
|
|||
|
||||
let(:positive_classification) do
|
||||
{
|
||||
1 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } },
|
||||
2 => available_models.map { |m| { m => NSFWInferenceStubs.positive_result(m) } },
|
||||
"opennsfw2" => {
|
||||
1 => NSFWInferenceStubs.negative_result("opennsfw2"),
|
||||
2 => NSFWInferenceStubs.positive_result("opennsfw2"),
|
||||
},
|
||||
"nsfw_detector" => {
|
||||
1 => NSFWInferenceStubs.negative_result("nsfw_detector"),
|
||||
2 => NSFWInferenceStubs.positive_result("nsfw_detector"),
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
let(:negative_classification) do
|
||||
{
|
||||
1 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } },
|
||||
2 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } },
|
||||
"opennsfw2" => {
|
||||
1 => NSFWInferenceStubs.negative_result("opennsfw2"),
|
||||
2 => NSFWInferenceStubs.negative_result("opennsfw2"),
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
|
|
|
@ -18,19 +18,19 @@ describe Jobs::PostSentimentAnalysis do
|
|||
|
||||
subject.execute({ post_id: post.id })
|
||||
|
||||
expect(PostCustomField.where(post: post).count).to be_zero
|
||||
expect(ClassificationResult.where(target: post).count).to be_zero
|
||||
end
|
||||
|
||||
it "does nothing if there's no arg called post_id" do
|
||||
subject.execute({})
|
||||
|
||||
expect(PostCustomField.where(post: post).count).to be_zero
|
||||
expect(ClassificationResult.where(target: post).count).to be_zero
|
||||
end
|
||||
|
||||
it "does nothing if no post match the given id" do
|
||||
subject.execute({ post_id: nil })
|
||||
|
||||
expect(PostCustomField.where(post: post).count).to be_zero
|
||||
expect(ClassificationResult.where(target: post).count).to be_zero
|
||||
end
|
||||
|
||||
it "does nothing if the post content is blank" do
|
||||
|
@ -38,7 +38,7 @@ describe Jobs::PostSentimentAnalysis do
|
|||
|
||||
subject.execute({ post_id: post.id })
|
||||
|
||||
expect(PostCustomField.where(post: post).count).to be_zero
|
||||
expect(ClassificationResult.where(target: post).count).to be_zero
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -48,7 +48,7 @@ describe Jobs::PostSentimentAnalysis do
|
|||
|
||||
subject.execute({ post_id: post.id })
|
||||
|
||||
expect(PostCustomField.where(post: post).count).to eq(expected_analysis)
|
||||
expect(ClassificationResult.where(target: post).count).to eq(expected_analysis)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -4,9 +4,9 @@ require "rails_helper"
|
|||
require_relative "../../../support/sentiment_inference_stubs"
|
||||
|
||||
describe DiscourseAI::Sentiment::SentimentClassification do
|
||||
describe "#request" do
|
||||
fab!(:target) { Fabricate(:post) }
|
||||
|
||||
describe "#request" do
|
||||
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
|
||||
|
||||
it "returns the classification and the model used for it" do
|
||||
|
|
|
@ -4,9 +4,9 @@ require "rails_helper"
|
|||
require_relative "../../../support/toxicity_inference_stubs"
|
||||
|
||||
describe DiscourseAI::Toxicity::ToxicityClassification do
|
||||
describe "#request" do
|
||||
fab!(:target) { Fabricate(:post) }
|
||||
|
||||
describe "#request" do
|
||||
it "returns the classification and the model used for it" do
|
||||
ToxicityInferenceStubs.stub_post_classification(target, toxic: false)
|
||||
|
||||
|
|
|
@ -12,15 +12,14 @@ describe DiscourseAI::ChatMessageClassification do
|
|||
describe "#classify!" do
|
||||
before { ToxicityInferenceStubs.stub_chat_message_classification(chat_message, toxic: true) }
|
||||
|
||||
it "stores the model classification data in a custom field" do
|
||||
it "stores the model classification data" do
|
||||
classification.classify!(chat_message)
|
||||
store_row = PluginStore.get("toxicity", "chat_message_#{chat_message.id}")
|
||||
|
||||
classified_data =
|
||||
store_row[SiteSetting.ai_toxicity_inference_service_api_model].symbolize_keys
|
||||
result = ClassificationResult.find_by(target: chat_message, classification_type: model.type)
|
||||
|
||||
expect(classified_data).to eq(ToxicityInferenceStubs.toxic_response)
|
||||
expect(store_row[:date]).to be_present
|
||||
classification = result.classification.symbolize_keys
|
||||
|
||||
expect(classification).to eq(ToxicityInferenceStubs.toxic_response)
|
||||
end
|
||||
|
||||
it "flags the message when the model decides we should" do
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "rails_helper"
|
||||
require_relative "../support/sentiment_inference_stubs"
|
||||
|
||||
describe DiscourseAI::Classification do
|
||||
describe "#classify!" do
|
||||
describe "saving the classification result" do
|
||||
let(:classification_raw_result) do
|
||||
model
|
||||
.available_models
|
||||
.reduce({}) do |memo, model_name|
|
||||
memo[model_name] = SentimentInferenceStubs.model_response(model_name)
|
||||
memo
|
||||
end
|
||||
end
|
||||
|
||||
let(:model) { DiscourseAI::Sentiment::SentimentClassification.new }
|
||||
let(:classification) { DiscourseAI::PostClassification.new(model) }
|
||||
fab!(:target) { Fabricate(:post) }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com"
|
||||
SentimentInferenceStubs.stub_classification(target)
|
||||
end
|
||||
|
||||
it "stores one result per model used" do
|
||||
classification.classify!(target)
|
||||
|
||||
stored_results = ClassificationResult.where(target: target)
|
||||
expect(stored_results.length).to eq(model.available_models.length)
|
||||
|
||||
model.available_models.each do |model_name|
|
||||
result = stored_results.detect { |c| c.model_used == model_name }
|
||||
|
||||
expect(result.classification_type).to eq(model.type.to_s)
|
||||
expect(result.created_at).to be_present
|
||||
expect(result.updated_at).to be_present
|
||||
|
||||
expected_classification = SentimentInferenceStubs.model_response(model)
|
||||
|
||||
expect(result.classification.deep_symbolize_keys).to eq(expected_classification)
|
||||
end
|
||||
end
|
||||
|
||||
it "updates an existing classification result" do
|
||||
original_creation = 3.days.ago
|
||||
|
||||
model.available_models.each do |model_name|
|
||||
ClassificationResult.create!(
|
||||
target: target,
|
||||
model_used: model_name,
|
||||
classification_type: model.type,
|
||||
created_at: original_creation,
|
||||
updated_at: original_creation,
|
||||
classification: {
|
||||
},
|
||||
)
|
||||
end
|
||||
|
||||
classification.classify!(target)
|
||||
|
||||
stored_results = ClassificationResult.where(target: target)
|
||||
expect(stored_results.length).to eq(model.available_models.length)
|
||||
|
||||
model.available_models.each do |model_name|
|
||||
result = stored_results.detect { |c| c.model_used == model_name }
|
||||
|
||||
expect(result.classification_type).to eq(model.type.to_s)
|
||||
expect(result.updated_at).to be > original_creation
|
||||
expect(result.created_at).to eq_time(original_creation)
|
||||
|
||||
expect(result.classification.deep_symbolize_keys).to eq(
|
||||
classification_raw_result[model_name],
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -12,16 +12,13 @@ describe DiscourseAI::PostClassification do
|
|||
describe "#classify!" do
|
||||
before { ToxicityInferenceStubs.stub_post_classification(post, toxic: true) }
|
||||
|
||||
it "stores the model classification data in a custom field" do
|
||||
it "stores the model classification data" do
|
||||
classification.classify!(post)
|
||||
custom_field = PostCustomField.find_by(post: post, name: model.type)
|
||||
result = ClassificationResult.find_by(target: post, classification_type: model.type)
|
||||
|
||||
expect(custom_field.value).to eq(
|
||||
{
|
||||
SiteSetting.ai_toxicity_inference_service_api_model =>
|
||||
ToxicityInferenceStubs.toxic_response,
|
||||
}.to_json,
|
||||
)
|
||||
classification = result.classification.symbolize_keys
|
||||
|
||||
expect(classification).to eq(ToxicityInferenceStubs.toxic_response)
|
||||
end
|
||||
|
||||
it "flags the message and hides the post when the model decides we should" do
|
||||
|
|
Loading…
Reference in New Issue