FEATURE: DNS SRV support for TEI (#363)

This commit is contained in:
Rafael dos Santos Silva 2023-12-18 13:21:21 -03:00 committed by GitHub
parent ba09582d7c
commit 4d7ccdda2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 95 additions and 3 deletions

View File

@ -128,6 +128,9 @@ discourse_ai:
default: ""
ai_hugging_face_tei_endpoint:
default: ""
ai_hugging_face_tei_endpoint_srv:
default: ""
hidden: true
ai_google_custom_search_api_key:
default: ""
secret: true

View File

@ -10,7 +10,7 @@ module DiscourseAi
.perform!(inference_model_name, { text: text })
.dig(:result, :data)
.first
elsif SiteSetting.ai_hugging_face_tei_endpoint.present?
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?

View File

@ -5,7 +5,7 @@ module DiscourseAi
module VectorRepresentations
class MultilingualE5Large < Base
def vector_from(text)
if SiteSetting.ai_hugging_face_tei_endpoint.present?
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?

View File

@ -7,7 +7,12 @@ module ::DiscourseAi
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
body = { inputs: content, truncate: true }.to_json
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
api_endpoint = "https://#{service.target}:#{service.port}"
else
api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint
end
response = Faraday.post(api_endpoint, body, headers)
@ -15,6 +20,11 @@ module ::DiscourseAi
JSON.parse(response.body, symbolize_names: true)
end
def self.configured?
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
end
end
end
end

45
lib/utils/dns_srv.rb Normal file
View File

@ -0,0 +1,45 @@
# frozen_string_literal: true
require "resolv"
module DiscourseAi
module Utils
module DnsSrv
def self.lookup(domain)
Discourse
.cache
.fetch("dns_srv_lookup:#{domain}", expires_in: 5.minutes) do
resources = dns_srv_lookup_for_domain(domain)
select_server(resources)
end
end
private
def self.dns_srv_lookup_for_domain(domain)
resolver = Resolv::DNS.new
resources = resolver.getresources(domain, Resolv::DNS::Resource::IN::SRV)
end
def self.select_server(resources)
priority = resources.group_by(&:priority).keys.min
priority_resources = resources.select { |r| r.priority == priority }
total_weight = priority_resources.map(&:weight).sum
random_weight = rand(total_weight)
priority_resources.each do |resource|
random_weight -= resource.weight
return resource if random_weight < 0
end
# fallback
resources.first
end
end
end
end

View File

@ -0,0 +1,34 @@
# frozen_string_literal: true
describe DiscourseAi::Utils::DnsSrv do
let(:domain) { "example.com" }
let(:weighted_dns_results) do
[
Resolv::DNS::Resource::IN::SRV.new(1, 1, 443, "service1.example.com"),
Resolv::DNS::Resource::IN::SRV.new(1, 2, 443, "service2.example.com"),
Resolv::DNS::Resource::IN::SRV.new(1, 2, 443, "service3.example.com"),
Resolv::DNS::Resource::IN::SRV.new(2, 1, 443, "service4.example.com"),
Resolv::DNS::Resource::IN::SRV.new(2, 1, 443, "service5.example.com"),
]
end
context "when there are several servers with the same priority" do
before do
Resolv::DNS.any_instance.stubs(:getresources).returns(weighted_dns_results)
Discourse.cache.delete("dns_srv_lookup:#{domain}")
end
it "picks a server" do
selected_server = DiscourseAi::Utils::DnsSrv.lookup(domain)
expect(weighted_dns_results).to include(selected_server)
expect(selected_server.port).to eq(443)
end
it "doesn't pick a server with lower priority" do
selected_server = DiscourseAi::Utils::DnsSrv.lookup(domain)
expect(weighted_dns_results.filter { |r| r.priority == 1 }).to include(selected_server)
end
end
end