FEATURE: HuggingFace Text Embeddings Inference compatibility (#323)

* FEATURE: HuggingFace Text Embeddings Inference compatibility

* lint
This commit is contained in:
Rafael dos Santos Silva 2023-11-28 17:05:26 -03:00 committed by GitHub
parent f26adf2cf6
commit fd0fb58eca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 52 additions and 10 deletions

View File

@ -122,6 +122,8 @@ discourse_ai:
default: 4096
ai_hugging_face_model_display_name:
default: ""
ai_hugging_face_tei_endpoint:
default: ""
ai_google_custom_search_api_key:
default: ""
secret: true

View File

@ -5,10 +5,23 @@ module DiscourseAi
module VectorRepresentations
class BgeLargeEn < Base
def vector_from(text)
DiscourseAi::Inference::CloudflareWorkersAi
.perform!(inference_model_name, { text: text })
.dig(:result, :data)
.first
if SiteSetting.ai_cloudflare_workers_api_token.present?
DiscourseAi::Inference::CloudflareWorkersAi
.perform!(inference_model_name, { text: text })
.dig(:result, :data)
.first
elsif SiteSetting.ai_hugging_face_tei_endpoint.present?
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(text).first
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
inference_model_name.split("/").last,
text,
SiteSetting.ai_embeddings_discourse_service_api_key,
)
else
raise "No inference endpoint configured"
end
end
def name

View File

@ -5,12 +5,18 @@ module DiscourseAi
module VectorRepresentations
class MultilingualE5Large < Base
def vector_from(text)
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
name,
"query: #{text}",
SiteSetting.ai_embeddings_discourse_service_api_key,
)
if SiteSetting.ai_hugging_face_tei_endpoint.present?
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(text).first
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
name,
"query: #{text}",
SiteSetting.ai_embeddings_discourse_service_api_key,
)
else
raise "No inference endpoint configured"
end
end
def id

View File

@ -0,0 +1,20 @@
# frozen_string_literal: true
module ::DiscourseAi
module Inference
class HuggingFaceTextEmbeddings
def self.perform!(content)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
body = { inputs: content }.to_json
api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint
response = Faraday.post(api_endpoint, body, headers)
raise Net::HTTPBadResponse if ![200].include?(response.status)
JSON.parse(response.body, symbolize_names: true)
end
end
end
end

View File

@ -42,6 +42,7 @@ after_initialize do
require_relative "lib/shared/inference/hugging_face_text_generation"
require_relative "lib/shared/inference/amazon_bedrock_inference"
require_relative "lib/shared/inference/cloudflare_workers_ai"
require_relative "lib/shared/inference/hugging_face_text_embeddings"
require_relative "lib/shared/inference/function"
require_relative "lib/shared/inference/function_list"