UX: Validate embeddings settings (#455)

This commit is contained in:
Roman Rizzi 2024-02-01 13:05:38 -03:00 committed by GitHub
parent cec4251b00
commit 85fca89e01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 314 additions and 66 deletions

View File

@ -251,3 +251,12 @@ en:
configuration_hint:
one: "Make sure the `%{settings}` setting was configured."
other: "Make sure these settings were configured: %{settings}"
embeddings:
configuration:
disable_embeddings: "You have to disable 'ai embeddings enabled' first."
choose_model: "Set 'ai embeddings model' first."
model_unreachable: "We failed to generate a test embedding with this model. Check your settings are correct."
hint:
one: "Make sure the `%{settings}` setting was configured."
other: "Make sure the settings of the provider you want were configured. Options are: %{settings}"

View File

@ -216,6 +216,7 @@ discourse_ai:
ai_embeddings_enabled:
default: false
client: true
validator: "DiscourseAi::Configuration::EmbeddingsDependencyValidator"
ai_embeddings_discourse_service_api_endpoint: ""
ai_embeddings_discourse_service_api_endpoint_srv:
default: ""
@ -225,17 +226,10 @@ discourse_ai:
secret: true
ai_embeddings_model:
type: enum
list_type: compact
default: "bge-large-en"
default: ""
allow_any: false
choices:
- all-mpnet-base-v2
- text-embedding-ada-002
- text-embedding-3-small
- text-embedding-3-large
- multilingual-e5-large
- bge-large-en
- gemini
enum: "DiscourseAi::Configuration::EmbeddingsModelEnumerator"
validator: "DiscourseAi::Configuration::EmbeddingsModelValidator"
ai_embeddings_per_post_enabled:
default: false
hidden: true

View File

@ -0,0 +1,21 @@
# frozen_string_literal: true
module DiscourseAi
module Configuration
class EmbeddingsDependencyValidator
def initialize(opts = {})
@opts = opts
end
def valid_value?(val)
return true if val == "f"
SiteSetting.ai_embeddings_model.present?
end
def error_message
I18n.t("discourse_ai.embeddings.configuration.choose_model")
end
end
end
end

View File

@ -0,0 +1,25 @@
# frozen_string_literal: true
require "enum_site_setting"
module DiscourseAi
module Configuration
class EmbeddingsModelEnumerator < ::EnumSiteSetting
def self.valid_value?(val)
true
end
def self.values
%w[
all-mpnet-base-v2
text-embedding-ada-002
text-embedding-3-small
text-embedding-3-large
multilingual-e5-large
bge-large-en
gemini
]
end
end
end
end

View File

@ -0,0 +1,56 @@
# frozen_string_literal: true
module DiscourseAi
module Configuration
class EmbeddingsModelValidator
def initialize(opts = {})
@opts = opts
end
def valid_value?(val)
if val == ""
@embeddings_enabled = SiteSetting.ai_embeddings_enabled
return !@embeddings_enabled
end
representation =
DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(val)
return false if representation.nil?
# Skip config for tests. We stub embeddings generation anyway.
return true if Rails.env.test? && val
if !representation.correctly_configured?
@representation = representation
return false
end
if !can_generate_embeddings?(val)
@unreachable = true
return false
end
true
end
def error_message
if @embeddings_enabled
return(I18n.t("discourse_ai.embeddings.configuration.disable_embeddings"))
end
return(I18n.t("discourse_ai.embeddings.configuration.model_unreachable")) if @unreachable
@representation&.configuration_hint
end
def can_generate_embeddings?(val)
DiscourseAi::Embeddings::VectorRepresentations::Base
.find_representation(val)
.new(DiscourseAi::Embeddings::Strategies::Truncation.new)
.vector_from("this is a test")
.present?
end
end
end
end

View File

@ -4,19 +4,34 @@ module DiscourseAi
module Embeddings
module VectorRepresentations
class AllMpnetBaseV2 < Base
class << self
def name
"all-mpnet-base-v2"
end
def correctly_configured?
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
end
def dependant_setting_names
%w[
ai_embeddings_discourse_service_api_key
ai_embeddings_discourse_service_api_endpoint_srv
ai_embeddings_discourse_service_api_endpoint
]
end
end
def vector_from(text)
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{discourse_embeddings_endpoint}/api/v1/classify",
name,
self.class.name,
text,
SiteSetting.ai_embeddings_discourse_service_api_key,
)
end
def name
"all-mpnet-base-v2"
end
def dimensions
768
end

View File

@ -4,18 +4,41 @@ module DiscourseAi
module Embeddings
module VectorRepresentations
class Base
def self.current_representation(strategy)
# we are explicit here cause the loader may have not
# loaded the subclasses yet
[
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
].map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model }
class << self
def find_representation(model_name)
# we are explicit here cause the loader may have not
# loaded the subclasses yet
[
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
].find { _1.name == model_name }
end
def current_representation(strategy)
find_representation(SiteSetting.ai_embeddings_model).new(strategy)
end
def correctly_configured?
raise NotImplementedError
end
def dependant_setting_names
raise NotImplementedError
end
def configuration_hint
settings = dependant_setting_names
I18n.t(
"discourse_ai.embeddings.configuration.hint",
settings: settings.join(", "),
count: settings.length,
)
end
end
def initialize(strategy)

