From e2e753d73c303e7239cfff280f91a1fab0317a46 Mon Sep 17 00:00:00 2001
From: Roman Rizzi <roman@discourse.org>
Date: Wed, 22 Jan 2025 11:26:46 -0300
Subject: [PATCH] FEATURE: Formalize support for matryoshka dimensions. (#1083)

We have a flag to signal we are shortening the embeddings of a model.
Only used in Open AI's text-embedding-3-*, but we plan to use it for other services.
---
 .../admin/ai_embeddings_controller.rb         |  1 +
 app/models/embedding_definition.rb            | 39 ++++++++++---------
 .../ai_embedding_definition_serializer.rb     |  1 +
 .../discourse/admin/models/ai-embedding.js    |  3 +-
 .../components/ai-embedding-editor.gjs        | 10 +++++
 .../common/ai-embedding-editor.scss           |  5 +++
 config/locales/client.en.yml                  |  1 +
 ...122131007_matryoshka_dimensions_support.rb | 22 +++++++++++
 spec/lib/modules/embeddings/vector_spec.rb    |  9 +----
 .../admin/ai_embeddings_controller_spec.rb    |  2 +
 10 files changed, 66 insertions(+), 27 deletions(-)
 create mode 100644 db/migrate/20250122131007_matryoshka_dimensions_support.rb

diff --git a/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb b/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb
index c3f21caf..66a9c7f3 100644
--- a/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb
+++ b/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb
@@ -113,6 +113,7 @@ module DiscourseAi
             :tokenizer_class,
             :embed_prompt,
             :search_prompt,
+            :matryoshka_dimensions,
           )
 
         extra_field_names = EmbeddingDefinition.provider_params.dig(permitted[:provider]&.to_sym)
diff --git a/app/models/embedding_definition.rb b/app/models/embedding_definition.rb
index 0c7cbda4..ba6dd362 100644
--- a/app/models/embedding_definition.rb
+++ b/app/models/embedding_definition.rb
@@ -84,6 +84,7 @@ class EmbeddingDefinition < ActiveRecord::Base
               tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
               url: "https://api.openai.com/v1/embeddings",
               provider: OPEN_AI,
+              matryoshka_dimensions: true,
               provider_params: {
                 model_name: "text-embedding-3-large",
               },
@@ -97,6 +98,7 @@ class EmbeddingDefinition < ActiveRecord::Base
               tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
               url: "https://api.openai.com/v1/embeddings",
               provider: OPEN_AI,
+              matryoshka_dimensions: true,
               provider_params: {
                 model_name: "text-embedding-3-small",
               },
@@ -200,9 +202,7 @@ class EmbeddingDefinition < ActiveRecord::Base
   end
 
   def open_ai_client
-    model_name = lookup_custom_param("model_name")
-    can_shorten_dimensions = %w[text-embedding-3-small text-embedding-3-large].include?(model_name)
-    client_dimensions = can_shorten_dimensions ? dimensions : nil
+    client_dimensions = matryoshka_dimensions ? dimensions : nil
 
     DiscourseAi::Inference::OpenAiEmbeddings.new(
       endpoint_url,
@@ -221,20 +221,21 @@ end
 #
 # Table name: embedding_definitions
 #
-#  id                  :bigint           not null, primary key
-#  display_name        :string           not null
-#  dimensions          :integer          not null
-#  max_sequence_length :integer          not null
-#  version             :integer          default(1), not null
-#  pg_function         :string           not null
-#  provider            :string           not null
-#  tokenizer_class     :string           not null
-#  url                 :string           not null
-#  api_key             :string
-#  seeded              :boolean          default(FALSE), not null
-#  provider_params     :jsonb
-#  created_at          :datetime         not null
-#  updated_at          :datetime         not null
-#  embed_prompt        :string           default(""), not null
-#  search_prompt       :string           default(""), not null
+#  id                    :bigint           not null, primary key
+#  display_name          :string           not null
+#  dimensions            :integer          not null
+#  max_sequence_length   :integer          not null
+#  version               :integer          default(1), not null
+#  pg_function           :string           not null
+#  provider              :string           not null
+#  tokenizer_class       :string           not null
+#  url                   :string           not null
+#  api_key               :string
+#  seeded                :boolean          default(FALSE), not null
+#  provider_params       :jsonb
+#  created_at            :datetime         not null
+#  updated_at            :datetime         not null
+#  embed_prompt          :string           default(""), not null
+#  search_prompt         :string           default(""), not null
+#  matryoshka_dimensions :boolean          default(FALSE), not null
 #
diff --git a/app/serializers/ai_embedding_definition_serializer.rb b/app/serializers/ai_embedding_definition_serializer.rb
index a15adcf2..8c5b17b3 100644
--- a/app/serializers/ai_embedding_definition_serializer.rb
+++ b/app/serializers/ai_embedding_definition_serializer.rb
@@ -15,6 +15,7 @@ class AiEmbeddingDefinitionSerializer < ApplicationSerializer
              :tokenizer_class,
              :embed_prompt,
              :search_prompt,
+             :matryoshka_dimensions,
              :provider_params
 
   def api_key
diff --git a/assets/javascripts/discourse/admin/models/ai-embedding.js b/assets/javascripts/discourse/admin/models/ai-embedding.js
index ea312f25..d3d620fe 100644
--- a/assets/javascripts/discourse/admin/models/ai-embedding.js
+++ b/assets/javascripts/discourse/admin/models/ai-embedding.js
@@ -16,7 +16,8 @@ export default class AiEmbedding extends RestModel {
       "provider_params",
       "pg_function",
       "embed_prompt",
-      "search_prompt"
+      "search_prompt",
+      "matryoshka_dimensions"
     );
   }
 
