REFACTOR: Tidy-up embedding endpoints config.
Two changes worth mentioning: `#instance` returns a fully configured embedding endpoint ready to use. All endpoints respond to the same method and have the same signature - `perform!(text)` This makes it easier to reuse them when generating embeddings in bulk.
This commit is contained in:
parent
1a10680818
commit
b0ab9ccc48
|
@ -313,7 +313,7 @@ module DiscourseAi
|
|||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings
|
||||
reranker = DiscourseAi::Inference::HuggingFaceText
|
||||
|
||||
interactions_vector = vector_rep.vector_from(consolidated_question)
|
||||
|
||||
|
@ -344,7 +344,7 @@ module DiscourseAi
|
|||
if reranker.reranker_configured?
|
||||
guidance = fragments.map { |fragment, _metadata| fragment }
|
||||
ranks =
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings
|
||||
DiscourseAi::Inference::HuggingFaceText
|
||||
.rerank(conversation_context.last[:content], guidance)
|
||||
.to_a
|
||||
.take(rag_conversation_chunks)
|
||||
|
|
|
@ -159,10 +159,7 @@ module DiscourseAi
|
|||
.map { _1.truncate(2000, omission: "") }
|
||||
|
||||
reranked_results =
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.rerank(
|
||||
search_term,
|
||||
rerank_posts_payload,
|
||||
)
|
||||
DiscourseAi::Inference::HuggingFaceText.rerank(search_term, rerank_posts_payload)
|
||||
|
||||
reordered_ids = reranked_results.map { _1[:index] }.map { filtered_results[_1].id }.take(5)
|
||||
|
||||
|
|
|
@ -24,12 +24,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||
"#{discourse_embeddings_endpoint}/api/v1/classify",
|
||||
self.class.name,
|
||||
text,
|
||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||
)
|
||||
inference_client.perform!(text)
|
||||
end
|
||||
|
||||
def dimensions
|
||||
|
@ -59,6 +54,10 @@ module DiscourseAi
|
|||
def tokenizer
|
||||
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
|
||||
end
|
||||
|
||||
def inference_client
|
||||
DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -426,16 +426,8 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def discourse_embeddings_endpoint
|
||||
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
|
||||
service =
|
||||
DiscourseAi::Utils::DnsSrv.lookup(
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
|
||||
)
|
||||
"https://#{service.target}:#{service.port}"
|
||||
else
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint
|
||||
end
|
||||
def inference_client
|
||||
raise NotImplementedError
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -11,7 +11,7 @@ module DiscourseAi
|
|||
|
||||
def correctly_configured?
|
||||
SiteSetting.ai_cloudflare_workers_api_token.present? ||
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
|
||||
DiscourseAi::Inference::TeiEmbeddings.configured? ||
|
||||
(
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||
|
@ -33,24 +33,12 @@ module DiscourseAi
|
|||
def vector_from(text, asymetric: false)
|
||||
text = "#{asymmetric_query_prefix} #{text}" if asymetric
|
||||
|
||||
if SiteSetting.ai_cloudflare_workers_api_token.present?
|
||||
DiscourseAi::Inference::CloudflareWorkersAi
|
||||
.perform!(inference_model_name, { text: text })
|
||||
.dig(:result, :data)
|
||||
.first
|
||||
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
||||
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
|
||||
elsif discourse_embeddings_endpoint.present?
|
||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||
"#{discourse_embeddings_endpoint}/api/v1/classify",
|
||||
inference_model_name.split("/").last,
|
||||
text,
|
||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||
)
|
||||
else
|
||||
raise "No inference endpoint configured"
|
||||
end
|
||||
client = inference_client
|
||||
|
||||
needs_truncation = client.class.name.include?("TeiEmbeddings")
|
||||
text = tokenizer.truncate(text, max_sequence_length - 2) if needs_truncation
|
||||
|
||||
inference_client.perform!(text)
|
||||
end
|
||||
|
||||
def inference_model_name
|
||||
|
@ -88,6 +76,21 @@ module DiscourseAi
|
|||
def asymmetric_query_prefix
|
||||
"Represent this sentence for searching relevant passages:"
|
||||
end
|
||||
|
||||
def inference_client
|
||||
if SiteSetting.ai_cloudflare_workers_api_token.present?
|
||||
DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name)
|
||||
elsif DiscourseAi::Inference::TeiEmbeddings.configured?
|
||||
DiscourseAi::Inference::TeiEmbeddings.instance
|
||||
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||
DiscourseAi::Inference::DiscourseClassifier.instance(
|
||||
inference_model_name.split("/").last,
|
||||
)
|
||||
else
|
||||
raise "No inference endpoint configured"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -10,7 +10,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def correctly_configured?
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
||||
DiscourseAi::Inference::TeiEmbeddings.configured?
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
|
@ -20,7 +20,7 @@ module DiscourseAi
|
|||
|
||||
def vector_from(text, asymetric: false)
|
||||
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
|
||||
inference_client.perform!(truncated_text)
|
||||
end
|
||||
|
||||
def dimensions
|
||||
|
@ -50,6 +50,10 @@ module DiscourseAi
|
|||
def tokenizer
|
||||
DiscourseAi::Tokenizer::BgeM3Tokenizer
|
||||
end
|
||||
|
||||
def inference_client
|
||||
DiscourseAi::Inference::TeiEmbeddings.instance
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -43,8 +43,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
response = DiscourseAi::Inference::GeminiEmbeddings.perform!(text)
|
||||
response[:embedding][:values]
|
||||
inference_client.perform!(text).dig(:embedding, :values)
|
||||
end
|
||||
|
||||
# There is no public tokenizer for Gemini, and from the ones we already ship in the plugin
|
||||
|
@ -53,6 +52,10 @@ module DiscourseAi
|
|||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
||||
def inference_client
|
||||
DiscourseAi::Inference::GeminiEmbeddings.instance
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -10,7 +10,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def correctly_configured?
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
|
||||
DiscourseAi::Inference::TeiEmbeddings.configured? ||
|
||||
(
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||
|
@ -29,19 +29,16 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
||||
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
|
||||
elsif discourse_embeddings_endpoint.present?
|
||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||
"#{discourse_embeddings_endpoint}/api/v1/classify",
|
||||
self.class.name,
|
||||
"query: #{text}",
|
||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||
)
|
||||
client = inference_client
|
||||
|
||||
needs_truncation = client.class.name.include?("TeiEmbeddings")
|
||||
if needs_truncation
|
||||
text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||
else
|
||||
raise "No inference endpoint configured"
|
||||
text = "query: #{text}"
|
||||
end
|
||||
|
||||
client.perform!(text)
|
||||
end
|
||||
|
||||
def id
|
||||
|
@ -71,6 +68,18 @@ module DiscourseAi
|
|||
def tokenizer
|
||||
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
|
||||
end
|
||||
|
||||
def inference_client
|
||||
if DiscourseAi::Inference::TeiEmbeddings.configured?
|
||||
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||
DiscourseAi::Inference::TeiEmbeddings.instance
|
||||
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||
DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name)
|
||||
else
|
||||
raise "No inference endpoint configured"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -45,18 +45,19 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
response =
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
|
||||
text,
|
||||
model: self.class.name,
|
||||
dimensions: dimensions,
|
||||
)
|
||||
response[:data].first[:embedding]
|
||||
inference_client.perform!(text)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
||||
def inference_client
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.instance(
|
||||
model: self.class.name,
|
||||
dimensions: dimensions,
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -43,13 +43,16 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
|
||||
response[:data].first[:embedding]
|
||||
inference_client.perform!(text)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
||||
def inference_client
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -43,13 +43,16 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def vector_from(text, asymetric: false)
|
||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
|
||||
response[:data].first[:embedding]
|
||||
inference_client.perform!(text)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
||||
def inference_client
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -3,25 +3,38 @@
|
|||
module ::DiscourseAi
|
||||
module Inference
|
||||
class CloudflareWorkersAi
|
||||
def self.perform!(model, content)
|
||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||
def initialize(account_id, api_token, model, referer = Discourse.base_url)
|
||||
@account_id = account_id
|
||||
@api_token = api_token
|
||||
@model = model
|
||||
@referer = referer
|
||||
end
|
||||
|
||||
account_id = SiteSetting.ai_cloudflare_workers_account_id
|
||||
token = SiteSetting.ai_cloudflare_workers_api_token
|
||||
def self.instance(model)
|
||||
new(
|
||||
SiteSetting.ai_cloudflare_workers_account_id,
|
||||
SiteSetting.ai_cloudflare_workers_api_token,
|
||||
model,
|
||||
)
|
||||
end
|
||||
|
||||
base_url = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/"
|
||||
headers["Authorization"] = "Bearer #{token}"
|
||||
attr_reader :account_id, :api_token, :model, :referer
|
||||
|
||||
endpoint = "#{base_url}#{model}"
|
||||
def perform!(content)
|
||||
headers = {
|
||||
"Referer" => Discourse.base_url,
|
||||
"Content-Type" => "application/json",
|
||||
"Authorization" => "Bearer #{api_token}",
|
||||
}
|
||||
|
||||
endpoint = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/#{model}"
|
||||
|
||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||
response = conn.post(endpoint, content.to_json, headers)
|
||||
|
||||
raise Net::HTTPBadResponse if ![200].include?(response.status)
|
||||
|
||||
case response.status
|
||||
when 200
|
||||
JSON.parse(response.body, symbolize_names: true)
|
||||
JSON.parse(response.body, symbolize_names: true).dig(:result, :data).first
|
||||
when 429
|
||||
# TODO add a AdminDashboard Problem?
|
||||
else
|
||||
|
|
|
@ -3,9 +3,36 @@
|
|||
module ::DiscourseAi
|
||||
module Inference
|
||||
class DiscourseClassifier
|
||||
def self.perform!(endpoint, model, content, api_key)
|
||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||
def initialize(endpoint, api_key, model, referer = Discourse.base_url)
|
||||
@endpoint = endpoint
|
||||
@api_key = api_key
|
||||
@model = model
|
||||
@referer = referer
|
||||
end
|
||||
|
||||
def self.instance(model)
|
||||
endpoint =
|
||||
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
|
||||
service =
|
||||
DiscourseAi::Utils::DnsSrv.lookup(
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
|
||||
)
|
||||
"https://#{service.target}:#{service.port}"
|
||||
else
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint
|
||||
end
|
||||
|
||||
new(
|
||||
"#{endpoint}/api/v1/classify",
|
||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||
model,
|
||||
)
|
||||
end
|
||||
|
||||
attr_reader :endpoint, :api_key, :model, :referer
|
||||
|
||||
def perform!(content)
|
||||
headers = { "Referer" => referer, "Content-Type" => "application/json" }
|
||||
headers["X-API-KEY"] = api_key if api_key.present?
|
||||
|
||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||
|
|
|
@ -3,12 +3,17 @@
|
|||
module ::DiscourseAi
|
||||
module Inference
|
||||
class GeminiEmbeddings
|
||||
def self.perform!(content)
|
||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||
def initialize(api_key, referer = Discourse.base_url)
|
||||
@api_key = api_key
|
||||
@referer = referer
|
||||
end
|
||||
|
||||
attr_reader :api_key, :referer
|
||||
|
||||
def perform!(content)
|
||||
headers = { "Referer" => referer, "Content-Type" => "application/json" }
|
||||
url =
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{SiteSetting.ai_gemini_api_key}"
|
||||
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{api_key}"
|
||||
body = { content: { parts: [{ text: content }] } }
|
||||
|
||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||
|
|
|
@ -2,33 +2,8 @@
|
|||
|
||||
module ::DiscourseAi
|
||||
module Inference
|
||||
class HuggingFaceTextEmbeddings
|
||||
class HuggingFaceText
|
||||
class << self
|
||||
def perform!(content)
|
||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||
body = { inputs: content, truncate: true }.to_json
|
||||
|
||||
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
||||
service =
|
||||
DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
|
||||
api_endpoint = "https://#{service.target}:#{service.port}"
|
||||
else
|
||||
api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint
|
||||
end
|
||||
|
||||
if SiteSetting.ai_hugging_face_tei_api_key.present?
|
||||
headers["X-API-KEY"] = SiteSetting.ai_hugging_face_tei_api_key
|
||||
headers["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_tei_api_key}"
|
||||
end
|
||||
|
||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||
response = conn.post(api_endpoint, body, headers)
|
||||
|
||||
raise Net::HTTPBadResponse if ![200].include?(response.status)
|
||||
|
||||
JSON.parse(response.body, symbolize_names: true)
|
||||
end
|
||||
|
||||
def rerank(content, candidates)
|
||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||
body = { query: content, texts: candidates, truncate: true }.to_json
|
||||
|
@ -85,11 +60,6 @@ module ::DiscourseAi
|
|||
SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? ||
|
||||
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
|
||||
end
|
||||
|
||||
def configured?
|
||||
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
|
||||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -3,13 +3,26 @@
|
|||
module ::DiscourseAi
|
||||
module Inference
|
||||
class OpenAiEmbeddings
|
||||
def self.perform!(content, model:, dimensions: nil)
|
||||
def initialize(endpoint, api_key, model, dimensions)
|
||||
@endpoint = endpoint
|
||||
@api_key = api_key
|
||||
@model = model
|
||||
@dimensions = dimensions
|
||||
end
|
||||
|
||||
attr_reader :endpoint, :api_key, :model, :dimensions
|
||||
|
||||
def self.instance(model:, dimensions: nil)
|
||||
new(SiteSetting.ai_openai_embeddings_url, SiteSetting.ai_openai_api_key, model, dimensions)
|
||||
end
|
||||
|
||||
def perform!(content)
|
||||
headers = { "Content-Type" => "application/json" }
|
||||
|
||||
if SiteSetting.ai_openai_embeddings_url.include?("azure")
|
||||
headers["api-key"] = SiteSetting.ai_openai_api_key
|
||||
if endpoint.include?("azure")
|
||||
headers["api-key"] = api_key
|
||||
else
|
||||
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
|
||||
headers["Authorization"] = "Bearer #{api_key}"
|
||||
end
|
||||
|
||||
payload = { model: model, input: content }
|
||||
|
@ -20,7 +33,7 @@ module ::DiscourseAi
|
|||
|
||||
case response.status
|
||||
when 200
|
||||
JSON.parse(response.body, symbolize_names: true)
|
||||
JSON.parse(response.body, symbolize_names: true).dig(:data, 0, :embedding)
|
||||
when 429
|
||||
# TODO add a AdminDashboard Problem?
|
||||
else
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Inference
|
||||
class TeiEmbeddings
|
||||
def initialize(endpoint, key, referer = Discourse.base_url)
|
||||
@endpoint = endpoint
|
||||
@key = key
|
||||
@referer = referer
|
||||
end
|
||||
|
||||
attr_reader :endpoint, :key, :referer
|
||||
|
||||
class << self
|
||||
def configured?
|
||||
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
|
||||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
||||
end
|
||||
|
||||
def instance
|
||||
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
||||
service =
|
||||
DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
|
||||
endpoint = "https://#{service.target}:#{service.port}"
|
||||
else
|
||||
endpoint = SiteSetting.ai_hugging_face_tei_endpoint
|
||||
end
|
||||
|
||||
new(endpoint, SiteSetting.ai_hugging_face_tei_api_key)
|
||||
end
|
||||
end
|
||||
|
||||
def perform!(content)
|
||||
headers = { "Referer" => referer, "Content-Type" => "application/json" }
|
||||
body = { inputs: content, truncate: true }.to_json
|
||||
|
||||
if key.present?
|
||||
headers["X-API-KEY"] = key
|
||||
headers["Authorization"] = "Bearer #{key}"
|
||||
end
|
||||
|
||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||
response = conn.post(endpoint, body, headers)
|
||||
|
||||
raise Net::HTTPBadResponse if ![200].include?(response.status)
|
||||
|
||||
JSON.parse(response.body, symbolize_names: true).first
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -54,12 +54,11 @@ module DiscourseAi
|
|||
upload_url = Discourse.store.cdn_url(upload.url)
|
||||
upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/")
|
||||
|
||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||
DiscourseAi::Inference::DiscourseClassifier.new(
|
||||
"#{endpoint}/api/v1/classify",
|
||||
model,
|
||||
upload_url,
|
||||
SiteSetting.ai_nsfw_inference_service_api_key,
|
||||
)
|
||||
model,
|
||||
).perform!(upload_url)
|
||||
end
|
||||
|
||||
def available_models
|
||||
|
|
|
@ -45,7 +45,7 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def request_with(content, model_config)
|
||||
result = ::DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, model_config)
|
||||
result = ::DiscourseAi::Inference::HuggingFaceText.classify(content, model_config)
|
||||
transform_result(result)
|
||||
end
|
||||
|
||||
|
|
|
@ -42,12 +42,11 @@ module DiscourseAi
|
|||
|
||||
def request(target_to_classify)
|
||||
data =
|
||||
::DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||
::DiscourseAi::Inference::DiscourseClassifier.new(
|
||||
"#{endpoint}/api/v1/classify",
|
||||
SiteSetting.ai_toxicity_inference_service_api_model,
|
||||
content_of(target_to_classify),
|
||||
SiteSetting.ai_toxicity_inference_service_api_key,
|
||||
)
|
||||
SiteSetting.ai_toxicity_inference_service_api_model,
|
||||
).perform!(content_of(target_to_classify))
|
||||
|
||||
{ available_model => data }
|
||||
end
|
||||
|
|
|
@ -26,10 +26,11 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
|
|||
).to_return(status: 200, body: body_json, headers: {})
|
||||
|
||||
result =
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello", model: "text-embedding-ada-002")
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: "text-embedding-ada-002").perform!(
|
||||
"hello",
|
||||
)
|
||||
|
||||
expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 })
|
||||
expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] })
|
||||
expect(result).to eq([0.0, 0.1])
|
||||
end
|
||||
|
||||
it "supports openai embeddings" do
|
||||
|
@ -54,13 +55,11 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
|
|||
).to_return(status: 200, body: body_json, headers: {})
|
||||
|
||||
result =
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
|
||||
"hello",
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.instance(
|
||||
model: "text-embedding-ada-002",
|
||||
dimensions: 1000,
|
||||
)
|
||||
).perform!("hello")
|
||||
|
||||
expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 })
|
||||
expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] })
|
||||
expect(result).to eq([0.0, 0.1])
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue