FEATURE: DNS SRV support for TEI (#363)
This commit is contained in:
parent
ba09582d7c
commit
4d7ccdda2f
|
@ -128,6 +128,9 @@ discourse_ai:
|
||||||
default: ""
|
default: ""
|
||||||
ai_hugging_face_tei_endpoint:
|
ai_hugging_face_tei_endpoint:
|
||||||
default: ""
|
default: ""
|
||||||
|
ai_hugging_face_tei_endpoint_srv:
|
||||||
|
default: ""
|
||||||
|
hidden: true
|
||||||
ai_google_custom_search_api_key:
|
ai_google_custom_search_api_key:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
secret: true
|
||||||
|
|
|
@ -10,7 +10,7 @@ module DiscourseAi
|
||||||
.perform!(inference_model_name, { text: text })
|
.perform!(inference_model_name, { text: text })
|
||||||
.dig(:result, :data)
|
.dig(:result, :data)
|
||||||
.first
|
.first
|
||||||
elsif SiteSetting.ai_hugging_face_tei_endpoint.present?
|
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
||||||
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
|
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
|
||||||
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||||
|
|
|
@ -5,7 +5,7 @@ module DiscourseAi
|
||||||
module VectorRepresentations
|
module VectorRepresentations
|
||||||
class MultilingualE5Large < Base
|
class MultilingualE5Large < Base
|
||||||
def vector_from(text)
|
def vector_from(text)
|
||||||
if SiteSetting.ai_hugging_face_tei_endpoint.present?
|
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
||||||
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
|
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
|
||||||
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||||
|
|
|
@ -7,7 +7,12 @@ module ::DiscourseAi
|
||||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||||
body = { inputs: content, truncate: true }.to_json
|
body = { inputs: content, truncate: true }.to_json
|
||||||
|
|
||||||
|
if SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
||||||
|
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv)
|
||||||
|
api_endpoint = "https://#{service.target}:#{service.port}"
|
||||||
|
else
|
||||||
api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint
|
api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint
|
||||||
|
end
|
||||||
|
|
||||||
response = Faraday.post(api_endpoint, body, headers)
|
response = Faraday.post(api_endpoint, body, headers)
|
||||||
|
|
||||||
|
@ -15,6 +20,11 @@ module ::DiscourseAi
|
||||||
|
|
||||||
JSON.parse(response.body, symbolize_names: true)
|
JSON.parse(response.body, symbolize_names: true)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def self.configured?
|
||||||
|
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
|
||||||
|
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require "resolv"
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Utils
|
||||||
|
module DnsSrv
|
||||||
|
def self.lookup(domain)
|
||||||
|
Discourse
|
||||||
|
.cache
|
||||||
|
.fetch("dns_srv_lookup:#{domain}", expires_in: 5.minutes) do
|
||||||
|
resources = dns_srv_lookup_for_domain(domain)
|
||||||
|
|
||||||
|
select_server(resources)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def self.dns_srv_lookup_for_domain(domain)
|
||||||
|
resolver = Resolv::DNS.new
|
||||||
|
resources = resolver.getresources(domain, Resolv::DNS::Resource::IN::SRV)
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.select_server(resources)
|
||||||
|
priority = resources.group_by(&:priority).keys.min
|
||||||
|
|
||||||
|
priority_resources = resources.select { |r| r.priority == priority }
|
||||||
|
|
||||||
|
total_weight = priority_resources.map(&:weight).sum
|
||||||
|
|
||||||
|
random_weight = rand(total_weight)
|
||||||
|
|
||||||
|
priority_resources.each do |resource|
|
||||||
|
random_weight -= resource.weight
|
||||||
|
|
||||||
|
return resource if random_weight < 0
|
||||||
|
end
|
||||||
|
|
||||||
|
# fallback
|
||||||
|
resources.first
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,34 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
describe DiscourseAi::Utils::DnsSrv do
|
||||||
|
let(:domain) { "example.com" }
|
||||||
|
let(:weighted_dns_results) do
|
||||||
|
[
|
||||||
|
Resolv::DNS::Resource::IN::SRV.new(1, 1, 443, "service1.example.com"),
|
||||||
|
Resolv::DNS::Resource::IN::SRV.new(1, 2, 443, "service2.example.com"),
|
||||||
|
Resolv::DNS::Resource::IN::SRV.new(1, 2, 443, "service3.example.com"),
|
||||||
|
Resolv::DNS::Resource::IN::SRV.new(2, 1, 443, "service4.example.com"),
|
||||||
|
Resolv::DNS::Resource::IN::SRV.new(2, 1, 443, "service5.example.com"),
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
||||||
|
context "when there are several servers with the same priority" do
|
||||||
|
before do
|
||||||
|
Resolv::DNS.any_instance.stubs(:getresources).returns(weighted_dns_results)
|
||||||
|
|
||||||
|
Discourse.cache.delete("dns_srv_lookup:#{domain}")
|
||||||
|
end
|
||||||
|
|
||||||
|
it "picks a server" do
|
||||||
|
selected_server = DiscourseAi::Utils::DnsSrv.lookup(domain)
|
||||||
|
|
||||||
|
expect(weighted_dns_results).to include(selected_server)
|
||||||
|
expect(selected_server.port).to eq(443)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "doesn't pick a server with lower priority" do
|
||||||
|
selected_server = DiscourseAi::Utils::DnsSrv.lookup(domain)
|
||||||
|
expect(weighted_dns_results.filter { |r| r.priority == 1 }).to include(selected_server)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
Loading…
Reference in New Issue