FEATURE: Support for SRV records for Discourse services (#414)

This allows admins to configure services with multiple backends using DNS SRV records. This PR also adds support for shared secret auth via headers for TEI and vLLM endpoints, so they are inline with the other ones.
This commit is contained in:
Rafael dos Santos Silva 2024-01-10 19:23:07 -03:00 committed by GitHub
parent 9d8bbe32a9
commit 8fcba12fae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 93 additions and 9 deletions

View File

@ -8,6 +8,9 @@ discourse_ai:
client: true
ai_toxicity_inference_service_api_endpoint:
default: "https://disorder-testing.demo-by-discourse.com"
ai_toxicity_inference_service_api_endpoint_srv:
default: ""
hidden: true
ai_toxicity_inference_service_api_key:
default: ""
secret: true
@ -55,6 +58,9 @@ discourse_ai:
client: true
ai_sentiment_inference_service_api_endpoint:
default: "https://sentiment-testing.demo-by-discourse.com"
ai_sentiment_inference_service_api_endpoint_srv:
default: ""
hidden: true
ai_sentiment_inference_service_api_key:
default: ""
secret: true
@ -70,6 +76,9 @@ discourse_ai:
ai_nsfw_detection_enabled: false
ai_nsfw_inference_service_api_endpoint:
default: "https://nsfw-testing.demo-by-discourse.com"
ai_nsfw_inference_service_api_endpoint_srv:
default: ""
hidden: true
ai_nsfw_inference_service_api_key:
default: ""
secret: true
@ -128,6 +137,7 @@ discourse_ai:
ai_hugging_face_tei_endpoint_srv:
default: ""
hidden: true
ai_hugging_face_tei_api_key: ""
ai_google_custom_search_api_key:
default: ""
secret: true
@ -155,6 +165,7 @@ discourse_ai:
ai_vllm_endpoint_srv:
default: ""
hidden: true
ai_vllm_api_key: ""
composer_ai_helper_enabled:
default: false
@ -211,6 +222,9 @@ discourse_ai:
default: false
client: true
ai_embeddings_discourse_service_api_endpoint: ""
ai_embeddings_discourse_service_api_endpoint_srv:
default: ""
hidden: true
ai_embeddings_discourse_service_api_key:
default: ""
secret: true
@ -257,6 +271,9 @@ discourse_ai:
- mistralai/Mistral-7B-Instruct-v0.2
ai_summarization_discourse_service_api_endpoint: ""
ai_summarization_discourse_service_api_endpoint_srv:
default: ""
hidden: true
ai_summarization_discourse_service_api_key:
default: ""
secret: true

View File

@ -50,6 +50,9 @@ module DiscourseAi
def prepare_request(payload)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
headers["X-API-KEY"] = SiteSetting.ai_vllm_api_key if SiteSetting.ai_vllm_api_key.present?
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end

View File

@ -6,7 +6,7 @@ module DiscourseAi
class AllMpnetBaseV2 < Base
def vector_from(text)
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
"#{discourse_embeddings_endpoint}/api/v1/classify",
name,
text,
SiteSetting.ai_embeddings_discourse_service_api_key,

View File

@ -308,6 +308,18 @@ module DiscourseAi
raise ArgumentError, "Invalid target type"
end
end
def discourse_embeddings_endpoint
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_embeddings_discourse_service_api_endpoint
end
end
end
end
end

View File

@ -13,9 +13,9 @@ module DiscourseAi
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
elsif discourse_embeddings_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
"#{discourse_embeddings_endpoint}/api/v1/classify",
inference_model_name.split("/").last,
text,
SiteSetting.ai_embeddings_discourse_service_api_key,

View File

@ -8,9 +8,9 @@ module DiscourseAi
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
elsif discourse_embeddings_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
"#{discourse_embeddings_endpoint}/api/v1/classify",
name,
"query: #{text}",
SiteSetting.ai_embeddings_discourse_service_api_key,

View File

@ -14,6 +14,10 @@ module ::DiscourseAi
api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint
end
if SiteSetting.ai_hugging_face_tei_api_key.present?
headers["X-API-KEY"] = SiteSetting.ai_hugging_face_tei_api_key
end
response = Faraday.post(api_endpoint, body, headers)
raise Net::HTTPBadResponse if ![200].include?(response.status)

View File

@ -55,7 +55,7 @@ module DiscourseAi
upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/")
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_nsfw_inference_service_api_endpoint}/api/v1/classify",
"#{endpoint}/api/v1/classify",
model,
upload_url,
SiteSetting.ai_nsfw_inference_service_api_key,
@ -79,6 +79,18 @@ module DiscourseAi
value.to_i >= SiteSetting.send("ai_nsfw_flag_threshold_#{key}")
end
end
def endpoint
if SiteSetting.ai_nsfw_inference_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_nsfw_inference_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_nsfw_inference_service_api_endpoint
end
end
end
end
end

View File

@ -40,7 +40,7 @@ module DiscourseAi
def request_with(model, content)
::DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify",
"#{endpoint}/api/v1/classify",
model,
content,
SiteSetting.ai_sentiment_inference_service_api_key,
@ -54,6 +54,18 @@ module DiscourseAi
target_to_classify.raw
end
end
def endpoint
if SiteSetting.ai_sentiment_inference_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_sentiment_inference_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_sentiment_inference_service_api_endpoint
end
end
end
end
end

View File

@ -44,12 +44,24 @@ module DiscourseAi
def completion(prompt)
::DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_summarization_discourse_service_api_endpoint}/api/v1/classify",
"#{endpoint}/api/v1/classify",
completion_model.model,
prompt,
SiteSetting.ai_summarization_discourse_service_api_key,
).dig(:summary_text)
end
def endpoint
if SiteSetting.ai_summarization_discourse_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_summarization_discourse_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_summarization_discourse_service_api_endpoint
end
end
end
end
end

View File

@ -43,7 +43,7 @@ module DiscourseAi
def request(target_to_classify)
data =
::DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_toxicity_inference_service_api_endpoint}/api/v1/classify",
"#{endpoint}/api/v1/classify",
SiteSetting.ai_toxicity_inference_service_api_model,
content_of(target_to_classify),
SiteSetting.ai_toxicity_inference_service_api_key,
@ -67,6 +67,18 @@ module DiscourseAi
target_to_classify.raw
end
end
def endpoint
if SiteSetting.ai_toxicity_inference_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_toxicity_inference_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_toxicity_inference_service_api_endpoint
end
end
end
end
end