2023-02-24 11:25:02 -05:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
2023-03-14 15:03:50 -04:00
|
|
|
module ::DiscourseAi
|
2023-02-28 09:17:03 -05:00
|
|
|
class Classificator
|
2023-02-24 11:25:02 -05:00
|
|
|
def initialize(classification_model)
|
|
|
|
@classification_model = classification_model
|
|
|
|
end
|
|
|
|
|
|
|
|
def classify!(target)
|
|
|
|
return :cannot_classify unless classification_model.can_classify?(target)
|
|
|
|
|
|
|
|
classification_model
|
|
|
|
.request(target)
|
|
|
|
.tap do |classification|
|
2023-02-27 14:21:40 -05:00
|
|
|
store_classification(target, classification)
|
2023-02-24 11:25:02 -05:00
|
|
|
|
2023-03-07 13:39:28 -05:00
|
|
|
verdicts = classification_model.get_verdicts(classification)
|
|
|
|
|
|
|
|
if classification_model.should_flag_based_on?(verdicts)
|
|
|
|
accuracies = get_model_accuracies(verdicts.keys)
|
|
|
|
flag!(target, classification, verdicts, accuracies)
|
2023-02-24 11:25:02 -05:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
protected
|
|
|
|
|
|
|
|
attr_reader :classification_model
|
|
|
|
|
2023-03-07 13:39:28 -05:00
|
|
|
def flag!(_target, _classification, _verdicts, _accuracies)
|
2023-02-24 11:25:02 -05:00
|
|
|
raise NotImplemented
|
|
|
|
end
|
|
|
|
|
2023-03-07 13:39:28 -05:00
|
|
|
def get_model_accuracies(models)
|
|
|
|
models
|
|
|
|
.map do |name|
|
|
|
|
accuracy =
|
|
|
|
ModelAccuracy.find_or_create_by(
|
|
|
|
model: name,
|
|
|
|
classification_type: classification_model.type,
|
|
|
|
)
|
|
|
|
[name, accuracy.calculate_accuracy]
|
|
|
|
end
|
|
|
|
.to_h
|
|
|
|
end
|
|
|
|
|
|
|
|
def add_score(reviewable)
|
|
|
|
reviewable.add_score(
|
|
|
|
Discourse.system_user,
|
|
|
|
ReviewableScore.types[:inappropriate],
|
|
|
|
reason: "flagged_by_#{classification_model.type}",
|
|
|
|
force_review: true,
|
|
|
|
)
|
|
|
|
end
|
|
|
|
|
2023-02-27 14:21:40 -05:00
|
|
|
def store_classification(target, classification)
|
|
|
|
attrs =
|
|
|
|
classification.map do |model_name, classifications|
|
|
|
|
{
|
|
|
|
model_used: model_name,
|
|
|
|
target_id: target.id,
|
2023-03-17 10:15:38 -04:00
|
|
|
target_type: target.class.sti_name,
|
2023-02-27 14:21:40 -05:00
|
|
|
classification_type: classification_model.type,
|
|
|
|
classification: classifications,
|
|
|
|
updated_at: DateTime.now,
|
|
|
|
created_at: DateTime.now,
|
|
|
|
}
|
|
|
|
end
|
|
|
|
|
|
|
|
ClassificationResult.upsert_all(
|
|
|
|
attrs,
|
|
|
|
unique_by: %i[target_id target_type model_used],
|
|
|
|
update_only: %i[classification],
|
|
|
|
)
|
2023-02-24 11:25:02 -05:00
|
|
|
end
|
|
|
|
|
|
|
|
def flagger
|
|
|
|
Discourse.system_user
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|