From f5cf1019fb6267b87f7b78b46d31ee5f0cdd9d24 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Tue, 21 Jan 2025 12:23:19 -0300 Subject: [PATCH] FEATURE: configurable embeddings (#1049) * Use AR model for embeddings features * endpoints * Embeddings CRUD UI * Add presets. Hide a couple more settings * system specs * Seed embedding definition from old settings * Generate search bit index on the fly. cleanup orphaned data * support for seeded models * Fix run test for new embedding * fix selected model not set correctly --- ...ugins-show-discourse-ai-embeddings-edit.js | 21 + ...lugins-show-discourse-ai-embeddings-new.js | 17 + ...in-plugins-show-discourse-ai-embeddings.js | 7 + .../show/discourse-ai-embeddings/edit.hbs | 4 + .../show/discourse-ai-embeddings/index.hbs | 1 + .../show/discourse-ai-embeddings/new.hbs | 4 + .../admin/ai_embeddings_controller.rb | 130 ++++++ app/jobs/regular/digest_rag_upload.rb | 2 +- .../manage_embedding_def_search_index.rb | 13 + .../scheduled/remove_orphaned_embeddings.rb | 11 + app/models/embedding_definition.rb | 231 +++++++++++ .../ai_embedding_definition_serializer.rb | 29 ++ .../admin-discourse-ai-plugin-route-map.js | 9 + .../discourse/admin/adapters/ai-embedding.js | 21 + .../discourse/admin/models/ai-embedding.js | 38 ++ .../components/ai-embedding-editor.gjs | 376 ++++++++++++++++++ .../components/ai-embeddings-list-editor.gjs | 114 ++++++ .../admin-plugin-configuration-nav.js | 4 + .../common/ai-embedding-editor.scss | 26 ++ config/locales/client.en.yml | 43 ++ config/locales/server.en.yml | 19 +- config/routes.rb | 7 + config/settings.yml | 69 ++-- ...1217164540_create_embedding_definitions.rb | 19 + ...0114305_embedding_config_data_migration.rb | 204 ++++++++++ lib/ai_bot/entry_point.rb | 2 +- lib/ai_bot/personas/persona.rb | 2 +- lib/ai_helper/semantic_categorizer.rb | 2 +- .../embedding_defs_enumerator.rb | 20 + lib/configuration/embedding_defs_validator.rb | 19 + .../embeddings_module_validator.rb | 34 +- lib/embeddings/schema.rb | 91 ++++- lib/embeddings/semantic_search.rb | 4 +- lib/embeddings/vector.rb | 11 +- .../all_mpnet_base_v2.rb | 56 --- lib/embeddings/vector_representations/base.rb | 103 ----- .../vector_representations/bge_large_en.rb | 80 ---- .../vector_representations/bge_m3.rb | 51 --- .../vector_representations/gemini.rb | 54 --- .../multilingual_e5_large.rb | 88 ---- .../text_embedding_3_large.rb | 56 --- .../text_embedding_3_small.rb | 51 --- .../text_embedding_ada_002.rb | 51 --- lib/inference/cloudflare_workers_ai.rb | 19 +- lib/inference/discourse_classifier.rb | 47 --- lib/inference/gemini_embeddings.rb | 14 +- lib/inference/hugging_face_text_embeddings.rb | 15 +- lib/inference/open_ai_embeddings.rb | 8 +- plugin.rb | 1 + .../embeddings_model_validator_spec.rb | 17 - .../embedding_definition_fabricator.rb | 40 ++ spec/jobs/regular/digest_rag_upload_spec.rb | 16 +- .../regular/generate_rag_embeddings_spec.rb | 21 +- .../manage_embedding_def_search_index_spec.rb | 56 +++ .../scheduled/embeddings_backfill_spec.rb | 12 +- .../remove_orphaned_embeddings_spec.rb | 71 ++++ .../embeddings_module_validator_spec.rb | 15 - .../inference/cloudflare_workers_ai_spec.rb | 2 +- .../modules/ai_bot/personas/persona_spec.rb | 19 +- spec/lib/modules/ai_bot/tools/search_spec.rb | 13 +- .../ai_helper/semantic_categorizer_spec.rb | 12 +- .../jobs/generate_embeddings_spec.rb | 13 +- spec/lib/modules/embeddings/schema_spec.rb | 6 +- .../embeddings/semantic_related_spec.rb | 22 +- .../embeddings/semantic_search_spec.rb | 18 +- .../embeddings/semantic_topic_query_spec.rb | 4 + .../embeddings/strategies/truncation_spec.rb | 8 +- spec/lib/modules/embeddings/vector_spec.rb | 87 ++-- spec/models/ai_tool_spec.rb | 10 +- spec/models/rag_document_fragment_spec.rb | 28 +- .../admin/ai_embeddings_controller_spec.rb | 184 +++++++++ .../admin/ai_personas_controller_spec.rb | 1 - .../rag_document_fragments_controller_spec.rb | 5 +- .../embeddings/embeddings_controller_spec.rb | 14 +- .../inference/openai_embeddings_spec.rb | 34 +- spec/support/embeddings_generation_stubs.rb | 19 +- .../ai_embedding_definition_spec.rb | 87 ++++ .../system/embeddings/semantic_search_spec.rb | 7 +- 78 files changed, 2131 insertions(+), 1008 deletions(-) create mode 100644 admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-embeddings-edit.js create mode 100644 admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-embeddings-new.js create mode 100644 admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-embeddings.js create mode 100644 admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-embeddings/edit.hbs create mode 100644 admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-embeddings/index.hbs create mode 100644 admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-embeddings/new.hbs create mode 100644 app/controllers/discourse_ai/admin/ai_embeddings_controller.rb create mode 100644 app/jobs/regular/manage_embedding_def_search_index.rb create mode 100644 app/jobs/scheduled/remove_orphaned_embeddings.rb create mode 100644 app/models/embedding_definition.rb create mode 100644 app/serializers/ai_embedding_definition_serializer.rb create mode 100644 assets/javascripts/discourse/admin/adapters/ai-embedding.js create mode 100644 assets/javascripts/discourse/admin/models/ai-embedding.js create mode 100644 assets/javascripts/discourse/components/ai-embedding-editor.gjs create mode 100644 assets/javascripts/discourse/components/ai-embeddings-list-editor.gjs create mode 100644 assets/stylesheets/modules/embeddings/common/ai-embedding-editor.scss create mode 100644 db/migrate/20241217164540_create_embedding_definitions.rb create mode 100644 db/migrate/20250110114305_embedding_config_data_migration.rb create mode 100644 lib/configuration/embedding_defs_enumerator.rb create mode 100644 lib/configuration/embedding_defs_validator.rb delete mode 100644 lib/embeddings/vector_representations/all_mpnet_base_v2.rb delete mode 100644 lib/embeddings/vector_representations/base.rb delete mode 100644 lib/embeddings/vector_representations/bge_large_en.rb delete mode 100644 lib/embeddings/vector_representations/bge_m3.rb delete mode 100644 lib/embeddings/vector_representations/gemini.rb delete mode 100644 lib/embeddings/vector_representations/multilingual_e5_large.rb delete mode 100644 lib/embeddings/vector_representations/text_embedding_3_large.rb delete mode 100644 lib/embeddings/vector_representations/text_embedding_3_small.rb delete mode 100644 lib/embeddings/vector_representations/text_embedding_ada_002.rb delete mode 100644 lib/inference/discourse_classifier.rb delete mode 100644 spec/configuration/embeddings_model_validator_spec.rb create mode 100644 spec/fabricators/embedding_definition_fabricator.rb create mode 100644 spec/jobs/regular/manage_embedding_def_search_index_spec.rb create mode 100644 spec/jobs/scheduled/remove_orphaned_embeddings_spec.rb delete mode 100644 spec/lib/configuration/embeddings_module_validator_spec.rb create mode 100644 spec/requests/admin/ai_embeddings_controller_spec.rb create mode 100644 spec/system/embeddings/ai_embedding_definition_spec.rb diff --git a/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-embeddings-edit.js b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-embeddings-edit.js new file mode 100644 index 00000000..46a7bc30 --- /dev/null +++ b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-embeddings-edit.js @@ -0,0 +1,21 @@ +import DiscourseRoute from "discourse/routes/discourse"; + +export default class AdminPluginsShowDiscourseAiEmbeddingsEdit extends DiscourseRoute { + async model(params) { + const allEmbeddings = this.modelFor( + "adminPlugins.show.discourse-ai-embeddings" + ); + const id = parseInt(params.id, 10); + const record = allEmbeddings.findBy("id", id); + record.provider_params = record.provider_params || {}; + return record; + } + + setupController(controller, model) { + super.setupController(controller, model); + controller.set( + "allEmbeddings", + this.modelFor("adminPlugins.show.discourse-ai-embeddings") + ); + } +} diff --git a/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-embeddings-new.js b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-embeddings-new.js new file mode 100644 index 00000000..ea8a9239 --- /dev/null +++ b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-embeddings-new.js @@ -0,0 +1,17 @@ +import DiscourseRoute from "discourse/routes/discourse"; + +export default class AdminPluginsShowDiscourseAiEmbeddingsNew extends DiscourseRoute { + async model() { + const record = this.store.createRecord("ai-embedding"); + record.provider_params = {}; + return record; + } + + setupController(controller, model) { + super.setupController(controller, model); + controller.set( + "allEmbeddings", + this.modelFor("adminPlugins.show.discourse-ai-embeddings") + ); + } +} diff --git a/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-embeddings.js b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-embeddings.js new file mode 100644 index 00000000..475e6f88 --- /dev/null +++ b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-embeddings.js @@ -0,0 +1,7 @@ +import DiscourseRoute from "discourse/routes/discourse"; + +export default class DiscourseAiAiEmbeddingsRoute extends DiscourseRoute { + model() { + return this.store.findAll("ai-embedding"); + } +} diff --git a/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-embeddings/edit.hbs b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-embeddings/edit.hbs new file mode 100644 index 00000000..8ec8776f --- /dev/null +++ b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-embeddings/edit.hbs @@ -0,0 +1,4 @@ + \ No newline at end of file diff --git a/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-embeddings/index.hbs b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-embeddings/index.hbs new file mode 100644 index 00000000..8226d03c --- /dev/null +++ b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-embeddings/index.hbs @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-embeddings/new.hbs b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-embeddings/new.hbs new file mode 100644 index 00000000..8ec8776f --- /dev/null +++ b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-embeddings/new.hbs @@ -0,0 +1,4 @@ + \ No newline at end of file diff --git a/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb b/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb new file mode 100644 index 00000000..7000aea4 --- /dev/null +++ b/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb @@ -0,0 +1,130 @@ +# frozen_string_literal: true + +module DiscourseAi + module Admin + class AiEmbeddingsController < ::Admin::AdminController + requires_plugin ::DiscourseAi::PLUGIN_NAME + + def index + embedding_defs = EmbeddingDefinition.all.order(:display_name) + + render json: { + ai_embeddings: + ActiveModel::ArraySerializer.new( + embedding_defs, + each_serializer: AiEmbeddingDefinitionSerializer, + root: false, + ).as_json, + meta: { + provider_params: EmbeddingDefinition.provider_params, + providers: EmbeddingDefinition.provider_names, + distance_functions: EmbeddingDefinition.distance_functions, + tokenizers: + EmbeddingDefinition.tokenizer_names.map { |tn| + { id: tn, name: tn.split("::").last } + }, + presets: EmbeddingDefinition.presets, + }, + } + end + + def new + end + + def edit + embedding_def = EmbeddingDefinition.find(params[:id]) + render json: AiEmbeddingDefinitionSerializer.new(embedding_def) + end + + def create + embedding_def = EmbeddingDefinition.new(ai_embeddings_params) + + if embedding_def.save + render json: AiEmbeddingDefinitionSerializer.new(embedding_def), status: :created + else + render_json_error embedding_def + end + end + + def update + embedding_def = EmbeddingDefinition.find(params[:id]) + + if embedding_def.seeded? + return( + render_json_error(I18n.t("discourse_ai.embeddings.cannot_edit_builtin"), status: 403) + ) + end + + if embedding_def.update(ai_embeddings_params.except(:dimensions)) + render json: AiEmbeddingDefinitionSerializer.new(embedding_def) + else + render_json_error embedding_def + end + end + + def destroy + embedding_def = EmbeddingDefinition.find(params[:id]) + + if embedding_def.seeded? + return( + render_json_error(I18n.t("discourse_ai.embeddings.cannot_edit_builtin"), status: 403) + ) + end + + if embedding_def.id == SiteSetting.ai_embeddings_selected_model.to_i + return render_json_error(I18n.t("discourse_ai.embeddings.delete_failed"), status: 409) + end + + if embedding_def.destroy + head :no_content + else + render_json_error embedding_def + end + end + + def test + RateLimiter.new( + current_user, + "ai_embeddings_test_#{current_user.id}", + 3, + 1.minute, + ).performed! + + embedding_def = EmbeddingDefinition.new(ai_embeddings_params) + DiscourseAi::Embeddings::Vector.new(embedding_def).vector_from("this is a test") + + render json: { success: true } + rescue Net::HTTPBadResponse => e + render json: { success: false, error: e.message } + end + + private + + def ai_embeddings_params + permitted = + params.require(:ai_embedding).permit( + :display_name, + :dimensions, + :max_sequence_length, + :pg_function, + :provider, + :url, + :api_key, + :tokenizer_class, + ) + + extra_field_names = EmbeddingDefinition.provider_params.dig(permitted[:provider]&.to_sym) + if extra_field_names.present? + received_prov_params = + params.dig(:ai_embedding, :provider_params)&.slice(*extra_field_names.keys) + + if received_prov_params.present? + permitted[:provider_params] = received_prov_params.permit! + end + end + + permitted + end + end + end +end diff --git a/app/jobs/regular/digest_rag_upload.rb b/app/jobs/regular/digest_rag_upload.rb index 76b9ee65..bfc2ac4b 100644 --- a/app/jobs/regular/digest_rag_upload.rb +++ b/app/jobs/regular/digest_rag_upload.rb @@ -18,7 +18,7 @@ module ::Jobs target = target_type.constantize.find_by(id: target_id) return if !target - vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation + vector_rep = DiscourseAi::Embeddings::Vector.instance tokenizer = vector_rep.tokenizer chunk_tokens = target.rag_chunk_tokens diff --git a/app/jobs/regular/manage_embedding_def_search_index.rb b/app/jobs/regular/manage_embedding_def_search_index.rb new file mode 100644 index 00000000..5baa526c --- /dev/null +++ b/app/jobs/regular/manage_embedding_def_search_index.rb @@ -0,0 +1,13 @@ +# frozen_string_literal: true + +module ::Jobs + class ManageEmbeddingDefSearchIndex < ::Jobs::Base + def execute(args) + embedding_def = EmbeddingDefinition.find_by(id: args[:id]) + return if embedding_def.nil? + return if DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_def) + + DiscourseAi::Embeddings::Schema.prepare_search_indexes(embedding_def) + end + end +end diff --git a/app/jobs/scheduled/remove_orphaned_embeddings.rb b/app/jobs/scheduled/remove_orphaned_embeddings.rb new file mode 100644 index 00000000..d5fddcf6 --- /dev/null +++ b/app/jobs/scheduled/remove_orphaned_embeddings.rb @@ -0,0 +1,11 @@ +# frozen_string_literal: true + +module Jobs + class RemoveOrphanedEmbeddings < ::Jobs::Scheduled + every 1.week + + def execute(_args) + DiscourseAi::Embeddings::Schema.remove_orphaned_data + end + end +end diff --git a/app/models/embedding_definition.rb b/app/models/embedding_definition.rb new file mode 100644 index 00000000..48fdf3a5 --- /dev/null +++ b/app/models/embedding_definition.rb @@ -0,0 +1,231 @@ +# frozen_string_literal: true + +class EmbeddingDefinition < ActiveRecord::Base + CLOUDFLARE = "cloudflare" + HUGGING_FACE = "hugging_face" + OPEN_AI = "open_ai" + GOOGLE = "google" + + class << self + def provider_names + [CLOUDFLARE, HUGGING_FACE, OPEN_AI, GOOGLE] + end + + def distance_functions + %w[<#> <=>] + end + + def tokenizer_names + [ + DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer, + DiscourseAi::Tokenizer::BgeLargeEnTokenizer, + DiscourseAi::Tokenizer::BgeM3Tokenizer, + DiscourseAi::Tokenizer::OpenAiTokenizer, + DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer, + DiscourseAi::Tokenizer::OpenAiTokenizer, + ].map(&:name) + end + + def provider_params + { open_ai: { model_name: :text } } + end + + def presets + @presets ||= + begin + [ + { + preset_id: "bge-large-en", + display_name: "bge-large-en", + dimensions: 1024, + max_sequence_length: 512, + pg_function: "<#>", + tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer", + provider: HUGGING_FACE, + }, + { + preset_id: "bge-m3", + display_name: "bge-m3", + dimensions: 1024, + max_sequence_length: 8192, + pg_function: "<#>", + tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer", + provider: HUGGING_FACE, + }, + { + preset_id: "gemini-embedding-001", + display_name: "Gemini's embedding-001", + dimensions: 768, + max_sequence_length: 1536, + pg_function: "<=>", + url: + "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent", + tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer", + provider: GOOGLE, + }, + { + preset_id: "multilingual-e5-large", + display_name: "multilingual-e5-large", + dimensions: 1024, + max_sequence_length: 512, + pg_function: "<=>", + tokenizer_class: "DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer", + provider: HUGGING_FACE, + }, + { + preset_id: "text-embedding-3-large", + display_name: "OpenAI's text-embedding-3-large", + dimensions: 2000, + max_sequence_length: 8191, + pg_function: "<=>", + tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer", + url: "https://api.openai.com/v1/embeddings", + provider: OPEN_AI, + provider_params: { + model_name: "text-embedding-3-large", + }, + }, + { + preset_id: "text-embedding-3-small", + display_name: "OpenAI's text-embedding-3-small", + dimensions: 1536, + max_sequence_length: 8191, + pg_function: "<=>", + tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer", + url: "https://api.openai.com/v1/embeddings", + provider: OPEN_AI, + provider_params: { + model_name: "text-embedding-3-small", + }, + }, + { + preset_id: "text-embedding-ada-002", + display_name: "OpenAI's text-embedding-ada-002", + dimensions: 1536, + max_sequence_length: 8191, + pg_function: "<=>", + tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer", + url: "https://api.openai.com/v1/embeddings", + provider: OPEN_AI, + provider_params: { + model_name: "text-embedding-ada-002", + }, + }, + ] + end + end + end + + validates :provider, presence: true, inclusion: provider_names + validates :display_name, presence: true, length: { maximum: 100 } + validates :tokenizer_class, presence: true, inclusion: tokenizer_names + validates_presence_of :url, :api_key, :dimensions, :max_sequence_length, :pg_function + + after_create :create_indexes + + def create_indexes + Jobs.enqueue(:manage_embedding_def_search_index, id: self.id) + end + + def tokenizer + tokenizer_class.constantize + end + + def inference_client + case provider + when CLOUDFLARE + cloudflare_client + when HUGGING_FACE + hugging_face_client + when OPEN_AI + open_ai_client + when GOOGLE + gemini_client + else + raise "Uknown embeddings provider" + end + end + + def lookup_custom_param(key) + provider_params&.dig(key) + end + + def endpoint_url + return url if !url.starts_with?("srv://") + + service = DiscourseAi::Utils::DnsSrv.lookup(url.sub("srv://", "")) + "https://#{service.target}:#{service.port}" + end + + def prepare_query_text(text, asymetric: false) + strategy.prepare_query_text(text, self, asymetric: asymetric) + end + + def prepare_target_text(target) + strategy.prepare_target_text(target, self) + end + + def strategy_id + strategy.id + end + + def strategy_version + strategy.version + end + + def api_key + if seeded? + env_key = "DISCOURSE_AI_SEEDED_EMBEDDING_API_KEY" + ENV[env_key] || self[:api_key] + else + self[:api_key] + end + end + + private + + def strategy + @strategy ||= DiscourseAi::Embeddings::Strategies::Truncation.new + end + + def cloudflare_client + DiscourseAi::Inference::CloudflareWorkersAi.new(endpoint_url, api_key) + end + + def hugging_face_client + DiscourseAi::Inference::HuggingFaceTextEmbeddings.new(endpoint_url, api_key) + end + + def open_ai_client + DiscourseAi::Inference::OpenAiEmbeddings.new( + endpoint_url, + api_key, + lookup_custom_param("model_name"), + dimensions, + ) + end + + def gemini_client + DiscourseAi::Inference::GeminiEmbeddings.new(endpoint_url, api_key) + end +end + +# == Schema Information +# +# Table name: embedding_definitions +# +# id :bigint not null, primary key +# display_name :string not null +# dimensions :integer not null +# max_sequence_length :integer not null +# version :integer default(1), not null +# pg_function :string not null +# provider :string not null +# tokenizer_class :string not null +# url :string not null +# api_key :string +# seeded :boolean default(FALSE), not null +# provider_params :jsonb +# created_at :datetime not null +# updated_at :datetime not null +# diff --git a/app/serializers/ai_embedding_definition_serializer.rb b/app/serializers/ai_embedding_definition_serializer.rb new file mode 100644 index 00000000..53a75d47 --- /dev/null +++ b/app/serializers/ai_embedding_definition_serializer.rb @@ -0,0 +1,29 @@ +# frozen_string_literal: true + +class AiEmbeddingDefinitionSerializer < ApplicationSerializer + root "ai_embedding" + + attributes :id, + :display_name, + :dimensions, + :max_sequence_length, + :pg_function, + :provider, + :url, + :api_key, + :seeded, + :tokenizer_class, + :provider_params + + def api_key + object.seeded? ? "********" : object.api_key + end + + def url + object.seeded? ? "********" : object.url + end + + def provider + object.seeded? ? "CDCK" : object.provider + end +end diff --git a/assets/javascripts/discourse/admin-discourse-ai-plugin-route-map.js b/assets/javascripts/discourse/admin-discourse-ai-plugin-route-map.js index e4c49570..3e798a18 100644 --- a/assets/javascripts/discourse/admin-discourse-ai-plugin-route-map.js +++ b/assets/javascripts/discourse/admin-discourse-ai-plugin-route-map.js @@ -20,5 +20,14 @@ export default { }); this.route("discourse-ai-spam", { path: "ai-spam" }); this.route("discourse-ai-usage", { path: "ai-usage" }); + + this.route( + "discourse-ai-embeddings", + { path: "ai-embeddings" }, + function () { + this.route("new"); + this.route("edit", { path: "/:id/edit" }); + } + ); }, }; diff --git a/assets/javascripts/discourse/admin/adapters/ai-embedding.js b/assets/javascripts/discourse/admin/adapters/ai-embedding.js new file mode 100644 index 00000000..5aa1b48c --- /dev/null +++ b/assets/javascripts/discourse/admin/adapters/ai-embedding.js @@ -0,0 +1,21 @@ +import RestAdapter from "discourse/adapters/rest"; + +export default class Adapter extends RestAdapter { + jsonMode = true; + + basePath() { + return "/admin/plugins/discourse-ai/"; + } + + pathFor(store, type, findArgs) { + // removes underscores which are implemented in base + let path = + this.basePath(store, type, findArgs) + + store.pluralize(this.apiNameFor(type)); + return this.appendQueryParams(path, findArgs); + } + + apiNameFor() { + return "ai-embedding"; + } +} diff --git a/assets/javascripts/discourse/admin/models/ai-embedding.js b/assets/javascripts/discourse/admin/models/ai-embedding.js new file mode 100644 index 00000000..b1896afa --- /dev/null +++ b/assets/javascripts/discourse/admin/models/ai-embedding.js @@ -0,0 +1,38 @@ +import { ajax } from "discourse/lib/ajax"; +import RestModel from "discourse/models/rest"; + +export default class AiEmbedding extends RestModel { + createProperties() { + return this.getProperties( + "id", + "display_name", + "dimensions", + "provider", + "tokenizer_class", + "dimensions", + "url", + "api_key", + "max_sequence_length", + "provider_params", + "pg_function" + ); + } + + updateProperties() { + const attrs = this.createProperties(); + attrs.id = this.id; + + return attrs; + } + + async testConfig() { + return await ajax(`/admin/plugins/discourse-ai/ai-embeddings/test.json`, { + data: { ai_embedding: this.createProperties() }, + }); + } + + workingCopy() { + const attrs = this.createProperties(); + return this.store.createRecord("ai-embedding", attrs); + } +} diff --git a/assets/javascripts/discourse/components/ai-embedding-editor.gjs b/assets/javascripts/discourse/components/ai-embedding-editor.gjs new file mode 100644 index 00000000..f6a98b83 --- /dev/null +++ b/assets/javascripts/discourse/components/ai-embedding-editor.gjs @@ -0,0 +1,376 @@ +import Component from "@glimmer/component"; +import { tracked } from "@glimmer/tracking"; +import { Input } from "@ember/component"; +import { concat, get } from "@ember/helper"; +import { on } from "@ember/modifier"; +import { action, computed } from "@ember/object"; +import didInsert from "@ember/render-modifiers/modifiers/did-insert"; +import didUpdate from "@ember/render-modifiers/modifiers/did-update"; +import { later } from "@ember/runloop"; +import { service } from "@ember/service"; +import BackButton from "discourse/components/back-button"; +import DButton from "discourse/components/d-button"; +import icon from "discourse/helpers/d-icon"; +import { popupAjaxError } from "discourse/lib/ajax-error"; +import { i18n } from "discourse-i18n"; +import ComboBox from "select-kit/components/combo-box"; +import DTooltip from "float-kit/components/d-tooltip"; +import not from "truth-helpers/helpers/not"; + +export default class AiEmbeddingEditor extends Component { + @service toasts; + @service router; + @service dialog; + @service store; + + @tracked isSaving = false; + @tracked selectedPreset = null; + + @tracked testRunning = false; + @tracked testResult = null; + @tracked testError = null; + @tracked apiKeySecret = true; + @tracked editingModel = null; + + get selectedProviders() { + const t = (provName) => { + return i18n(`discourse_ai.embeddings.providers.${provName}`); + }; + + return this.args.embeddings.resultSetMeta.providers.map((prov) => { + return { id: prov, name: t(prov) }; + }); + } + + get distanceFunctions() { + const t = (df) => { + return i18n(`discourse_ai.embeddings.distance_functions.${df}`); + }; + + return this.args.embeddings.resultSetMeta.distance_functions.map((df) => { + return { id: df, name: t(df) }; + }); + } + + get presets() { + const presets = this.args.embeddings.resultSetMeta.presets.map((preset) => { + return { + name: preset.display_name, + id: preset.preset_id, + }; + }); + + presets.pushObject({ + name: i18n("discourse_ai.embeddings.configure_manually"), + id: "manual", + }); + + return presets; + } + + get showPresets() { + return !this.selectedPreset && this.args.model.isNew; + } + + @computed("editingModel.provider") + get metaProviderParams() { + return ( + this.args.embeddings.resultSetMeta.provider_params[ + this.editingModel?.provider + ] || {} + ); + } + + get testErrorMessage() { + return i18n("discourse_ai.llms.tests.failure", { error: this.testError }); + } + + get displayTestResult() { + return this.testRunning || this.testResult !== null; + } + + @action + configurePreset() { + this.selectedPreset = + this.args.embeddings.resultSetMeta.presets.findBy( + "preset_id", + this.presetId + ) || {}; + + this.editingModel = this.store + .createRecord("ai-embedding", this.selectedPreset) + .workingCopy(); + } + + @action + updateModel() { + this.editingModel = this.args.model.workingCopy(); + } + + @action + makeApiKeySecret() { + this.apiKeySecret = true; + } + + @action + toggleApiKeySecret() { + this.apiKeySecret = !this.apiKeySecret; + } + + @action + async save() { + this.isSaving = true; + const isNew = this.args.model.isNew; + + try { + await this.editingModel.save(); + + if (isNew) { + this.args.embeddings.addObject(this.editingModel); + this.router.transitionTo( + "adminPlugins.show.discourse-ai-embeddings.index" + ); + } else { + this.toasts.success({ + data: { message: i18n("discourse_ai.embeddings.saved") }, + duration: 2000, + }); + } + } catch (e) { + popupAjaxError(e); + } finally { + later(() => { + this.isSaving = false; + }, 1000); + } + } + + @action + async test() { + this.testRunning = true; + + try { + const configTestResult = await this.editingModel.testConfig(); + this.testResult = configTestResult.success; + + if (this.testResult) { + this.testError = null; + } else { + this.testError = configTestResult.error; + } + } catch (e) { + popupAjaxError(e); + } finally { + later(() => { + this.testRunning = false; + }, 1000); + } + } + + @action + delete() { + return this.dialog.confirm({ + message: i18n("discourse_ai.embeddings.confirm_delete"), + didConfirm: () => { + return this.args.model + .destroyRecord() + .then(() => { + this.args.llms.removeObject(this.args.model); + this.router.transitionTo( + "adminPlugins.show.discourse-ai-embeddings.index" + ); + }) + .catch(popupAjaxError); + }, + }); + } + + +} diff --git a/assets/javascripts/discourse/components/ai-embeddings-list-editor.gjs b/assets/javascripts/discourse/components/ai-embeddings-list-editor.gjs new file mode 100644 index 00000000..473f9fae --- /dev/null +++ b/assets/javascripts/discourse/components/ai-embeddings-list-editor.gjs @@ -0,0 +1,114 @@ +import Component from "@glimmer/component"; +import { concat } from "@ember/helper"; +import { service } from "@ember/service"; +import DBreadcrumbsItem from "discourse/components/d-breadcrumbs-item"; +import DButton from "discourse/components/d-button"; +import DPageSubheader from "discourse/components/d-page-subheader"; +import { i18n } from "discourse-i18n"; +import AdminConfigAreaEmptyList from "admin/components/admin-config-area-empty-list"; +import DTooltip from "float-kit/components/d-tooltip"; +import AiEmbeddingEditor from "./ai-embedding-editor"; + +export default class AiEmbeddingsListEditor extends Component { + @service adminPluginNavManager; + + get hasEmbeddingElements() { + return this.args.embeddings.length !== 0; + } + + +} diff --git a/assets/javascripts/initializers/admin-plugin-configuration-nav.js b/assets/javascripts/initializers/admin-plugin-configuration-nav.js index 119744a1..54499728 100644 --- a/assets/javascripts/initializers/admin-plugin-configuration-nav.js +++ b/assets/javascripts/initializers/admin-plugin-configuration-nav.js @@ -12,6 +12,10 @@ export default { withPluginApi("1.1.0", (api) => { api.addAdminPluginConfigurationNav("discourse-ai", PLUGIN_NAV_MODE_TOP, [ + { + label: "discourse_ai.embeddings.short_title", + route: "adminPlugins.show.discourse-ai-embeddings", + }, { label: "discourse_ai.llms.short_title", route: "adminPlugins.show.discourse-ai-llms", diff --git a/assets/stylesheets/modules/embeddings/common/ai-embedding-editor.scss b/assets/stylesheets/modules/embeddings/common/ai-embedding-editor.scss new file mode 100644 index 00000000..29af6bcd --- /dev/null +++ b/assets/stylesheets/modules/embeddings/common/ai-embedding-editor.scss @@ -0,0 +1,26 @@ +.ai-embedding-editor { + padding-left: 0.5em; + + .ai-embedding-editor-input { + width: 350px; + } + + .ai-embedding-editor-tests { + &__failure { + color: var(--danger); + } + + &__success { + color: var(--success); + } + } + + &__api-key { + margin-right: 0.5em; + } + + &__secret-api-key-group { + display: flex; + align-items: center; + } +} diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 8d3a2ab5..bf2af706 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -502,6 +502,49 @@ en: accuracy: "Accuracy:" embeddings: + short_title: "Embeddings" + description: "Embeddings are a crucial component of the Discourse AI plugin, enabling features like related topics and semantic search." + new: "New embedding" + back: "Back" + save: "Save" + saved: "Embedding configuration saved" + delete: "Delete" + confirm_delete: Are you sure you want to remove this embedding configuration? + empty: "You haven't setup embeddings yet" + presets: "Select a preset..." + configure_manually: "Configure manually" + edit: "Edit" + seeded_warning: "This is pre-configured on your site and cannot be edited." + tests: + title: "Run test" + running: "Running test..." + success: "Success!" + failure: "Attempting to generate an embedding resulted in: %{error}" + hints: + dimensions_warning: "Once saved, this value can't be changed." + + display_name: "Name" + provider: "Provider" + url: "Embeddings service URL" + api_key: "Embeddings service API Key" + tokenizer: "Tokenizer" + dimensions: "Embedding dimensions" + max_sequence_length: "Sequence length" + + distance_function: "Distance function" + distance_functions: + <#>: "Negative inner product (<#>)" + <=>: "Cosine distance (<=>)" + providers: + hugging_face: "Hugging Face" + open_ai: "OpenAI" + google: "Google" + cloudflare: "Cloudflare" + CDCK: "CDCK" + provider_fields: + model_name: "Model name" + + semantic_search: "Topics (Semantic)" semantic_search_loading: "Searching for more results using AI" semantic_search_results: diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 616bca3a..5b4dee40 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -49,10 +49,7 @@ en: ai_nsfw_flag_threshold_sexy: "Threshold for an image classified as sexy to be considered NSFW." ai_nsfw_models: "Models to use for NSFW inference." - ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)" - ai_openai_api_key: "API key for OpenAI API. ONLY used for embeddings and Dall-E. For GPT use the LLM config tab" - ai_hugging_face_tei_endpoint: URL where the API is running for the Hugging Face text embeddings inference - ai_hugging_face_tei_api_key: API key for Hugging Face text embeddings inference + ai_openai_api_key: "API key for OpenAI API. ONLY used for Dall-E. For GPT use the LLM config tab" ai_helper_enabled: "Enable the AI helper." composer_ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer." @@ -67,15 +64,11 @@ en: ai_helper_image_caption_model: "Select the model to use for generating image captions" ai_auto_image_caption_allowed_groups: "Users on these groups can toggle automatic image captioning." - ai_embeddings_enabled: "Enable the embeddings module." - ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module" - ai_embeddings_discourse_service_api_key: "API key for the embeddings API" - ai_embeddings_model: "Use all-mpnet-base-v2 for local and fast inference in english, text-embedding-ada-002 to use OpenAI API (need API key) and multilingual-e5-large for local multilingual embeddings" + ai_embeddings_selected_model: "Use the selected model for generating embeddings." ai_embeddings_generate_for_pms: "Generate embeddings for personal messages." ai_embeddings_semantic_related_topics_enabled: "Use Semantic Search for related topics." ai_embeddings_semantic_related_topics: "Maximum number of topics to show in related topic section." ai_embeddings_backfill_batch_size: "Number of embeddings to backfill every 15 minutes." - ai_embeddings_pg_connection_string: "PostgreSQL connection string for the embeddings module. Needs pgvector extension enabled and a series of tables created. See docs for more info." ai_embeddings_semantic_search_enabled: "Enable full-page semantic search." ai_embeddings_semantic_quick_search_enabled: "Enable semantic search option in search menu popup." ai_embeddings_semantic_related_include_closed_topics: "Include closed topics in semantic search results" @@ -437,13 +430,11 @@ en: cannot_edit_builtin: "You can't edit a built-in model." embeddings: + delete_failed: "This model is currently in use. Update the `ai embeddings selected model` first." + cannot_edit_builtin: "You can't edit a built-in model." configuration: disable_embeddings: "You have to disable 'ai embeddings enabled' first." - choose_model: "Set 'ai embeddings model' first." - model_unreachable: "We failed to generate a test embedding with this model. Check your settings are correct." - hint: - one: "Make sure the `%{settings}` setting was configured." - other: "Make sure the settings of the provider you want were configured. Options are: %{settings}" + choose_model: "Set 'ai embeddings selected model' first." llm_models: missing_provider_param: "%{param} can't be blank" diff --git a/config/routes.rb b/config/routes.rb index 2034d473..44dd0988 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -96,6 +96,13 @@ Discourse::Application.routes.draw do controller: "discourse_ai/admin/ai_llm_quotas", path: "quotas", only: %i[index create update destroy] + + resources :ai_embeddings, + only: %i[index new create edit update destroy], + path: "ai-embeddings", + controller: "discourse_ai/admin/ai_embeddings" do + collection { get :test } + end end end diff --git a/config/settings.yml b/config/settings.yml index b1b86be5..7c0ea066 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -28,11 +28,14 @@ discourse_ai: ai_openai_dall_e_3_url: "https://api.openai.com/v1/images/generations" - ai_openai_embeddings_url: "https://api.openai.com/v1/embeddings" + ai_openai_embeddings_url: + hidden: true + default: "https://api.openai.com/v1/embeddings" ai_openai_organization: default: "" hidden: true ai_openai_api_key: + hidden: true default: "" secret: true ai_stability_api_key: @@ -50,11 +53,14 @@ discourse_ai: - "stable-diffusion-768-v2-1" - "stable-diffusion-v1-5" ai_hugging_face_tei_endpoint: + hidden: true default: "" ai_hugging_face_tei_endpoint_srv: default: "" hidden: true - ai_hugging_face_tei_api_key: "" + ai_hugging_face_tei_api_key: + default: "" + hidden: true ai_hugging_face_tei_reranker_endpoint: default: "" ai_hugging_face_tei_reranker_endpoint_srv: @@ -69,12 +75,14 @@ discourse_ai: ai_cloudflare_workers_account_id: default: "" secret: true + hidden: true ai_cloudflare_workers_api_token: default: "" secret: true + hidden: true ai_gemini_api_key: default: "" - hidden: false + hidden: true ai_strict_token_counting: default: false hidden: true @@ -158,27 +166,12 @@ discourse_ai: default: false client: true validator: "DiscourseAi::Configuration::EmbeddingsModuleValidator" - ai_embeddings_discourse_service_api_endpoint: "" - ai_embeddings_discourse_service_api_endpoint_srv: - default: "" - hidden: true - ai_embeddings_discourse_service_api_key: - default: "" - secret: true - ai_embeddings_model: + ai_embeddings_selected_model: type: enum - default: "bge-large-en" + default: "" allow_any: false - choices: - - all-mpnet-base-v2 - - text-embedding-ada-002 - - text-embedding-3-small - - text-embedding-3-large - - multilingual-e5-large - - bge-large-en - - gemini - - bge-m3 - validator: "DiscourseAi::Configuration::EmbeddingsModelValidator" + enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator" + validator: "DiscourseAi::Configuration::EmbeddingDefsValidator" ai_embeddings_per_post_enabled: default: false hidden: true @@ -191,9 +184,6 @@ discourse_ai: ai_embeddings_backfill_batch_size: default: 250 hidden: true - ai_embeddings_pg_connection_string: - default: "" - hidden: true ai_embeddings_semantic_search_enabled: default: false client: true @@ -213,6 +203,35 @@ discourse_ai: default: false client: true hidden: true + + ai_embeddings_discourse_service_api_endpoint: + default: "" + hidden: true + ai_embeddings_discourse_service_api_endpoint_srv: + default: "" + hidden: true + ai_embeddings_discourse_service_api_key: + hidden: true + default: "" + secret: true + ai_embeddings_model: + hidden: true + type: enum + default: "bge-large-en" + allow_any: false + choices: + - all-mpnet-base-v2 + - text-embedding-ada-002 + - text-embedding-3-small + - text-embedding-3-large + - multilingual-e5-large + - bge-large-en + - gemini + - bge-m3 + ai_embeddings_pg_connection_string: + default: "" + hidden: true + ai_summarization_enabled: default: false client: true diff --git a/db/migrate/20241217164540_create_embedding_definitions.rb b/db/migrate/20241217164540_create_embedding_definitions.rb new file mode 100644 index 00000000..517ef407 --- /dev/null +++ b/db/migrate/20241217164540_create_embedding_definitions.rb @@ -0,0 +1,19 @@ +# frozen_string_literal: true +class CreateEmbeddingDefinitions < ActiveRecord::Migration[7.2] + def change + create_table :embedding_definitions do |t| + t.string :display_name, null: false + t.integer :dimensions, null: false + t.integer :max_sequence_length, null: false + t.integer :version, null: false, default: 1 + t.string :pg_function, null: false + t.string :provider, null: false + t.string :tokenizer_class, null: false + t.string :url, null: false + t.string :api_key + t.boolean :seeded, null: false, default: false + t.jsonb :provider_params + t.timestamps + end + end +end diff --git a/db/migrate/20250110114305_embedding_config_data_migration.rb b/db/migrate/20250110114305_embedding_config_data_migration.rb new file mode 100644 index 00000000..3125bc99 --- /dev/null +++ b/db/migrate/20250110114305_embedding_config_data_migration.rb @@ -0,0 +1,204 @@ +# frozen_string_literal: true + +class EmbeddingConfigDataMigration < ActiveRecord::Migration[7.0] + def up + current_model = fetch_setting("ai_embeddings_model") || "bge-large-en" + provider = provider_for(current_model) + + if provider.present? + attrs = creds_for(provider) + + if attrs.present? + attrs = attrs.merge(model_attrs(current_model)) + attrs[:display_name] = current_model + attrs[:provider] = provider + persist_config(attrs) + end + end + end + + def down + end + + # Utils + + def fetch_setting(name) + DB.query_single( + "SELECT value FROM site_settings WHERE name = :setting_name", + setting_name: name, + ).first || ENV["DISCOURSE_#{name&.upcase}"] + end + + def provider_for(model) + cloudflare_api_token = fetch_setting("ai_cloudflare_workers_api_token") + + return "cloudflare" if model == "bge-large-en" && cloudflare_api_token.present? + + tei_models = %w[bge-large-en bge-m3 multilingual-e5-large] + return "hugging_face" if tei_models.include?(model) + + return "google" if model == "gemini" + + if %w[text-embedding-3-large text-embedding-3-small text-embedding-ada-002].include?(model) + return "open_ai" + end + + nil + end + + def creds_for(provider) + # CF + if provider == "cloudflare" + api_key = fetch_setting("ai_cloudflare_workers_api_token") + account_id = fetch_setting("ai_cloudflare_workers_account_id") + + return if api_key.blank? || account_id.blank? + + { + url: + "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/baai/bge-large-en-v1.5", + api_key: api_key, + } + # TEI + elsif provider == "hugging_face" + seeded = false + endpoint = fetch_setting("ai_hugging_face_tei_endpoint") + + if endpoint.blank? + endpoint = fetch_setting("ai_hugging_face_tei_endpoint_srv") + if endpoint.present? + endpoint = "srv://#{endpoint}" + seeded = true + end + end + + api_key = fetch_setting("ai_hugging_face_tei_api_key") + + return if endpoint.blank? || api_key.blank? + + { url: endpoint, api_key: api_key, seeded: seeded } + # Gemini + elsif provider == "google" + api_key = fetch_setting("ai_gemini_api_key") + + return if api_key.blank? + + { + url: "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent", + api_key: api_key, + } + + # Open AI + elsif provider == "open_ai" + endpoint = fetch_setting("ai_openai_embeddings_url") + api_key = fetch_setting("ai_openai_api_key") + + return if endpoint.blank? || api_key.blank? + + { url: endpoint, api_key: api_key } + else + nil + end + end + + def model_attrs(model_name) + if model_name == "bge-large-en" + { + dimensions: 1024, + max_sequence_length: 512, + id: 4, + pg_function: "<#>", + tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer", + } + elsif model_name == "bge-m3" + { + dimensions: 1024, + max_sequence_length: 8192, + id: 8, + pg_function: "<#>", + tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer", + } + elsif model_name == "gemini" + { + dimensions: 768, + max_sequence_length: 1536, + id: 5, + pg_function: "<=>", + tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer", + } + elsif model_name == "multilingual-e5-large" + { + dimensions: 1024, + max_sequence_length: 512, + id: 3, + pg_function: "<=>", + tokenizer_class: "DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer", + } + elsif model_name == "text-embedding-3-large" + { + dimensions: 2000, + max_sequence_length: 8191, + id: 7, + pg_function: "<=>", + tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer", + provider_params: { + model_name: "text-embedding-3-large", + }, + } + elsif model_name == "text-embedding-3-small" + { + dimensions: 1536, + max_sequence_length: 8191, + id: 6, + pg_function: "<=>", + tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer", + provider_params: { + model_name: "text-embedding-3-small", + }, + } + else + { + dimensions: 1536, + max_sequence_length: 8191, + id: 2, + pg_function: "<=>", + tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer", + provider_params: { + model_name: "text-embedding-ada-002", + }, + } + end + end + + def persist_config(attrs) + DB.exec( + <<~SQL, + INSERT INTO embedding_definitions (id, display_name, dimensions, max_sequence_length, version, pg_function, provider, tokenizer_class, url, api_key, provider_params, seeded, created_at, updated_at) + VALUES (:id, :display_name, :dimensions, :max_sequence_length, 1, :pg_function, :provider, :tokenizer_class, :url, :api_key, :provider_params, :seeded, :now, :now) + SQL + id: attrs[:id], + display_name: attrs[:display_name], + dimensions: attrs[:dimensions], + max_sequence_length: attrs[:max_sequence_length], + pg_function: attrs[:pg_function], + provider: attrs[:provider], + tokenizer_class: attrs[:tokenizer_class], + url: attrs[:url], + api_key: attrs[:api_key], + provider_params: attrs[:provider_params], + seeded: !!attrs[:seeded], + now: Time.zone.now, + ) + + # We hardcoded the ID to match with already generated embeddings. Let's restart the seq to avoid conflicts. + DB.exec( + "ALTER SEQUENCE embedding_definitions_id_seq RESTART WITH :new_seq", + new_seq: attrs[:id].to_i + 1, + ) + + DB.exec(<<~SQL, new_value: attrs[:id]) + INSERT INTO site_settings(name, data_type, value, created_at, updated_at) + VALUES ('ai_embeddings_selected_model', 3, :new_value, NOW(), NOW()) + SQL + end +end diff --git a/lib/ai_bot/entry_point.rb b/lib/ai_bot/entry_point.rb index 4076b8c4..77a802fe 100644 --- a/lib/ai_bot/entry_point.rb +++ b/lib/ai_bot/entry_point.rb @@ -196,7 +196,7 @@ module DiscourseAi ) plugin.on(:site_setting_changed) do |name, old_value, new_value| - if name == :ai_embeddings_model && SiteSetting.ai_embeddings_enabled? && + if name == :ai_embeddings_selected_model && SiteSetting.ai_embeddings_enabled? && new_value != old_value RagDocumentFragment.delete_all UploadReference diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index 1569ce45..63220e71 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -327,7 +327,7 @@ module DiscourseAi rag_conversation_chunks end - schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector.vdef) + schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment) candidate_fragment_ids = schema diff --git a/lib/ai_helper/semantic_categorizer.rb b/lib/ai_helper/semantic_categorizer.rb index effb4018..842e49c3 100644 --- a/lib/ai_helper/semantic_categorizer.rb +++ b/lib/ai_helper/semantic_categorizer.rb @@ -93,7 +93,7 @@ module DiscourseAi def nearest_neighbors(limit: 100) vector = DiscourseAi::Embeddings::Vector.instance - schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef) + schema = DiscourseAi::Embeddings::Schema.for(Topic) raw_vector = vector.vector_from(@text) diff --git a/lib/configuration/embedding_defs_enumerator.rb b/lib/configuration/embedding_defs_enumerator.rb new file mode 100644 index 00000000..b4adac1b --- /dev/null +++ b/lib/configuration/embedding_defs_enumerator.rb @@ -0,0 +1,20 @@ +# frozen_string_literal: true + +require "enum_site_setting" + +module DiscourseAi + module Configuration + class EmbeddingDefsEnumerator < ::EnumSiteSetting + def self.valid_value?(val) + true + end + + def self.values + DB.query_hash(<<~SQL).map(&:symbolize_keys) + SELECT display_name AS name, id AS value + FROM embedding_definitions + SQL + end + end + end +end diff --git a/lib/configuration/embedding_defs_validator.rb b/lib/configuration/embedding_defs_validator.rb new file mode 100644 index 00000000..600cf759 --- /dev/null +++ b/lib/configuration/embedding_defs_validator.rb @@ -0,0 +1,19 @@ +# frozen_string_literal: true + +module DiscourseAi + module Configuration + class EmbeddingDefsValidator + def initialize(opts = {}) + @opts = opts + end + + def valid_value?(val) + val.blank? || EmbeddingDefinition.exists?(id: val) + end + + def error_message + "" + end + end + end +end diff --git a/lib/configuration/embeddings_module_validator.rb b/lib/configuration/embeddings_module_validator.rb index bd5a38bc..cb320b30 100644 --- a/lib/configuration/embeddings_module_validator.rb +++ b/lib/configuration/embeddings_module_validator.rb @@ -11,41 +11,11 @@ module DiscourseAi return true if val == "f" return true if Rails.env.test? - chosen_model = SiteSetting.ai_embeddings_model - - return false if !chosen_model - - representation = - DiscourseAi::Embeddings::VectorRepresentations::Base.find_representation(chosen_model) - - return false if representation.nil? - - if !representation.correctly_configured? - @representation = representation - return false - end - - if !can_generate_embeddings?(chosen_model) - @unreachable = true - return false - end - - true + SiteSetting.ai_embeddings_selected_model.present? end def error_message - return(I18n.t("discourse_ai.embeddings.configuration.model_unreachable")) if @unreachable - - @representation&.configuration_hint - end - - def can_generate_embeddings?(val) - DiscourseAi::Embeddings::VectorRepresentations::Base - .find_representation(val) - .new - .inference_client - .perform!("this is a test") - .present? + I18n.t("discourse_ai.embeddings.configuration.choose_model") end end end diff --git a/lib/embeddings/schema.rb b/lib/embeddings/schema.rb index af4785fe..013c460e 100644 --- a/lib/embeddings/schema.rb +++ b/lib/embeddings/schema.rb @@ -12,19 +12,80 @@ module DiscourseAi POSTS_TABLE = "ai_posts_embeddings" RAG_DOCS_TABLE = "ai_document_fragments_embeddings" - def self.for( - target_klass, - vector_def: DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation - ) - case target_klass&.name - when "Topic" - new(TOPICS_TABLE, "topic_id", vector_def) - when "Post" - new(POSTS_TABLE, "post_id", vector_def) - when "RagDocumentFragment" - new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def) - else - raise ArgumentError, "Invalid target type for embeddings" + EMBEDDING_TARGETS = %w[topics posts document_fragments] + EMBEDDING_TABLES = [TOPICS_TABLE, POSTS_TABLE, RAG_DOCS_TABLE] + + MissingEmbeddingError = Class.new(StandardError) + + class << self + def for(target_klass) + vector_def = EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_selected_model) + raise "Invalid embeddings selected model" if vector_def.nil? + + case target_klass&.name + when "Topic" + new(TOPICS_TABLE, "topic_id", vector_def) + when "Post" + new(POSTS_TABLE, "post_id", vector_def) + when "RagDocumentFragment" + new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def) + else + raise ArgumentError, "Invalid target type for embeddings" + end + end + + def search_index_name(table, def_id) + "ai_#{table}_embeddings_#{def_id}_1_search_bit" + end + + def prepare_search_indexes(vector_def) + EMBEDDING_TARGETS.each { |target| DB.exec <<~SQL } + CREATE INDEX IF NOT EXISTS #{search_index_name(target, vector_def.id)} ON ai_#{target}_embeddings + USING hnsw ((binary_quantize(embeddings)::bit(#{vector_def.dimensions})) bit_hamming_ops) + WHERE model_id = #{vector_def.id} AND strategy_id = 1; + SQL + end + + def correctly_indexed?(vector_def) + index_names = EMBEDDING_TARGETS.map { |t| search_index_name(t, vector_def.id) } + indexdefs = + DB.query_single( + "SELECT indexdef FROM pg_indexes WHERE indexname IN (:names)", + names: index_names, + ) + + return false if indexdefs.length < index_names.length + + indexdefs.all? do |defs| + defs.include? "(binary_quantize(embeddings))::bit(#{vector_def.dimensions})" + end + end + + def remove_orphaned_data + removed_defs_ids = + DB.query_single( + "SELECT DISTINCT(model_id) FROM #{TOPICS_TABLE} te LEFT JOIN embedding_definitions ed ON te.model_id = ed.id WHERE ed.id IS NULL", + ) + + EMBEDDING_TABLES.each do |t| + DB.exec( + "DELETE FROM #{t} WHERE model_id IN (:removed_defs)", + removed_defs: removed_defs_ids, + ) + end + + drop_index_statement = + EMBEDDING_TARGETS + .reduce([]) do |memo, et| + removed_defs_ids.each do |rdi| + memo << "DROP INDEX IF EXISTS #{search_index_name(et, rdi)};" + end + + memo + end + .join("\n") + + DB.exec(drop_index_statement) end end @@ -117,7 +178,7 @@ module DiscourseAi offset: offset, ) rescue PG::Error => e - Rails.logger.error("Error #{e} querying embeddings for model #{name}") + Rails.logger.error("Error #{e} querying embeddings for model #{vector_def.display_name}") raise MissingEmbeddingError end @@ -168,7 +229,7 @@ module DiscourseAi builder.query(vid: vector_def.id, vsid: vector_def.strategy_id, target_id: record.id) rescue PG::Error => e - Rails.logger.error("Error #{e} querying embeddings for model #{name}") + Rails.logger.error("Error #{e} querying embeddings for model #{vector_def.display_name}") raise MissingEmbeddingError end diff --git a/lib/embeddings/semantic_search.rb b/lib/embeddings/semantic_search.rb index 490772fa..b28e9ef9 100644 --- a/lib/embeddings/semantic_search.rb +++ b/lib/embeddings/semantic_search.rb @@ -82,7 +82,7 @@ module DiscourseAi over_selection_limit = limit * OVER_SELECTION_FACTOR - schema = DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector.vdef) + schema = DiscourseAi::Embeddings::Schema.for(Topic) candidate_topic_ids = schema.asymmetric_similarity_search( @@ -132,7 +132,7 @@ module DiscourseAi candidate_post_ids = DiscourseAi::Embeddings::Schema - .for(Post, vector_def: vector.vdef) + .for(Post) .asymmetric_similarity_search( search_term_embedding, limit: max_semantic_results_per_page, diff --git a/lib/embeddings/vector.rb b/lib/embeddings/vector.rb index 2fe8c72c..582de7fb 100644 --- a/lib/embeddings/vector.rb +++ b/lib/embeddings/vector.rb @@ -4,13 +4,18 @@ module DiscourseAi module Embeddings class Vector def self.instance - new(DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation) + vector_def = EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_selected_model) + raise "Invalid embeddings selected model" if vector_def.nil? + + new(vector_def) end def initialize(vector_definition) @vdef = vector_definition end + delegate :tokenizer, to: :vdef + def gen_bulk_reprensentations(relation) http_pool_size = 100 pool = @@ -20,7 +25,7 @@ module DiscourseAi idletime: 30, ) - schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector_def: vdef) + schema = DiscourseAi::Embeddings::Schema.for(relation.first.class) embedding_gen = vdef.inference_client promised_embeddings = @@ -53,7 +58,7 @@ module DiscourseAi text = vdef.prepare_target_text(target) return if text.blank? - schema = DiscourseAi::Embeddings::Schema.for(target.class, vector_def: vdef) + schema = DiscourseAi::Embeddings::Schema.for(target.class) new_digest = OpenSSL::Digest::SHA1.hexdigest(text) return if schema.find_by_target(target)&.digest == new_digest diff --git a/lib/embeddings/vector_representations/all_mpnet_base_v2.rb b/lib/embeddings/vector_representations/all_mpnet_base_v2.rb deleted file mode 100644 index a89aab8b..00000000 --- a/lib/embeddings/vector_representations/all_mpnet_base_v2.rb +++ /dev/null @@ -1,56 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Embeddings - module VectorRepresentations - class AllMpnetBaseV2 < Base - class << self - def name - "all-mpnet-base-v2" - end - - def correctly_configured? - SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? || - SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? - end - - def dependant_setting_names - %w[ - ai_embeddings_discourse_service_api_key - ai_embeddings_discourse_service_api_endpoint_srv - ai_embeddings_discourse_service_api_endpoint - ] - end - end - - def dimensions - 768 - end - - def max_sequence_length - 384 - end - - def id - 1 - end - - def version - 1 - end - - def pg_function - "<#>" - end - - def tokenizer - DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer - end - - def inference_client - DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name) - end - end - end - end -end diff --git a/lib/embeddings/vector_representations/base.rb b/lib/embeddings/vector_representations/base.rb deleted file mode 100644 index 8670404b..00000000 --- a/lib/embeddings/vector_representations/base.rb +++ /dev/null @@ -1,103 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Embeddings - module VectorRepresentations - class Base - class << self - def find_representation(model_name) - # we are explicit here cause the loader may have not - # loaded the subclasses yet - [ - DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2, - DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn, - DiscourseAi::Embeddings::VectorRepresentations::BgeM3, - DiscourseAi::Embeddings::VectorRepresentations::Gemini, - DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large, - DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large, - DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small, - DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002, - ].find { _1.name == model_name } - end - - def current_representation - find_representation(SiteSetting.ai_embeddings_model).new - end - - def correctly_configured? - raise NotImplementedError - end - - def dependant_setting_names - raise NotImplementedError - end - - def configuration_hint - settings = dependant_setting_names - I18n.t( - "discourse_ai.embeddings.configuration.hint", - settings: settings.join(", "), - count: settings.length, - ) - end - end - - def name - raise NotImplementedError - end - - def dimensions - raise NotImplementedError - end - - def max_sequence_length - raise NotImplementedError - end - - def id - raise NotImplementedError - end - - def pg_function - raise NotImplementedError - end - - def version - raise NotImplementedError - end - - def tokenizer - raise NotImplementedError - end - - def asymmetric_query_prefix - "" - end - - def strategy_id - strategy.id - end - - def strategy_version - strategy.version - end - - def prepare_query_text(text, asymetric: false) - strategy.prepare_query_text(text, self, asymetric: asymetric) - end - - def prepare_target_text(target) - strategy.prepare_target_text(target, self) - end - - def strategy - @strategy ||= DiscourseAi::Embeddings::Strategies::Truncation.new - end - - def inference_client - raise NotImplementedError - end - end - end - end -end diff --git a/lib/embeddings/vector_representations/bge_large_en.rb b/lib/embeddings/vector_representations/bge_large_en.rb deleted file mode 100644 index 9006ebbe..00000000 --- a/lib/embeddings/vector_representations/bge_large_en.rb +++ /dev/null @@ -1,80 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Embeddings - module VectorRepresentations - class BgeLargeEn < Base - class << self - def name - "bge-large-en" - end - - def correctly_configured? - SiteSetting.ai_cloudflare_workers_api_token.present? || - DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? || - ( - SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? || - SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? - ) - end - - def dependant_setting_names - %w[ - ai_cloudflare_workers_api_token - ai_hugging_face_tei_endpoint_srv - ai_hugging_face_tei_endpoint - ai_embeddings_discourse_service_api_key - ai_embeddings_discourse_service_api_endpoint_srv - ai_embeddings_discourse_service_api_endpoint - ] - end - end - - def dimensions - 1024 - end - - def max_sequence_length - 512 - end - - def id - 4 - end - - def version - 1 - end - - def pg_function - "<#>" - end - - def tokenizer - DiscourseAi::Tokenizer::BgeLargeEnTokenizer - end - - def asymmetric_query_prefix - "Represent this sentence for searching relevant passages:" - end - - def inference_client - inference_model_name = "baai/bge-large-en-v1.5" - - if SiteSetting.ai_cloudflare_workers_api_token.present? - DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name) - elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? - DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance - elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? || - SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? - DiscourseAi::Inference::DiscourseClassifier.instance( - inference_model_name.split("/").last, - ) - else - raise "No inference endpoint configured" - end - end - end - end - end -end diff --git a/lib/embeddings/vector_representations/bge_m3.rb b/lib/embeddings/vector_representations/bge_m3.rb deleted file mode 100644 index d3625e41..00000000 --- a/lib/embeddings/vector_representations/bge_m3.rb +++ /dev/null @@ -1,51 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Embeddings - module VectorRepresentations - class BgeM3 < Base - class << self - def name - "bge-m3" - end - - def correctly_configured? - DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? - end - - def dependant_setting_names - %w[ai_hugging_face_tei_endpoint_srv ai_hugging_face_tei_endpoint] - end - end - - def dimensions - 1024 - end - - def max_sequence_length - 8192 - end - - def id - 8 - end - - def version - 1 - end - - def pg_function - "<#>" - end - - def tokenizer - DiscourseAi::Tokenizer::BgeM3Tokenizer - end - - def inference_client - DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance - end - end - end - end -end diff --git a/lib/embeddings/vector_representations/gemini.rb b/lib/embeddings/vector_representations/gemini.rb deleted file mode 100644 index 110b8b4c..00000000 --- a/lib/embeddings/vector_representations/gemini.rb +++ /dev/null @@ -1,54 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Embeddings - module VectorRepresentations - class Gemini < Base - class << self - def name - "gemini" - end - - def correctly_configured? - SiteSetting.ai_gemini_api_key.present? - end - - def dependant_setting_names - %w[ai_gemini_api_key] - end - end - - def id - 5 - end - - def version - 1 - end - - def dimensions - 768 - end - - def max_sequence_length - 1536 # Gemini has a max sequence length of 2048, but the API has a limit of 10000 bytes, hence the lower value - end - - def pg_function - "<=>" - end - - # There is no public tokenizer for Gemini, and from the ones we already ship in the plugin - # OpenAI gets the closest results. Gemini Tokenizer results in ~10% less tokens, so it's safe - # to use OpenAI tokenizer since it will overestimate the number of tokens. - def tokenizer - DiscourseAi::Tokenizer::OpenAiTokenizer - end - - def inference_client - DiscourseAi::Inference::GeminiEmbeddings.instance - end - end - end - end -end diff --git a/lib/embeddings/vector_representations/multilingual_e5_large.rb b/lib/embeddings/vector_representations/multilingual_e5_large.rb deleted file mode 100644 index 7d6894a8..00000000 --- a/lib/embeddings/vector_representations/multilingual_e5_large.rb +++ /dev/null @@ -1,88 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Embeddings - module VectorRepresentations - class MultilingualE5Large < Base - class << self - def name - "multilingual-e5-large" - end - - def correctly_configured? - DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? || - ( - SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? || - SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? - ) - end - - def dependant_setting_names - %w[ - ai_hugging_face_tei_endpoint_srv - ai_hugging_face_tei_endpoint - ai_embeddings_discourse_service_api_key - ai_embeddings_discourse_service_api_endpoint_srv - ai_embeddings_discourse_service_api_endpoint - ] - end - end - - def id - 3 - end - - def version - 1 - end - - def dimensions - 1024 - end - - def max_sequence_length - 512 - end - - def pg_function - "<=>" - end - - def tokenizer - DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer - end - - def inference_client - if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? - DiscourseAi::Inference::HuggingFaceTextEmbeddings.instance - elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? || - SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? - DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name) - else - raise "No inference endpoint configured" - end - end - - def prepare_text(text, asymetric: false) - prepared_text = super(text, asymetric: asymetric) - - if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier") - return "query: #{prepared_text}" - end - - prepared_text - end - - def prepare_target_text(target) - prepared_text = super(target) - - if prepared_text.present? && inference_client.class.name.include?("DiscourseClassifier") - return "query: #{prepared_text}" - end - - prepared_text - end - end - end - end -end diff --git a/lib/embeddings/vector_representations/text_embedding_3_large.rb b/lib/embeddings/vector_representations/text_embedding_3_large.rb deleted file mode 100644 index d73f4cee..00000000 --- a/lib/embeddings/vector_representations/text_embedding_3_large.rb +++ /dev/null @@ -1,56 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Embeddings - module VectorRepresentations - class TextEmbedding3Large < Base - class << self - def name - "text-embedding-3-large" - end - - def correctly_configured? - SiteSetting.ai_openai_api_key.present? - end - - def dependant_setting_names - %w[ai_openai_api_key] - end - end - - def id - 7 - end - - def version - 1 - end - - def dimensions - # real dimentions are 3072, but we only support up to 2000 in the - # indexes, so we downsample to 2000 via API - 2000 - end - - def max_sequence_length - 8191 - end - - def pg_function - "<=>" - end - - def tokenizer - DiscourseAi::Tokenizer::OpenAiTokenizer - end - - def inference_client - DiscourseAi::Inference::OpenAiEmbeddings.instance( - model: self.class.name, - dimensions: dimensions, - ) - end - end - end - end -end diff --git a/lib/embeddings/vector_representations/text_embedding_3_small.rb b/lib/embeddings/vector_representations/text_embedding_3_small.rb deleted file mode 100644 index 90a3f790..00000000 --- a/lib/embeddings/vector_representations/text_embedding_3_small.rb +++ /dev/null @@ -1,51 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Embeddings - module VectorRepresentations - class TextEmbedding3Small < Base - class << self - def name - "text-embedding-3-small" - end - - def correctly_configured? - SiteSetting.ai_openai_api_key.present? - end - - def dependant_setting_names - %w[ai_openai_api_key] - end - end - - def id - 6 - end - - def version - 1 - end - - def dimensions - 1536 - end - - def max_sequence_length - 8191 - end - - def pg_function - "<=>" - end - - def tokenizer - DiscourseAi::Tokenizer::OpenAiTokenizer - end - - def inference_client - DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name) - end - end - end - end -end diff --git a/lib/embeddings/vector_representations/text_embedding_ada_002.rb b/lib/embeddings/vector_representations/text_embedding_ada_002.rb deleted file mode 100644 index f5340918..00000000 --- a/lib/embeddings/vector_representations/text_embedding_ada_002.rb +++ /dev/null @@ -1,51 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Embeddings - module VectorRepresentations - class TextEmbeddingAda002 < Base - class << self - def name - "text-embedding-ada-002" - end - - def correctly_configured? - SiteSetting.ai_openai_api_key.present? - end - - def dependant_setting_names - %w[ai_openai_api_key] - end - end - - def id - 2 - end - - def version - 1 - end - - def dimensions - 1536 - end - - def max_sequence_length - 8191 - end - - def pg_function - "<=>" - end - - def tokenizer - DiscourseAi::Tokenizer::OpenAiTokenizer - end - - def inference_client - DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name) - end - end - end - end -end diff --git a/lib/inference/cloudflare_workers_ai.rb b/lib/inference/cloudflare_workers_ai.rb index d6e0dfb3..360725e7 100644 --- a/lib/inference/cloudflare_workers_ai.rb +++ b/lib/inference/cloudflare_workers_ai.rb @@ -3,22 +3,13 @@ module ::DiscourseAi module Inference class CloudflareWorkersAi - def initialize(account_id, api_token, model, referer = Discourse.base_url) - @account_id = account_id + def initialize(endpoint, api_token, referer = Discourse.base_url) + @endpoint = endpoint @api_token = api_token - @model = model @referer = referer end - def self.instance(model) - new( - SiteSetting.ai_cloudflare_workers_account_id, - SiteSetting.ai_cloudflare_workers_api_token, - model, - ) - end - - attr_reader :account_id, :api_token, :model, :referer + attr_reader :endpoint, :api_token, :referer def perform!(content) headers = { @@ -29,8 +20,6 @@ module ::DiscourseAi payload = { text: [content] } - endpoint = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/#{model}" - conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } response = conn.post(endpoint, payload.to_json, headers) @@ -43,7 +32,7 @@ module ::DiscourseAi Rails.logger.warn( "Cloudflare Workers AI Embeddings failed with status: #{response.status} body: #{response.body}", ) - raise Net::HTTPBadResponse + raise Net::HTTPBadResponse.new(response.body.to_s) end end end diff --git a/lib/inference/discourse_classifier.rb b/lib/inference/discourse_classifier.rb deleted file mode 100644 index 46f912dd..00000000 --- a/lib/inference/discourse_classifier.rb +++ /dev/null @@ -1,47 +0,0 @@ -# frozen_string_literal: true - -module ::DiscourseAi - module Inference - class DiscourseClassifier - def initialize(endpoint, api_key, model, referer = Discourse.base_url) - @endpoint = endpoint - @api_key = api_key - @model = model - @referer = referer - end - - def self.instance(model) - endpoint = - if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? - service = - DiscourseAi::Utils::DnsSrv.lookup( - SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv, - ) - "https://#{service.target}:#{service.port}" - else - SiteSetting.ai_embeddings_discourse_service_api_endpoint - end - - new( - "#{endpoint}/api/v1/classify", - SiteSetting.ai_embeddings_discourse_service_api_key, - model, - ) - end - - attr_reader :endpoint, :api_key, :model, :referer - - def perform!(content) - headers = { "Referer" => referer, "Content-Type" => "application/json" } - headers["X-API-KEY"] = api_key if api_key.present? - - conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } - response = conn.post(endpoint, { model: model, content: content }.to_json, headers) - - raise Net::HTTPBadResponse if ![200, 415].include?(response.status) - - JSON.parse(response.body, symbolize_names: true) - end - end - end -end diff --git a/lib/inference/gemini_embeddings.rb b/lib/inference/gemini_embeddings.rb index 2684dcd4..6250f5a0 100644 --- a/lib/inference/gemini_embeddings.rb +++ b/lib/inference/gemini_embeddings.rb @@ -3,21 +3,17 @@ module ::DiscourseAi module Inference class GeminiEmbeddings - def self.instance - new(SiteSetting.ai_gemini_api_key) - end - - def initialize(api_key, referer = Discourse.base_url) + def initialize(embedding_url, api_key, referer = Discourse.base_url) @api_key = api_key + @embedding_url = embedding_url @referer = referer end - attr_reader :api_key, :referer + attr_reader :embedding_url, :api_key, :referer def perform!(content) headers = { "Referer" => referer, "Content-Type" => "application/json" } - url = - "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{api_key}" + url = "#{embedding_url}\?key\=#{api_key}" body = { content: { parts: [{ text: content }] } } conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } @@ -32,7 +28,7 @@ module ::DiscourseAi Rails.logger.warn( "Google Gemini Embeddings failed with status: #{response.status} body: #{response.body}", ) - raise Net::HTTPBadResponse + raise Net::HTTPBadResponse.new(response.body.to_s) end end end diff --git a/lib/inference/hugging_face_text_embeddings.rb b/lib/inference/hugging_face_text_embeddings.rb index 3881cfc3..954e2a30 100644 --- a/lib/inference/hugging_face_text_embeddings.rb +++ b/lib/inference/hugging_face_text_embeddings.rb @@ -12,19 +12,6 @@ module ::DiscourseAi attr_reader :endpoint, :key, :referer class << self - def instance - endpoint = - if SiteSetting.ai_hugging_face_tei_endpoint_srv.present? - service = - DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv) - "https://#{service.target}:#{service.port}" - else - SiteSetting.ai_hugging_face_tei_endpoint - end - - new(endpoint, SiteSetting.ai_hugging_face_tei_api_key) - end - def configured? SiteSetting.ai_hugging_face_tei_endpoint.present? || SiteSetting.ai_hugging_face_tei_endpoint_srv.present? @@ -100,7 +87,7 @@ module ::DiscourseAi conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } response = conn.post(endpoint, body, headers) - raise Net::HTTPBadResponse if ![200].include?(response.status) + raise Net::HTTPBadResponse.new(response.body.to_s) if ![200].include?(response.status) JSON.parse(response.body, symbolize_names: true).first end diff --git a/lib/inference/open_ai_embeddings.rb b/lib/inference/open_ai_embeddings.rb index e3e6551c..5e7820cf 100644 --- a/lib/inference/open_ai_embeddings.rb +++ b/lib/inference/open_ai_embeddings.rb @@ -12,10 +12,6 @@ module ::DiscourseAi attr_reader :endpoint, :api_key, :model, :dimensions - def self.instance(model:, dimensions: nil) - new(SiteSetting.ai_openai_embeddings_url, SiteSetting.ai_openai_api_key, model, dimensions) - end - def perform!(content) headers = { "Content-Type" => "application/json" } @@ -29,7 +25,7 @@ module ::DiscourseAi payload[:dimensions] = dimensions if dimensions.present? conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } - response = conn.post(SiteSetting.ai_openai_embeddings_url, payload.to_json, headers) + response = conn.post(endpoint, payload.to_json, headers) case response.status when 200 @@ -40,7 +36,7 @@ module ::DiscourseAi Rails.logger.warn( "OpenAI Embeddings failed with status: #{response.status} body: #{response.body}", ) - raise Net::HTTPBadResponse + raise Net::HTTPBadResponse.new(response.body.to_s) end end end diff --git a/plugin.rb b/plugin.rb index e02a102f..66fc60a5 100644 --- a/plugin.rb +++ b/plugin.rb @@ -35,6 +35,7 @@ register_asset "stylesheets/modules/embeddings/common/semantic-search.scss" register_asset "stylesheets/modules/sentiment/common/dashboard.scss" register_asset "stylesheets/modules/llms/common/ai-llms-editor.scss" +register_asset "stylesheets/modules/embeddings/common/ai-embedding-editor.scss" register_asset "stylesheets/modules/llms/common/usage.scss" register_asset "stylesheets/modules/llms/common/spam.scss" diff --git a/spec/configuration/embeddings_model_validator_spec.rb b/spec/configuration/embeddings_model_validator_spec.rb deleted file mode 100644 index f28ea56c..00000000 --- a/spec/configuration/embeddings_model_validator_spec.rb +++ /dev/null @@ -1,17 +0,0 @@ -# frozen_string_literal: true - -require_relative "../support/embeddings_generation_stubs" - -RSpec.describe DiscourseAi::Configuration::EmbeddingsModelValidator do - before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" } - - describe "#can_generate_embeddings?" do - it "works" do - discourse_model = "all-mpnet-base-v2" - - EmbeddingsGenerationStubs.discourse_service(discourse_model, "this is a test", [1] * 1024) - - expect(subject.can_generate_embeddings?(discourse_model)).to eq(true) - end - end -end diff --git a/spec/fabricators/embedding_definition_fabricator.rb b/spec/fabricators/embedding_definition_fabricator.rb new file mode 100644 index 00000000..5daee977 --- /dev/null +++ b/spec/fabricators/embedding_definition_fabricator.rb @@ -0,0 +1,40 @@ +# frozen_string_literal: true + +Fabricator(:embedding_definition) do + display_name "Multilingual E5 Large" + provider "hugging_face" + tokenizer_class "DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer" + api_key "123" + url "https://test.com/embeddings" + provider_params nil + pg_function "<=>" + max_sequence_length 512 + dimensions 1024 +end + +Fabricator(:cloudflare_embedding_def, from: :embedding_definition) do + display_name "BGE Large EN" + provider "cloudflare" + pg_function "<#>" + tokenizer_class "DiscourseAi::Tokenizer::BgeLargeEnTokenizer" + provider_params nil +end + +Fabricator(:open_ai_embedding_def, from: :embedding_definition) do + display_name "ADA 002" + provider "open_ai" + url "https://api.openai.com/v1/embeddings" + tokenizer_class "DiscourseAi::Tokenizer::OpenAiTokenizer" + provider_params { { model_name: "text-embedding-ada-002" } } + max_sequence_length 8191 + dimensions 1536 +end + +Fabricator(:gemini_embedding_def, from: :embedding_definition) do + display_name "Gemini's embedding-001" + provider "google" + dimensions 768 + max_sequence_length 1536 + tokenizer_class "DiscourseAi::Tokenizer::OpenAiTokenizer" + url "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent" +end diff --git a/spec/jobs/regular/digest_rag_upload_spec.rb b/spec/jobs/regular/digest_rag_upload_spec.rb index eed03fc8..97da3b79 100644 --- a/spec/jobs/regular/digest_rag_upload_spec.rb +++ b/spec/jobs/regular/digest_rag_upload_spec.rb @@ -6,9 +6,8 @@ RSpec.describe Jobs::DigestRagUpload do let(:document_file) { StringIO.new("some text" * 200) } - let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } - - let(:expected_embedding) { [0.0038493] * vector_rep.dimensions } + fab!(:cloudflare_embedding_def) + let(:expected_embedding) { [0.0038493] * cloudflare_embedding_def.dimensions } let(:document_with_metadata) { plugin_file_from_fixtures("doc_with_metadata.txt", "rag") } @@ -21,15 +20,14 @@ RSpec.describe Jobs::DigestRagUpload do end before do + SiteSetting.ai_embeddings_selected_model = cloudflare_embedding_def.id SiteSetting.ai_embeddings_enabled = true - SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" - SiteSetting.ai_embeddings_model = "bge-large-en" SiteSetting.authorized_extensions = "txt" - WebMock.stub_request( - :post, - "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", - ).to_return(status: 200, body: JSON.dump(expected_embedding)) + WebMock.stub_request(:post, cloudflare_embedding_def.url).to_return( + status: 200, + body: JSON.dump(expected_embedding), + ) end describe "#execute" do diff --git a/spec/jobs/regular/generate_rag_embeddings_spec.rb b/spec/jobs/regular/generate_rag_embeddings_spec.rb index 1cba9d06..10558745 100644 --- a/spec/jobs/regular/generate_rag_embeddings_spec.rb +++ b/spec/jobs/regular/generate_rag_embeddings_spec.rb @@ -2,23 +2,26 @@ RSpec.describe Jobs::GenerateRagEmbeddings do describe "#execute" do - let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } + fab!(:vector_def) { Fabricate(:embedding_definition) } - let(:expected_embedding) { [0.0038493] * vector_rep.dimensions } + let(:expected_embedding) { [0.0038493] * vector_def.dimensions } fab!(:ai_persona) - fab!(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, target: ai_persona) } - fab!(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, target: ai_persona) } + let(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, target: ai_persona) } + let(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, target: ai_persona) } before do + SiteSetting.ai_embeddings_selected_model = vector_def.id SiteSetting.ai_embeddings_enabled = true - SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" - WebMock.stub_request( - :post, - "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", - ).to_return(status: 200, body: JSON.dump(expected_embedding)) + rag_document_fragment_1 + rag_document_fragment_2 + + WebMock.stub_request(:post, vector_def.url).to_return( + status: 200, + body: JSON.dump(expected_embedding), + ) end it "generates a new vector for each fragment" do diff --git a/spec/jobs/regular/manage_embedding_def_search_index_spec.rb b/spec/jobs/regular/manage_embedding_def_search_index_spec.rb new file mode 100644 index 00000000..21c9bb6d --- /dev/null +++ b/spec/jobs/regular/manage_embedding_def_search_index_spec.rb @@ -0,0 +1,56 @@ +# frozen_string_literal: true + +RSpec.describe Jobs::ManageEmbeddingDefSearchIndex do + fab!(:embedding_definition) + + describe "#execute" do + context "when there is no embedding def" do + it "does nothing" do + invalid_id = 999_999_999 + + subject.execute(id: invalid_id) + + expect( + DiscourseAi::Embeddings::Schema.correctly_indexed?( + EmbeddingDefinition.new(id: invalid_id), + ), + ).to eq(false) + end + end + + context "when the embedding def is fresh" do + it "creates the indexes" do + subject.execute(id: embedding_definition.id) + + expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(true) + end + + it "creates them only once" do + subject.execute(id: embedding_definition.id) + subject.execute(id: embedding_definition.id) + + expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(true) + end + + context "when one of the idxs is missing" do + it "automatically recovers by creating it" do + DB.exec <<~SQL + CREATE INDEX IF NOT EXISTS ai_topics_embeddings_#{embedding_definition.id}_1_search_bit ON ai_topics_embeddings + USING hnsw ((binary_quantize(embeddings)::bit(#{embedding_definition.dimensions})) bit_hamming_ops) + WHERE model_id = #{embedding_definition.id} AND strategy_id = 1; + SQL + + expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq( + false, + ) + + subject.execute(id: embedding_definition.id) + + expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq( + true, + ) + end + end + end + end +end diff --git a/spec/jobs/scheduled/embeddings_backfill_spec.rb b/spec/jobs/scheduled/embeddings_backfill_spec.rb index c7fcc436..bd74e1f3 100644 --- a/spec/jobs/scheduled/embeddings_backfill_spec.rb +++ b/spec/jobs/scheduled/embeddings_backfill_spec.rb @@ -19,11 +19,11 @@ RSpec.describe Jobs::EmbeddingsBackfill do topic end - let(:vector_rep) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } + fab!(:vector_def) { Fabricate(:embedding_definition) } before do + SiteSetting.ai_embeddings_selected_model = vector_def.id SiteSetting.ai_embeddings_enabled = true - SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" SiteSetting.ai_embeddings_backfill_batch_size = 1 SiteSetting.ai_embeddings_per_post_enabled = true Jobs.run_immediately! @@ -32,10 +32,10 @@ RSpec.describe Jobs::EmbeddingsBackfill do it "backfills topics based on bumped_at date" do embedding = Array.new(1024) { 1 } - WebMock.stub_request( - :post, - "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", - ).to_return(status: 200, body: JSON.dump(embedding)) + WebMock.stub_request(:post, "https://test.com/embeddings").to_return( + status: 200, + body: JSON.dump(embedding), + ) Jobs::EmbeddingsBackfill.new.execute({}) diff --git a/spec/jobs/scheduled/remove_orphaned_embeddings_spec.rb b/spec/jobs/scheduled/remove_orphaned_embeddings_spec.rb new file mode 100644 index 00000000..1ecee441 --- /dev/null +++ b/spec/jobs/scheduled/remove_orphaned_embeddings_spec.rb @@ -0,0 +1,71 @@ +# frozen_string_literal: true + +RSpec.describe Jobs::RemoveOrphanedEmbeddings do + describe "#execute" do + fab!(:embedding_definition) + fab!(:embedding_definition_2) { Fabricate(:embedding_definition) } + fab!(:topic) + fab!(:post) + + before do + DiscourseAi::Embeddings::Schema.prepare_search_indexes(embedding_definition) + DiscourseAi::Embeddings::Schema.prepare_search_indexes(embedding_definition_2) + + # Seed embeddings. One of each def x target classes. + [embedding_definition, embedding_definition_2].each do |edef| + SiteSetting.ai_embeddings_selected_model = edef.id + + [topic, post].each do |target| + schema = DiscourseAi::Embeddings::Schema.for(target.class) + schema.store(target, [1] * edef.dimensions, "test") + end + end + + embedding_definition.destroy! + end + + def find_all_embeddings_of(target, table, target_column) + DB.query_single("SELECT model_id FROM #{table} WHERE #{target_column} = #{target.id}") + end + + it "delete embeddings without an existing embedding definition" do + expect(find_all_embeddings_of(post, "ai_posts_embeddings", "post_id")).to contain_exactly( + embedding_definition.id, + embedding_definition_2.id, + ) + expect(find_all_embeddings_of(topic, "ai_topics_embeddings", "topic_id")).to contain_exactly( + embedding_definition.id, + embedding_definition_2.id, + ) + + subject.execute({}) + + expect(find_all_embeddings_of(topic, "ai_topics_embeddings", "topic_id")).to contain_exactly( + embedding_definition_2.id, + ) + expect(find_all_embeddings_of(post, "ai_posts_embeddings", "post_id")).to contain_exactly( + embedding_definition_2.id, + ) + end + + it "deletes orphaned indexes" do + expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition)).to eq(true) + expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition_2)).to eq(true) + + subject.execute({}) + + index_names = + DiscourseAi::Embeddings::Schema::EMBEDDING_TARGETS.map do |t| + "ai_#{t}_embeddings_#{embedding_definition.id}_1_search_bit" + end + indexnames = + DB.query_single( + "SELECT indexname FROM pg_indexes WHERE indexname IN (:names)", + names: index_names, + ) + + expect(indexnames).to be_empty + expect(DiscourseAi::Embeddings::Schema.correctly_indexed?(embedding_definition_2)).to eq(true) + end + end +end diff --git a/spec/lib/configuration/embeddings_module_validator_spec.rb b/spec/lib/configuration/embeddings_module_validator_spec.rb deleted file mode 100644 index 0b1e5713..00000000 --- a/spec/lib/configuration/embeddings_module_validator_spec.rb +++ /dev/null @@ -1,15 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe DiscourseAi::Configuration::EmbeddingsModuleValidator do - let(:validator) { described_class.new } - - describe "#can_generate_embeddings?" do - it "returns true if embeddings can be generated" do - stub_request( - :post, - "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent?key=", - ).to_return(status: 200, body: { embedding: { values: [1, 2, 3] } }.to_json) - expect(validator.can_generate_embeddings?("gemini")).to eq(true) - end - end -end diff --git a/spec/lib/inference/cloudflare_workers_ai_spec.rb b/spec/lib/inference/cloudflare_workers_ai_spec.rb index eb939f4a..1199111d 100644 --- a/spec/lib/inference/cloudflare_workers_ai_spec.rb +++ b/spec/lib/inference/cloudflare_workers_ai_spec.rb @@ -4,7 +4,7 @@ require "rails_helper" require "webmock/rspec" RSpec.describe DiscourseAi::Inference::CloudflareWorkersAi do - subject { described_class.new(account_id, api_token, model) } + subject { described_class.new(endpoint, api_token) } let(:account_id) { "test_account_id" } let(:api_token) { "test_api_token" } diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb index c58158d3..34d056a8 100644 --- a/spec/lib/modules/ai_bot/personas/persona_spec.rb +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -297,9 +297,11 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do end describe "#craft_prompt" do + fab!(:vector_def) { Fabricate(:embedding_definition) } + before do Group.refresh_automatic_groups! - SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" + SiteSetting.ai_embeddings_selected_model = vector_def.id SiteSetting.ai_embeddings_enabled = true end @@ -326,13 +328,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do fab!(:llm_model) { Fabricate(:fake_model) } it "will run the question consolidator" do - vector_def = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation context_embedding = vector_def.dimensions.times.map { rand(-1.0...1.0) } - EmbeddingsGenerationStubs.discourse_service( - SiteSetting.ai_embeddings_model, - consolidated_question, - context_embedding, - ) + EmbeddingsGenerationStubs.hugging_face_service(consolidated_question, context_embedding) custom_ai_persona = Fabricate( @@ -373,14 +370,11 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do end context "when a persona has RAG uploads" do - let(:vector_def) do - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation - end let(:embedding_value) { 0.04381 } let(:prompt_cc_embeddings) { [embedding_value] * vector_def.dimensions } def stub_fragments(fragment_count, persona: ai_persona) - schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector_def: vector_def) + schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment) fragment_count.times do |i| fragment = @@ -403,8 +397,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do stored_ai_persona = AiPersona.find(ai_persona.id) UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id]) - EmbeddingsGenerationStubs.discourse_service( - SiteSetting.ai_embeddings_model, + EmbeddingsGenerationStubs.hugging_face_service( with_cc.dig(:conversation_context, 0, :content), prompt_cc_embeddings, ) diff --git a/spec/lib/modules/ai_bot/tools/search_spec.rb b/spec/lib/modules/ai_bot/tools/search_spec.rb index 04aa4d3f..976d9c42 100644 --- a/spec/lib/modules/ai_bot/tools/search_spec.rb +++ b/spec/lib/modules/ai_bot/tools/search_spec.rb @@ -108,16 +108,13 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do it "supports semantic search when enabled" do assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model) + vector_def = Fabricate(:embedding_definition) + SiteSetting.ai_embeddings_selected_model = vector_def.id SiteSetting.ai_embeddings_semantic_search_enabled = true - SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" - vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation - hyde_embedding = [0.049382] * vector_rep.dimensions - EmbeddingsGenerationStubs.discourse_service( - SiteSetting.ai_embeddings_model, - query, - hyde_embedding, - ) + hyde_embedding = [0.049382] * vector_def.dimensions + + EmbeddingsGenerationStubs.hugging_face_service(query, hyde_embedding) post1 = Fabricate(:post, topic: topic_with_tags) search = diff --git a/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb index 8b89d800..bbbfe6af 100644 --- a/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb +++ b/spec/lib/modules/ai_helper/semantic_categorizer_spec.rb @@ -1,6 +1,7 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do + fab!(:vector_def) { Fabricate(:embedding_definition) } fab!(:user) fab!(:muted_category) { Fabricate(:category) } fab!(:category_mute) do @@ -19,14 +20,13 @@ RSpec.describe DiscourseAi::AiHelper::SemanticCategorizer do let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions } before do + SiteSetting.ai_embeddings_selected_model = vector_def.id SiteSetting.ai_embeddings_enabled = true - SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" - SiteSetting.ai_embeddings_model = "bge-large-en" - WebMock.stub_request( - :post, - "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", - ).to_return(status: 200, body: JSON.dump(expected_embedding)) + WebMock.stub_request(:post, vector_def.url).to_return( + status: 200, + body: JSON.dump([expected_embedding]), + ) vector.generate_representation_from(topic) vector.generate_representation_from(muted_topic) diff --git a/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb b/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb index 56384d59..069248fa 100644 --- a/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb +++ b/spec/lib/modules/embeddings/jobs/generate_embeddings_spec.rb @@ -3,25 +3,26 @@ RSpec.describe Jobs::GenerateEmbeddings do subject(:job) { described_class.new } + fab!(:vector_def) { Fabricate(:embedding_definition) } + describe "#execute" do before do - SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" + SiteSetting.ai_embeddings_selected_model = vector_def.id SiteSetting.ai_embeddings_enabled = true end fab!(:topic) fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } - let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } - let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vector_def) } - let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vector_def) } + let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic) } + let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post) } it "works for topics" do expected_embedding = [0.0038493] * vector_def.dimensions text = vector_def.prepare_target_text(topic) - EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding) + EmbeddingsGenerationStubs.hugging_face_service(text, expected_embedding) job.execute(target_id: topic.id, target_type: "Topic") @@ -32,7 +33,7 @@ RSpec.describe Jobs::GenerateEmbeddings do expected_embedding = [0.0038493] * vector_def.dimensions text = vector_def.prepare_target_text(post) - EmbeddingsGenerationStubs.discourse_service(vector_def.class.name, text, expected_embedding) + EmbeddingsGenerationStubs.hugging_face_service(text, expected_embedding) job.execute(target_id: post.id, target_type: "Post") diff --git a/spec/lib/modules/embeddings/schema_spec.rb b/spec/lib/modules/embeddings/schema_spec.rb index 107428aa..13a4fcb0 100644 --- a/spec/lib/modules/embeddings/schema_spec.rb +++ b/spec/lib/modules/embeddings/schema_spec.rb @@ -1,12 +1,14 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::Embeddings::Schema do - subject(:posts_schema) { described_class.for(Post, vector_def: vector_def) } + subject(:posts_schema) { described_class.for(Post) } + fab!(:vector_def) { Fabricate(:cloudflare_embedding_def) } let(:embeddings) { [0.0038490295] * vector_def.dimensions } fab!(:post) { Fabricate(:post, post_number: 1) } let(:digest) { OpenSSL::Digest.hexdigest("SHA1", "test") } - let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new } + + before { SiteSetting.ai_embeddings_selected_model = vector_def.id } before { posts_schema.store(post, embeddings, digest) } diff --git a/spec/lib/modules/embeddings/semantic_related_spec.rb b/spec/lib/modules/embeddings/semantic_related_spec.rb index 0349abad..f829a948 100644 --- a/spec/lib/modules/embeddings/semantic_related_spec.rb +++ b/spec/lib/modules/embeddings/semantic_related_spec.rb @@ -13,7 +13,13 @@ describe DiscourseAi::Embeddings::SemanticRelated do fab!(:secured_category_topic) { Fabricate(:topic, category: secured_category) } fab!(:closed_topic) { Fabricate(:topic, closed: true) } - before { SiteSetting.ai_embeddings_semantic_related_topics_enabled = true } + fab!(:vector_def) { Fabricate(:embedding_definition) } + + before do + SiteSetting.ai_embeddings_semantic_related_topics_enabled = true + SiteSetting.ai_embeddings_selected_model = vector_def.id + SiteSetting.ai_embeddings_enabled = true + end describe "#related_topic_ids_for" do context "when embeddings do not exist" do @@ -24,21 +30,15 @@ describe DiscourseAi::Embeddings::SemanticRelated do topic end - let(:vector_rep) do - DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation - end - it "properly generates embeddings if missing" do - SiteSetting.ai_embeddings_enabled = true - SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" Jobs.run_immediately! embedding = Array.new(1024) { 1 } - WebMock.stub_request( - :post, - "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", - ).to_return(status: 200, body: JSON.dump(embedding)) + WebMock.stub_request(:post, vector_def.url).to_return( + status: 200, + body: JSON.dump([embedding]), + ) # miss first ids = semantic_related.related_topic_ids_for(topic) diff --git a/spec/lib/modules/embeddings/semantic_search_spec.rb b/spec/lib/modules/embeddings/semantic_search_spec.rb index 2ed0bae9..e77da502 100644 --- a/spec/lib/modules/embeddings/semantic_search_spec.rb +++ b/spec/lib/modules/embeddings/semantic_search_spec.rb @@ -7,22 +7,18 @@ RSpec.describe DiscourseAi::Embeddings::SemanticSearch do let(:query) { "test_query" } let(:subject) { described_class.new(Guardian.new(user)) } - before { assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model) } + fab!(:vector_def) { Fabricate(:embedding_definition) } + + before do + SiteSetting.ai_embeddings_selected_model = vector_def.id + assign_fake_provider_to(:ai_embeddings_semantic_search_hyde_model) + end describe "#search_for_topics" do let(:hypothetical_post) { "This is an hypothetical post generated from the keyword test_query" } - let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation } let(:hyde_embedding) { [0.049382] * vector_def.dimensions } - before do - SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" - - EmbeddingsGenerationStubs.discourse_service( - SiteSetting.ai_embeddings_model, - hypothetical_post, - hyde_embedding, - ) - end + before { EmbeddingsGenerationStubs.hugging_face_service(hypothetical_post, hyde_embedding) } after { described_class.clear_cache_for(query) } diff --git a/spec/lib/modules/embeddings/semantic_topic_query_spec.rb b/spec/lib/modules/embeddings/semantic_topic_query_spec.rb index 9c83f42e..ffc35837 100644 --- a/spec/lib/modules/embeddings/semantic_topic_query_spec.rb +++ b/spec/lib/modules/embeddings/semantic_topic_query_spec.rb @@ -9,6 +9,10 @@ describe DiscourseAi::Embeddings::EntryPoint do fab!(:target) { Fabricate(:topic) } + fab!(:vector_def) { Fabricate(:cloudflare_embedding_def) } + + before { SiteSetting.ai_embeddings_selected_model = vector_def.id } + # The Distance gap to target increases for each element of topics. def seed_embeddings(topics) schema = DiscourseAi::Embeddings::Schema.for(Topic) diff --git a/spec/lib/modules/embeddings/strategies/truncation_spec.rb b/spec/lib/modules/embeddings/strategies/truncation_spec.rb index 0b31fb11..0792e5d4 100644 --- a/spec/lib/modules/embeddings/strategies/truncation_spec.rb +++ b/spec/lib/modules/embeddings/strategies/truncation_spec.rb @@ -19,13 +19,13 @@ RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do ) end fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) } - - let(:vector_def) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new } + fab!(:open_ai_embedding_def) it "truncates a topic" do - prepared_text = truncation.prepare_target_text(topic, vector_def) + prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def) - expect(vector_def.tokenizer.size(prepared_text)).to be <= vector_def.max_sequence_length + expect(open_ai_embedding_def.tokenizer.size(prepared_text)).to be <= + open_ai_embedding_def.max_sequence_length end end end diff --git a/spec/lib/modules/embeddings/vector_spec.rb b/spec/lib/modules/embeddings/vector_spec.rb index 6542323d..c6e8b8c5 100644 --- a/spec/lib/modules/embeddings/vector_spec.rb +++ b/spec/lib/modules/embeddings/vector_spec.rb @@ -7,8 +7,10 @@ RSpec.describe DiscourseAi::Embeddings::Vector do let(:expected_embedding_1) { [0.0038493] * vdef.dimensions } let(:expected_embedding_2) { [0.0037684] * vdef.dimensions } - let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic, vector_def: vdef) } - let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post, vector_def: vdef) } + before { SiteSetting.ai_embeddings_selected_model = vdef.id } + + let(:topics_schema) { DiscourseAi::Embeddings::Schema.for(Topic) } + let(:posts_schema) { DiscourseAi::Embeddings::Schema.for(Post) } fab!(:topic) fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } @@ -84,63 +86,16 @@ RSpec.describe DiscourseAi::Embeddings::Vector do end end - context "with text-embedding-ada-002" do - let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002.new } - - def stub_vector_mapping(text, expected_embedding) - EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding) - end - - it_behaves_like "generates and store embeddings using a vector definition" - end - - context "with all all-mpnet-base-v2" do - let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2.new } - - before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" } - - def stub_vector_mapping(text, expected_embedding) - EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding) - end - - it_behaves_like "generates and store embeddings using a vector definition" - end - - context "with gemini" do - let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::Gemini.new } - let(:api_key) { "test-123" } - - before { SiteSetting.ai_gemini_api_key = api_key } - - def stub_vector_mapping(text, expected_embedding) - EmbeddingsGenerationStubs.gemini_service(api_key, text, expected_embedding) - end - - it_behaves_like "generates and store embeddings using a vector definition" - end - - context "with multilingual-e5-large" do - let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large.new } - - before { SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" } - - def stub_vector_mapping(text, expected_embedding) - EmbeddingsGenerationStubs.discourse_service(vdef.class.name, text, expected_embedding) - end - - it_behaves_like "generates and store embeddings using a vector definition" - end - - context "with text-embedding-3-large" do - let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large.new } + context "with open_ai as the provider" do + fab!(:vdef) { Fabricate(:open_ai_embedding_def) } def stub_vector_mapping(text, expected_embedding) EmbeddingsGenerationStubs.openai_service( - vdef.class.name, + vdef.lookup_custom_param("model_name"), text, expected_embedding, extra_args: { - dimensions: 2000, + dimensions: vdef.dimensions, }, ) end @@ -148,11 +103,31 @@ RSpec.describe DiscourseAi::Embeddings::Vector do it_behaves_like "generates and store embeddings using a vector definition" end - context "with text-embedding-3-small" do - let(:vdef) { DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small.new } + context "with hugging_face as the provider" do + fab!(:vdef) { Fabricate(:embedding_definition) } def stub_vector_mapping(text, expected_embedding) - EmbeddingsGenerationStubs.openai_service(vdef.class.name, text, expected_embedding) + EmbeddingsGenerationStubs.hugging_face_service(text, expected_embedding) + end + + it_behaves_like "generates and store embeddings using a vector definition" + end + + context "with google as the provider" do + fab!(:vdef) { Fabricate(:gemini_embedding_def) } + + def stub_vector_mapping(text, expected_embedding) + EmbeddingsGenerationStubs.gemini_service(vdef.api_key, text, expected_embedding) + end + + it_behaves_like "generates and store embeddings using a vector definition" + end + + context "with cloudflare as the provider" do + fab!(:vdef) { Fabricate(:cloudflare_embedding_def) } + + def stub_vector_mapping(text, expected_embedding) + EmbeddingsGenerationStubs.cloudflare_service(text, expected_embedding) end it_behaves_like "generates and store embeddings using a vector definition" diff --git a/spec/models/ai_tool_spec.rb b/spec/models/ai_tool_spec.rb index a0d9fc01..fda5351f 100644 --- a/spec/models/ai_tool_spec.rb +++ b/spec/models/ai_tool_spec.rb @@ -204,12 +204,12 @@ RSpec.describe AiTool do end context "when defining RAG fragments" do + fab!(:cloudflare_embedding_def) + before do SiteSetting.authorized_extensions = "txt" + SiteSetting.ai_embeddings_selected_model = cloudflare_embedding_def.id SiteSetting.ai_embeddings_enabled = true - SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" - SiteSetting.ai_embeddings_model = "bge-large-en" - Jobs.run_immediately! end @@ -228,9 +228,9 @@ RSpec.describe AiTool do # this is a trick, we get ever increasing embeddings, this gives us in turn # 100% consistent search results @counter = 0 - stub_request(:post, "http://test.com/api/v1/classify").to_return( + stub_request(:post, cloudflare_embedding_def.url).to_return( status: 200, - body: lambda { |req| ([@counter += 1] * 1024).to_json }, + body: lambda { |req| { result: { data: [([@counter += 1] * 1024)] } }.to_json }, headers: { }, ) diff --git a/spec/models/rag_document_fragment_spec.rb b/spec/models/rag_document_fragment_spec.rb index afb95d83..bb77f3de 100644 --- a/spec/models/rag_document_fragment_spec.rb +++ b/spec/models/rag_document_fragment_spec.rb @@ -4,11 +4,9 @@ RSpec.describe RagDocumentFragment do fab!(:persona) { Fabricate(:ai_persona) } fab!(:upload_1) { Fabricate(:upload) } fab!(:upload_2) { Fabricate(:upload) } + fab!(:vector_def) { Fabricate(:embedding_definition) } - before do - SiteSetting.ai_embeddings_enabled = true - SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" - end + before { SiteSetting.ai_embeddings_enabled = true } describe ".link_uploads_and_persona" do it "does nothing if there is no persona" do @@ -76,25 +74,25 @@ RSpec.describe RagDocumentFragment do describe ".indexing_status" do let(:vector) { DiscourseAi::Embeddings::Vector.instance } - fab!(:rag_document_fragment_1) do + let(:rag_document_fragment_1) do Fabricate(:rag_document_fragment, upload: upload_1, target: persona) end - fab!(:rag_document_fragment_2) do + let(:rag_document_fragment_2) do Fabricate(:rag_document_fragment, upload: upload_1, target: persona) end - let(:expected_embedding) { [0.0038493] * vector.vdef.dimensions } + let(:expected_embedding) { [0.0038493] * vector_def.dimensions } before do - SiteSetting.ai_embeddings_enabled = true - SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" - SiteSetting.ai_embeddings_model = "bge-large-en" + SiteSetting.ai_embeddings_selected_model = vector_def.id + rag_document_fragment_1 + rag_document_fragment_2 - WebMock.stub_request( - :post, - "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", - ).to_return(status: 200, body: JSON.dump(expected_embedding)) + WebMock.stub_request(:post, "https://test.com/embeddings").to_return( + status: 200, + body: JSON.dump(expected_embedding), + ) vector.generate_representation_from(rag_document_fragment_1) end @@ -106,7 +104,7 @@ RSpec.describe RagDocumentFragment do UploadReference.create!(upload_id: upload_2.id, target: persona) Sidekiq::Testing.fake! do - SiteSetting.ai_embeddings_model = "all-mpnet-base-v2" + SiteSetting.ai_embeddings_selected_model = Fabricate(:open_ai_embedding_def).id expect(RagDocumentFragment.exists?(old_id)).to eq(false) expect(Jobs::DigestRagUpload.jobs.size).to eq(2) end diff --git a/spec/requests/admin/ai_embeddings_controller_spec.rb b/spec/requests/admin/ai_embeddings_controller_spec.rb new file mode 100644 index 00000000..97f38ac6 --- /dev/null +++ b/spec/requests/admin/ai_embeddings_controller_spec.rb @@ -0,0 +1,184 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Admin::AiEmbeddingsController do + fab!(:admin) + + before { sign_in(admin) } + + let(:valid_attrs) do + { + display_name: "Embedding config test", + dimensions: 1001, + max_sequence_length: 234, + pg_function: "<#>", + provider: "hugging_face", + url: "https://test.com/api/v1/embeddings", + api_key: "test", + tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer", + } + end + + describe "POST #create" do + context "with valid attrs" do + it "creates a new embedding definition" do + post "/admin/plugins/discourse-ai/ai-embeddings.json", params: { ai_embedding: valid_attrs } + + created_def = EmbeddingDefinition.last + + expect(response.status).to eq(201) + expect(created_def.display_name).to eq(valid_attrs[:display_name]) + end + + it "stores provider-specific config params" do + post "/admin/plugins/discourse-ai/ai-embeddings.json", + params: { + ai_embedding: + valid_attrs.merge( + provider: "open_ai", + provider_params: { + model_name: "embeddings-v1", + }, + ), + } + + created_def = EmbeddingDefinition.last + + expect(response.status).to eq(201) + expect(created_def.provider_params["model_name"]).to eq("embeddings-v1") + end + + it "ignores parameters not associated with that provider" do + post "/admin/plugins/discourse-ai/ai-embeddings.json", + params: { + ai_embedding: valid_attrs.merge(provider_params: { custom: "custom" }), + } + + created_def = EmbeddingDefinition.last + + expect(response.status).to eq(201) + expect(created_def.lookup_custom_param("custom")).to be_nil + end + end + + context "with invalid attrs" do + it "doesn't create a new embedding defitinion" do + post "/admin/plugins/discourse-ai/ai-embeddings.json", + params: { + ai_embedding: valid_attrs.except(:provider), + } + + created_def = EmbeddingDefinition.last + + expect(created_def).to be_nil + end + end + end + + describe "PUT #update" do + fab!(:embedding_definition) + + context "with valid update params" do + let(:update_attrs) { { provider: "open_ai" } } + + it "updates the model" do + put "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json", + params: { + ai_embedding: update_attrs, + } + + expect(response.status).to eq(200) + expect(embedding_definition.reload.provider).to eq(update_attrs[:provider]) + end + + it "returns a 404 if there is no model with the given Id" do + put "/admin/plugins/discourse-ai/ai-embeddings/9999999.json" + + expect(response.status).to eq(404) + end + + it "doesn't allow dimenstions to be updated" do + new_dimensions = 200 + + put "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json", + params: { + ai_embedding: { + dimensions: new_dimensions, + }, + } + + expect(response.status).to eq(200) + expect(embedding_definition.reload.dimensions).not_to eq(new_dimensions) + end + end + + context "with invalid update params" do + it "doesn't update the model" do + put "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json", + params: { + ai_embedding: { + url: "", + }, + } + + expect(response.status).to eq(422) + end + end + end + + describe "DELETE #destroy" do + fab!(:embedding_definition) + + it "destroys the embedding defitinion" do + expect { + delete "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json" + + expect(response).to have_http_status(:no_content) + }.to change(EmbeddingDefinition, :count).by(-1) + end + + it "validates the model is not in use" do + SiteSetting.ai_embeddings_selected_model = embedding_definition.id + + delete "/admin/plugins/discourse-ai/ai-embeddings/#{embedding_definition.id}.json" + + expect(response.status).to eq(409) + expect(embedding_definition.reload).to eq(embedding_definition) + end + end + + describe "GET #test" do + context "when we can generate an embedding" do + it "returns a success true flag" do + WebMock.stub_request(:post, valid_attrs[:url]).to_return(status: 200, body: [[1]].to_json) + + get "/admin/plugins/discourse-ai/ai-embeddings/test.json", + params: { + ai_embedding: valid_attrs, + } + + expect(response).to be_successful + expect(response.parsed_body["success"]).to eq(true) + end + end + + context "when we cannot generate an embedding" do + it "returns a success false flag and the error message" do + error_message = { error: "Embedding generation failed." } + + WebMock.stub_request(:post, valid_attrs[:url]).to_return( + status: 422, + body: error_message.to_json, + ) + + get "/admin/plugins/discourse-ai/ai-embeddings/test.json", + params: { + ai_embedding: valid_attrs, + } + + expect(response).to be_successful + expect(response.parsed_body["success"]).to eq(false) + expect(response.parsed_body["error"]).to eq(error_message.to_json) + end + end + end +end diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb index 96677212..4edceb27 100644 --- a/spec/requests/admin/ai_personas_controller_spec.rb +++ b/spec/requests/admin/ai_personas_controller_spec.rb @@ -8,7 +8,6 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do sign_in(admin) SiteSetting.ai_embeddings_enabled = true - SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" end describe "GET #index" do diff --git a/spec/requests/admin/rag_document_fragments_controller_spec.rb b/spec/requests/admin/rag_document_fragments_controller_spec.rb index 06b4d778..c0e7d4df 100644 --- a/spec/requests/admin/rag_document_fragments_controller_spec.rb +++ b/spec/requests/admin/rag_document_fragments_controller_spec.rb @@ -4,11 +4,12 @@ RSpec.describe DiscourseAi::Admin::RagDocumentFragmentsController do fab!(:admin) fab!(:ai_persona) + fab!(:vector_def) { Fabricate(:embedding_definition) } + before do sign_in(admin) - + SiteSetting.ai_embeddings_selected_model = vector_def.id SiteSetting.ai_embeddings_enabled = true - SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" end describe "GET #indexing_status_check" do diff --git a/spec/requests/embeddings/embeddings_controller_spec.rb b/spec/requests/embeddings/embeddings_controller_spec.rb index e79039fb..9ad04d5b 100644 --- a/spec/requests/embeddings/embeddings_controller_spec.rb +++ b/spec/requests/embeddings/embeddings_controller_spec.rb @@ -2,9 +2,11 @@ describe DiscourseAi::Embeddings::EmbeddingsController do context "when performing a topic search" do + fab!(:vector_def) { Fabricate(:open_ai_embedding_def) } + before do SiteSetting.min_search_term_length = 3 - SiteSetting.ai_embeddings_model = "text-embedding-3-small" + SiteSetting.ai_embeddings_selected_model = vector_def.id DiscourseAi::Embeddings::SemanticSearch.clear_cache_for("test") SearchIndexer.enable end @@ -31,7 +33,15 @@ describe DiscourseAi::Embeddings::EmbeddingsController do def stub_embedding(query) embedding = [0.049382] * 1536 - EmbeddingsGenerationStubs.openai_service(SiteSetting.ai_embeddings_model, query, embedding) + + EmbeddingsGenerationStubs.openai_service( + vector_def.lookup_custom_param("model_name"), + query, + embedding, + extra_args: { + dimensions: vector_def.dimensions, + }, + ) end def create_api_key(user) diff --git a/spec/shared/inference/openai_embeddings_spec.rb b/spec/shared/inference/openai_embeddings_spec.rb index 7db19a7e..0040b30f 100644 --- a/spec/shared/inference/openai_embeddings_spec.rb +++ b/spec/shared/inference/openai_embeddings_spec.rb @@ -1,10 +1,13 @@ # frozen_string_literal: true describe DiscourseAi::Inference::OpenAiEmbeddings do + let(:api_key) { "123456" } + let(:dimensions) { 1000 } + let(:model) { "text-embedding-ada-002" } + it "supports azure embeddings" do - SiteSetting.ai_openai_embeddings_url = + azure_url = "https://my-company.openai.azure.com/openai/deployments/embeddings-deployment/embeddings?api-version=2023-05-15" - SiteSetting.ai_openai_api_key = "123456" body_json = { usage: { @@ -14,28 +17,22 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do data: [{ object: "embedding", embedding: [0.0, 0.1] }], }.to_json - stub_request( - :post, - "https://my-company.openai.azure.com/openai/deployments/embeddings-deployment/embeddings?api-version=2023-05-15", - ).with( + stub_request(:post, azure_url).with( body: "{\"model\":\"text-embedding-ada-002\",\"input\":\"hello\"}", headers: { - "Api-Key" => "123456", + "Api-Key" => api_key, "Content-Type" => "application/json", }, ).to_return(status: 200, body: body_json, headers: {}) result = - DiscourseAi::Inference::OpenAiEmbeddings.instance(model: "text-embedding-ada-002").perform!( - "hello", - ) + DiscourseAi::Inference::OpenAiEmbeddings.new(azure_url, api_key, model, nil).perform!("hello") expect(result).to eq([0.0, 0.1]) end it "supports openai embeddings" do - SiteSetting.ai_openai_api_key = "123456" - + url = "https://api.openai.com/v1/embeddings" body_json = { usage: { prompt_tokens: 1, @@ -44,21 +41,20 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do data: [{ object: "embedding", embedding: [0.0, 0.1] }], }.to_json - body = { model: "text-embedding-ada-002", input: "hello", dimensions: 1000 }.to_json + body = { model: model, input: "hello", dimensions: dimensions }.to_json - stub_request(:post, "https://api.openai.com/v1/embeddings").with( + stub_request(:post, url).with( body: body, headers: { - "Authorization" => "Bearer 123456", + "Authorization" => "Bearer #{api_key}", "Content-Type" => "application/json", }, ).to_return(status: 200, body: body_json, headers: {}) result = - DiscourseAi::Inference::OpenAiEmbeddings.instance( - model: "text-embedding-ada-002", - dimensions: 1000, - ).perform!("hello") + DiscourseAi::Inference::OpenAiEmbeddings.new(url, api_key, model, dimensions).perform!( + "hello", + ) expect(result).to eq([0.0, 0.1]) end diff --git a/spec/support/embeddings_generation_stubs.rb b/spec/support/embeddings_generation_stubs.rb index 283ed7d9..48da06eb 100644 --- a/spec/support/embeddings_generation_stubs.rb +++ b/spec/support/embeddings_generation_stubs.rb @@ -2,15 +2,11 @@ class EmbeddingsGenerationStubs class << self - def discourse_service(model, string, embedding) - model = "bge-large-en-v1.5" if model == "bge-large-en" + def hugging_face_service(string, embedding) WebMock - .stub_request( - :post, - "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify", - ) - .with(body: JSON.dump({ model: model, content: string })) - .to_return(status: 200, body: JSON.dump(embedding)) + .stub_request(:post, "https://test.com/embeddings") + .with(body: JSON.dump({ inputs: string, truncate: true })) + .to_return(status: 200, body: JSON.dump([embedding])) end def openai_service(model, string, embedding, extra_args: {}) @@ -29,5 +25,12 @@ class EmbeddingsGenerationStubs .with(body: JSON.dump({ content: { parts: [{ text: string }] } })) .to_return(status: 200, body: JSON.dump({ embedding: { values: embedding } })) end + + def cloudflare_service(string, embedding) + WebMock + .stub_request(:post, "https://test.com/embeddings") + .with(body: JSON.dump({ text: [string] })) + .to_return(status: 200, body: JSON.dump({ result: { data: [embedding] } })) + end end end diff --git a/spec/system/embeddings/ai_embedding_definition_spec.rb b/spec/system/embeddings/ai_embedding_definition_spec.rb new file mode 100644 index 00000000..2c7347f5 --- /dev/null +++ b/spec/system/embeddings/ai_embedding_definition_spec.rb @@ -0,0 +1,87 @@ +# frozen_string_literal: true + +RSpec.describe "Managing Embeddings configurations", type: :system, js: true do + fab!(:admin) + let(:page_header) { PageObjects::Components::DPageHeader.new } + + before { sign_in(admin) } + + it "correctly sets defaults" do + preset = "text-embedding-3-small" + api_key = "abcd" + + visit "/admin/plugins/discourse-ai/ai-embeddings" + + find(".ai-embeddings-list-editor__new-button").click() + select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__presets") + select_kit.expand + select_kit.select_row_by_value(preset) + find(".ai-embedding-editor__next").click + find("input.ai-embedding-editor__api-key").fill_in(with: api_key) + find(".ai-embedding-editor__save").click() + + expect(page).to have_current_path("/admin/plugins/discourse-ai/ai-embeddings") + + embedding_def = EmbeddingDefinition.order(:id).last + expect(embedding_def.api_key).to eq(api_key) + + preset = EmbeddingDefinition.presets.find { |p| p[:preset_id] == preset } + + expect(embedding_def.display_name).to eq(preset[:display_name]) + expect(embedding_def.url).to eq(preset[:url]) + expect(embedding_def.tokenizer_class).to eq(preset[:tokenizer_class]) + expect(embedding_def.dimensions).to eq(preset[:dimensions]) + expect(embedding_def.max_sequence_length).to eq(preset[:max_sequence_length]) + expect(embedding_def.pg_function).to eq(preset[:pg_function]) + expect(embedding_def.provider).to eq(preset[:provider]) + expect(embedding_def.provider_params.symbolize_keys).to eq(preset[:provider_params]) + end + + it "supports manual config" do + api_key = "abcd" + + visit "/admin/plugins/discourse-ai/ai-embeddings" + + find(".ai-embeddings-list-editor__new-button").click() + select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__presets") + select_kit.expand + select_kit.select_row_by_value("manual") + find(".ai-embedding-editor__next").click + + find("input.ai-embedding-editor__display-name").fill_in(with: "OpenAI's text-embedding-3-small") + + select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__provider") + select_kit.expand + select_kit.select_row_by_value(EmbeddingDefinition::OPEN_AI) + + find("input.ai-embedding-editor__url").fill_in(with: "https://api.openai.com/v1/embeddings") + find("input.ai-embedding-editor__api-key").fill_in(with: api_key) + + select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__tokenizer") + select_kit.expand + select_kit.select_row_by_value("DiscourseAi::Tokenizer::OpenAiTokenizer") + + find("input.ai-embedding-editor__dimensions").fill_in(with: 1536) + find("input.ai-embedding-editor__max_sequence_length").fill_in(with: 8191) + + select_kit = PageObjects::Components::SelectKit.new(".ai-embedding-editor__distance_functions") + select_kit.expand + select_kit.select_row_by_value("<=>") + find(".ai-embedding-editor__save").click() + + expect(page).to have_current_path("/admin/plugins/discourse-ai/ai-embeddings") + + embedding_def = EmbeddingDefinition.order(:id).last + expect(embedding_def.api_key).to eq(api_key) + + preset = EmbeddingDefinition.presets.find { |p| p[:preset_id] == "text-embedding-3-small" } + + expect(embedding_def.display_name).to eq(preset[:display_name]) + expect(embedding_def.url).to eq(preset[:url]) + expect(embedding_def.tokenizer_class).to eq(preset[:tokenizer_class]) + expect(embedding_def.dimensions).to eq(preset[:dimensions]) + expect(embedding_def.max_sequence_length).to eq(preset[:max_sequence_length]) + expect(embedding_def.pg_function).to eq(preset[:pg_function]) + expect(embedding_def.provider).to eq(preset[:provider]) + end +end diff --git a/spec/system/embeddings/semantic_search_spec.rb b/spec/system/embeddings/semantic_search_spec.rb index fcf124c3..4805989e 100644 --- a/spec/system/embeddings/semantic_search_spec.rb +++ b/spec/system/embeddings/semantic_search_spec.rb @@ -10,7 +10,6 @@ RSpec.describe "AI Composer helper", type: :system, js: true do fab!(:post) { Fabricate(:post, topic: topic, raw: "Apple pie is a delicious dessert to eat") } before do - SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com" prompt = DiscourseAi::Embeddings::HydeGenerators::OpenAi.new.prompt(query) OpenAiCompletionsInferenceStubs.stub_response( prompt, @@ -21,11 +20,7 @@ RSpec.describe "AI Composer helper", type: :system, js: true do ) hyde_embedding = [0.049382, 0.9999] - EmbeddingsGenerationStubs.discourse_service( - SiteSetting.ai_embeddings_model, - hypothetical_post, - hyde_embedding, - ) + EmbeddingsGenerationStubs.hugging_face_service(hypothetical_post, hyde_embedding) SearchIndexer.enable SearchIndexer.index(topic, force: true)