diff --git a/assets/javascripts/discourse/components/ai-embedding-editor.gjs b/assets/javascripts/discourse/components/ai-embedding-editor.gjs
index 22a6a4dc..2ba63c92 100644
--- a/assets/javascripts/discourse/components/ai-embedding-editor.gjs
+++ b/assets/javascripts/discourse/components/ai-embedding-editor.gjs
@@ -290,6 +290,16 @@ export default class AiEmbeddingEditor extends Component {
           {{/if}}
         </div>
 
+        <div class="control-group ai-embedding-editor__matryoshka_dimensions">
+          <Input
+            @type="checkbox"
+            @checked={{this.editingModel.matryoshka_dimensions}}
+          />
+          <label>{{i18n
+              "discourse_ai.embeddings.matryoshka_dimensions"
+            }}</label>
+        </div>
+
         <div class="control-group">
           <label>{{i18n "discourse_ai.embeddings.embed_prompt"}}</label>
           <Input
diff --git a/assets/stylesheets/modules/embeddings/common/ai-embedding-editor.scss b/assets/stylesheets/modules/embeddings/common/ai-embedding-editor.scss
index 29af6bcd..a735652b 100644
--- a/assets/stylesheets/modules/embeddings/common/ai-embedding-editor.scss
+++ b/assets/stylesheets/modules/embeddings/common/ai-embedding-editor.scss
@@ -23,4 +23,9 @@
     display: flex;
     align-items: center;
   }
+
+  &__matryoshka_dimensions {
+    display: flex;
+    align-items: flex-start;
+  }
 }
diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml
index 4b4318f6..22f1c0ee 100644
--- a/config/locales/client.en.yml
+++ b/config/locales/client.en.yml
@@ -532,6 +532,7 @@ en:
         max_sequence_length: "Sequence length"
         embed_prompt: "Embed prompt"
         search_prompt: "Search prompt"
+        matryoshka_dimensions: "Matryoshka dimensions"
 
         distance_function: "Distance function"
         distance_functions:
diff --git a/db/migrate/20250122131007_matryoshka_dimensions_support.rb b/db/migrate/20250122131007_matryoshka_dimensions_support.rb
new file mode 100644
index 00000000..a8a38174
--- /dev/null
+++ b/db/migrate/20250122131007_matryoshka_dimensions_support.rb
@@ -0,0 +1,22 @@
+# frozen_string_literal: true
+class MatryoshkaDimensionsSupport < ActiveRecord::Migration[7.2]
+  def change
+    add_column :embedding_definitions, :matryoshka_dimensions, :boolean, null: false, default: false
+
+    execute <<~SQL
+      UPDATE embedding_definitions
+      SET matryoshka_dimensions = TRUE
+      WHERE 
+        provider = 'open_ai' AND 
+        provider_params IS NOT NULL AND
+        (
+          (provider_params->>'model_name') = 'text-embedding-3-large' OR
+          (provider_params->>'model_name') = 'text-embedding-3-small'
+        )
+    SQL
+  end
+
+  def down
+    raise ActiveRecord::IrreversibleMigration
+  end
+end
diff --git a/spec/lib/modules/embeddings/vector_spec.rb b/spec/lib/modules/embeddings/vector_spec.rb
index b69073c3..d2bc4bbc 100644
--- a/spec/lib/modules/embeddings/vector_spec.rb
+++ b/spec/lib/modules/embeddings/vector_spec.rb
@@ -99,15 +99,10 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
 
     it_behaves_like "generates and store embeddings using a vector definition"
 
-    context "when working with models that support shortening embeddings" do
+    context "when matryoshka_dimensions is enabled" do
       it "passes the dimensions param" do
         shorter_dimensions = 10
-        vdef.update!(
-          dimensions: shorter_dimensions,
-          provider_params: {
-            model_name: "text-embedding-3-small",
-          },
-        )
+        vdef.update!(dimensions: shorter_dimensions, matryoshka_dimensions: true)
         text = "This is a piece of text"
         short_expected_embedding = [0.0038493] * shorter_dimensions
 
diff --git a/spec/requests/admin/ai_embeddings_controller_spec.rb b/spec/requests/admin/ai_embeddings_controller_spec.rb
index ac3f0c23..40106b5d 100644
--- a/spec/requests/admin/ai_embeddings_controller_spec.rb
+++ b/spec/requests/admin/ai_embeddings_controller_spec.rb
@@ -17,6 +17,7 @@ RSpec.describe DiscourseAi::Admin::AiEmbeddingsController do
       tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
       embed_prompt: "I come first:",
       search_prompt: "prefix for search",
+      matryoshka_dimensions: true,
     }
   end
 
@@ -31,6 +32,7 @@ RSpec.describe DiscourseAi::Admin::AiEmbeddingsController do
         expect(created_def.display_name).to eq(valid_attrs[:display_name])
         expect(created_def.embed_prompt).to eq(valid_attrs[:embed_prompt])
         expect(created_def.search_prompt).to eq(valid_attrs[:search_prompt])
+        expect(created_def.matryoshka_dimensions).to eq(true)
       end
 
       it "stores provider-specific config params" do