DEV: Preparation work for multiple inference providers (#5)
This commit is contained in:
parent
a838116cd5
commit
510c6487e3
|
@ -84,3 +84,6 @@ plugins:
|
|||
choices:
|
||||
- opennsfw2
|
||||
- nsfw_detector
|
||||
|
||||
ai_openai_api_key:
|
||||
default: ""
|
||||
|
|
|
@ -52,7 +52,7 @@ module DiscourseAI
|
|||
upload_url = Discourse.store.cdn_url(upload.url)
|
||||
upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/")
|
||||
|
||||
DiscourseAI::InferenceManager.perform!(
|
||||
DiscourseAI::Inference::DiscourseClassifier.perform!(
|
||||
"#{SiteSetting.ai_nsfw_inference_service_api_endpoint}/api/v1/classify",
|
||||
model,
|
||||
upload_url,
|
||||
|
|
|
@ -39,7 +39,7 @@ module DiscourseAI
|
|||
private
|
||||
|
||||
def request_with(model, content)
|
||||
::DiscourseAI::InferenceManager.perform!(
|
||||
::DiscourseAI::Inference::DiscourseClassifier.perform!(
|
||||
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify",
|
||||
model,
|
||||
content,
|
||||
|
|
|
@ -42,7 +42,7 @@ module DiscourseAI
|
|||
|
||||
def request(target_to_classify)
|
||||
data =
|
||||
::DiscourseAI::InferenceManager.perform!(
|
||||
::DiscourseAI::Inference::DiscourseClassifier.perform!(
|
||||
"#{SiteSetting.ai_toxicity_inference_service_api_endpoint}/api/v1/classify",
|
||||
SiteSetting.ai_toxicity_inference_service_api_model,
|
||||
content_of(target_to_classify),
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAI
|
||||
module Inference
|
||||
class DiscourseClassifier
|
||||
def self.perform!(endpoint, model, content, api_key)
|
||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||
|
||||
headers["X-API-KEY"] = api_key if api_key.present?
|
||||
|
||||
response = Faraday.post(endpoint, { model: model, content: content }.to_json, headers)
|
||||
|
||||
raise Net::HTTPBadResponse unless response.status == 200
|
||||
|
||||
JSON.parse(response.body, symbolize_names: true)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,24 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAI
|
||||
module Inference
|
||||
class DiscourseReranker
|
||||
def self.perform!(endpoint, model, content, candidates, api_key)
|
||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||
|
||||
headers["X-API-KEY"] = api_key if api_key.present?
|
||||
|
||||
response =
|
||||
Faraday.post(
|
||||
endpoint,
|
||||
{ model: model, content: content, candidates: candidates }.to_json,
|
||||
headers,
|
||||
)
|
||||
|
||||
raise Net::HTTPBadResponse unless response.status == 200
|
||||
|
||||
JSON.parse(response.body, symbolize_names: true)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,27 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAI
|
||||
module Inference
|
||||
class OpenAICompletions
|
||||
def self.perform!(model, content, api_key)
|
||||
headers = {
|
||||
"Authorization" => "Bearer #{SiteSetting.ai_openai_api_key}",
|
||||
"Content-Type" => "application/json",
|
||||
}
|
||||
|
||||
model ||= "gpt-3.5-turbo"
|
||||
|
||||
response =
|
||||
Faraday.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
{ model: model, messages: content }.to_json,
|
||||
headers,
|
||||
)
|
||||
|
||||
raise Net::HTTPBadResponse unless response.status == 200
|
||||
|
||||
JSON.parse(response.body, symbolize_names: true)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,27 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAI
|
||||
module Inference
|
||||
class OpenAIEmbeddings
|
||||
def self.perform!(content, model = nil)
|
||||
headers = {
|
||||
"Authorization" => "Bearer #{SiteSetting.ai_openai_api_key}",
|
||||
"Content-Type" => "application/json",
|
||||
}
|
||||
|
||||
model ||= "text-embedding-ada-002"
|
||||
|
||||
response =
|
||||
Faraday.post(
|
||||
"https://api.openai.com/v1/embeddings",
|
||||
{ model: model, input: content }.to_json,
|
||||
headers,
|
||||
)
|
||||
|
||||
raise Net::HTTPBadResponse unless response.status == 200
|
||||
|
||||
JSON.parse(response.body, symbolize_names: true)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,17 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module ::DiscourseAI
|
||||
class InferenceManager
|
||||
def self.perform!(endpoint, model, content, api_key)
|
||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||
|
||||
headers["X-API-KEY"] = api_key if api_key.present?
|
||||
|
||||
response = Faraday.post(endpoint, { model: model, content: content }.to_json, headers)
|
||||
|
||||
raise Net::HTTPBadResponse unless response.status == 200
|
||||
|
||||
JSON.parse(response.body, symbolize_names: true)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,7 +1,7 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
def classify(content)
|
||||
::DiscourseAI::InferenceManager.perform!(
|
||||
::DiscourseAI::Inference::DiscourseClassifier.perform!(
|
||||
"#{SiteSetting.ai_toxicity_inference_service_api_endpoint}/api/v1/classify",
|
||||
SiteSetting.ai_toxicity_inference_service_api_model,
|
||||
content,
|
||||
|
|
|
@ -16,7 +16,11 @@ after_initialize do
|
|||
PLUGIN_NAME = "discourse-ai"
|
||||
end
|
||||
|
||||
require_relative "lib/shared/inference_manager"
|
||||
require_relative "lib/shared/inference/discourse_classifier"
|
||||
require_relative "lib/shared/inference/discourse_reranker"
|
||||
require_relative "lib/shared/inference/openai_completions"
|
||||
require_relative "lib/shared/inference/openai_embeddings"
|
||||
|
||||
require_relative "lib/shared/classificator"
|
||||
require_relative "lib/shared/post_classificator"
|
||||
require_relative "lib/shared/chat_message_classificator"
|
||||
|
|
Loading…
Reference in New Issue