UX: Re-introduce embedding settings validations (#457)
* Revert "Revert "UX: Validate embeddings settings (#455)" (#456)"
This reverts commit 392e2e8aef
.
* Resstore previous default
This commit is contained in:
parent
392e2e8aef
commit
fba9c1bf2c
|
@ -251,3 +251,12 @@ en:
|
|||
configuration_hint:
|
||||
one: "Make sure the `%{settings}` setting was configured."
|
||||
other: "Make sure these settings were configured: %{settings}"
|
||||
|
||||
embeddings:
|
||||
configuration:
|
||||
disable_embeddings: "You have to disable 'ai embeddings enabled' first."
|
||||
choose_model: "Set 'ai embeddings model' first."
|
||||
model_unreachable: "We failed to generate a test embedding with this model. Check your settings are correct."
|
||||
hint:
|
||||
one: "Make sure the `%{settings}` setting was configured."
|
||||
other: "Make sure the settings of the provider you want were configured. Options are: %{settings}"
|
||||
|
|
|
@ -216,6 +216,7 @@ discourse_ai:
|
|||
ai_embeddings_enabled:
|
||||
default: false
|
||||
client: true
|
||||
validator: "DiscourseAi::Configuration::EmbeddingsModuleValidator"
|
||||
ai_embeddings_discourse_service_api_endpoint: ""
|
||||
ai_embeddings_discourse_service_api_endpoint_srv:
|
||||
default: ""
|
||||
|
@ -225,7 +226,6 @@ discourse_ai:
|
|||
secret: true
|
||||
ai_embeddings_model:
|
||||
type: enum
|
||||
list_type: compact
|
||||
default: "bge-large-en"
|
||||
allow_any: false
|
||||
choices:
|
||||
|
@ -236,6 +236,7 @@ discourse_ai:
|
|||
- multilingual-e5-large
|
||||
- bge-large-en
|
||||
- gemini
|
||||
validator: "DiscourseAi::Configuration::EmbeddingsModelValidator"
|
||||
ai_embeddings_per_post_enabled:
|
||||
default: false
|
||||
hidden: true
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Configuration
|
||||
class EmbeddingsModelValidator
|
||||
def initialize(opts = {})
|
||||
@opts = opts
|
||||
end
|
||||
|
||||
def valid_value?(val)
|
||||
return true if Rails.env.test?
|
||||
|
||||
representation =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(val)
|
||||
|
||||
return false if representation.nil?
|
||||
|
||||
if !representation.correctly_configured?
|
||||
@representation = representation
|
||||
return false
|
||||
end
|
||||
|
||||
if !can_generate_embeddings?(val)
|
||||
@unreachable = true
|
||||
return false
|
||||
end
|
||||
|
||||
true
|
||||
end
|
||||
|
||||
def error_message
|
||||
return(I18n.t("discourse_ai.embeddings.configuration.model_unreachable")) if @unreachable
|
||||
|
||||
@representation&.configuration_hint
|
||||
end
|
||||
|
||||
def can_generate_embeddings?(val)
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base
|
||||
.find_representation(val)
|
||||
.new(DiscourseAi::Embeddings::Strategies::Truncation.new)
|
||||
.vector_from("this is a test")
|
||||
.present?
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,51 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Configuration
|
||||
class EmbeddingsModuleValidator
|
||||
def initialize(opts = {})
|
||||
@opts = opts
|
||||
end
|
||||
|
||||
def valid_value?(val)
|
||||
return true if val == "f"
|
||||
return true if Rails.env.test?
|
||||
|
||||
chosen_model = SiteSetting.ai_embeddings_model
|
||||
|
||||
return false if !chosen_model
|
||||
|
||||
representation =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(chosen_model)
|
||||
|
||||
return false if representation.nil?
|
||||
|
||||
if !representation.correctly_configured?
|
||||
@representation = representation
|
||||
return false
|
||||
end
|
||||
|
||||
if !can_generate_embeddings?(chosen_model)
|
||||
@unreachable = true
|
||||
return false
|
||||
end
|
||||
|
||||
true
|
||||
end
|
||||
|
||||
def error_message
|
||||
return(I18n.t("discourse_ai.embeddings.configuration.model_unreachable")) if @unreachable
|
||||
|
||||
@representation&.configuration_hint
|
||||
end
|
||||
|
||||
def can_generate_embeddings?(val)
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base
|
||||
.find_representation(val)
|
||||
.new(DiscourseAi::Embeddings::Strategies::Truncation.new)
|
||||
.vector_from("this is a test")
|
||||
.present?
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -4,19 +4,34 @@ module DiscourseAi
|
|||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class AllMpnetBaseV2 < Base
|
||||
class << self
|
||||
def name
|
||||
"all-mpnet-base-v2"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[
|
||||
ai_embeddings_discourse_service_api_key
|
||||
ai_embeddings_discourse_service_api_endpoint_srv
|
||||
ai_embeddings_discourse_service_api_endpoint
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def vector_from(text)
|
||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||
"#{discourse_embeddings_endpoint}/api/v1/classify",
|
||||
name,
|
||||
self.class.name,
|
||||
text,
|
||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||
)
|
||||
end
|
||||
|
||||
def name
|
||||
"all-mpnet-base-v2"
|
||||
end
|
||||
|
||||
def dimensions
|
||||
768
|
||||
end
|
||||
|
|
|
@ -4,7 +4,8 @@ module DiscourseAi
|
|||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class Base
|
||||
def self.current_representation(strategy)
|
||||
class << self
|
||||
def find_representation(model_name)
|
||||
# we are explicit here cause the loader may have not
|
||||
# loaded the subclasses yet
|
||||
[
|
||||
|
@ -15,7 +16,29 @@ module DiscourseAi
|
|||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
|
||||
].map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model }
|
||||
].find { _1.name == model_name }
|
||||
end
|
||||
|
||||
def current_representation(strategy)
|
||||
find_representation(SiteSetting.ai_embeddings_model).new(strategy)
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def configuration_hint
|
||||
settings = dependant_setting_names
|
||||
I18n.t(
|
||||
"discourse_ai.embeddings.configuration.hint",
|
||||
settings: settings.join(", "),
|
||||
count: settings.length,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
def initialize(strategy)
|
||||
|
|
|
@ -4,6 +4,32 @@ module DiscourseAi
|
|||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class BgeLargeEn < Base
|
||||
class << self
|
||||
def name
|
||||
"bge-large-en"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
SiteSetting.ai_cloudflare_workers_api_token.present? ||
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
|
||||
(
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||
)
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[
|
||||
ai_cloudflare_workers_api_token
|
||||
ai_hugging_face_tei_endpoint_srv
|
||||
ai_hugging_face_tei_endpoint
|
||||
ai_embeddings_discourse_service_api_key
|
||||
ai_embeddings_discourse_service_api_endpoint_srv
|
||||
ai_embeddings_discourse_service_api_endpoint
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def vector_from(text)
|
||||
if SiteSetting.ai_cloudflare_workers_api_token.present?
|
||||
DiscourseAi::Inference::CloudflareWorkersAi
|
||||
|
@ -25,10 +51,6 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def name
|
||||
"bge-large-en"
|
||||
end
|
||||
|
||||
def inference_model_name
|
||||
"baai/bge-large-en-v1.5"
|
||||
end
|
||||
|
|
|
@ -4,6 +4,20 @@ module DiscourseAi
|
|||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class Gemini < Base
|
||||
class << self
|
||||
def name
|
||||
"gemini"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
SiteSetting.ai_gemini_api_key.present?
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_gemini_api_key]
|
||||
end
|
||||
end
|
||||
|
||||
def id
|
||||
5
|
||||
end
|
||||
|
@ -12,10 +26,6 @@ module DiscourseAi
|
|||
1
|
||||
end
|
||||
|
||||
def name
|
||||
"gemini"
|
||||
end
|
||||
|
||||
def dimensions
|
||||
768
|
||||
end
|
||||
|
|
|
@ -4,6 +4,30 @@ module DiscourseAi
|
|||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class MultilingualE5Large < Base
|
||||
class << self
|
||||
def name
|
||||
"multilingual-e5-large"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
|
||||
(
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
|
||||
)
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[
|
||||
ai_hugging_face_tei_endpoint_srv
|
||||
ai_hugging_face_tei_endpoint
|
||||
ai_embeddings_discourse_service_api_key
|
||||
ai_embeddings_discourse_service_api_endpoint_srv
|
||||
ai_embeddings_discourse_service_api_endpoint
|
||||
]
|
||||
end
|
||||
end
|
||||
|
||||
def vector_from(text)
|
||||
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
|
||||
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
|
||||
|
@ -11,7 +35,7 @@ module DiscourseAi
|
|||
elsif discourse_embeddings_endpoint.present?
|
||||
DiscourseAi::Inference::DiscourseClassifier.perform!(
|
||||
"#{discourse_embeddings_endpoint}/api/v1/classify",
|
||||
name,
|
||||
self.class.name,
|
||||
"query: #{text}",
|
||||
SiteSetting.ai_embeddings_discourse_service_api_key,
|
||||
)
|
||||
|
@ -28,10 +52,6 @@ module DiscourseAi
|
|||
1
|
||||
end
|
||||
|
||||
def name
|
||||
"multilingual-e5-large"
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1024
|
||||
end
|
||||
|
|
|
@ -4,6 +4,20 @@ module DiscourseAi
|
|||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class TextEmbedding3Large < Base
|
||||
class << self
|
||||
def name
|
||||
"text-embedding-3-large"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
SiteSetting.ai_openai_api_key.present?
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_openai_api_key]
|
||||
end
|
||||
end
|
||||
|
||||
def id
|
||||
7
|
||||
end
|
||||
|
@ -12,10 +26,6 @@ module DiscourseAi
|
|||
1
|
||||
end
|
||||
|
||||
def name
|
||||
"text-embedding-3-large"
|
||||
end
|
||||
|
||||
def dimensions
|
||||
# real dimentions are 3072, but we only support up to 2000 in the
|
||||
# indexes, so we downsample to 2000 via API
|
||||
|
@ -38,7 +48,7 @@ module DiscourseAi
|
|||
response =
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
|
||||
text,
|
||||
model: name,
|
||||
model: self.clas.name,
|
||||
dimensions: dimensions,
|
||||
)
|
||||
response[:data].first[:embedding]
|
||||
|
|
|
@ -4,6 +4,20 @@ module DiscourseAi
|
|||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class TextEmbedding3Small < Base
|
||||
class << self
|
||||
def name
|
||||
"text-embedding-3-small"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
SiteSetting.ai_openai_api_key.present?
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_openai_api_key]
|
||||
end
|
||||
end
|
||||
|
||||
def id
|
||||
6
|
||||
end
|
||||
|
@ -12,10 +26,6 @@ module DiscourseAi
|
|||
1
|
||||
end
|
||||
|
||||
def name
|
||||
"text-embedding-3-small"
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1536
|
||||
end
|
||||
|
@ -33,7 +43,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def vector_from(text)
|
||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: name)
|
||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
|
||||
response[:data].first[:embedding]
|
||||
end
|
||||
|
||||
|
|
|
@ -4,6 +4,20 @@ module DiscourseAi
|
|||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class TextEmbeddingAda002 < Base
|
||||
class << self
|
||||
def name
|
||||
"text-embedding-ada-002"
|
||||
end
|
||||
|
||||
def correctly_configured?
|
||||
SiteSetting.ai_openai_api_key.present?
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_openai_api_key]
|
||||
end
|
||||
end
|
||||
|
||||
def id
|
||||
2
|
||||
end
|
||||
|
@ -12,10 +26,6 @@ module DiscourseAi
|
|||
1
|
||||
end
|
||||
|
||||
def name
|
||||
"text-embedding-ada-002"
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1536
|
||||
end
|
||||
|
@ -33,7 +43,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def vector_from(text)
|
||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: name)
|
||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
|
||||
response[:data].first[:embedding]
|
||||
end
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ RSpec.describe Jobs::GenerateEmbeddings do
|
|||
before do
|
||||
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
SiteSetting.ai_embeddings_model = "bge-large-en"
|
||||
end
|
||||
|
||||
fab!(:topic) { Fabricate(:topic) }
|
||||
|
@ -27,7 +26,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
|
|||
vector_rep.tokenizer,
|
||||
vector_rep.max_sequence_length - 2,
|
||||
)
|
||||
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding)
|
||||
|
||||
job.execute(target_id: topic.id, target_type: "Topic")
|
||||
|
||||
|
@ -39,7 +38,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
|
|||
|
||||
text =
|
||||
truncation.prepare_text_from(post, vector_rep.tokenizer, vector_rep.max_sequence_length - 2)
|
||||
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding)
|
||||
|
||||
job.execute(target_id: post.id, target_type: "Post")
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 do
|
|||
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.discourse_service(described_class.name, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embedding using with vector representation"
|
||||
|
|
|
@ -11,7 +11,7 @@ RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Lar
|
|||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.discourse_service(
|
||||
vector_rep.name,
|
||||
described_class.name,
|
||||
"query: #{text}",
|
||||
expected_embedding,
|
||||
)
|
||||
|
|
|
@ -8,7 +8,7 @@ RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda0
|
|||
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.openai_service(vector_rep.name, text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.openai_service(described_class.name, text, expected_embedding)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embedding using with vector representation"
|
||||
|
|
Loading…
Reference in New Issue