From fba9c1bf2c53d5aabe8cc3b36a07f209e3a53358 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Thu, 1 Feb 2024 16:54:09 -0300 Subject: [PATCH] UX: Re-introduce embedding settings validations (#457) * Revert "Revert "UX: Validate embeddings settings (#455)" (#456)" This reverts commit 392e2e8aef7d5b0d988b3c3bc5cc19f1d83c4491. * Resstore previous default --- config/locales/server.en.yml | 9 ++++ config/settings.yml | 3 +- .../embeddings_model_validator.rb | 46 +++++++++++++++++ .../embeddings_module_validator.rb | 51 +++++++++++++++++++ .../all_mpnet_base_v2.rb | 25 +++++++-- lib/embeddings/vector_representations/base.rb | 47 ++++++++++++----- .../vector_representations/bge_large_en.rb | 30 +++++++++-- .../vector_representations/gemini.rb | 18 +++++-- .../multilingual_e5_large.rb | 30 +++++++++-- .../text_embedding_3_large.rb | 20 ++++++-- .../text_embedding_3_small.rb | 20 ++++++-- .../text_embedding_ada_002.rb | 20 ++++++-- .../jobs/generate_embeddings_spec.rb | 5 +- .../all_mpnet_base_v2_spec.rb | 2 +- .../multilingual_e5_large_spec.rb | 2 +- .../text_embedding_ada_002_spec.rb | 2 +- 16 files changed, 278 insertions(+), 52 deletions(-) create mode 100644 lib/configuration/embeddings_model_validator.rb create mode 100644 lib/configuration/embeddings_module_validator.rb diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 6efa5dc9..023ab611 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -251,3 +251,12 @@ en: configuration_hint: one: "Make sure the `%{settings}` setting was configured." other: "Make sure these settings were configured: %{settings}" + + embeddings: + configuration: + disable_embeddings: "You have to disable 'ai embeddings enabled' first." + choose_model: "Set 'ai embeddings model' first." + model_unreachable: "We failed to generate a test embedding with this model. Check your settings are correct." + hint: + one: "Make sure the `%{settings}` setting was configured." + other: "Make sure the settings of the provider you want were configured. Options are: %{settings}" diff --git a/config/settings.yml b/config/settings.yml index b933477b..fe8be23b 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -216,6 +216,7 @@ discourse_ai: ai_embeddings_enabled: default: false client: true + validator: "DiscourseAi::Configuration::EmbeddingsModuleValidator" ai_embeddings_discourse_service_api_endpoint: "" ai_embeddings_discourse_service_api_endpoint_srv: default: "" @@ -225,7 +226,6 @@ discourse_ai: secret: true ai_embeddings_model: type: enum - list_type: compact default: "bge-large-en" allow_any: false choices: @@ -236,6 +236,7 @@ discourse_ai: - multilingual-e5-large - bge-large-en - gemini + validator: "DiscourseAi::Configuration::EmbeddingsModelValidator" ai_embeddings_per_post_enabled: default: false hidden: true diff --git a/lib/configuration/embeddings_model_validator.rb b/lib/configuration/embeddings_model_validator.rb new file mode 100644 index 00000000..099e6168 --- /dev/null +++ b/lib/configuration/embeddings_model_validator.rb @@ -0,0 +1,46 @@ +# frozen_string_literal: true + +module DiscourseAi + module Configuration + class EmbeddingsModelValidator + def initialize(opts = {}) + @opts = opts + end + + def valid_value?(val) + return true if Rails.env.test? + + representation = + DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(val) + + return false if representation.nil? + + if !representation.correctly_configured? + @representation = representation + return false + end + + if !can_generate_embeddings?(val) + @unreachable = true + return false + end + + true + end + + def error_message + return(I18n.t("discourse_ai.embeddings.configuration.model_unreachable")) if @unreachable + + @representation&.configuration_hint + end + + def can_generate_embeddings?(val) + DiscourseAi::Embeddings::VectorRepresentations::Base + .find_representation(val) + .new(DiscourseAi::Embeddings::Strategies::Truncation.new) + .vector_from("this is a test") + .present? + end + end + end +end diff --git a/lib/configuration/embeddings_module_validator.rb b/lib/configuration/embeddings_module_validator.rb new file mode 100644 index 00000000..08db0692 --- /dev/null +++ b/lib/configuration/embeddings_module_validator.rb @@ -0,0 +1,51 @@ +# frozen_string_literal: true + +module DiscourseAi + module Configuration + class EmbeddingsModuleValidator + def initialize(opts = {}) + @opts = opts + end + + def valid_value?(val) + return true if val == "f" + return true if Rails.env.test? + + chosen_model = SiteSetting.ai_embeddings_model + + return false if !chosen_model + + representation = + DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(chosen_model) + + return false if representation.nil? + + if !representation.correctly_configured? + @representation = representation + return false + end + + if !can_generate_embeddings?(chosen_model) + @unreachable = true + return false + end + + true + end + + def error_message + return(I18n.t("discourse_ai.embeddings.configuration.model_unreachable")) if @unreachable + + @representation&.configuration_hint + end + + def can_generate_embeddings?(val) + DiscourseAi::Embeddings::VectorRepresentations::Base + .find_representation(val) + .new(DiscourseAi::Embeddings::Strategies::Truncation.new) + .vector_from("this is a test") + .present? + end + end + end +end diff --git a/lib/embeddings/vector_representations/all_mpnet_base_v2.rb b/lib/embeddings/vector_representations/all_mpnet_base_v2.rb index a8bbe86c..5d5793f2 100644 --- a/lib/embeddings/vector_representations/all_mpnet_base_v2.rb +++ b/lib/embeddings/vector_representations/all_mpnet_base_v2.rb @@ -4,19 +4,34 @@ module DiscourseAi module Embeddings module VectorRepresentations class AllMpnetBaseV2 < Base + class << self + def name + "all-mpnet-base-v2" + end + + def correctly_configured? + SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? || + SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? + end + + def dependant_setting_names + %w[ + ai_embeddings_discourse_service_api_key + ai_embeddings_discourse_service_api_endpoint_srv + ai_embeddings_discourse_service_api_endpoint + ] + end + end + def vector_from(text) DiscourseAi::Inference::DiscourseClassifier.perform!( "#{discourse_embeddings_endpoint}/api/v1/classify", - name, + self.class.name, text, SiteSetting.ai_embeddings_discourse_service_api_key, ) end - def name - "all-mpnet-base-v2" - end - def dimensions 768 end diff --git a/lib/embeddings/vector_representations/base.rb b/lib/embeddings/vector_representations/base.rb index 58226913..fc944cf0 100644 --- a/lib/embeddings/vector_representations/base.rb +++ b/lib/embeddings/vector_representations/base.rb @@ -4,18 +4,41 @@ module DiscourseAi module Embeddings module VectorRepresentations class Base - def self.current_representation(strategy) - # we are explicit here cause the loader may have not - # loaded the subclasses yet - [ - DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2, - DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn, - DiscourseAi::Embeddings::VectorRepresentations::Gemini, - DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large, - DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002, - DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small, - DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large, - ].map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model } + class << self + def find_representation(model_name) + # we are explicit here cause the loader may have not + # loaded the subclasses yet + [ + DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2, + DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn, + DiscourseAi::Embeddings::VectorRepresentations::Gemini, + DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large, + DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002, + DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small, + DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large, + ].find { _1.name == model_name } + end + + def current_representation(strategy) + find_representation(SiteSetting.ai_embeddings_model).new(strategy) + end + + def correctly_configured? + raise NotImplementedError + end + + def dependant_setting_names + raise NotImplementedError + end + + def configuration_hint + settings = dependant_setting_names + I18n.t( + "discourse_ai.embeddings.configuration.hint", + settings: settings.join(", "), + count: settings.length, + ) + end end def initialize(strategy) diff --git a/lib/embeddings/vector_representations/bge_large_en.rb b/lib/embeddings/vector_representations/bge_large_en.rb index f3e24c48..cf7adec4 100644 --- a/lib/embeddings/vector_representations/bge_large_en.rb +++ b/lib/embeddings/vector_representations/bge_large_en.rb @@ -4,6 +4,32 @@ module DiscourseAi module Embeddings module VectorRepresentations class BgeLargeEn < Base + class << self + def name + "bge-large-en" + end + + def correctly_configured? + SiteSetting.ai_cloudflare_workers_api_token.present? || + DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? || + ( + SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? || + SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? + ) + end + + def dependant_setting_names + %w[ + ai_cloudflare_workers_api_token + ai_hugging_face_tei_endpoint_srv + ai_hugging_face_tei_endpoint + ai_embeddings_discourse_service_api_key + ai_embeddings_discourse_service_api_endpoint_srv + ai_embeddings_discourse_service_api_endpoint + ] + end + end + def vector_from(text) if SiteSetting.ai_cloudflare_workers_api_token.present? DiscourseAi::Inference::CloudflareWorkersAi @@ -25,10 +51,6 @@ module DiscourseAi end end - def name - "bge-large-en" - end - def inference_model_name "baai/bge-large-en-v1.5" end diff --git a/lib/embeddings/vector_representations/gemini.rb b/lib/embeddings/vector_representations/gemini.rb index 4b75da49..529f5ac4 100644 --- a/lib/embeddings/vector_representations/gemini.rb +++ b/lib/embeddings/vector_representations/gemini.rb @@ -4,6 +4,20 @@ module DiscourseAi module Embeddings module VectorRepresentations class Gemini < Base + class << self + def name + "gemini" + end + + def correctly_configured? + SiteSetting.ai_gemini_api_key.present? + end + + def dependant_setting_names + %w[ai_gemini_api_key] + end + end + def id 5 end @@ -12,10 +26,6 @@ module DiscourseAi 1 end - def name - "gemini" - end - def dimensions 768 end diff --git a/lib/embeddings/vector_representations/multilingual_e5_large.rb b/lib/embeddings/vector_representations/multilingual_e5_large.rb index 55dfc448..59133263 100644 --- a/lib/embeddings/vector_representations/multilingual_e5_large.rb +++ b/lib/embeddings/vector_representations/multilingual_e5_large.rb @@ -4,6 +4,30 @@ module DiscourseAi module Embeddings module VectorRepresentations class MultilingualE5Large < Base + class << self + def name + "multilingual-e5-large" + end + + def correctly_configured? + DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? || + ( + SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? || + SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? + ) + end + + def dependant_setting_names + %w[ + ai_hugging_face_tei_endpoint_srv + ai_hugging_face_tei_endpoint + ai_embeddings_discourse_service_api_key + ai_embeddings_discourse_service_api_endpoint_srv + ai_embeddings_discourse_service_api_endpoint + ] + end + end + def vector_from(text) if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? truncated_text = tokenizer.truncate(text, max_sequence_length - 2) @@ -11,7 +35,7 @@ module DiscourseAi elsif discourse_embeddings_endpoint.present? DiscourseAi::Inference::DiscourseClassifier.perform!( "#{discourse_embeddings_endpoint}/api/v1/classify", - name, + self.class.name, "query: #{text}", SiteSetting.ai_embeddings_discourse_service_api_key, ) @@ -28,10 +52,6 @@ module DiscourseAi 1 end - def name - "multilingual-e5-large" - end - def dimensions 1024 end diff --git a/lib/embeddings/vector_representations/text_embedding_3_large.rb b/lib/embeddings/vector_representations/text_embedding_3_large.rb index ab88b238..f7d478bf 100644 --- a/lib/embeddings/vector_representations/text_embedding_3_large.rb +++ b/lib/embeddings/vector_representations/text_embedding_3_large.rb @@ -4,6 +4,20 @@ module DiscourseAi module Embeddings module VectorRepresentations class TextEmbedding3Large < Base + class << self + def name + "text-embedding-3-large" + end + + def correctly_configured? + SiteSetting.ai_openai_api_key.present? + end + + def dependant_setting_names + %w[ai_openai_api_key] + end + end + def id 7 end @@ -12,10 +26,6 @@ module DiscourseAi 1 end - def name - "text-embedding-3-large" - end - def dimensions # real dimentions are 3072, but we only support up to 2000 in the # indexes, so we downsample to 2000 via API @@ -38,7 +48,7 @@ module DiscourseAi response = DiscourseAi::Inference::OpenAiEmbeddings.perform!( text, - model: name, + model: self.clas.name, dimensions: dimensions, ) response[:data].first[:embedding] diff --git a/lib/embeddings/vector_representations/text_embedding_3_small.rb b/lib/embeddings/vector_representations/text_embedding_3_small.rb index a32544d9..842ed183 100644 --- a/lib/embeddings/vector_representations/text_embedding_3_small.rb +++ b/lib/embeddings/vector_representations/text_embedding_3_small.rb @@ -4,6 +4,20 @@ module DiscourseAi module Embeddings module VectorRepresentations class TextEmbedding3Small < Base + class << self + def name + "text-embedding-3-small" + end + + def correctly_configured? + SiteSetting.ai_openai_api_key.present? + end + + def dependant_setting_names + %w[ai_openai_api_key] + end + end + def id 6 end @@ -12,10 +26,6 @@ module DiscourseAi 1 end - def name - "text-embedding-3-small" - end - def dimensions 1536 end @@ -33,7 +43,7 @@ module DiscourseAi end def vector_from(text) - response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: name) + response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name) response[:data].first[:embedding] end diff --git a/lib/embeddings/vector_representations/text_embedding_ada_002.rb b/lib/embeddings/vector_representations/text_embedding_ada_002.rb index 2bce079b..a5bbe1ac 100644 --- a/lib/embeddings/vector_representations/text_embedding_ada_002.rb +++ b/lib/embeddings/vector_representations/text_embedding_ada_002.rb @@ -4,6 +4,20 @@ module DiscourseAi module Embeddings module VectorRepresentations class TextEmbeddingAda002 < Base + class << self + def name + "text-embedding-ada-002" + end + + def correctly_configured? + SiteSetting.ai_openai_api_key.present? + end + + def dependant_setting_names + %w[ai_openai_api_key] + end + end + def id 2 end @@ -12,10 +26,6 @@ module DiscourseAi 1 end - def name - "text-embedding-ada-002" - end - def dimensions 1536 end @@ -33,7 +43,7 @@ module DiscourseAi end def vector_from(text) - response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: name) + response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name) response[:data].first[:embedding] end diff --git a/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb b/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb index 78aee2ae..38f6fb0d 100644 --- a/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb +++ b/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb @@ -7,7 +7,6 @@ RSpec.describe Jobs::GenerateEmbeddings do before do SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" SiteSetting.ai_embeddings_enabled = true - SiteSetting.ai_embeddings_model = "bge-large-en" end fab!(:topic) { Fabricate(:topic) } @@ -27,7 +26,7 @@ RSpec.describe Jobs::GenerateEmbeddings do vector_rep.tokenizer, vector_rep.max_sequence_length - 2, ) - EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding) + EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding) job.execute(target_id: topic.id, target_type: "Topic") @@ -39,7 +38,7 @@ RSpec.describe Jobs::GenerateEmbeddings do text = truncation.prepare_text_from(post, vector_rep.tokenizer, vector_rep.max_sequence_length - 2) - EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding) + EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding) job.execute(target_id: post.id, target_type: "Post") diff --git a/spec/lib/modules/embeddings/vector_representations/all_mpnet_base_v2_spec.rb b/spec/lib/modules/embeddings/vector_representations/all_mpnet_base_v2_spec.rb index 890e69c1..c1cefe04 100644 --- a/spec/lib/modules/embeddings/vector_representations/all_mpnet_base_v2_spec.rb +++ b/spec/lib/modules/embeddings/vector_representations/all_mpnet_base_v2_spec.rb @@ -10,7 +10,7 @@ RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 do before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" } def stub_vector_mapping(text, expected_embedding) - EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding) + EmbeddingsGenerationStubs.discourse_service(described_class.name, text, expected_embedding) end it_behaves_like "generates and store embedding using with vector representation" diff --git a/spec/lib/modules/embeddings/vector_representations/multilingual_e5_large_spec.rb b/spec/lib/modules/embeddings/vector_representations/multilingual_e5_large_spec.rb index 01930028..e7af5eba 100644 --- a/spec/lib/modules/embeddings/vector_representations/multilingual_e5_large_spec.rb +++ b/spec/lib/modules/embeddings/vector_representations/multilingual_e5_large_spec.rb @@ -11,7 +11,7 @@ RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Lar def stub_vector_mapping(text, expected_embedding) EmbeddingsGenerationStubs.discourse_service( - vector_rep.name, + described_class.name, "query: #{text}", expected_embedding, ) diff --git a/spec/lib/modules/embeddings/vector_representations/text_embedding_ada_002_spec.rb b/spec/lib/modules/embeddings/vector_representations/text_embedding_ada_002_spec.rb index 9d014cb4..ed5a80fc 100644 --- a/spec/lib/modules/embeddings/vector_representations/text_embedding_ada_002_spec.rb +++ b/spec/lib/modules/embeddings/vector_representations/text_embedding_ada_002_spec.rb @@ -8,7 +8,7 @@ RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda0 let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new } def stub_vector_mapping(text, expected_embedding) - EmbeddingsGenerationStubs.openai_service(vector_rep.name, text, expected_embedding) + EmbeddingsGenerationStubs.openai_service(described_class.name, text, expected_embedding) end it_behaves_like "generates and store embedding using with vector representation"