REFACTOR: Tidy-up embedding endpoints config.
Two changes worth mentioning: `#instance` returns a fully configured embedding endpoint ready to use. All endpoints respond to the same method and have the same signature - `perform!(text)` This makes it easier to reuse them when generating embeddings in bulk.
This commit is contained in:
parent
1a10680818
commit
b0ab9ccc48
|
@ -313,7 +313,7 @@ module DiscourseAi
|
||||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
vector_rep =
|
vector_rep =
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||||
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings
|
reranker = DiscourseAi::Inference::HuggingFaceText
|
||||||
|
|
||||||
interactions_vector = vector_rep.vector_from(consolidated_question)
|
interactions_vector = vector_rep.vector_from(consolidated_question)
|
||||||
|
|
||||||
|
@ -344,7 +344,7 @@ module DiscourseAi
|
||||||
if reranker.reranker_configured?
|
if reranker.reranker_configured?
|
||||||
guidance = fragments.map { |fragment, _metadata| fragment }
|
guidance = fragments.map { |fragment, _metadata| fragment }
|
||||||
ranks =
|
ranks =
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings
|
DiscourseAi::Inference::HuggingFaceText
|
||||||
.rerank(conversation_context.last[:content], guidance)
|
.rerank(conversation_context.last[:content], guidance)
|
||||||
.to_a
|
.to_a
|
||||||
.take(rag_conversation_chunks)
|
.take(rag_conversation_chunks)
|
||||||
|
|
|
@ -159,10 +159,7 @@ module DiscourseAi
|
||||||
.map { _1.truncate(2000, omission: "") }
|
.map { _1.truncate(2000, omission: "") }
|
||||||
|
|
||||||
reranked_results =
|
reranked_results =
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.rerank(
|
DiscourseAi::Inference::HuggingFaceText.rerank(search_term, rerank_posts_payload)
|
||||||
search_term,
|
|
||||||
rerank_posts_payload,
|
|
||||||
)
|
|
||||||
|
|
||||||
reordered_ids = reranked_results.map { _1[:index] }.map { filtered_results[_1].id }.take(5)
|
reordered_ids = reranked_results.map { _1[:index] }.map { filtered_results[_1].id }.take(5)
|
||||||
|
|
||||||
|
|
|
@ -24,12 +24,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text, asymetric: false)
|
def vector_from(text, asymetric: false)
|
||||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
inference_client.perform!(text)
|
||||||
"#{discourse_embeddings_endpoint}/api/v1/classify",
|
|
||||||
self.class.name,
|
|
||||||
text,
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
|
||||||
)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def dimensions
|
def dimensions
|
||||||
|
@ -59,6 +54,10 @@ module DiscourseAi
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
|
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def inference_client
|
||||||
|
DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -426,16 +426,8 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def discourse_embeddings_endpoint
|
def inference_client
|
||||||
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
|
raise NotImplementedError
|
||||||
service =
|
|
||||||
DiscourseAi::Utils::DnsSrv.lookup(
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
|
|
||||||
)
|
|
||||||
"https://#{service.target}:#{service.port}"
|
|
||||||
else
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -11,7 +11,7 @@ module DiscourseAi
|
||||||
|
|
||||||
def correctly_configured?
|
def correctly_configured?
|
||||||
SiteSetting.ai_cloudflare_workers_api_token.present? ||
|
SiteSetting.ai_cloudflare_workers_api_token.present? ||
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
|
DiscourseAi::Inference::TeiEmbeddings.configured? ||
|
||||||
(
|
(
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||||
|
@ -33,24 +33,12 @@ module DiscourseAi
|
||||||
def vector_from(text, asymetric: false)
|
def vector_from(text, asymetric: false)
|
||||||
text = "#{asymmetric_query_prefix} #{text}" if asymetric
|
text = "#{asymmetric_query_prefix} #{text}" if asymetric
|
||||||
|
|
||||||
if SiteSetting.ai_cloudflare_workers_api_token.present?
|
client = inference_client
|
||||||
DiscourseAi::Inference::CloudflareWorkersAi
|
|
||||||
.perform!(inference_model_name, { text: text })
|
needs_truncation = client.class.name.include?("TeiEmbeddings")
|
||||||
.dig(:result, :data)
|
text = tokenizer.truncate(text, max_sequence_length - 2) if needs_truncation
|
||||||
.first
|
|
||||||
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
inference_client.perform!(text)
|
||||||
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
|
|
||||||
elsif discourse_embeddings_endpoint.present?
|
|
||||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
|
||||||
"#{discourse_embeddings_endpoint}/api/v1/classify",
|
|
||||||
inference_model_name.split("/").last,
|
|
||||||
text,
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
|
||||||
)
|
|
||||||
else
|
|
||||||
raise "No inference endpoint configured"
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def inference_model_name
|
def inference_model_name
|
||||||
|
@ -88,6 +76,21 @@ module DiscourseAi
|
||||||
def asymmetric_query_prefix
|
def asymmetric_query_prefix
|
||||||
"Represent this sentence for searching relevant passages:"
|
"Represent this sentence for searching relevant passages:"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def inference_client
|
||||||
|
if SiteSetting.ai_cloudflare_workers_api_token.present?
|
||||||
|
DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name)
|
||||||
|
elsif DiscourseAi::Inference::TeiEmbeddings.configured?
|
||||||
|
DiscourseAi::Inference::TeiEmbeddings.instance
|
||||||
|
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
||||||
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||||
|
DiscourseAi::Inference::DiscourseClassifier.instance(
|
||||||
|
inference_model_name.split("/").last,
|
||||||
|
)
|
||||||
|
else
|
||||||
|
raise "No inference endpoint configured"
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -10,7 +10,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def correctly_configured?
|
def correctly_configured?
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
DiscourseAi::Inference::TeiEmbeddings.configured?
|
||||||
end
|
end
|
||||||
|
|
||||||
def dependant_setting_names
|
def dependant_setting_names
|
||||||
|
@ -20,7 +20,7 @@ module DiscourseAi
|
||||||
|
|
||||||
def vector_from(text, asymetric: false)
|
def vector_from(text, asymetric: false)
|
||||||
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
|
inference_client.perform!(truncated_text)
|
||||||
end
|
end
|
||||||
|
|
||||||
def dimensions
|
def dimensions
|
||||||
|
@ -50,6 +50,10 @@ module DiscourseAi
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::BgeM3Tokenizer
|
DiscourseAi::Tokenizer::BgeM3Tokenizer
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def inference_client
|
||||||
|
DiscourseAi::Inference::TeiEmbeddings.instance
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -43,8 +43,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text, asymetric: false)
|
def vector_from(text, asymetric: false)
|
||||||
response = DiscourseAi::Inference::GeminiEmbeddings.perform!(text)
|
inference_client.perform!(text).dig(:embedding, :values)
|
||||||
response[:embedding][:values]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
# There is no public tokenizer for Gemini, and from the ones we already ship in the plugin
|
# There is no public tokenizer for Gemini, and from the ones we already ship in the plugin
|
||||||
|
@ -53,6 +52,10 @@ module DiscourseAi
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def inference_client
|
||||||
|
DiscourseAi::Inference::GeminiEmbeddings.instance
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -10,7 +10,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def correctly_configured?
|
def correctly_configured?
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
|
DiscourseAi::Inference::TeiEmbeddings.configured? ||
|
||||||
(
|
(
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||||
|
@ -29,19 +29,16 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text, asymetric: false)
|
def vector_from(text, asymetric: false)
|
||||||
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
client = inference_client
|
||||||
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
|
needs_truncation = client.class.name.include?("TeiEmbeddings")
|
||||||
elsif discourse_embeddings_endpoint.present?
|
if needs_truncation
|
||||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||||
"#{discourse_embeddings_endpoint}/api/v1/classify",
|
|
||||||
self.class.name,
|
|
||||||
"query: #{text}",
|
|
||||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
|
||||||
)
|
|
||||||
else
|
else
|
||||||
raise "No inference endpoint configured"
|
text = "query: #{text}"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
client.perform!(text)
|
||||||
end
|
end
|
||||||
|
|
||||||
def id
|
def id
|
||||||
|
@ -71,6 +68,18 @@ module DiscourseAi
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
|
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def inference_client
|
||||||
|
if DiscourseAi::Inference::TeiEmbeddings.configured?
|
||||||
|
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||||
|
DiscourseAi::Inference::TeiEmbeddings.instance
|
||||||
|
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
||||||
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||||
|
DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name)
|
||||||
|
else
|
||||||
|
raise "No inference endpoint configured"
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -45,18 +45,19 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text, asymetric: false)
|
def vector_from(text, asymetric: false)
|
||||||
response =
|
inference_client.perform!(text)
|
||||||
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
|
|
||||||
text,
|
|
||||||
model: self.class.name,
|
|
||||||
dimensions: dimensions,
|
|
||||||
)
|
|
||||||
response[:data].first[:embedding]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def inference_client
|
||||||
|
DiscourseAi::Inference::OpenAiEmbeddings.instance(
|
||||||
|
model: self.class.name,
|
||||||
|
dimensions: dimensions,
|
||||||
|
)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -43,13 +43,16 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text, asymetric: false)
|
def vector_from(text, asymetric: false)
|
||||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
|
inference_client.perform!(text)
|
||||||
response[:data].first[:embedding]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def inference_client
|
||||||
|
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -43,13 +43,16 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text, asymetric: false)
|
def vector_from(text, asymetric: false)
|
||||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
|
inference_client.perform!(text)
|
||||||
response[:data].first[:embedding]
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def inference_client
|
||||||
|
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,25 +3,38 @@
|
||||||
module ::DiscourseAi
|
module ::DiscourseAi
|
||||||
module Inference
|
module Inference
|
||||||
class CloudflareWorkersAi
|
class CloudflareWorkersAi
|
||||||
def self.perform!(model, content)
|
def initialize(account_id, api_token, model, referer = Discourse.base_url)
|
||||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
@account_id = account_id
|
||||||
|
@api_token = api_token
|
||||||
|
@model = model
|
||||||
|
@referer = referer
|
||||||
|
end
|
||||||
|
|
||||||
account_id = SiteSetting.ai_cloudflare_workers_account_id
|
def self.instance(model)
|
||||||
token = SiteSetting.ai_cloudflare_workers_api_token
|
new(
|
||||||
|
SiteSetting.ai_cloudflare_workers_account_id,
|
||||||
|
SiteSetting.ai_cloudflare_workers_api_token,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
base_url = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/"
|
attr_reader :account_id, :api_token, :model, :referer
|
||||||
headers["Authorization"] = "Bearer #{token}"
|
|
||||||
|
|
||||||
endpoint = "#{base_url}#{model}"
|
def perform!(content)
|
||||||
|
headers = {
|
||||||
|
"Referer" => Discourse.base_url,
|
||||||
|
"Content-Type" => "application/json",
|
||||||
|
"Authorization" => "Bearer #{api_token}",
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/#{model}"
|
||||||
|
|
||||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||||
response = conn.post(endpoint, content.to_json, headers)
|
response = conn.post(endpoint, content.to_json, headers)
|
||||||
|
|
||||||
raise Net::HTTPBadResponse if ![200].include?(response.status)
|
|
||||||
|
|
||||||
case response.status
|
case response.status
|
||||||
when 200
|
when 200
|
||||||
JSON.parse(response.body, symbolize_names: true)
|
JSON.parse(response.body, symbolize_names: true).dig(:result, :data).first
|
||||||
when 429
|
when 429
|
||||||
# TODO add a AdminDashboard Problem?
|
# TODO add a AdminDashboard Problem?
|
||||||
else
|
else
|
||||||
|
|
|
@ -3,9 +3,36 @@
|
||||||
module ::DiscourseAi
|
module ::DiscourseAi
|
||||||
module Inference
|
module Inference
|
||||||
class DiscourseClassifier
|
class DiscourseClassifier
|
||||||
def self.perform!(endpoint, model, content, api_key)
|
def initialize(endpoint, api_key, model, referer = Discourse.base_url)
|
||||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
@endpoint = endpoint
|
||||||
|
@api_key = api_key
|
||||||
|
@model = model
|
||||||
|
@referer = referer
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.instance(model)
|
||||||
|
endpoint =
|
||||||
|
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
|
||||||
|
service =
|
||||||
|
DiscourseAi::Utils::DnsSrv.lookup(
|
||||||
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
|
||||||
|
)
|
||||||
|
"https://#{service.target}:#{service.port}"
|
||||||
|
else
|
||||||
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint
|
||||||
|
end
|
||||||
|
|
||||||
|
new(
|
||||||
|
"#{endpoint}/api/v1/classify",
|
||||||
|
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
attr_reader :endpoint, :api_key, :model, :referer
|
||||||
|
|
||||||
|
def perform!(content)
|
||||||
|
headers = { "Referer" => referer, "Content-Type" => "application/json" }
|
||||||
headers["X-API-KEY"] = api_key if api_key.present?
|
headers["X-API-KEY"] = api_key if api_key.present?
|
||||||
|
|
||||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||||
|
|
|
@ -3,12 +3,17 @@
|
||||||
module ::DiscourseAi
|
module ::DiscourseAi
|
||||||
module Inference
|
module Inference
|
||||||
class GeminiEmbeddings
|
class GeminiEmbeddings
|
||||||
def self.perform!(content)
|
def initialize(api_key, referer = Discourse.base_url)
|
||||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
@api_key = api_key
|
||||||
|
@referer = referer
|
||||||
|
end
|
||||||
|
|
||||||
|
attr_reader :api_key, :referer
|
||||||
|
|
||||||
|
def perform!(content)
|
||||||
|
headers = { "Referer" => referer, "Content-Type" => "application/json" }
|
||||||
url =
|
url =
|
||||||
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{SiteSetting.ai_gemini_api_key}"
|
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{api_key}"
|
||||||
|
|
||||||
body = { content: { parts: [{ text: content }] } }
|
body = { content: { parts: [{ text: content }] } }
|
||||||
|
|
||||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||||
|
|
|
@ -2,33 +2,8 @@
|
||||||
|
|
||||||
module ::DiscourseAi
|
module ::DiscourseAi
|
||||||
module Inference
|
module Inference
|
||||||
class HuggingFaceTextEmbeddings
|
class HuggingFaceText
|
||||||
class << self
|
class << self
|
||||||
def perform!(content)
|
|
||||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
|
||||||
body = { inputs: content, truncate: true }.to_json
|
|
||||||
|
|
||||||
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
|
||||||
service =
|
|
||||||
DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
|
|
||||||
api_endpoint = "https://#{service.target}:#{service.port}"
|
|
||||||
else
|
|
||||||
api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint
|
|
||||||
end
|
|
||||||
|
|
||||||
if SiteSetting.ai_hugging_face_tei_api_key.present?
|
|
||||||
headers["X-API-KEY"] = SiteSetting.ai_hugging_face_tei_api_key
|
|
||||||
headers["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_tei_api_key}"
|
|
||||||
end
|
|
||||||
|
|
||||||
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
|
||||||
response = conn.post(api_endpoint, body, headers)
|
|
||||||
|
|
||||||
raise Net::HTTPBadResponse if ![200].include?(response.status)
|
|
||||||
|
|
||||||
JSON.parse(response.body, symbolize_names: true)
|
|
||||||
end
|
|
||||||
|
|
||||||
def rerank(content, candidates)
|
def rerank(content, candidates)
|
||||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||||
body = { query: content, texts: candidates, truncate: true }.to_json
|
body = { query: content, texts: candidates, truncate: true }.to_json
|
||||||
|
@ -85,11 +60,6 @@ module ::DiscourseAi
|
||||||
SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? ||
|
SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? ||
|
||||||
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
|
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
|
||||||
end
|
end
|
||||||
|
|
||||||
def configured?
|
|
||||||
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
|
|
||||||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
|
@ -3,13 +3,26 @@
|
||||||
module ::DiscourseAi
|
module ::DiscourseAi
|
||||||
module Inference
|
module Inference
|
||||||
class OpenAiEmbeddings
|
class OpenAiEmbeddings
|
||||||
def self.perform!(content, model:, dimensions: nil)
|
def initialize(endpoint, api_key, model, dimensions)
|
||||||
|
@endpoint = endpoint
|
||||||
|
@api_key = api_key
|
||||||
|
@model = model
|
||||||
|
@dimensions = dimensions
|
||||||
|
end
|
||||||
|
|
||||||
|
attr_reader :endpoint, :api_key, :model, :dimensions
|
||||||
|
|
||||||
|
def self.instance(model:, dimensions: nil)
|
||||||
|
new(SiteSetting.ai_openai_embeddings_url, SiteSetting.ai_openai_api_key, model, dimensions)
|
||||||
|
end
|
||||||
|
|
||||||
|
def perform!(content)
|
||||||
headers = { "Content-Type" => "application/json" }
|
headers = { "Content-Type" => "application/json" }
|
||||||
|
|
||||||
if SiteSetting.ai_openai_embeddings_url.include?("azure")
|
if endpoint.include?("azure")
|
||||||
headers["api-key"] = SiteSetting.ai_openai_api_key
|
headers["api-key"] = api_key
|
||||||
else
|
else
|
||||||
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
|
headers["Authorization"] = "Bearer #{api_key}"
|
||||||
end
|
end
|
||||||
|
|
||||||
payload = { model: model, input: content }
|
payload = { model: model, input: content }
|
||||||
|
@ -20,7 +33,7 @@ module ::DiscourseAi
|
||||||
|
|
||||||
case response.status
|
case response.status
|
||||||
when 200
|
when 200
|
||||||
JSON.parse(response.body, symbolize_names: true)
|
JSON.parse(response.body, symbolize_names: true).dig(:data, 0, :embedding)
|
||||||
when 429
|
when 429
|
||||||
# TODO add a AdminDashboard Problem?
|
# TODO add a AdminDashboard Problem?
|
||||||
else
|
else
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Inference
|
||||||
|
class TeiEmbeddings
|
||||||
|
def initialize(endpoint, key, referer = Discourse.base_url)
|
||||||
|
@endpoint = endpoint
|
||||||
|
@key = key
|
||||||
|
@referer = referer
|
||||||
|
end
|
||||||
|
|
||||||
|
attr_reader :endpoint, :key, :referer
|
||||||
|
|
||||||
|
class << self
|
||||||
|
def configured?
|
||||||
|
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
|
||||||
|
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
||||||
|
end
|
||||||
|
|
||||||
|
def instance
|
||||||
|
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
||||||
|
service =
|
||||||
|
DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
|
||||||
|
endpoint = "https://#{service.target}:#{service.port}"
|
||||||
|
else
|
||||||
|
endpoint = SiteSetting.ai_hugging_face_tei_endpoint
|
||||||
|
end
|
||||||
|
|
||||||
|
new(endpoint, SiteSetting.ai_hugging_face_tei_api_key)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def perform!(content)
|
||||||
|
headers = { "Referer" => referer, "Content-Type" => "application/json" }
|
||||||
|
body = { inputs: content, truncate: true }.to_json
|
||||||
|
|
||||||
|
if key.present?
|
||||||
|
headers["X-API-KEY"] = key
|
||||||
|
headers["Authorization"] = "Bearer #{key}"
|
||||||
|
end
|
||||||
|
|
||||||
|
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
||||||
|
response = conn.post(endpoint, body, headers)
|
||||||
|
|
||||||
|
raise Net::HTTPBadResponse if ![200].include?(response.status)
|
||||||
|
|
||||||
|
JSON.parse(response.body, symbolize_names: true).first
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -54,12 +54,11 @@ module DiscourseAi
|
||||||
upload_url = Discourse.store.cdn_url(upload.url)
|
upload_url = Discourse.store.cdn_url(upload.url)
|
||||||
upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/")
|
upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/")
|
||||||
|
|
||||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
DiscourseAi::Inference::DiscourseClassifier.new(
|
||||||
"#{endpoint}/api/v1/classify",
|
"#{endpoint}/api/v1/classify",
|
||||||
model,
|
|
||||||
upload_url,
|
|
||||||
SiteSetting.ai_nsfw_inference_service_api_key,
|
SiteSetting.ai_nsfw_inference_service_api_key,
|
||||||
)
|
model,
|
||||||
|
).perform!(upload_url)
|
||||||
end
|
end
|
||||||
|
|
||||||
def available_models
|
def available_models
|
||||||
|
|
|
@ -45,7 +45,7 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def request_with(content, model_config)
|
def request_with(content, model_config)
|
||||||
result = ::DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, model_config)
|
result = ::DiscourseAi::Inference::HuggingFaceText.classify(content, model_config)
|
||||||
transform_result(result)
|
transform_result(result)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -42,12 +42,11 @@ module DiscourseAi
|
||||||
|
|
||||||
def request(target_to_classify)
|
def request(target_to_classify)
|
||||||
data =
|
data =
|
||||||
::DiscourseAi::Inference::DiscourseClassifier.perform!(
|
::DiscourseAi::Inference::DiscourseClassifier.new(
|
||||||
"#{endpoint}/api/v1/classify",
|
"#{endpoint}/api/v1/classify",
|
||||||
SiteSetting.ai_toxicity_inference_service_api_model,
|
|
||||||
content_of(target_to_classify),
|
|
||||||
SiteSetting.ai_toxicity_inference_service_api_key,
|
SiteSetting.ai_toxicity_inference_service_api_key,
|
||||||
)
|
SiteSetting.ai_toxicity_inference_service_api_model,
|
||||||
|
).perform!(content_of(target_to_classify))
|
||||||
|
|
||||||
{ available_model => data }
|
{ available_model => data }
|
||||||
end
|
end
|
||||||
|
|
|
@ -26,10 +26,11 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
|
||||||
).to_return(status: 200, body: body_json, headers: {})
|
).to_return(status: 200, body: body_json, headers: {})
|
||||||
|
|
||||||
result =
|
result =
|
||||||
DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello", model: "text-embedding-ada-002")
|
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: "text-embedding-ada-002").perform!(
|
||||||
|
"hello",
|
||||||
|
)
|
||||||
|
|
||||||
expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 })
|
expect(result).to eq([0.0, 0.1])
|
||||||
expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] })
|
|
||||||
end
|
end
|
||||||
|
|
||||||
it "supports openai embeddings" do
|
it "supports openai embeddings" do
|
||||||
|
@ -54,13 +55,11 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
|
||||||
).to_return(status: 200, body: body_json, headers: {})
|
).to_return(status: 200, body: body_json, headers: {})
|
||||||
|
|
||||||
result =
|
result =
|
||||||
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
|
DiscourseAi::Inference::OpenAiEmbeddings.instance(
|
||||||
"hello",
|
|
||||||
model: "text-embedding-ada-002",
|
model: "text-embedding-ada-002",
|
||||||
dimensions: 1000,
|
dimensions: 1000,
|
||||||
)
|
).perform!("hello")
|
||||||
|
|
||||||
expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 })
|
expect(result).to eq([0.0, 0.1])
|
||||||
expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] })
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue