DEV: Dedicated table for saving classification results (#1)

This commit is contained in:
Roman Rizzi 2023-02-27 16:21:40 -03:00 committed by GitHub
parent 5f9597474c
commit b9a650fde4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 210 additions and 73 deletions

View File

@ -0,0 +1,23 @@
# frozen_string_literal: true
class ClassificationResult < ActiveRecord::Base
belongs_to :target, polymorphic: true
end
# == Schema Information
#
# Table name: classification_results
#
# id :bigint not null, primary key
# model_used :string
# classification_type :string
# target_id :integer
# target_type :string
# classification :jsonb
# created_at :datetime not null
# updated_at :datetime not null
#
# Indexes
#
# unique_classification_target_per_type (target_id,target_type,model_used) UNIQUE
#

View File

View File

@ -0,0 +1,19 @@
# frozen_string_literal: true
class CreateClassificationResultsTable < ActiveRecord::Migration[7.0]
def change
create_table :classification_results do |t|
t.string :model_used, null: true
t.string :classification_type, null: true
t.integer :target_id, null: true
t.string :target_type, null: true
t.jsonb :classification, null: true
t.timestamps
end
add_index :classification_results,
%i[target_id target_type model_used],
unique: true,
name: "unique_classification_target_per_type"
end
end

View File

@ -14,20 +14,23 @@ module DiscourseAI
def should_flag_based_on?(classification_data) def should_flag_based_on?(classification_data)
return false if !SiteSetting.ai_nsfw_flag_automatically return false if !SiteSetting.ai_nsfw_flag_automatically
# Flat representation of each model classification of each upload. classification_data.any? do |model_name, classifications|
# Each element looks like [model_name, data] classifications.values.any? do |data|
all_classifications = classification_data.values.flatten.map { |x| x.to_a.flatten } send("#{model_name}_verdict?", data.except(:neutral, :target_classified_type))
end
all_classifications.any? { |(model_name, data)| send("#{model_name}_verdict?", data) } end
end end
def request(target_to_classify) def request(target_to_classify)
uploads_to_classify = content_of(target_to_classify) uploads_to_classify = content_of(target_to_classify)
uploads_to_classify.reduce({}) do |memo, upload| available_models.reduce({}) do |memo, model|
memo[upload.id] = available_models.reduce({}) do |per_model, model| memo[model] = uploads_to_classify.reduce({}) do |upl_memo, upload|
per_model[model] = evaluate_with_model(model, upload) upl_memo[upload.id] = evaluate_with_model(model, upload).merge(
per_model target_classified_type: upload.class.name,
)
upl_memo
end end
memo memo
@ -61,11 +64,9 @@ module DiscourseAI
end end
def nsfw_detector_verdict?(classification) def nsfw_detector_verdict?(classification)
classification.each do |key, value| classification.any? do |key, value|
next if key == :neutral value.to_i >= SiteSetting.send("ai_nsfw_flag_threshold_#{key}")
return true if value.to_i >= SiteSetting.send("ai_nsfw_flag_threshold_#{key}") end
end
false
end end
end end
end end

View File

@ -42,11 +42,15 @@ module DiscourseAI
SiteSetting.ai_toxicity_inference_service_api_key, SiteSetting.ai_toxicity_inference_service_api_key,
) )
{ SiteSetting.ai_toxicity_inference_service_api_model => data } { available_model => data }
end end
private private
def available_model
SiteSetting.ai_toxicity_inference_service_api_model
end
def content_of(target_to_classify) def content_of(target_to_classify)
return target_to_classify.message if target_to_classify.is_a?(ChatMessage) return target_to_classify.message if target_to_classify.is_a?(ChatMessage)

View File

