DEV: Dedicated table for saving classification results (#1)

This commit is contained in:
Roman Rizzi 2023-02-27 16:21:40 -03:00 committed by GitHub
parent 5f9597474c
commit b9a650fde4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 210 additions and 73 deletions

View File

@ -0,0 +1,23 @@
# frozen_string_literal: true
class ClassificationResult < ActiveRecord::Base
belongs_to :target, polymorphic: true
end
# == Schema Information
#
# Table name: classification_results
#
# id :bigint not null, primary key
# model_used :string
# classification_type :string
# target_id :integer
# target_type :string
# classification :jsonb
# created_at :datetime not null
# updated_at :datetime not null
#
# Indexes
#
# unique_classification_target_per_type (target_id,target_type,model_used) UNIQUE
#

View File

View File

@ -0,0 +1,19 @@
# frozen_string_literal: true
class CreateClassificationResultsTable < ActiveRecord::Migration[7.0]
def change
create_table :classification_results do |t|
t.string :model_used, null: true
t.string :classification_type, null: true
t.integer :target_id, null: true
t.string :target_type, null: true
t.jsonb :classification, null: true
t.timestamps
end
add_index :classification_results,
%i[target_id target_type model_used],
unique: true,
name: "unique_classification_target_per_type"
end
end

View File

@ -14,20 +14,23 @@ module DiscourseAI
def should_flag_based_on?(classification_data)
return false if !SiteSetting.ai_nsfw_flag_automatically
# Flat representation of each model classification of each upload.
# Each element looks like [model_name, data]
all_classifications = classification_data.values.flatten.map { |x| x.to_a.flatten }
all_classifications.any? { |(model_name, data)| send("#{model_name}_verdict?", data) }
classification_data.any? do |model_name, classifications|
classifications.values.any? do |data|
send("#{model_name}_verdict?", data.except(:neutral, :target_classified_type))
end
end
end
def request(target_to_classify)
uploads_to_classify = content_of(target_to_classify)
uploads_to_classify.reduce({}) do |memo, upload|
memo[upload.id] = available_models.reduce({}) do |per_model, model|
per_model[model] = evaluate_with_model(model, upload)
per_model
available_models.reduce({}) do |memo, model|
memo[model] = uploads_to_classify.reduce({}) do |upl_memo, upload|
upl_memo[upload.id] = evaluate_with_model(model, upload).merge(
target_classified_type: upload.class.name,
)
upl_memo
end
memo
@ -61,11 +64,9 @@ module DiscourseAI
end
def nsfw_detector_verdict?(classification)
classification.each do |key, value|
next if key == :neutral
return true if value.to_i >= SiteSetting.send("ai_nsfw_flag_threshold_#{key}")
classification.any? do |key, value|
value.to_i >= SiteSetting.send("ai_nsfw_flag_threshold_#{key}")
end
false
end
end
end

View File

@ -42,11 +42,15 @@ module DiscourseAI
SiteSetting.ai_toxicity_inference_service_api_key,
)
{ SiteSetting.ai_toxicity_inference_service_api_model => data }
{ available_model => data }
end
private
def available_model
SiteSetting.ai_toxicity_inference_service_api_model
end
def content_of(target_to_classify)
return target_to_classify.message if target_to_classify.is_a?(ChatMessage)

View File

