REFACTOR: Tidy-up embedding endpoints config.

Two changes worth mentioning:

`#instance` returns a fully configured embedding endpoint ready to use.
All endpoints respond to the same method and have the same signature - `perform!(text)`

This makes it easier to reuse them when generating embeddings in bulk.
This commit is contained in:
Roman Rizzi 2024-11-20 18:24:19 -03:00
parent 1a10680818
commit b0ab9ccc48
No known key found for this signature in database
GPG Key ID: 64024A71CE7330D3
21 changed files with 227 additions and 137 deletions

View File

@ -313,7 +313,7 @@ module DiscourseAi
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings
reranker = DiscourseAi::Inference::HuggingFaceText
interactions_vector = vector_rep.vector_from(consolidated_question)
@ -344,7 +344,7 @@ module DiscourseAi
if reranker.reranker_configured?
guidance = fragments.map { |fragment, _metadata| fragment }
ranks =
DiscourseAi::Inference::HuggingFaceTextEmbeddings
DiscourseAi::Inference::HuggingFaceText
.rerank(conversation_context.last[:content], guidance)
.to_a
.take(rag_conversation_chunks)

View File

@ -159,10 +159,7 @@ module DiscourseAi
.map { _1.truncate(2000, omission: "") }
reranked_results =
DiscourseAi::Inference::HuggingFaceTextEmbeddings.rerank(
search_term,
rerank_posts_payload,
)
DiscourseAi::Inference::HuggingFaceText.rerank(search_term, rerank_posts_payload)
reordered_ids = reranked_results.map { _1[:index] }.map { filtered_results[_1].id }.take(5)

View File

@ -24,12 +24,7 @@ module DiscourseAi
end
def vector_from(text, asymetric: false)
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{discourse_embeddings_endpoint}/api/v1/classify",
self.class.name,
text,
SiteSetting.ai_embeddings_discourse_service_api_key,
)
inference_client.perform!(text)
end
def dimensions
@ -59,6 +54,10 @@ module DiscourseAi
def tokenizer
DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer
end
def inference_client
DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name)
end
end
end
end

View File

@ -426,16 +426,8 @@ module DiscourseAi
end
end
def discourse_embeddings_endpoint
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_embeddings_discourse_service_api_endpoint
end
def inference_client
raise NotImplementedError
end
end
end

View File

@ -11,7 +11,7 @@ module DiscourseAi
def correctly_configured?
SiteSetting.ai_cloudflare_workers_api_token.present? ||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
DiscourseAi::Inference::TeiEmbeddings.configured? ||
(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
@ -33,24 +33,12 @@ module DiscourseAi
def vector_from(text, asymetric: false)
text = "#{asymmetric_query_prefix} #{text}" if asymetric
if SiteSetting.ai_cloudflare_workers_api_token.present?
DiscourseAi::Inference::CloudflareWorkersAi
.perform!(inference_model_name, { text: text })
.dig(:result, :data)
.first
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
elsif discourse_embeddings_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{discourse_embeddings_endpoint}/api/v1/classify",
inference_model_name.split("/").last,
text,
SiteSetting.ai_embeddings_discourse_service_api_key,
)
else
raise "No inference endpoint configured"
end
client = inference_client
needs_truncation = client.class.name.include?("TeiEmbeddings")
text = tokenizer.truncate(text, max_sequence_length - 2) if needs_truncation
inference_client.perform!(text)
end
def inference_model_name
@ -88,6 +76,21 @@ module DiscourseAi
def asymmetric_query_prefix
"Represent this sentence for searching relevant passages:"
end
def inference_client
if SiteSetting.ai_cloudflare_workers_api_token.present?
DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name)
elsif DiscourseAi::Inference::TeiEmbeddings.configured?
DiscourseAi::Inference::TeiEmbeddings.instance
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.instance(
inference_model_name.split("/").last,
)
else
raise "No inference endpoint configured"
end
end
end
end
end

View File

@ -10,7 +10,7 @@ module DiscourseAi
end
def correctly_configured?
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
DiscourseAi::Inference::TeiEmbeddings.configured?
end
def dependant_setting_names
@ -20,7 +20,7 @@ module DiscourseAi
def vector_from(text, asymetric: false)
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
inference_client.perform!(truncated_text)
end
def dimensions
@ -50,6 +50,10 @@ module DiscourseAi
def tokenizer
DiscourseAi::Tokenizer::BgeM3Tokenizer
end
def inference_client
DiscourseAi::Inference::TeiEmbeddings.instance
end
end
end
end

View File

@ -43,8 +43,7 @@ module DiscourseAi
end
def vector_from(text, asymetric: false)
response = DiscourseAi::Inference::GeminiEmbeddings.perform!(text)
response[:embedding][:values]
inference_client.perform!(text).dig(:embedding, :values)
end
# There is no public tokenizer for Gemini, and from the ones we already ship in the plugin
@ -53,6 +52,10 @@ module DiscourseAi
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
def inference_client
DiscourseAi::Inference::GeminiEmbeddings.instance
end
end
end
end