@ -4,14 +4,6 @@ module ::DiscourseAI
class ChatMessageClassification < Classification class ChatMessageClassification < Classification
private private
def store_classification(chat_message, type, classification_data)
PluginStore.set(
type,
"chat_message_#{chat_message.id}",
classification_data.merge(date: Time.now.utc),
)
end
def flag!(chat_message, _toxic_labels) def flag!(chat_message, _toxic_labels)
Chat::ChatReviewQueue.new.flag_message( Chat::ChatReviewQueue.new.flag_message(
chat_message, chat_message,

View File

@ -12,7 +12,7 @@ module ::DiscourseAI
classification_model classification_model
.request(target) .request(target)
.tap do |classification| .tap do |classification|
store_classification(target, classification_model.type, classification) store_classification(target, classification)
if classification_model.should_flag_based_on?(classification) if classification_model.should_flag_based_on?(classification)
flag!(target, classification) flag!(target, classification)
@ -28,8 +28,25 @@ module ::DiscourseAI
raise NotImplemented raise NotImplemented
end end
def store_classification(_target, _classification) def store_classification(target, classification)
raise NotImplemented attrs =
classification.map do |model_name, classifications|
{
model_used: model_name,
target_id: target.id,
target_type: target.class.name,
classification_type: classification_model.type,
classification: classifications,
updated_at: DateTime.now,
created_at: DateTime.now,
}
end
ClassificationResult.upsert_all(
attrs,
unique_by: %i[target_id target_type model_used],
update_only: %i[classification],
)
end end
def flagger def flagger

View File

@ -4,10 +4,6 @@ module ::DiscourseAI
class PostClassification < Classification class PostClassification < Classification
private private
def store_classification(post, type, classification_data)
PostCustomField.create!(post_id: post.id, name: type, value: classification_data.to_json)
end
def flag!(post, classification_type) def flag!(post, classification_type)
PostActionCreator.new( PostActionCreator.new(
flagger, flagger,

View File

@ -14,6 +14,8 @@ after_initialize do
PLUGIN_NAME = "discourse-ai" PLUGIN_NAME = "discourse-ai"
end end
require_relative "app/models/classification_result"
require_relative "lib/shared/inference_manager" require_relative "lib/shared/inference_manager"
require_relative "lib/shared/classification" require_relative "lib/shared/classification"
require_relative "lib/shared/post_classification" require_relative "lib/shared/post_classification"

View File

@ -8,19 +8,15 @@ describe DiscourseAI::NSFW::NSFWClassification do
let(:available_models) { SiteSetting.ai_nsfw_models.split("|") } let(:available_models) { SiteSetting.ai_nsfw_models.split("|") }
describe "#request" do
fab!(:upload_1) { Fabricate(:s3_image_upload) } fab!(:upload_1) { Fabricate(:s3_image_upload) }
fab!(:post) { Fabricate(:post, uploads: [upload_1]) } fab!(:post) { Fabricate(:post, uploads: [upload_1]) }
def assert_correctly_classified(upload, results, expected) describe "#request" do
available_models.each do |model| def assert_correctly_classified(results, expected)
model_result = results.dig(upload.id, model) available_models.each { |model| expect(results[model]).to eq(expected[model]) }
expect(model_result).to eq(expected[model])
end
end end
def build_expected_classification(positive: true) def build_expected_classification(target, positive: true)
available_models.reduce({}) do |memo, model| available_models.reduce({}) do |memo, model|
model_expected = model_expected =
if positive if positive
@ -29,7 +25,9 @@ describe DiscourseAI::NSFW::NSFWClassification do
NSFWInferenceStubs.negative_result(model) NSFWInferenceStubs.negative_result(model)
end end
memo[model] = model_expected memo[model] = {
target.id => model_expected.merge(target_classified_type: target.class.name),
}
memo memo
end end
end end
@ -37,11 +35,11 @@ describe DiscourseAI::NSFW::NSFWClassification do
context "when the target has one upload" do context "when the target has one upload" do
it "returns the classification and the model used for it" do it "returns the classification and the model used for it" do
NSFWInferenceStubs.positive(upload_1) NSFWInferenceStubs.positive(upload_1)
expected = build_expected_classification expected = build_expected_classification(upload_1)
classification = subject.request(post) classification = subject.request(post)
assert_correctly_classified(upload_1, classification, expected) assert_correctly_classified(classification, expected)
end end
context "when the target has multiple uploads" do context "when the target has multiple uploads" do
@ -52,13 +50,14 @@ describe DiscourseAI::NSFW::NSFWClassification do
it "returns a classification for each one" do it "returns a classification for each one" do
NSFWInferenceStubs.positive(upload_1) NSFWInferenceStubs.positive(upload_1)
NSFWInferenceStubs.negative(upload_2) NSFWInferenceStubs.negative(upload_2)
expected_upload_1 = build_expected_classification expected_classification = build_expected_classification(upload_1)
expected_upload_2 = build_expected_classification(positive: false) expected_classification.deep_merge!(
build_expected_classification(upload_2, positive: false),
)
classification = subject.request(post) classification = subject.request(post)
assert_correctly_classified(upload_1, classification, expected_upload_1) assert_correctly_classified(classification, expected_classification)
assert_correctly_classified(upload_2, classification, expected_upload_2)
end end
end end
end end
@ -69,15 +68,23 @@ describe DiscourseAI::NSFW::NSFWClassification do
let(:positive_classification) do let(:positive_classification) do
{ {
1 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } }, "opennsfw2" => {
2 => available_models.map { |m| { m => NSFWInferenceStubs.positive_result(m) } }, 1 => NSFWInferenceStubs.negative_result("opennsfw2"),
2 => NSFWInferenceStubs.positive_result("opennsfw2"),
},
"nsfw_detector" => {
1 => NSFWInferenceStubs.negative_result("nsfw_detector"),
2 => NSFWInferenceStubs.positive_result("nsfw_detector"),
},
} }
end end
let(:negative_classification) do let(:negative_classification) do
{ {
1 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } }, "opennsfw2" => {
2 => available_models.map { |m| { m => NSFWInferenceStubs.negative_result(m) } }, 1 => NSFWInferenceStubs.negative_result("opennsfw2"),
2 => NSFWInferenceStubs.negative_result("opennsfw2"),
},
} }
end end