@ -4,14 +4,6 @@ module ::DiscourseAI
class ChatMessageClassification < Classification
private
def store_classification(chat_message, type, classification_data)
PluginStore.set(
type,
"chat_message_#{chat_message.id}",
classification_data.merge(date: Time.now.utc),
)
end
def flag!(chat_message, _toxic_labels)
Chat::ChatReviewQueue.new.flag_message(
chat_message,

View File

@ -12,7 +12,7 @@ module ::DiscourseAI
classification_model
.request(target)
.tap do |classification|
store_classification(target, classification_model.type, classification)
store_classification(target, classification)
if classification_model.should_flag_based_on?(classification)
flag!(target, classification)
@ -28,8 +28,25 @@ module ::DiscourseAI
raise NotImplemented
end
def store_classification(_target, _classification)
raise NotImplemented
def store_classification(target, classification)
attrs =
classification.map do |model_name, classifications|
{
model_used: model_name,
target_id: target.id,
target_type: target.class.name,
classification_type: classification_model.type,
classification: classifications,
updated_at: DateTime.now,
created_at: DateTime.now,
}
end
ClassificationResult.upsert_all(
attrs,
unique_by: %i[target_id target_type model_used],
update_only: %i[classification],
)
end
def flagger

View File

@ -4,10 +4,6 @@ module ::DiscourseAI
class PostClassification < Classification
private
def store_classification(post, type, classification_data)
PostCustomField.create!(post_id: post.id, name: type, value: classification_data.to_json)
end
def flag!(post, classification_type)
PostActionCreator.new(
flagger,

View File

@ -14,6 +14,8 @@ after_initialize do
PLUGIN_NAME = "discourse-ai"
end
require_relative "app/models/classification_result"
require_relative "lib/shared/inference_manager"
require_relative "lib/shared/classification"
require_relative "lib/shared/post_classification"

View File

@ -8,19 +8,15 @@ describe DiscourseAI::NSFW::NSFWClassification do
let(:available_models) { SiteSetting.ai_nsfw_models.split("|") }
describe "#request" do
fab!(:upload_1) { Fabricate(:s3_image_upload) }
fab!(:post) { Fabricate(:post, uploads: [upload_1]) }
def assert_correctly_classified(upload, results, expected)
available_models.each do |model|
model_result = results.dig(upload.id, model)
expect(model_result).to eq(expected[model])
end
describe "#request" do
def assert_correctly_classified(results, expected)
available_models.each { |model| expect(results[model]).to eq(expected[model]) }
end
def build_expected_classification(positive: true)
def build_expected_classification(target, positive: true)
available_models.reduce({}) do |memo, model|
model_expected =
if positive
@ -29,7 +25,9 @@ describe DiscourseAI::NSFW::NSFWClassification do
NSFWInferenceStubs.negative_result(model)
end
memo[model] = model_expected
memo[model] = {
target.id => model_expected.merge(target_classified_type: target.class.name),
}
memo
end
end
@ -37,11 +35,11 @@ describe DiscourseAI::NSFW::NSFWClassification do
context "when the target has one upload" do
it "returns the classification and the model used for it" do
NSFWInferenceStubs.positive(upload_1)
expected = build_expected_classification
expected = build_expected_classification(upload_1)
classification = subject.request(post)
assert_correctly_classified(upload_1, classification, expected)
assert_correctly_classified(classification, expected)
end
context "when the target has multiple uploads" do
@ -52,13 +50,14 @@ describe DiscourseAI::NSFW::NSFWClassification do
it "returns a classification for each one" do
NSFWInferenceStubs.positive(upload_1)
NSFWInferenceStubs.negative(upload_2)
expected_upload_1 = build_expected_classification
expected_upload_2 = build_expected_classification(positive: false)
expected_classification = build_expected_classification(upload_1)
expected_classification.deep_merge!(
build_expected_classification(upload_2, positive: false),
)
classification = subject.request(post)
assert_correctly_classified(upload_1, classification, expected_upload_1)
assert_correctly_classified(upload_2, classification, expected_upload_2)
assert_correctly_classified(classification, expected_classification)
end
end
end
@ -69,15 +68,23 @@ describe DiscourseAI::NSFW::NSFWClassification do
let(:positive_classification) do
{
1 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } },
2 => available_models.map { |m| { m => NSFWInferenceStubs.positive_result(m) } },
"opennsfw2" => {
1 => NSFWInferenceStubs.negative_result("opennsfw2"),
2 => NSFWInferenceStubs.positive_result("opennsfw2"),
},
"nsfw_detector" => {
1 => NSFWInferenceStubs.negative_result("nsfw_detector"),
2 => NSFWInferenceStubs.positive_result("nsfw_detector"),
},
}
end
let(:negative_classification) do
{
1 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } },
2 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } },
"opennsfw2" => {
1 => NSFWInferenceStubs.negative_result("opennsfw2"),
2 => NSFWInferenceStubs.negative_result("opennsfw2"),
},
}
end

View File

@ -18,19 +18,19 @@ describe Jobs::PostSentimentAnalysis do
subject.execute({ post_id: post.id })
expect(PostCustomField.where(post: post).count).to be_zero
expect(ClassificationResult.where(target: post).count).to be_zero
end
it "does nothing if there's no arg called post_id" do
subject.execute({})
expect(PostCustomField.where(post: post).count).to be_zero
expect(ClassificationResult.where(target: post).count).to be_zero
end
it "does nothing if no post match the given id" do
subject.execute({ post_id: nil })
expect(PostCustomField.where(post: post).count).to be_zero
expect(ClassificationResult.where(target: post).count).to be_zero
end
it "does nothing if the post content is blank" do
@ -38,7 +38,7 @@ describe Jobs::PostSentimentAnalysis do
subject.execute({ post_id: post.id })
expect(PostCustomField.where(post: post).count).to be_zero
expect(ClassificationResult.where(target: post).count).to be_zero
end
end
@ -48,7 +48,7 @@ describe Jobs::PostSentimentAnalysis do
subject.execute({ post_id: post.id })
expect(PostCustomField.where(post: post).count).to eq(expected_analysis)
expect(ClassificationResult.where(target: post).count).to eq(expected_analysis)
end
end
end

