From 510c6487e3fc178a5333b9fcaa119cce2314277f Mon Sep 17 00:00:00 2001 From: Rafael dos Santos Silva Date: Tue, 7 Mar 2023 16:14:39 -0300 Subject: [PATCH] DEV: Preparation work for multiple inference providers (#5) --- config/settings.yml | 3 +++ lib/modules/nsfw/nsfw_classification.rb | 2 +- .../sentiment/sentiment_classification.rb | 2 +- .../toxicity/toxicity_classification.rb | 2 +- lib/shared/inference/discourse_classifier.rb | 19 +++++++++++++ lib/shared/inference/discourse_reranker.rb | 24 +++++++++++++++++ lib/shared/inference/openai_completions.rb | 27 +++++++++++++++++++ lib/shared/inference/openai_embeddings.rb | 27 +++++++++++++++++++ lib/shared/inference_manager.rb | 17 ------------ lib/tasks/modules/toxicity/calibration.rake | 2 +- plugin.rb | 6 ++++- 11 files changed, 109 insertions(+), 22 deletions(-) create mode 100644 lib/shared/inference/discourse_classifier.rb create mode 100644 lib/shared/inference/discourse_reranker.rb create mode 100644 lib/shared/inference/openai_completions.rb create mode 100644 lib/shared/inference/openai_embeddings.rb delete mode 100644 lib/shared/inference_manager.rb diff --git a/config/settings.yml b/config/settings.yml index 4fcf3a8b..71e78e96 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -84,3 +84,6 @@ plugins: choices: - opennsfw2 - nsfw_detector + + ai_openai_api_key: + default: "" diff --git a/lib/modules/nsfw/nsfw_classification.rb b/lib/modules/nsfw/nsfw_classification.rb index a63dd992..f3d5967f 100644 --- a/lib/modules/nsfw/nsfw_classification.rb +++ b/lib/modules/nsfw/nsfw_classification.rb @@ -52,7 +52,7 @@ module DiscourseAI upload_url = Discourse.store.cdn_url(upload.url) upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/") - DiscourseAI::InferenceManager.perform!( + DiscourseAI::Inference::DiscourseClassifier.perform!( "#{SiteSetting.ai_nsfw_inference_service_api_endpoint}/api/v1/classify", model, upload_url, diff --git a/lib/modules/sentiment/sentiment_classification.rb b/lib/modules/sentiment/sentiment_classification.rb index 76a2ba3e..633563dd 100644 --- a/lib/modules/sentiment/sentiment_classification.rb +++ b/lib/modules/sentiment/sentiment_classification.rb @@ -39,7 +39,7 @@ module DiscourseAI private def request_with(model, content) - ::DiscourseAI::InferenceManager.perform!( + ::DiscourseAI::Inference::DiscourseClassifier.perform!( "#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify", model, content, diff --git a/lib/modules/toxicity/toxicity_classification.rb b/lib/modules/toxicity/toxicity_classification.rb index bf4e3679..815b9269 100644 --- a/lib/modules/toxicity/toxicity_classification.rb +++ b/lib/modules/toxicity/toxicity_classification.rb @@ -42,7 +42,7 @@ module DiscourseAI def request(target_to_classify) data = - ::DiscourseAI::InferenceManager.perform!( + ::DiscourseAI::Inference::DiscourseClassifier.perform!( "#{SiteSetting.ai_toxicity_inference_service_api_endpoint}/api/v1/classify", SiteSetting.ai_toxicity_inference_service_api_model, content_of(target_to_classify), diff --git a/lib/shared/inference/discourse_classifier.rb b/lib/shared/inference/discourse_classifier.rb new file mode 100644 index 00000000..601c3245 --- /dev/null +++ b/lib/shared/inference/discourse_classifier.rb @@ -0,0 +1,19 @@ +# frozen_string_literal: true + +module ::DiscourseAI + module Inference + class DiscourseClassifier + def self.perform!(endpoint, model, content, api_key) + headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } + + headers["X-API-KEY"] = api_key if api_key.present? + + response = Faraday.post(endpoint, { model: model, content: content }.to_json, headers) + + raise Net::HTTPBadResponse unless response.status == 200 + + JSON.parse(response.body, symbolize_names: true) + end + end + end +end diff --git a/lib/shared/inference/discourse_reranker.rb b/lib/shared/inference/discourse_reranker.rb new file mode 100644 index 00000000..b50953f5 --- /dev/null +++ b/lib/shared/inference/discourse_reranker.rb @@ -0,0 +1,24 @@ +# frozen_string_literal: true + +module ::DiscourseAI + module Inference + class DiscourseReranker + def self.perform!(endpoint, model, content, candidates, api_key) + headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } + + headers["X-API-KEY"] = api_key if api_key.present? + + response = + Faraday.post( + endpoint, + { model: model, content: content, candidates: candidates }.to_json, + headers, + ) + + raise Net::HTTPBadResponse unless response.status == 200 + + JSON.parse(response.body, symbolize_names: true) + end + end + end +end diff --git a/lib/shared/inference/openai_completions.rb b/lib/shared/inference/openai_completions.rb new file mode 100644 index 00000000..49364e5f --- /dev/null +++ b/lib/shared/inference/openai_completions.rb @@ -0,0 +1,27 @@ +# frozen_string_literal: true + +module ::DiscourseAI + module Inference + class OpenAICompletions + def self.perform!(model, content, api_key) + headers = { + "Authorization" => "Bearer #{SiteSetting.ai_openai_api_key}", + "Content-Type" => "application/json", + } + + model ||= "gpt-3.5-turbo" + + response = + Faraday.post( + "https://api.openai.com/v1/chat/completions", + { model: model, messages: content }.to_json, + headers, + ) + + raise Net::HTTPBadResponse unless response.status == 200 + + JSON.parse(response.body, symbolize_names: true) + end + end + end +end diff --git a/lib/shared/inference/openai_embeddings.rb b/lib/shared/inference/openai_embeddings.rb new file mode 100644 index 00000000..116bc62e --- /dev/null +++ b/lib/shared/inference/openai_embeddings.rb @@ -0,0 +1,27 @@ +# frozen_string_literal: true + +module ::DiscourseAI + module Inference + class OpenAIEmbeddings + def self.perform!(content, model = nil) + headers = { + "Authorization" => "Bearer #{SiteSetting.ai_openai_api_key}", + "Content-Type" => "application/json", + } + + model ||= "text-embedding-ada-002" + + response = + Faraday.post( + "https://api.openai.com/v1/embeddings", + { model: model, input: content }.to_json, + headers, + ) + + raise Net::HTTPBadResponse unless response.status == 200 + + JSON.parse(response.body, symbolize_names: true) + end + end + end +end diff --git a/lib/shared/inference_manager.rb b/lib/shared/inference_manager.rb deleted file mode 100644 index 6f7b12ee..00000000 --- a/lib/shared/inference_manager.rb +++ /dev/null @@ -1,17 +0,0 @@ -# frozen_string_literal: true - -module ::DiscourseAI - class InferenceManager - def self.perform!(endpoint, model, content, api_key) - headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } - - headers["X-API-KEY"] = api_key if api_key.present? - - response = Faraday.post(endpoint, { model: model, content: content }.to_json, headers) - - raise Net::HTTPBadResponse unless response.status == 200 - - JSON.parse(response.body, symbolize_names: true) - end - end -end diff --git a/lib/tasks/modules/toxicity/calibration.rake b/lib/tasks/modules/toxicity/calibration.rake index b9fe6bf3..50cf3f1c 100644 --- a/lib/tasks/modules/toxicity/calibration.rake +++ b/lib/tasks/modules/toxicity/calibration.rake @@ -1,7 +1,7 @@ # frozen_string_literal: true def classify(content) - ::DiscourseAI::InferenceManager.perform!( + ::DiscourseAI::Inference::DiscourseClassifier.perform!( "#{SiteSetting.ai_toxicity_inference_service_api_endpoint}/api/v1/classify", SiteSetting.ai_toxicity_inference_service_api_model, content, diff --git a/plugin.rb b/plugin.rb index 6f0e3549..7181712e 100644 --- a/plugin.rb +++ b/plugin.rb @@ -16,7 +16,11 @@ after_initialize do PLUGIN_NAME = "discourse-ai" end - require_relative "lib/shared/inference_manager" + require_relative "lib/shared/inference/discourse_classifier" + require_relative "lib/shared/inference/discourse_reranker" + require_relative "lib/shared/inference/openai_completions" + require_relative "lib/shared/inference/openai_embeddings" + require_relative "lib/shared/classificator" require_relative "lib/shared/post_classificator" require_relative "lib/shared/chat_message_classificator"