View File

@ -18,19 +18,19 @@ describe Jobs::PostSentimentAnalysis do
subject.execute({ post_id: post.id }) subject.execute({ post_id: post.id })
expect(PostCustomField.where(post: post).count).to be_zero expect(ClassificationResult.where(target: post).count).to be_zero
end end
it "does nothing if there's no arg called post_id" do it "does nothing if there's no arg called post_id" do
subject.execute({}) subject.execute({})
expect(PostCustomField.where(post: post).count).to be_zero expect(ClassificationResult.where(target: post).count).to be_zero
end end
it "does nothing if no post match the given id" do it "does nothing if no post match the given id" do
subject.execute({ post_id: nil }) subject.execute({ post_id: nil })
expect(PostCustomField.where(post: post).count).to be_zero expect(ClassificationResult.where(target: post).count).to be_zero
end end
it "does nothing if the post content is blank" do it "does nothing if the post content is blank" do
@ -38,7 +38,7 @@ describe Jobs::PostSentimentAnalysis do
subject.execute({ post_id: post.id }) subject.execute({ post_id: post.id })
expect(PostCustomField.where(post: post).count).to be_zero expect(ClassificationResult.where(target: post).count).to be_zero
end end
end end
@ -48,7 +48,7 @@ describe Jobs::PostSentimentAnalysis do
subject.execute({ post_id: post.id }) subject.execute({ post_id: post.id })
expect(PostCustomField.where(post: post).count).to eq(expected_analysis) expect(ClassificationResult.where(target: post).count).to eq(expected_analysis)
end end
end end
end end

View File

@ -4,9 +4,9 @@ require "rails_helper"
require_relative "../../../support/sentiment_inference_stubs" require_relative "../../../support/sentiment_inference_stubs"
describe DiscourseAI::Sentiment::SentimentClassification do describe DiscourseAI::Sentiment::SentimentClassification do
describe "#request" do
fab!(:target) { Fabricate(:post) } fab!(:target) { Fabricate(:post) }
describe "#request" do
before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" } before { SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com" }
it "returns the classification and the model used for it" do it "returns the classification and the model used for it" do

View File

@ -4,9 +4,9 @@ require "rails_helper"
require_relative "../../../support/toxicity_inference_stubs" require_relative "../../../support/toxicity_inference_stubs"
describe DiscourseAI::Toxicity::ToxicityClassification do describe DiscourseAI::Toxicity::ToxicityClassification do
describe "#request" do
fab!(:target) { Fabricate(:post) } fab!(:target) { Fabricate(:post) }
describe "#request" do
it "returns the classification and the model used for it" do it "returns the classification and the model used for it" do
ToxicityInferenceStubs.stub_post_classification(target, toxic: false) ToxicityInferenceStubs.stub_post_classification(target, toxic: false)