View File

@ -4,9 +4,9 @@ require "rails_helper"
require_relative "../../../support/sentiment_inference_stubs"
describe DiscourseAI::Sentiment::SentimentClassification do
describe "#request" do
fab!(:target) { Fabricate(:post) }
describe "#request" do
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
it "returns the classification and the model used for it" do

View File

@ -4,9 +4,9 @@ require "rails_helper"
require_relative "../../../support/toxicity_inference_stubs"
describe DiscourseAI::Toxicity::ToxicityClassification do
describe "#request" do
fab!(:target) { Fabricate(:post) }
describe "#request" do
it "returns the classification and the model used for it" do
ToxicityInferenceStubs.stub_post_classification(target, toxic: false)

View File

@ -12,15 +12,14 @@ describe DiscourseAI::ChatMessageClassification do
describe "#classify!" do
before { ToxicityInferenceStubs.stub_chat_message_classification(chat_message, toxic: true) }
it "stores the model classification data in a custom field" do
it "stores the model classification data" do
classification.classify!(chat_message)
store_row = PluginStore.get("toxicity", "chat_message_#{chat_message.id}")
classified_data =
store_row[SiteSetting.ai_toxicity_inference_service_api_model].symbolize_keys
result = ClassificationResult.find_by(target: chat_message, classification_type: model.type)
expect(classified_data).to eq(ToxicityInferenceStubs.toxic_response)
expect(store_row[:date]).to be_present
classification = result.classification.symbolize_keys
expect(classification).to eq(ToxicityInferenceStubs.toxic_response)
end
it "flags the message when the model decides we should" do

View File

@ -0,0 +1,80 @@
# frozen_string_literal: true
require "rails_helper"
require_relative "../support/sentiment_inference_stubs"
describe DiscourseAI::Classification do
describe "#classify!" do
describe "saving the classification result" do
let(:classification_raw_result) do
model
.available_models
.reduce({}) do |memo, model_name|
memo[model_name] = SentimentInferenceStubs.model_response(model_name)
memo
end
end
let(:model) { DiscourseAI::Sentiment::SentimentClassification.new }
let(:classification) { DiscourseAI::PostClassification.new(model) }
fab!(:target) { Fabricate(:post) }
before do
SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com"
SentimentInferenceStubs.stub_classification(target)
end
it "stores one result per model used" do
classification.classify!(target)
stored_results = ClassificationResult.where(target: target)
expect(stored_results.length).to eq(model.available_models.length)
model.available_models.each do |model_name|
result = stored_results.detect { |c| c.model_used == model_name }
expect(result.classification_type).to eq(model.type.to_s)
expect(result.created_at).to be_present
expect(result.updated_at).to be_present
expected_classification = SentimentInferenceStubs.model_response(model)
expect(result.classification.deep_symbolize_keys).to eq(expected_classification)
end
end
it "updates an existing classification result" do
original_creation = 3.days.ago
model.available_models.each do |model_name|
ClassificationResult.create!(
target: target,
model_used: model_name,
classification_type: model.type,
created_at: original_creation,
updated_at: original_creation,
classification: {
},
)
end
classification.classify!(target)
stored_results = ClassificationResult.where(target: target)
expect(stored_results.length).to eq(model.available_models.length)
model.available_models.each do |model_name|
result = stored_results.detect { |c| c.model_used == model_name }
expect(result.classification_type).to eq(model.type.to_s)
expect(result.updated_at).to be > original_creation
expect(result.created_at).to eq_time(original_creation)
expect(result.classification.deep_symbolize_keys).to eq(
classification_raw_result[model_name],
)
end
end
end
end
end

View File

@ -12,16 +12,13 @@ describe DiscourseAI::PostClassification do
describe "#classify!" do
before { ToxicityInferenceStubs.stub_post_classification(post, toxic: true) }
it "stores the model classification data in a custom field" do
it "stores the model classification data" do
classification.classify!(post)
custom_field = PostCustomField.find_by(post: post, name: model.type)
result = ClassificationResult.find_by(target: post, classification_type: model.type)
expect(custom_field.value).to eq(
{
SiteSetting.ai_toxicity_inference_service_api_model =>
ToxicityInferenceStubs.toxic_response,
}.to_json,
)
classification = result.classification.symbolize_keys
expect(classification).to eq(ToxicityInferenceStubs.toxic_response)
end
it "flags the message and hides the post when the model decides we should" do