DEV: Preparation work for multiple inference providers (#5)

This commit is contained in:
Rafael dos Santos Silva 2023-03-07 16:14:39 -03:00 committed by GitHub
parent a838116cd5
commit 510c6487e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 109 additions and 22 deletions

View File

@ -84,3 +84,6 @@ plugins:
choices:
- opennsfw2
- nsfw_detector
ai_openai_api_key:
default: ""

View File

@ -52,7 +52,7 @@ module DiscourseAI
upload_url = Discourse.store.cdn_url(upload.url)
upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/")
DiscourseAI::InferenceManager.perform!(
DiscourseAI::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_nsfw_inference_service_api_endpoint}/api/v1/classify",
model,
upload_url,

View File

@ -39,7 +39,7 @@ module DiscourseAI
private
def request_with(model, content)
::DiscourseAI::InferenceManager.perform!(
::DiscourseAI::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify",
model,
content,

View File

@ -42,7 +42,7 @@ module DiscourseAI
def request(target_to_classify)
data =
::DiscourseAI::InferenceManager.perform!(
::DiscourseAI::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_toxicity_inference_service_api_endpoint}/api/v1/classify",
SiteSetting.ai_toxicity_inference_service_api_model,
content_of(target_to_classify),

View File

@ -0,0 +1,19 @@
# frozen_string_literal: true
module ::DiscourseAI
module Inference
class DiscourseClassifier
def self.perform!(endpoint, model, content, api_key)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
headers["X-API-KEY"] = api_key if api_key.present?
response = Faraday.post(endpoint, { model: model, content: content }.to_json, headers)
raise Net::HTTPBadResponse unless response.status == 200
JSON.parse(response.body, symbolize_names: true)
end
end
end
end

View File

@ -0,0 +1,24 @@
# frozen_string_literal: true
module ::DiscourseAI
module Inference
class DiscourseReranker
def self.perform!(endpoint, model, content, candidates, api_key)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
headers["X-API-KEY"] = api_key if api_key.present?
response =
Faraday.post(
endpoint,
{ model: model, content: content, candidates: candidates }.to_json,
headers,
)
raise Net::HTTPBadResponse unless response.status == 200
JSON.parse(response.body, symbolize_names: true)
end
end
end
end

View File

@ -0,0 +1,27 @@
# frozen_string_literal: true
module ::DiscourseAI
module Inference
class OpenAICompletions
def self.perform!(model, content, api_key)
headers = {
"Authorization" => "Bearer #{SiteSetting.ai_openai_api_key}",
"Content-Type" => "application/json",
}
model ||= "gpt-3.5-turbo"
response =
Faraday.post(
"https://api.openai.com/v1/chat/completions",
{ model: model, messages: content }.to_json,
headers,
)
raise Net::HTTPBadResponse unless response.status == 200
JSON.parse(response.body, symbolize_names: true)
end
end
end
end

View File

@ -0,0 +1,27 @@
# frozen_string_literal: true
module ::DiscourseAI
module Inference
class OpenAIEmbeddings
def self.perform!(content, model = nil)
headers = {
"Authorization" => "Bearer #{SiteSetting.ai_openai_api_key}",
"Content-Type" => "application/json",
}
model ||= "text-embedding-ada-002"
response =
Faraday.post(
"https://api.openai.com/v1/embeddings",
{ model: model, input: content }.to_json,
headers,
)
raise Net::HTTPBadResponse unless response.status == 200
JSON.parse(response.body, symbolize_names: true)
end
end
end
end

View File

@ -1,17 +0,0 @@
# frozen_string_literal: true
module ::DiscourseAI
class InferenceManager
def self.perform!(endpoint, model, content, api_key)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
headers["X-API-KEY"] = api_key if api_key.present?
response = Faraday.post(endpoint, { model: model, content: content }.to_json, headers)
raise Net::HTTPBadResponse unless response.status == 200
JSON.parse(response.body, symbolize_names: true)
end
end
end

View File

@ -1,7 +1,7 @@
# frozen_string_literal: true
def classify(content)
::DiscourseAI::InferenceManager.perform!(
::DiscourseAI::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_toxicity_inference_service_api_endpoint}/api/v1/classify",
SiteSetting.ai_toxicity_inference_service_api_model,
content,

View File

@ -16,7 +16,11 @@ after_initialize do
PLUGIN_NAME = "discourse-ai"
end
require_relative "lib/shared/inference_manager"
require_relative "lib/shared/inference/discourse_classifier"
require_relative "lib/shared/inference/discourse_reranker"
require_relative "lib/shared/inference/openai_completions"
require_relative "lib/shared/inference/openai_embeddings"
require_relative "lib/shared/classificator"
require_relative "lib/shared/post_classificator"
require_relative "lib/shared/chat_message_classificator"