FEATURE: Support for SRV records for Discourse services (#414)

This allows admins to configure services with multiple backends using DNS SRV records. This PR also adds support for shared secret auth via headers for TEI and vLLM endpoints, so they are inline with the other ones.
This commit is contained in:
Rafael dos Santos Silva 2024-01-10 19:23:07 -03:00 committed by GitHub
parent 9d8bbe32a9
commit 8fcba12fae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 93 additions and 9 deletions

View File

@ -8,6 +8,9 @@ discourse_ai:
client: true client: true
ai_toxicity_inference_service_api_endpoint: ai_toxicity_inference_service_api_endpoint:
default: "https://disorder-testing.demo-by-discourse.com" default: "https://disorder-testing.demo-by-discourse.com"
ai_toxicity_inference_service_api_endpoint_srv:
default: ""
hidden: true
ai_toxicity_inference_service_api_key: ai_toxicity_inference_service_api_key:
default: "" default: ""
secret: true secret: true
@ -55,6 +58,9 @@ discourse_ai:
client: true client: true
ai_sentiment_inference_service_api_endpoint: ai_sentiment_inference_service_api_endpoint:
default: "https://sentiment-testing.demo-by-discourse.com" default: "https://sentiment-testing.demo-by-discourse.com"
ai_sentiment_inference_service_api_endpoint_srv:
default: ""
hidden: true
ai_sentiment_inference_service_api_key: ai_sentiment_inference_service_api_key:
default: "" default: ""
secret: true secret: true
@ -70,6 +76,9 @@ discourse_ai:
ai_nsfw_detection_enabled: false ai_nsfw_detection_enabled: false
ai_nsfw_inference_service_api_endpoint: ai_nsfw_inference_service_api_endpoint:
default: "https://nsfw-testing.demo-by-discourse.com" default: "https://nsfw-testing.demo-by-discourse.com"
ai_nsfw_inference_service_api_endpoint_srv:
default: ""
hidden: true
ai_nsfw_inference_service_api_key: ai_nsfw_inference_service_api_key:
default: "" default: ""
secret: true secret: true
@ -128,6 +137,7 @@ discourse_ai:
ai_hugging_face_tei_endpoint_srv: ai_hugging_face_tei_endpoint_srv:
default: "" default: ""
hidden: true hidden: true
ai_hugging_face_tei_api_key: ""
ai_google_custom_search_api_key: ai_google_custom_search_api_key:
default: "" default: ""
secret: true secret: true
@ -155,6 +165,7 @@ discourse_ai:
ai_vllm_endpoint_srv: ai_vllm_endpoint_srv:
default: "" default: ""
hidden: true hidden: true
ai_vllm_api_key: ""
composer_ai_helper_enabled: composer_ai_helper_enabled:
default: false default: false
@ -211,6 +222,9 @@ discourse_ai:
default: false default: false
client: true client: true
ai_embeddings_discourse_service_api_endpoint: "" ai_embeddings_discourse_service_api_endpoint: ""
ai_embeddings_discourse_service_api_endpoint_srv:
default: ""
hidden: true
ai_embeddings_discourse_service_api_key: ai_embeddings_discourse_service_api_key:
default: "" default: ""
secret: true secret: true
@ -257,6 +271,9 @@ discourse_ai:
- mistralai/Mistral-7B-Instruct-v0.2 - mistralai/Mistral-7B-Instruct-v0.2
ai_summarization_discourse_service_api_endpoint: "" ai_summarization_discourse_service_api_endpoint: ""
ai_summarization_discourse_service_api_endpoint_srv:
default: ""
hidden: true
ai_summarization_discourse_service_api_key: ai_summarization_discourse_service_api_key:
default: "" default: ""
secret: true secret: true

View File

@ -50,6 +50,9 @@ module DiscourseAi
def prepare_request(payload) def prepare_request(payload)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
headers["X-API-KEY"] = SiteSetting.ai_vllm_api_key if SiteSetting.ai_vllm_api_key.present?
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end end

View File

@ -6,7 +6,7 @@ module DiscourseAi
class AllMpnetBaseV2 < Base class AllMpnetBaseV2 < Base
def vector_from(text) def vector_from(text)
DiscourseAi::Inference::DiscourseClassifier.perform!( DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", "#{discourse_embeddings_endpoint}/api/v1/classify",
name, name,
text, text,
SiteSetting.ai_embeddings_discourse_service_api_key, SiteSetting.ai_embeddings_discourse_service_api_key,

View File

@ -308,6 +308,18 @@ module DiscourseAi
raise ArgumentError, "Invalid target type" raise ArgumentError, "Invalid target type"
end end
end end
def discourse_embeddings_endpoint
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_embeddings_discourse_service_api_endpoint
end
end
end end
end end
end end

View File

@ -13,9 +13,9 @@ module DiscourseAi
elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
truncated_text = tokenizer.truncate(text, max_sequence_length - 2) truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? elsif discourse_embeddings_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.perform!( DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", "#{discourse_embeddings_endpoint}/api/v1/classify",
inference_model_name.split("/").last, inference_model_name.split("/").last,
text, text,
SiteSetting.ai_embeddings_discourse_service_api_key, SiteSetting.ai_embeddings_discourse_service_api_key,

View File

@ -8,9 +8,9 @@ module DiscourseAi
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
truncated_text = tokenizer.truncate(text, max_sequence_length - 2) truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first
elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? elsif discourse_embeddings_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.perform!( DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", "#{discourse_embeddings_endpoint}/api/v1/classify",
name, name,
"query: #{text}", "query: #{text}",
SiteSetting.ai_embeddings_discourse_service_api_key, SiteSetting.ai_embeddings_discourse_service_api_key,

View File

@ -14,6 +14,10 @@ module ::DiscourseAi
api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint
end end
if SiteSetting.ai_hugging_face_tei_api_key.present?
headers["X-API-KEY"] = SiteSetting.ai_hugging_face_tei_api_key
end
response = Faraday.post(api_endpoint, body, headers) response = Faraday.post(api_endpoint, body, headers)
raise Net::HTTPBadResponse if ![200].include?(response.status) raise Net::HTTPBadResponse if ![200].include?(response.status)

View File

@ -55,7 +55,7 @@ module DiscourseAi
upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/") upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/")
DiscourseAi::Inference::DiscourseClassifier.perform!( DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_nsfw_inference_service_api_endpoint}/api/v1/classify", "#{endpoint}/api/v1/classify",
model, model,
upload_url, upload_url,
SiteSetting.ai_nsfw_inference_service_api_key, SiteSetting.ai_nsfw_inference_service_api_key,
@ -79,6 +79,18 @@ module DiscourseAi
value.to_i >= SiteSetting.send("ai_nsfw_flag_threshold_#{key}") value.to_i >= SiteSetting.send("ai_nsfw_flag_threshold_#{key}")
end end
end end
def endpoint
if SiteSetting.ai_nsfw_inference_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_nsfw_inference_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_nsfw_inference_service_api_endpoint
end
end
end end
end end
end end

View File

@ -40,7 +40,7 @@ module DiscourseAi
def request_with(model, content) def request_with(model, content)
::DiscourseAi::Inference::DiscourseClassifier.perform!( ::DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify", "#{endpoint}/api/v1/classify",
model, model,
content, content,
SiteSetting.ai_sentiment_inference_service_api_key, SiteSetting.ai_sentiment_inference_service_api_key,
@ -54,6 +54,18 @@ module DiscourseAi
target_to_classify.raw target_to_classify.raw
end end
end end
def endpoint
if SiteSetting.ai_sentiment_inference_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_sentiment_inference_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_sentiment_inference_service_api_endpoint
end
end
end end
end end
end end

View File

@ -44,12 +44,24 @@ module DiscourseAi
def completion(prompt) def completion(prompt)
::DiscourseAi::Inference::DiscourseClassifier.perform!( ::DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_summarization_discourse_service_api_endpoint}/api/v1/classify", "#{endpoint}/api/v1/classify",
completion_model.model, completion_model.model,
prompt, prompt,
SiteSetting.ai_summarization_discourse_service_api_key, SiteSetting.ai_summarization_discourse_service_api_key,
).dig(:summary_text) ).dig(:summary_text)
end end
def endpoint
if SiteSetting.ai_summarization_discourse_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_summarization_discourse_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_summarization_discourse_service_api_endpoint
end
end
end end
end end
end end

View File

@ -43,7 +43,7 @@ module DiscourseAi
def request(target_to_classify) def request(target_to_classify)
data = data =
::DiscourseAi::Inference::DiscourseClassifier.perform!( ::DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{SiteSetting.ai_toxicity_inference_service_api_endpoint}/api/v1/classify", "#{endpoint}/api/v1/classify",
SiteSetting.ai_toxicity_inference_service_api_model, SiteSetting.ai_toxicity_inference_service_api_model,
content_of(target_to_classify), content_of(target_to_classify),
SiteSetting.ai_toxicity_inference_service_api_key, SiteSetting.ai_toxicity_inference_service_api_key,
@ -67,6 +67,18 @@ module DiscourseAi
target_to_classify.raw target_to_classify.raw
end end
end end
def endpoint
if SiteSetting.ai_toxicity_inference_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_toxicity_inference_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_toxicity_inference_service_api_endpoint
end
end
end end
end end
end end