FEATURE: Formalize support for matryoshka dimensions. (#1083)
We have a flag to signal we are shortening the embeddings of a model. Only used in Open AI's text-embedding-3-*, but we plan to use it for other services.
This commit is contained in:
parent
654f90f1cd
commit
e2e753d73c
|
@ -113,6 +113,7 @@ module DiscourseAi
|
|||
:tokenizer_class,
|
||||
:embed_prompt,
|
||||
:search_prompt,
|
||||
:matryoshka_dimensions,
|
||||
)
|
||||
|
||||
extra_field_names = EmbeddingDefinition.provider_params.dig(permitted[:provider]&.to_sym)
|
||||
|
|
|
@ -84,6 +84,7 @@ class EmbeddingDefinition < ActiveRecord::Base
|
|||
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||
url: "https://api.openai.com/v1/embeddings",
|
||||
provider: OPEN_AI,
|
||||
matryoshka_dimensions: true,
|
||||
provider_params: {
|
||||
model_name: "text-embedding-3-large",
|
||||
},
|
||||
|
@ -97,6 +98,7 @@ class EmbeddingDefinition < ActiveRecord::Base
|
|||
tokenizer_class: "DiscourseAi::Tokenizer::OpenAiTokenizer",
|
||||
url: "https://api.openai.com/v1/embeddings",
|
||||
provider: OPEN_AI,
|
||||
matryoshka_dimensions: true,
|
||||
provider_params: {
|
||||
model_name: "text-embedding-3-small",
|
||||
},
|
||||
|
@ -200,9 +202,7 @@ class EmbeddingDefinition < ActiveRecord::Base
|
|||
end
|
||||
|
||||
def open_ai_client
|
||||
model_name = lookup_custom_param("model_name")
|
||||
can_shorten_dimensions = %w[text-embedding-3-small text-embedding-3-large].include?(model_name)
|
||||
client_dimensions = can_shorten_dimensions ? dimensions : nil
|
||||
client_dimensions = matryoshka_dimensions ? dimensions : nil
|
||||
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.new(
|
||||
endpoint_url,
|
||||
|
@ -221,20 +221,21 @@ end
|
|||
#
|
||||
# Table name: embedding_definitions
|
||||
#
|
||||
# id :bigint not null, primary key
|
||||
# display_name :string not null
|
||||
# dimensions :integer not null
|
||||
# max_sequence_length :integer not null
|
||||
# version :integer default(1), not null
|
||||
# pg_function :string not null
|
||||
# provider :string not null
|
||||
# tokenizer_class :string not null
|
||||
# url :string not null
|
||||
# api_key :string
|
||||
# seeded :boolean default(FALSE), not null
|
||||
# provider_params :jsonb
|
||||
# created_at :datetime not null
|
||||
# updated_at :datetime not null
|
||||
# embed_prompt :string default(""), not null
|
||||
# search_prompt :string default(""), not null
|
||||
# id :bigint not null, primary key
|
||||
# display_name :string not null
|
||||
# dimensions :integer not null
|
||||
# max_sequence_length :integer not null
|
||||
# version :integer default(1), not null
|
||||
# pg_function :string not null
|
||||
# provider :string not null
|
||||
# tokenizer_class :string not null
|
||||
# url :string not null
|
||||
# api_key :string
|
||||
# seeded :boolean default(FALSE), not null
|
||||
# provider_params :jsonb
|
||||
# created_at :datetime not null
|
||||
# updated_at :datetime not null
|
||||
# embed_prompt :string default(""), not null
|
||||
# search_prompt :string default(""), not null
|
||||
# matryoshka_dimensions :boolean default(FALSE), not null
|
||||
#
|
||||
|
|
|
@ -15,6 +15,7 @@ class AiEmbeddingDefinitionSerializer < ApplicationSerializer
|
|||
:tokenizer_class,
|
||||
:embed_prompt,
|
||||
:search_prompt,
|
||||
:matryoshka_dimensions,
|
||||
:provider_params
|
||||
|
||||
def api_key
|
||||
|
|
|
@ -16,7 +16,8 @@ export default class AiEmbedding extends RestModel {
|
|||
"provider_params",
|
||||
"pg_function",
|
||||
"embed_prompt",
|
||||
"search_prompt"
|
||||
"search_prompt",
|
||||
"matryoshka_dimensions"
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -290,6 +290,16 @@ export default class AiEmbeddingEditor extends Component {
|
|||
{{/if}}
|
||||
</div>
|
||||
|
||||
<div class="control-group ai-embedding-editor__matryoshka_dimensions">
|
||||
<Input
|
||||
@type="checkbox"
|
||||
@checked={{this.editingModel.matryoshka_dimensions}}
|
||||
/>
|
||||
<label>{{i18n
|
||||
"discourse_ai.embeddings.matryoshka_dimensions"
|
||||
}}</label>
|
||||
</div>
|
||||
|
||||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.embeddings.embed_prompt"}}</label>
|
||||
<Input
|
||||
|
|
|
@ -23,4 +23,9 @@
|
|||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
&__matryoshka_dimensions {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -532,6 +532,7 @@ en:
|
|||
max_sequence_length: "Sequence length"
|
||||
embed_prompt: "Embed prompt"
|
||||
search_prompt: "Search prompt"
|
||||
matryoshka_dimensions: "Matryoshka dimensions"
|
||||
|
||||
distance_function: "Distance function"
|
||||
distance_functions:
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# frozen_string_literal: true
|
||||
class MatryoshkaDimensionsSupport < ActiveRecord::Migration[7.2]
|
||||
def change
|
||||
add_column :embedding_definitions, :matryoshka_dimensions, :boolean, null: false, default: false
|
||||
|
||||
execute <<~SQL
|
||||
UPDATE embedding_definitions
|
||||
SET matryoshka_dimensions = TRUE
|
||||
WHERE
|
||||
provider = 'open_ai' AND
|
||||
provider_params IS NOT NULL AND
|
||||
(
|
||||
(provider_params->>'model_name') = 'text-embedding-3-large' OR
|
||||
(provider_params->>'model_name') = 'text-embedding-3-small'
|
||||
)
|
||||
SQL
|
||||
end
|
||||
|
||||
def down
|
||||
raise ActiveRecord::IrreversibleMigration
|
||||
end
|
||||
end
|
|
@ -99,15 +99,10 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
|
|||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
|
||||
context "when working with models that support shortening embeddings" do
|
||||
context "when matryoshka_dimensions is enabled" do
|
||||
it "passes the dimensions param" do
|
||||
shorter_dimensions = 10
|
||||
vdef.update!(
|
||||
dimensions: shorter_dimensions,
|
||||
provider_params: {
|
||||
model_name: "text-embedding-3-small",
|
||||
},
|
||||
)
|
||||
vdef.update!(dimensions: shorter_dimensions, matryoshka_dimensions: true)
|
||||
text = "This is a piece of text"
|
||||
short_expected_embedding = [0.0038493] * shorter_dimensions
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ RSpec.describe DiscourseAi::Admin::AiEmbeddingsController do
|
|||
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
|
||||
embed_prompt: "I come first:",
|
||||
search_prompt: "prefix for search",
|
||||
matryoshka_dimensions: true,
|
||||
}
|
||||
end
|
||||
|
||||
|
@ -31,6 +32,7 @@ RSpec.describe DiscourseAi::Admin::AiEmbeddingsController do
|
|||
expect(created_def.display_name).to eq(valid_attrs[:display_name])
|
||||
expect(created_def.embed_prompt).to eq(valid_attrs[:embed_prompt])
|
||||
expect(created_def.search_prompt).to eq(valid_attrs[:search_prompt])
|
||||
expect(created_def.matryoshka_dimensions).to eq(true)
|
||||
end
|
||||
|
||||
it "stores provider-specific config params" do
|
||||
|
|
Loading…
Reference in New Issue