View File

@ -12,15 +12,14 @@ describe DiscourseAI::ChatMessageClassification do
describe "#classify!" do describe "#classify!" do
before { ToxicityInferenceStubs.stub_chat_message_classification(chat_message, toxic: true) } before { ToxicityInferenceStubs.stub_chat_message_classification(chat_message, toxic: true) }
it "stores the model classification data in a custom field" do it "stores the model classification data" do
classification.classify!(chat_message) classification.classify!(chat_message)
store_row = PluginStore.get("toxicity", "chat_message_#{chat_message.id}")
classified_data = result = ClassificationResult.find_by(target: chat_message, classification_type: model.type)
store_row[SiteSetting.ai_toxicity_inference_service_api_model].symbolize_keys
expect(classified_data).to eq(ToxicityInferenceStubs.toxic_response) classification = result.classification.symbolize_keys
expect(store_row[:date]).to be_present
expect(classification).to eq(ToxicityInferenceStubs.toxic_response)
end end
it "flags the message when the model decides we should" do it "flags the message when the model decides we should" do

View File

@ -0,0 +1,80 @@
# frozen_string_literal: true
require "rails_helper"
require_relative "../support/sentiment_inference_stubs"
describe DiscourseAI::Classification do
describe "#classify!" do
describe "saving the classification result" do
let(:classification_raw_result) do
model
.available_models
.reduce({}) do |memo, model_name|
memo[model_name] = SentimentInferenceStubs.model_response(model_name)
memo
end
end
let(:model) { DiscourseAI::Sentiment::SentimentClassification.new }
let(:classification) { DiscourseAI::PostClassification.new(model) }
fab!(:target) { Fabricate(:post) }
before do
SiteSetting.ai_sentiment_inference_service_api_endpoint = "http://test.com"
SentimentInferenceStubs.stub_classification(target)
end
it "stores one result per model used" do
classification.classify!(target)
stored_results = ClassificationResult.where(target: target)
expect(stored_results.length).to eq(model.available_models.length)
model.available_models.each do |model_name|
result = stored_results.detect { |c| c.model_used == model_name }
expect(result.classification_type).to eq(model.type.to_s)
expect(result.created_at).to be_present
expect(result.updated_at).to be_present
expected_classification = SentimentInferenceStubs.model_response(model)
expect(result.classification.deep_symbolize_keys).to eq(expected_classification)
end
end
it "updates an existing classification result" do
original_creation = 3.days.ago
model.available_models.each do |model_name|
ClassificationResult.create!(
target: target,
model_used: model_name,
classification_type: model.type,
created_at: original_creation,
updated_at: original_creation,
classification: {
},
)
end
classification.classify!(target)
stored_results = ClassificationResult.where(target: target)
expect(stored_results.length).to eq(model.available_models.length)
model.available_models.each do |model_name|
result = stored_results.detect { |c| c.model_used == model_name }
expect(result.classification_type).to eq(model.type.to_s)
expect(result.updated_at).to be > original_creation
expect(result.created_at).to eq_time(original_creation)
expect(result.classification.deep_symbolize_keys).to eq(
classification_raw_result[model_name],
)
end
end
end
end
end

View File

@ -12,16 +12,13 @@ describe DiscourseAI::PostClassification do
describe "#classify!" do describe "#classify!" do
before { ToxicityInferenceStubs.stub_post_classification(post, toxic: true) } before { ToxicityInferenceStubs.stub_post_classification(post, toxic: true) }
it "stores the model classification data in a custom field" do it "stores the model classification data" do
classification.classify!(post) classification.classify!(post)
custom_field = PostCustomField.find_by(post: post, name: model.type) result = ClassificationResult.find_by(target: post, classification_type: model.type)
expect(custom_field.value).to eq( classification = result.classification.symbolize_keys
{
SiteSetting.ai_toxicity_inference_service_api_model => expect(classification).to eq(ToxicityInferenceStubs.toxic_response)
ToxicityInferenceStubs.toxic_response,
}.to_json,
)
end end
it "flags the message and hides the post when the model decides we should" do it "flags the message and hides the post when the model decides we should" do