FEATURE: DNS SRV support for TEI (#363)
This commit is contained in:
parent
ba09582d7c
commit
4d7ccdda2f
|
@ -128,6 +128,9 @@ discourse_ai:
|
|||
default: ""
|
||||
ai_hugging_face_tei_endpoint:
|
||||
default: ""
|
||||
ai_hugging_face_tei_endpoint_srv:
|
||||
default: ""
|
||||
hidden: true
|
||||
ai_google_custom_search_api_key:
|
||||
default: ""
|
||||
secret: true
|
||||
|
|
|
@ -10,7 +10,7 @@ module DiscourseAi
|
|||
.perform!(inference_model_name, { text: text })
|
||||
.dig(:result, :data)
|
||||
.first
|
||||
elsif SiteSetting.ai_hugging_face_tei_endpoint.present?
|
||||
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
||||
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
|
||||
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||
|
|
|
@ -5,7 +5,7 @@ module DiscourseAi
|
|||
module VectorRepresentations
|
||||
class MultilingualE5Large < Base
|
||||
def vector_from(text)
|
||||
if SiteSetting.ai_hugging_face_tei_endpoint.present?
|
||||
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
||||
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
|
||||
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||
|
|
|
@ -7,7 +7,12 @@ module ::DiscourseAi
|
|||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||
body = { inputs: content, truncate: true }.to_json
|
||||
|
||||
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
||||
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
|
||||
api_endpoint = "https://#{service.target}:#{service.port}"
|
||||
else
|
||||
api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint
|
||||
end
|
||||
|
||||
response = Faraday.post(api_endpoint, body, headers)
|
||||
|
||||
|
@ -15,6 +20,11 @@ module ::DiscourseAi
|
|||
|
||||
JSON.parse(response.body, symbolize_names: true)
|
||||
end
|
||||
|
||||
def self.configured?
|
||||
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
|
||||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "resolv"
|
||||
|
||||
module DiscourseAi
|
||||
module Utils
|
||||
module DnsSrv
|
||||
def self.lookup(domain)
|
||||
Discourse
|
||||
.cache
|
||||
.fetch("dns_srv_lookup:#{domain}", expires_in: 5.minutes) do
|
||||
resources = dns_srv_lookup_for_domain(domain)
|
||||
|
||||
select_server(resources)
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def self.dns_srv_lookup_for_domain(domain)
|
||||
resolver = Resolv::DNS.new
|
||||
resources = resolver.getresources(domain, Resolv::DNS::Resource::IN::SRV)
|
||||
end
|
||||
|
||||
def self.select_server(resources)
|
||||
priority = resources.group_by(&:priority).keys.min
|
||||
|
||||
priority_resources = resources.select { |r| r.priority == priority }
|
||||
|
||||
total_weight = priority_resources.map(&:weight).sum
|
||||
|
||||
random_weight = rand(total_weight)
|
||||
|
||||
priority_resources.each do |resource|
|
||||
random_weight -= resource.weight
|
||||
|
||||
return resource if random_weight < 0
|
||||
end
|
||||
|
||||
# fallback
|
||||
resources.first
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,34 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
describe DiscourseAi::Utils::DnsSrv do
|
||||
let(:domain) { "example.com" }
|
||||
let(:weighted_dns_results) do
|
||||
[
|
||||
Resolv::DNS::Resource::IN::SRV.new(1, 1, 443, "service1.example.com"),
|
||||
Resolv::DNS::Resource::IN::SRV.new(1, 2, 443, "service2.example.com"),
|
||||
Resolv::DNS::Resource::IN::SRV.new(1, 2, 443, "service3.example.com"),
|
||||
Resolv::DNS::Resource::IN::SRV.new(2, 1, 443, "service4.example.com"),
|
||||
Resolv::DNS::Resource::IN::SRV.new(2, 1, 443, "service5.example.com"),
|
||||
]
|
||||
end
|
||||
|
||||
context "when there are several servers with the same priority" do
|
||||
before do
|
||||
Resolv::DNS.any_instance.stubs(:getresources).returns(weighted_dns_results)
|
||||
|
||||
Discourse.cache.delete("dns_srv_lookup:#{domain}")
|
||||
end
|
||||
|
||||
it "picks a server" do
|
||||
selected_server = DiscourseAi::Utils::DnsSrv.lookup(domain)
|
||||
|
||||
expect(weighted_dns_results).to include(selected_server)
|
||||
expect(selected_server.port).to eq(443)
|
||||
end
|
||||
|
||||
it "doesn't pick a server with lower priority" do
|
||||
selected_server = DiscourseAi::Utils::DnsSrv.lookup(domain)
|
||||
expect(weighted_dns_results.filter { |r| r.priority == 1 }).to include(selected_server)
|
||||
end
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue