UX: Re-introduce embedding settings validations (#457)

* Revert "Revert "UX: Validate embeddings settings (#455)" (#456)"

This reverts commit 392e2e8aef.

* Resstore previous default
This commit is contained in:
Roman Rizzi 2024-02-01 16:54:09 -03:00 committed by GitHub
parent 392e2e8aef
commit fba9c1bf2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 278 additions and 52 deletions

View File

@ -251,3 +251,12 @@ en:
configuration_hint:
one: "Make sure the `%{settings}` setting was configured."
other: "Make sure these settings were configured: %{settings}"
embeddings:
configuration:
disable_embeddings: "You have to disable 'ai embeddings enabled' first."
choose_model: "Set 'ai embeddings model' first."
model_unreachable: "We failed to generate a test embedding with this model. Check your settings are correct."
hint:
one: "Make sure the `%{settings}` setting was configured."
other: "Make sure the settings of the provider you want were configured. Options are: %{settings}"

View File

@ -216,6 +216,7 @@ discourse_ai:
ai_embeddings_enabled:
default: false
client: true
validator: "DiscourseAi::Configuration::EmbeddingsModuleValidator"
ai_embeddings_discourse_service_api_endpoint: ""
ai_embeddings_discourse_service_api_endpoint_srv:
default: ""
@ -225,7 +226,6 @@ discourse_ai:
secret: true
ai_embeddings_model:
type: enum
list_type: compact
default: "bge-large-en"
allow_any: false
choices:
@ -236,6 +236,7 @@ discourse_ai:
- multilingual-e5-large
- bge-large-en
- gemini
validator: "DiscourseAi::Configuration::EmbeddingsModelValidator"
ai_embeddings_per_post_enabled:
default: false
hidden: true

View File

@ -0,0 +1,46 @@
# frozen_string_literal: true
module DiscourseAi
module Configuration
class EmbeddingsModelValidator
def initialize(opts = {})
@opts = opts
end
def valid_value?(val)
return true if Rails.env.test?
representation =
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(val)
return false if representation.nil?
if !representation.correctly_configured?
@representation = representation
return false
end
if !can_generate_embeddings?(val)
@unreachable = true
return false
end
true
end
def error_message
return(I18n.t("discourse_ai.embeddings.configuration.model_unreachable")) if @unreachable
@representation&.configuration_hint
end
def can_generate_embeddings?(val)
DiscourseAi::Embeddings::VectorRepresentations::Base
.find_representation(val)
.new(DiscourseAi::Embeddings::Strategies::Truncation.new)
.vector_from("this is a test")
.present?
end
end
end
end

View File

@ -0,0 +1,51 @@
# frozen_string_literal: true
module DiscourseAi
module Configuration
class EmbeddingsModuleValidator
def initialize(opts = {})
@opts = opts
end
def valid_value?(val)
return true if val == "f"
return true if Rails.env.test?
chosen_model = SiteSetting.ai_embeddings_model
return false if !chosen_model
representation =
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(chosen_model)
return false if representation.nil?
if !representation.correctly_configured?
@representation = representation
return false
end
if !can_generate_embeddings?(chosen_model)
@unreachable = true
return false
end
true
end
def error_message
return(I18n.t("discourse_ai.embeddings.configuration.model_unreachable")) if @unreachable
@representation&.configuration_hint
end
def can_generate_embeddings?(val)
DiscourseAi::Embeddings::VectorRepresentations::Base
.find_representation(val)
.new(DiscourseAi::Embeddings::Strategies::Truncation.new)
.vector_from("this is a test")
.present?
end
end
end
end

View File

@ -4,19 +4,34 @@ module DiscourseAi
module Embeddings
module VectorRepresentations
class AllMpnetBaseV2 < Base
class << self
def name
"all-mpnet-base-v2"
end
def correctly_configured?
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
end
def dependant_setting_names
%w[
ai_embeddings_discourse_service_api_key
ai_embeddings_discourse_service_api_endpoint_srv
ai_embeddings_discourse_service_api_endpoint
]
end
end
def vector_from(text)
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{discourse_embeddings_endpoint}/api/v1/classify",
name,
self.class.name,
text,
SiteSetting.ai_embeddings_discourse_service_api_key,
)
end
def name
"all-mpnet-base-v2"
end
def dimensions
768
end

View File

@ -4,7 +4,8 @@ module DiscourseAi
module Embeddings
module VectorRepresentations
class Base
def self.current_representation(strategy)
class << self
def find_representation(model_name)
# we are explicit here cause the loader may have not
# loaded the subclasses yet
[
@ -15,7 +16,29 @@ module DiscourseAi
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
].map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model }
].find { _1.name == model_name }
end
def current_representation(strategy)
find_representation(SiteSetting.ai_embeddings_model).new(strategy)
end
def correctly_configured?
raise NotImplementedError
end
def dependant_setting_names
raise NotImplementedError
end
def configuration_hint
settings = dependant_setting_names
I18n.t(
"discourse_ai.embeddings.configuration.hint",
settings: settings.join(", "),
count: settings.length,
)
end
end
def initialize(strategy)

View File

@ -4,6 +4,32 @@ module DiscourseAi
module Embeddings
module VectorRepresentations
class BgeLargeEn < Base
class << self
def name
"bge-large-en"
end
def correctly_configured?
SiteSetting.ai_cloudflare_workers_api_token.present? ||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
)
end
def dependant_setting_names
%w[
ai_cloudflare_workers_api_token
ai_hugging_face_tei_endpoint_srv
ai_hugging_face_tei_endpoint
ai_embeddings_discourse_service_api_key
ai_embeddings_discourse_service_api_endpoint_srv
ai_embeddings_discourse_service_api_endpoint
]
end
end
def vector_from(text)
if SiteSetting.ai_cloudflare_workers_api_token.present?
DiscourseAi::Inference::CloudflareWorkersAi
@ -25,10 +51,6 @@ module DiscourseAi
end
end
def name
"bge-large-en"
end
def inference_model_name
"baai/bge-large-en-v1.5"
end

View File

@ -4,6 +4,20 @@ module DiscourseAi
module Embeddings
module VectorRepresentations
class Gemini < Base
class << self
def name
"gemini"
end
def correctly_configured?
SiteSetting.ai_gemini_api_key.present?
end
def dependant_setting_names
%w[ai_gemini_api_key]
end
end
def id
5
end
@ -12,10 +26,6 @@ module DiscourseAi
1
end
def name
"gemini"
end
def dimensions
768
end

View File

@ -4,6 +4,30 @@ module DiscourseAi
module Embeddings
module VectorRepresentations
class MultilingualE5Large < Base
class << self
def name
"multilingual-e5-large"
end
def correctly_configured?
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
)
end
def dependant_setting_names
%w[
ai_hugging_face_tei_endpoint_srv
ai_hugging_face_tei_endpoint
ai_embeddings_discourse_service_api_key
ai_embeddings_discourse_service_api_endpoint_srv
ai_embeddings_discourse_service_api_endpoint
]
end
end
def vector_from(text)
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
@ -11,7 +35,7 @@ module DiscourseAi
elsif discourse_embeddings_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{discourse_embeddings_endpoint}/api/v1/classify",
name,
self.class.name,
"query: #{text}",
SiteSetting.ai_embeddings_discourse_service_api_key,
)
@ -28,10 +52,6 @@ module DiscourseAi
1
end
def name
"multilingual-e5-large"
end
def dimensions
1024
end

View File

@ -4,6 +4,20 @@ module DiscourseAi
module Embeddings
module VectorRepresentations
class TextEmbedding3Large < Base
class << self
def name
"text-embedding-3-large"
end
def correctly_configured?
SiteSetting.ai_openai_api_key.present?
end
def dependant_setting_names
%w[ai_openai_api_key]
end
end
def id
7
end
@ -12,10 +26,6 @@ module DiscourseAi
1
end
def name
"text-embedding-3-large"
end
def dimensions
# real dimentions are 3072, but we only support up to 2000 in the
# indexes, so we downsample to 2000 via API
@ -38,7 +48,7 @@ module DiscourseAi
response =
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
text,
model: name,
model: self.clas.name,
dimensions: dimensions,
)
response[:data].first[:embedding]

View File

@ -4,6 +4,20 @@ module DiscourseAi
module Embeddings
module VectorRepresentations
class TextEmbedding3Small < Base
class << self
def name
"text-embedding-3-small"
end
def correctly_configured?
SiteSetting.ai_openai_api_key.present?
end
def dependant_setting_names
%w[ai_openai_api_key]
end
end
def id
6
end
@ -12,10 +26,6 @@ module DiscourseAi
1
end
def name
"text-embedding-3-small"
end
def dimensions
1536
end
@ -33,7 +43,7 @@ module DiscourseAi
end
def vector_from(text)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: name)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
response[:data].first[:embedding]
end

View File

@ -4,6 +4,20 @@ module DiscourseAi
module Embeddings
module VectorRepresentations
class TextEmbeddingAda002 < Base
class << self
def name
"text-embedding-ada-002"
end
def correctly_configured?
SiteSetting.ai_openai_api_key.present?
end
def dependant_setting_names
%w[ai_openai_api_key]
end
end
def id
2
end
@ -12,10 +26,6 @@ module DiscourseAi
1
end
def name
"text-embedding-ada-002"
end
def dimensions
1536
end
@ -33,7 +43,7 @@ module DiscourseAi
end
def vector_from(text)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: name)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
response[:data].first[:embedding]
end

View File

@ -7,7 +7,6 @@ RSpec.describe Jobs::GenerateEmbeddings do
before do
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_model = "bge-large-en"
end
fab!(:topic) { Fabricate(:topic) }
@ -27,7 +26,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding)
job.execute(target_id: topic.id, target_type: "Topic")
@ -39,7 +38,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
text =
truncation.prepare_text_from(post, vector_rep.tokenizer, vector_rep.max_sequence_length - 2)
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding)
job.execute(target_id: post.id, target_type: "Post")

View File

@ -10,7 +10,7 @@ RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 do
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(described_class.name, text, expected_embedding)
end
it_behaves_like "generates and store embedding using with vector representation"

View File

@ -11,7 +11,7 @@ RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Lar
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(
vector_rep.name,
described_class.name,
"query: #{text}",
expected_embedding,
)

View File

@ -8,7 +8,7 @@ RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda0
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(vector_rep.name, text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(described_class.name, text, expected_embedding)
end
it_behaves_like "generates and store embedding using with vector representation"