View File

@ -4,6 +4,32 @@ module DiscourseAi
module Embeddings
module VectorRepresentations
class BgeLargeEn < Base
class << self
def name
"bge-large-en"
end
def correctly_configured?
SiteSetting.ai_cloudflare_workers_api_token.present? ||
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
)
end
def dependant_setting_names
%w[
ai_cloudflare_workers_api_token
ai_hugging_face_tei_endpoint_srv
ai_hugging_face_tei_endpoint
ai_embeddings_discourse_service_api_key
ai_embeddings_discourse_service_api_endpoint_srv
ai_embeddings_discourse_service_api_endpoint
]
end
end
def vector_from(text)
if SiteSetting.ai_cloudflare_workers_api_token.present?
DiscourseAi::Inference::CloudflareWorkersAi
@ -25,10 +51,6 @@ module DiscourseAi
end
end
def name
"bge-large-en"
end
def inference_model_name
"baai/bge-large-en-v1.5"
end

View File

@ -4,6 +4,20 @@ module DiscourseAi
module Embeddings
module VectorRepresentations
class Gemini < Base
class << self
def name
"gemini"
end
def correctly_configured?
SiteSetting.ai_gemini_api_key.present?
end
def dependant_setting_names
%w[ai_gemini_api_key]
end
end
def id
5
end
@ -12,10 +26,6 @@ module DiscourseAi
1
end
def name
"gemini"
end
def dimensions
768
end

View File

@ -4,6 +4,30 @@ module DiscourseAi
module Embeddings
module VectorRepresentations
class MultilingualE5Large < Base
class << self
def name
"multilingual-e5-large"
end
def correctly_configured?
DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? ||
(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? ||
SiteSetting.ai_embeddings_discourse_service_api_endpoint.present?
)
end
def dependant_setting_names
%w[
ai_hugging_face_tei_endpoint_srv
ai_hugging_face_tei_endpoint
ai_embeddings_discourse_service_api_key
ai_embeddings_discourse_service_api_endpoint_srv
ai_embeddings_discourse_service_api_endpoint
]
end
end
def vector_from(text)
if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured?
truncated_text = tokenizer.truncate(text, max_sequence_length - 2)
@ -11,7 +35,7 @@ module DiscourseAi
elsif discourse_embeddings_endpoint.present?
DiscourseAi::Inference::DiscourseClassifier.perform!(
"#{discourse_embeddings_endpoint}/api/v1/classify",
name,
self.class.name,
"query: #{text}",
SiteSetting.ai_embeddings_discourse_service_api_key,
)
@ -28,10 +52,6 @@ module DiscourseAi
1
end
def name
"multilingual-e5-large"
end
def dimensions
1024
end

View File

@ -4,6 +4,20 @@ module DiscourseAi
module Embeddings
module VectorRepresentations
class TextEmbedding3Large < Base
class << self
def name
"text-embedding-3-large"
end
def correctly_configured?
SiteSetting.ai_openai_api_key.present?
end
def dependant_setting_names
%w[ai_openai_api_key]
end
end
def id
7
end
@ -12,10 +26,6 @@ module DiscourseAi
1
end
def name
"text-embedding-3-large"
end
def dimensions
# real dimentions are 3072, but we only support up to 2000 in the
# indexes, so we downsample to 2000 via API
@ -38,7 +48,7 @@ module DiscourseAi
response =
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
text,
model: name,
model: self.clas.name,
dimensions: dimensions,
)
response[:data].first[:embedding]

View File

@ -4,6 +4,20 @@ module DiscourseAi
module Embeddings
module VectorRepresentations
class TextEmbedding3Small < Base
class << self
def name
"text-embedding-3-small"
end
def correctly_configured?
SiteSetting.ai_openai_api_key.present?
end
def dependant_setting_names
%w[ai_openai_api_key]
end
end
def id
6
end
@ -12,10 +26,6 @@ module DiscourseAi
1
end
def name
"text-embedding-3-small"
end
def dimensions
1536
end
@ -33,7 +43,7 @@ module DiscourseAi
end
def vector_from(text)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: name)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
response[:data].first[:embedding]
end

View File

@ -4,6 +4,20 @@ module DiscourseAi
module Embeddings
module VectorRepresentations
class TextEmbeddingAda002 < Base
class << self
def name
"text-embedding-ada-002"
end
def correctly_configured?
SiteSetting.ai_openai_api_key.present?
end
def dependant_setting_names
%w[ai_openai_api_key]
end
end
def id
2
end
@ -12,10 +26,6 @@ module DiscourseAi
1
end
def name
"text-embedding-ada-002"
end
def dimensions
1536
end
@ -33,7 +43,7 @@ module DiscourseAi
end
def vector_from(text)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: name)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name)
response[:data].first[:embedding]
end

View File

@ -25,6 +25,7 @@ RSpec.describe Jobs::EmbeddingsBackfill do
end
it "backfills topics based on bumped_at date" do
SiteSetting.ai_embeddings_model = "bge-large-en"
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_backfill_batch_size = 1

View File

@ -70,6 +70,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
SiteSetting.ai_embeddings_semantic_search_hyde_model = "fake:fake"
SiteSetting.ai_embeddings_semantic_search_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_model = "bge-large-en"
hyde_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service(

View File

@ -13,6 +13,8 @@ describe DiscourseAi::Embeddings::EntryPoint do
)
end
before { SiteSetting.ai_embeddings_model = "bge-large-en" }
it "queues a job on create if embeddings is enabled" do
SiteSetting.ai_embeddings_enabled = true

View File

@ -6,8 +6,8 @@ RSpec.describe Jobs::GenerateEmbeddings do
describe "#execute" do
before do
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_model = "bge-large-en"
SiteSetting.ai_embeddings_enabled = true
end
fab!(:topic) { Fabricate(:topic) }
@ -27,7 +27,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding)
job.execute(target_id: topic.id, target_type: "Topic")
@ -39,7 +39,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
text =
truncation.prepare_text_from(post, vector_rep.tokenizer, vector_rep.max_sequence_length - 2)
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(vector_rep.class.name, text, expected_embedding)
job.execute(target_id: post.id, target_type: "Post")

View File

@ -13,7 +13,10 @@ describe DiscourseAi::Embeddings::SemanticRelated do
fab!(:secured_category_topic) { Fabricate(:topic, category: secured_category) }
fab!(:closed_topic) { Fabricate(:topic, closed: true) }
before { SiteSetting.ai_embeddings_semantic_related_topics_enabled = true }
before do
SiteSetting.ai_embeddings_model = "bge-large-en"
SiteSetting.ai_embeddings_semantic_related_topics_enabled = true
end
describe "#related_topic_ids_for" do
context "when embeddings do not exist" do

View File

@ -14,6 +14,7 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do
before do
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_model = "bge-large-en"
hyde_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service(

View File

@ -4,6 +4,8 @@ describe DiscourseAi::Embeddings::EntryPoint do
fab!(:user) { Fabricate(:user) }
describe "SemanticTopicQuery extension" do
before { SiteSetting.ai_embeddings_model = "bge-large-en" }
describe "#list_semantic_related_topics" do
subject(:topic_query) { DiscourseAi::Embeddings::SemanticTopicQuery.new(user) }

View File

@ -10,7 +10,7 @@ RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2 do
before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(described_class.name, text, expected_embedding)
end
it_behaves_like "generates and store embedding using with vector representation"

View File

@ -11,7 +11,7 @@ RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Lar
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.discourse_service(
vector_rep.name,
described_class.name,
"query: #{text}",
expected_embedding,
)

View File

@ -8,7 +8,7 @@ RSpec.describe DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda0
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(vector_rep.name, text, expected_embedding)
EmbeddingsGenerationStubs.openai_service(described_class.name, text, expected_embedding)
end
it_behaves_like "generates and store embedding using with vector representation"

View File

@ -327,7 +327,10 @@ RSpec.describe "AI Composer helper", type: :system, js: true do
end
context "when suggesting the category with AI category suggester" do
before { SiteSetting.ai_embeddings_enabled = true }
before do
SiteSetting.ai_embeddings_model = "bge-large-en"
SiteSetting.ai_embeddings_enabled = true
end
it "updates the category with the suggested category" do
response =
@ -352,7 +355,10 @@ RSpec.describe "AI Composer helper", type: :system, js: true do
end
context "when suggesting the tags with AI tag suggester" do
before { SiteSetting.ai_embeddings_enabled = true }
before do
SiteSetting.ai_embeddings_model = "bge-large-en"
SiteSetting.ai_embeddings_enabled = true
end
it "updates the tag with the suggested tag" do
response =

View File

@ -80,7 +80,10 @@ RSpec.describe "AI Post helper", type: :system, js: true do
end
context "when suggesting categories with AI category suggester" do
before { SiteSetting.ai_embeddings_enabled = true }
before do
SiteSetting.ai_embeddings_model = "bge-large-en"
SiteSetting.ai_embeddings_enabled = true
end
skip "TODO: Category suggester only loading one category in test" do
it "updates the category with the suggested category" do
@ -108,7 +111,10 @@ RSpec.describe "AI Post helper", type: :system, js: true do
end
context "when suggesting tags with AI tag suggester" do
before { SiteSetting.ai_embeddings_enabled = true }
before do
SiteSetting.ai_embeddings_model = "bge-large-en"
SiteSetting.ai_embeddings_enabled = true
end
it "update the tag with the suggested tag" do
response =

View File

@ -11,6 +11,7 @@ RSpec.describe "AI Composer helper", type: :system, js: true do
before do
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_model = "bge-large-en"
prompt = DiscourseAi::Embeddings::HydeGenerators::OpenAi.new.prompt(query)
OpenAiCompletionsInferenceStubs.stub_response(
prompt,