From 8fcba12faed98854fd608508384676b865ebaa16 Mon Sep 17 00:00:00 2001 From: Rafael dos Santos Silva Date: Wed, 10 Jan 2024 19:23:07 -0300 Subject: [PATCH] FEATURE: Support for SRV records for Discourse services (#414) This allows admins to configure services with multiple backends using DNS SRV records. This PR also adds support for shared secret auth via headers for TEI and vLLM endpoints, so they are inline with the other ones. --- config/settings.yml | 17 +++++++++++++++++ lib/completions/endpoints/vllm.rb | 3 +++ .../vector_representations/all_mpnet_base_v2.rb | 2 +- lib/embeddings/vector_representations/base.rb | 12 ++++++++++++ .../vector_representations/bge_large_en.rb | 4 ++-- .../multilingual_e5_large.rb | 4 ++-- lib/inference/hugging_face_text_embeddings.rb | 4 ++++ lib/nsfw/classification.rb | 14 +++++++++++++- lib/sentiment/sentiment_classification.rb | 14 +++++++++++++- .../strategies/truncate_content.rb | 14 +++++++++++++- lib/toxicity/toxicity_classification.rb | 14 +++++++++++++- 11 files changed, 93 insertions(+), 9 deletions(-) diff --git a/config/settings.yml b/config/settings.yml index d6723eb0..4409503d 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -8,6 +8,9 @@ discourse_ai: client: true ai_toxicity_inference_service_api_endpoint: default: "https://disorder-testing.demo-by-discourse.com" + ai_toxicity_inference_service_api_endpoint_srv: + default: "" + hidden: true ai_toxicity_inference_service_api_key: default: "" secret: true @@ -55,6 +58,9 @@ discourse_ai: client: true ai_sentiment_inference_service_api_endpoint: default: "https://sentiment-testing.demo-by-discourse.com" + ai_sentiment_inference_service_api_endpoint_srv: + default: "" + hidden: true ai_sentiment_inference_service_api_key: default: "" secret: true @@ -70,6 +76,9 @@ discourse_ai: ai_nsfw_detection_enabled: false ai_nsfw_inference_service_api_endpoint: default: "https://nsfw-testing.demo-by-discourse.com" + ai_nsfw_inference_service_api_endpoint_srv: + default: "" + hidden: true ai_nsfw_inference_service_api_key: default: "" secret: true @@ -128,6 +137,7 @@ discourse_ai: ai_hugging_face_tei_endpoint_srv: default: "" hidden: true + ai_hugging_face_tei_api_key: "" ai_google_custom_search_api_key: default: "" secret: true @@ -155,6 +165,7 @@ discourse_ai: ai_vllm_endpoint_srv: default: "" hidden: true + ai_vllm_api_key: "" composer_ai_helper_enabled: default: false @@ -211,6 +222,9 @@ discourse_ai: default: false client: true ai_embeddings_discourse_service_api_endpoint: "" + ai_embeddings_discourse_service_api_endpoint_srv: + default: "" + hidden: true ai_embeddings_discourse_service_api_key: default: "" secret: true @@ -257,6 +271,9 @@ discourse_ai: - mistralai/Mistral-7B-Instruct-v0.2 ai_summarization_discourse_service_api_endpoint: "" + ai_summarization_discourse_service_api_endpoint_srv: + default: "" + hidden: true ai_summarization_discourse_service_api_key: default: "" secret: true diff --git a/lib/completions/endpoints/vllm.rb b/lib/completions/endpoints/vllm.rb index 71385e94..1ea69bbd 100644 --- a/lib/completions/endpoints/vllm.rb +++ b/lib/completions/endpoints/vllm.rb @@ -50,6 +50,9 @@ module DiscourseAi def prepare_request(payload) headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } + + headers["X-API-KEY"] = SiteSetting.ai_vllm_api_key if SiteSetting.ai_vllm_api_key.present? + Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end diff --git a/lib/embeddings/vector_representations/all_mpnet_base_v2.rb b/lib/embeddings/vector_representations/all_mpnet_base_v2.rb index 8dfb2a47..a8bbe86c 100644 --- a/lib/embeddings/vector_representations/all_mpnet_base_v2.rb +++ b/lib/embeddings/vector_representations/all_mpnet_base_v2.rb @@ -6,7 +6,7 @@ module DiscourseAi class AllMpnetBaseV2 < Base def vector_from(text) DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", + "#{discourse_embeddings_endpoint}/api/v1/classify", name, text, SiteSetting.ai_embeddings_discourse_service_api_key, diff --git a/lib/embeddings/vector_representations/base.rb b/lib/embeddings/vector_representations/base.rb index 6b7f8ff0..b73f91c8 100644 --- a/lib/embeddings/vector_representations/base.rb +++ b/lib/embeddings/vector_representations/base.rb @@ -308,6 +308,18 @@ module DiscourseAi raise ArgumentError, "Invalid target type" end end + + def discourse_embeddings_endpoint + if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? + service = + DiscourseAi::Utils::DnsSrv.lookup( + SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv, + ) + "https://#{service.target}:#{service.port}" + else + SiteSetting.ai_embeddings_discourse_service_api_endpoint + end + end end end end diff --git a/lib/embeddings/vector_representations/bge_large_en.rb b/lib/embeddings/vector_representations/bge_large_en.rb index af600b7e..f3e24c48 100644 --- a/lib/embeddings/vector_representations/bge_large_en.rb +++ b/lib/embeddings/vector_representations/bge_large_en.rb @@ -13,9 +13,9 @@ module DiscourseAi elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? truncated_text = tokenizer.truncate(text, max_sequence_length - 2) DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first - elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? + elsif discourse_embeddings_endpoint.present? DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", + "#{discourse_embeddings_endpoint}/api/v1/classify", inference_model_name.split("/").last, text, SiteSetting.ai_embeddings_discourse_service_api_key, diff --git a/lib/embeddings/vector_representations/multilingual_e5_large.rb b/lib/embeddings/vector_representations/multilingual_e5_large.rb index d7fcab4d..55dfc448 100644 --- a/lib/embeddings/vector_representations/multilingual_e5_large.rb +++ b/lib/embeddings/vector_representations/multilingual_e5_large.rb @@ -8,9 +8,9 @@ module DiscourseAi if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? truncated_text = tokenizer.truncate(text, max_sequence_length - 2) DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first - elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? + elsif discourse_embeddings_endpoint.present? DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", + "#{discourse_embeddings_endpoint}/api/v1/classify", name, "query: #{text}", SiteSetting.ai_embeddings_discourse_service_api_key, diff --git a/lib/inference/hugging_face_text_embeddings.rb b/lib/inference/hugging_face_text_embeddings.rb index fbc4ddec..09118bcb 100644 --- a/lib/inference/hugging_face_text_embeddings.rb +++ b/lib/inference/hugging_face_text_embeddings.rb @@ -14,6 +14,10 @@ module ::DiscourseAi api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint end + if SiteSetting.ai_hugging_face_tei_api_key.present? + headers["X-API-KEY"] = SiteSetting.ai_hugging_face_tei_api_key + end + response = Faraday.post(api_endpoint, body, headers) raise Net::HTTPBadResponse if ![200].include?(response.status) diff --git a/lib/nsfw/classification.rb b/lib/nsfw/classification.rb index 2c2623a2..5c36e566 100644 --- a/lib/nsfw/classification.rb +++ b/lib/nsfw/classification.rb @@ -55,7 +55,7 @@ module DiscourseAi upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/") DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{SiteSetting.ai_nsfw_inference_service_api_endpoint}/api/v1/classify", + "#{endpoint}/api/v1/classify", model, upload_url, SiteSetting.ai_nsfw_inference_service_api_key, @@ -79,6 +79,18 @@ module DiscourseAi value.to_i >= SiteSetting.send("ai_nsfw_flag_threshold_#{key}") end end + + def endpoint + if SiteSetting.ai_nsfw_inference_service_api_endpoint_srv.present? + service = + DiscourseAi::Utils::DnsSrv.lookup( + SiteSetting.ai_nsfw_inference_service_api_endpoint_srv, + ) + "https://#{service.target}:#{service.port}" + else + SiteSetting.ai_nsfw_inference_service_api_endpoint + end + end end end end diff --git a/lib/sentiment/sentiment_classification.rb b/lib/sentiment/sentiment_classification.rb index 00993d01..dc5b2e87 100644 --- a/lib/sentiment/sentiment_classification.rb +++ b/lib/sentiment/sentiment_classification.rb @@ -40,7 +40,7 @@ module DiscourseAi def request_with(model, content) ::DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{SiteSetting.ai_sentiment_inference_service_api_endpoint}/api/v1/classify", + "#{endpoint}/api/v1/classify", model, content, SiteSetting.ai_sentiment_inference_service_api_key, @@ -54,6 +54,18 @@ module DiscourseAi target_to_classify.raw end end + + def endpoint + if SiteSetting.ai_sentiment_inference_service_api_endpoint_srv.present? + service = + DiscourseAi::Utils::DnsSrv.lookup( + SiteSetting.ai_sentiment_inference_service_api_endpoint_srv, + ) + "https://#{service.target}:#{service.port}" + else + SiteSetting.ai_sentiment_inference_service_api_endpoint + end + end end end end diff --git a/lib/summarization/strategies/truncate_content.rb b/lib/summarization/strategies/truncate_content.rb index 1b4cbab6..afbfa5f9 100644 --- a/lib/summarization/strategies/truncate_content.rb +++ b/lib/summarization/strategies/truncate_content.rb @@ -44,12 +44,24 @@ module DiscourseAi def completion(prompt) ::DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{SiteSetting.ai_summarization_discourse_service_api_endpoint}/api/v1/classify", + "#{endpoint}/api/v1/classify", completion_model.model, prompt, SiteSetting.ai_summarization_discourse_service_api_key, ).dig(:summary_text) end + + def endpoint + if SiteSetting.ai_summarization_discourse_service_api_endpoint_srv.present? + service = + DiscourseAi::Utils::DnsSrv.lookup( + SiteSetting.ai_summarization_discourse_service_api_endpoint_srv, + ) + "https://#{service.target}:#{service.port}" + else + SiteSetting.ai_summarization_discourse_service_api_endpoint + end + end end end end diff --git a/lib/toxicity/toxicity_classification.rb b/lib/toxicity/toxicity_classification.rb index 4425e63e..8bb5f788 100644 --- a/lib/toxicity/toxicity_classification.rb +++ b/lib/toxicity/toxicity_classification.rb @@ -43,7 +43,7 @@ module DiscourseAi def request(target_to_classify) data = ::DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{SiteSetting.ai_toxicity_inference_service_api_endpoint}/api/v1/classify", + "#{endpoint}/api/v1/classify", SiteSetting.ai_toxicity_inference_service_api_model, content_of(target_to_classify), SiteSetting.ai_toxicity_inference_service_api_key, @@ -67,6 +67,18 @@ module DiscourseAi target_to_classify.raw end end + + def endpoint + if SiteSetting.ai_toxicity_inference_service_api_endpoint_srv.present? + service = + DiscourseAi::Utils::DnsSrv.lookup( + SiteSetting.ai_toxicity_inference_service_api_endpoint_srv, + ) + "https://#{service.target}:#{service.port}" + else + SiteSetting.ai_toxicity_inference_service_api_endpoint + end + end end end end