View File

@ -10,7 +10,7 @@ module DiscourseAi
end
def correctly_configured?
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
DiscourseAi::Inference::TeiEmbeddings.configured? ||
(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
@ -29,19 +29,16 @@ module DiscourseAi
end
def vector_from(text, asymetric: false)
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
elsif discourse_embeddings_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{discourse_embeddings_endpoint}/api/v1/classify",
self.class.name,
"query: #{text}",
SiteSetting.ai_embeddings_discourse_service_api_key,
)
client = inference_client
needs_truncation = client.class.name.include?("TeiEmbeddings")
if needs_truncation
text = tokenizer.truncate(text, max_sequence_length - 2)
else
raise "No inference endpoint configured"
text = "query: #{text}"
end
client.perform!(text)
end
def id
@ -71,6 +68,18 @@ module DiscourseAi
def tokenizer
DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer
end
def inference_client
if DiscourseAi::Inference::TeiEmbeddings.configured?
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
DiscourseAi::Inference::TeiEmbeddings.instance
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name)
else
raise "No inference endpoint configured"
end
end
end
end
end

View File

@ -45,18 +45,19 @@ module DiscourseAi
end
def vector_from(text, asymetric: false)
response =
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
text,
model: self.class.name,
dimensions: dimensions,
)
response[:data].first[:embedding]
inference_client.perform!(text)
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
def inference_client
DiscourseAi::Inference::OpenAiEmbeddings.instance(
model: self.class.name,
dimensions: dimensions,
)
end
end
end
end

View File

@ -43,13 +43,16 @@ module DiscourseAi
end
def vector_from(text, asymetric: false)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
response[:data].first[:embedding]
inference_client.perform!(text)
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
def inference_client
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name)
end
end
end
end

View File

@ -43,13 +43,16 @@ module DiscourseAi
end
def vector_from(text, asymetric: false)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
response[:data].first[:embedding]
inference_client.perform!(text)
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
def inference_client
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name)
end
end
end
end

View File

@ -3,25 +3,38 @@
module ::DiscourseAi
module Inference
class CloudflareWorkersAi
def self.perform!(model, content)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
def initialize(account_id, api_token, model, referer = Discourse.base_url)
@account_id = account_id
@api_token = api_token
@model = model
@referer = referer
end
account_id = SiteSetting.ai_cloudflare_workers_account_id
token = SiteSetting.ai_cloudflare_workers_api_token
def self.instance(model)
new(
SiteSetting.ai_cloudflare_workers_account_id,
SiteSetting.ai_cloudflare_workers_api_token,
model,
)
end
base_url = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/"
headers["Authorization"] = "Bearer #{token}"
attr_reader :account_id, :api_token, :model, :referer
endpoint = "#{base_url}#{model}"
def perform!(content)
headers = {
"Referer" => Discourse.base_url,
"Content-Type" => "application/json",
"Authorization" => "Bearer #{api_token}",
}
endpoint = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/#{model}"
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post(endpoint, content.to_json, headers)
raise Net::HTTPBadResponse if ![200].include?(response.status)
case response.status
when 200
JSON.parse(response.body, symbolize_names: true)
JSON.parse(response.body, symbolize_names: true).dig(:result, :data).first
when 429
# TODO add a AdminDashboard Problem?
else

View File

@ -3,9 +3,36 @@
module ::DiscourseAi
module Inference
class DiscourseClassifier
def self.perform!(endpoint, model, content, api_key)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
def initialize(endpoint, api_key, model, referer = Discourse.base_url)
@endpoint = endpoint
@api_key = api_key
@model = model
@referer = referer
end
def self.instance(model)
endpoint =
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_embeddings_discourse_service_api_endpoint
end
new(
"#{endpoint}/api/v1/classify",
SiteSetting.ai_embeddings_discourse_service_api_key,
model,
)
end
attr_reader :endpoint, :api_key, :model, :referer
def perform!(content)
headers = { "Referer" => referer, "Content-Type" => "application/json" }
headers["X-API-KEY"] = api_key if api_key.present?
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }

View File

@ -3,12 +3,17 @@
module ::DiscourseAi
module Inference
class GeminiEmbeddings
def self.perform!(content)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
def initialize(api_key, referer = Discourse.base_url)
@api_key = api_key
@referer = referer
end
attr_reader :api_key, :referer
def perform!(content)
headers = { "Referer" => referer, "Content-Type" => "application/json" }
url =
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{SiteSetting.ai_gemini_api_key}"
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{api_key}"
body = { content: { parts: [{ text: content }] } }
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }

View File

