From 66bf4c74c695e03c7b45b9413e5803f8c18d5ca4 Mon Sep 17 00:00:00 2001 From: Rafael dos Santos Silva Date: Thu, 11 May 2023 15:35:39 -0300 Subject: [PATCH] FEATURE: Handle invalid media in NSFW module (#57) * FEATURE: Handle invalid media in NSFW module * fix lint --- lib/modules/nsfw/nsfw_classification.rb | 14 ++++++-------- lib/shared/inference/discourse_classifier.rb | 2 +- spec/lib/modules/nsfw/nsfw_classification_spec.rb | 10 ++++++++++ spec/support/nsfw_inference_stubs.rb | 12 ++++++++++++ 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/lib/modules/nsfw/nsfw_classification.rb b/lib/modules/nsfw/nsfw_classification.rb index a21a3a86..c8187bb9 100644 --- a/lib/modules/nsfw/nsfw_classification.rb +++ b/lib/modules/nsfw/nsfw_classification.rb @@ -35,9 +35,11 @@ module DiscourseAi available_models.reduce({}) do |memo, model| memo[model] = uploads_to_classify.reduce({}) do |upl_memo, upload| - upl_memo[upload.id] = evaluate_with_model(model, upload).merge( - target_classified_type: upload.class.name, - ) + classification = + evaluate_with_model(model, upload).merge(target_classified_type: upload.class.name) + + # 415 denotes that the image is not supported by the model, so we skip it + upl_memo[upload.id] = classification if classification.dig(:status) != 415 upl_memo end @@ -65,11 +67,7 @@ module DiscourseAi end def content_of(target_to_classify) - target_to_classify - .uploads - .where(extension: %w[png jpeg jpg PNG JPEG JPG]) - .to_a - .select { |u| FileHelper.is_supported_image?(u.url) } + target_to_classify.uploads.to_a.select { |u| FileHelper.is_supported_image?(u.url) } end def opennsfw2_verdict?(clasification) diff --git a/lib/shared/inference/discourse_classifier.rb b/lib/shared/inference/discourse_classifier.rb index 7ab2f188..041b746c 100644 --- a/lib/shared/inference/discourse_classifier.rb +++ b/lib/shared/inference/discourse_classifier.rb @@ -10,7 +10,7 @@ module ::DiscourseAi response = Faraday.post(endpoint, { model: model, content: content }.to_json, headers) - raise Net::HTTPBadResponse unless response.status == 200 + raise Net::HTTPBadResponse if ![200, 415].include?(response.status) JSON.parse(response.body, symbolize_names: true) end diff --git a/spec/lib/modules/nsfw/nsfw_classification_spec.rb b/spec/lib/modules/nsfw/nsfw_classification_spec.rb index 9edfc82e..1d3b99e1 100644 --- a/spec/lib/modules/nsfw/nsfw_classification_spec.rb +++ b/spec/lib/modules/nsfw/nsfw_classification_spec.rb @@ -59,6 +59,16 @@ describe DiscourseAi::NSFW::NSFWClassification do assert_correctly_classified(classification, expected_classification) end + + it "correctly skips unsupported uploads" do + NSFWInferenceStubs.positive(upload_1) + NSFWInferenceStubs.unsupported(upload_2) + expected_classification = build_expected_classification(upload_1) + + classification = subject.request(post) + + assert_correctly_classified(classification, expected_classification) + end end end end diff --git a/spec/support/nsfw_inference_stubs.rb b/spec/support/nsfw_inference_stubs.rb index 49edbd82..3fe45b30 100644 --- a/spec/support/nsfw_inference_stubs.rb +++ b/spec/support/nsfw_inference_stubs.rb @@ -46,5 +46,17 @@ class NSFWInferenceStubs .with(body: JSON.dump(model: "opennsfw2", content: upload_url(upload))) .to_return(status: 200, body: JSON.dump(negative_result("opennsfw2"))) end + + def unsupported(upload) + WebMock + .stub_request(:post, endpoint) + .with(body: JSON.dump(model: "nsfw_detector", content: upload_url(upload))) + .to_return(status: 415, body: JSON.dump({ error: "Unsupported image type", status: 415 })) + + WebMock + .stub_request(:post, endpoint) + .with(body: JSON.dump(model: "opennsfw2", content: upload_url(upload))) + .to_return(status: 415, body: JSON.dump({ error: "Unsupported image type", status: 415 })) + end end end