@ -2,33 +2,8 @@
module ::DiscourseAi
module Inference
class HuggingFaceTextEmbeddings
class HuggingFaceText
class << self
def perform!(content)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
body = { inputs: content, truncate: true }.to_json
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
api_endpoint = "https://#{service.target}:#{service.port}"
else
api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint
end
if SiteSetting.ai_hugging_face_tei_api_key.present?
headers["X-API-KEY"] = SiteSetting.ai_hugging_face_tei_api_key
headers["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_tei_api_key}"
end
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post(api_endpoint, body, headers)
raise Net::HTTPBadResponse if ![200].include?(response.status)
JSON.parse(response.body, symbolize_names: true)
end
def rerank(content, candidates)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
body = { query: content, texts: candidates, truncate: true }.to_json
@ -85,11 +60,6 @@ module ::DiscourseAi
SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
end
def configured?
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
end
end
end
end

View File

@ -3,13 +3,26 @@
module ::DiscourseAi
module Inference
class OpenAiEmbeddings
def self.perform!(content, model:, dimensions: nil)
def initialize(endpoint, api_key, model, dimensions)
@endpoint = endpoint
@api_key = api_key
@model = model
@dimensions = dimensions
end
attr_reader :endpoint, :api_key, :model, :dimensions
def self.instance(model:, dimensions: nil)
new(SiteSetting.ai_openai_embeddings_url, SiteSetting.ai_openai_api_key, model, dimensions)
end
def perform!(content)
headers = { "Content-Type" => "application/json" }
if SiteSetting.ai_openai_embeddings_url.include?("azure")
headers["api-key"] = SiteSetting.ai_openai_api_key
if endpoint.include?("azure")
headers["api-key"] = api_key
else
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
headers["Authorization"] = "Bearer #{api_key}"
end
payload = { model: model, input: content }
@ -20,7 +33,7 @@ module ::DiscourseAi
case response.status
when 200
JSON.parse(response.body, symbolize_names: true)
JSON.parse(response.body, symbolize_names: true).dig(:data, 0, :embedding)
when 429
# TODO add a AdminDashboard Problem?
else

View File

@ -0,0 +1,51 @@
# frozen_string_literal: true
module DiscourseAi
module Inference
class TeiEmbeddings
def initialize(endpoint, key, referer = Discourse.base_url)
@endpoint = endpoint
@key = key
@referer = referer
end
attr_reader :endpoint, :key, :referer
class << self
def configured?
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
end
def instance
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
endpoint = "https://#{service.target}:#{service.port}"
else
endpoint = SiteSetting.ai_hugging_face_tei_endpoint
end
new(endpoint, SiteSetting.ai_hugging_face_tei_api_key)
end
end
def perform!(content)
headers = { "Referer" => referer, "Content-Type" => "application/json" }
body = { inputs: content, truncate: true }.to_json
if key.present?
headers["X-API-KEY"] = key
headers["Authorization"] = "Bearer #{key}"
end
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
response = conn.post(endpoint, body, headers)
raise Net::HTTPBadResponse if ![200].include?(response.status)
JSON.parse(response.body, symbolize_names: true).first
end
end
end
end

View File

@ -54,12 +54,11 @@ module DiscourseAi
upload_url = Discourse.store.cdn_url(upload.url)
upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/")
DiscourseAi::Inference::DiscourseClassifier.perform!(
DiscourseAi::Inference::DiscourseClassifier.new(
"#{endpoint}/api/v1/classify",
model,
upload_url,
SiteSetting.ai_nsfw_inference_service_api_key,
)
model,
).perform!(upload_url)
end
def available_models

View File

@ -45,7 +45,7 @@ module DiscourseAi
private
def request_with(content, model_config)
result = ::DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, model_config)
result = ::DiscourseAi::Inference::HuggingFaceText.classify(content, model_config)
transform_result(result)
end

View File

@ -42,12 +42,11 @@ module DiscourseAi
def request(target_to_classify)
data =
::DiscourseAi::Inference::DiscourseClassifier.perform!(
::DiscourseAi::Inference::DiscourseClassifier.new(
"#{endpoint}/api/v1/classify",
SiteSetting.ai_toxicity_inference_service_api_model,
content_of(target_to_classify),
SiteSetting.ai_toxicity_inference_service_api_key,
)
SiteSetting.ai_toxicity_inference_service_api_model,
).perform!(content_of(target_to_classify))
{ available_model => data }
end

View File

@ -26,10 +26,11 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
).to_return(status: 200, body: body_json, headers: {})
result =
DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello", model: "text-embedding-ada-002")
DiscourseAi::Inference::OpenAiEmbeddings.instance(model: "text-embedding-ada-002").perform!(
"hello",
)
expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 })
expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] })
expect(result).to eq([0.0, 0.1])
end
it "supports openai embeddings" do
@ -54,13 +55,11 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
).to_return(status: 200, body: body_json, headers: {})
result =
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
"hello",
DiscourseAi::Inference::OpenAiEmbeddings.instance(
model: "text-embedding-ada-002",
dimensions: 1000,
)
).perform!("hello")
expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 })
expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] })
expect(result).to eq([0.0, 0.1])
end
end