diff --git a/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb b/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb
index 7000aea4..c3f21caf 100644
--- a/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb
+++ b/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb
@@ -111,6 +111,8 @@ module DiscourseAi
:url,
:api_key,
:tokenizer_class,
+ :embed_prompt,
+ :search_prompt,
)
extra_field_names = EmbeddingDefinition.provider_params.dig(permitted[:provider]&.to_sym)
diff --git a/app/models/embedding_definition.rb b/app/models/embedding_definition.rb
index 48fdf3a5..8f0bd5c2 100644
--- a/app/models/embedding_definition.rb
+++ b/app/models/embedding_definition.rb
@@ -42,6 +42,7 @@ class EmbeddingDefinition < ActiveRecord::Base
pg_function: "<#>",
tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer",
provider: HUGGING_FACE,
+ search_prompt: "Represent this sentence for searching relevant passages:",
},
{
preset_id: "bge-m3",
@@ -228,4 +229,6 @@ end
# provider_params :jsonb
# created_at :datetime not null
# updated_at :datetime not null
+# embed_prompt :string default(""), not null
+# search_prompt :string default(""), not null
#
diff --git a/app/serializers/ai_embedding_definition_serializer.rb b/app/serializers/ai_embedding_definition_serializer.rb
index 53a75d47..a15adcf2 100644
--- a/app/serializers/ai_embedding_definition_serializer.rb
+++ b/app/serializers/ai_embedding_definition_serializer.rb
@@ -13,6 +13,8 @@ class AiEmbeddingDefinitionSerializer < ApplicationSerializer
:api_key,
:seeded,
:tokenizer_class,
+ :embed_prompt,
+ :search_prompt,
:provider_params
def api_key
diff --git a/assets/javascripts/discourse/admin/models/ai-embedding.js b/assets/javascripts/discourse/admin/models/ai-embedding.js
index b1896afa..ea312f25 100644
--- a/assets/javascripts/discourse/admin/models/ai-embedding.js
+++ b/assets/javascripts/discourse/admin/models/ai-embedding.js
@@ -14,7 +14,9 @@ export default class AiEmbedding extends RestModel {
"api_key",
"max_sequence_length",
"provider_params",
- "pg_function"
+ "pg_function",
+ "embed_prompt",
+ "search_prompt"
);
}
diff --git a/assets/javascripts/discourse/components/ai-embedding-editor.gjs b/assets/javascripts/discourse/components/ai-embedding-editor.gjs
index f6a98b83..22a6a4dc 100644
--- a/assets/javascripts/discourse/components/ai-embedding-editor.gjs
+++ b/assets/javascripts/discourse/components/ai-embedding-editor.gjs
@@ -290,6 +290,24 @@ export default class AiEmbeddingEditor extends Component {
{{/if}}
+
+
+
+
+
+
+
+
+
+
: "Negative inner product (<#>)"
diff --git a/db/migrate/20250121162520_configurable_embeddings_prefixes.rb b/db/migrate/20250121162520_configurable_embeddings_prefixes.rb
new file mode 100644
index 00000000..2064ed85
--- /dev/null
+++ b/db/migrate/20250121162520_configurable_embeddings_prefixes.rb
@@ -0,0 +1,18 @@
+# frozen_string_literal: true
+class ConfigurableEmbeddingsPrefixes < ActiveRecord::Migration[7.2]
+ def up
+ add_column :embedding_definitions, :embed_prompt, :string, null: false, default: ""
+ add_column :embedding_definitions, :search_prompt, :string, null: false, default: ""
+
+ # 4 is bge-large-en. Default model and the only one using this so far.
+ execute <<~SQL
+ UPDATE embedding_definitions
+ SET search_prompt='Represent this sentence for searching relevant passages:'
+ WHERE id = 4
+ SQL
+ end
+
+ def down
+ raise ActiveRecord::IrreversibleMigration
+ end
+end
diff --git a/lib/embeddings/strategies/truncation.rb b/lib/embeddings/strategies/truncation.rb
index cbe463c0..f35fd92c 100644
--- a/lib/embeddings/strategies/truncation.rb
+++ b/lib/embeddings/strategies/truncation.rb
@@ -15,23 +15,28 @@ module DiscourseAi
def prepare_target_text(target, vdef)
max_length = vdef.max_sequence_length - 2
- case target
- when Topic
- topic_truncation(target, vdef.tokenizer, max_length)
- when Post
- post_truncation(target, vdef.tokenizer, max_length)
- when RagDocumentFragment
- vdef.tokenizer.truncate(target.fragment, max_length)
- else
- raise ArgumentError, "Invalid target type"
- end
+ prepared_text =
+ case target
+ when Topic
+ topic_truncation(target, vdef.tokenizer, max_length)
+ when Post
+ post_truncation(target, vdef.tokenizer, max_length)
+ when RagDocumentFragment
+ vdef.tokenizer.truncate(target.fragment, max_length)
+ else
+ raise ArgumentError, "Invalid target type"
+ end
+
+ return prepared_text if vdef.embed_prompt.blank?
+
+ [vdef.embed_prompt, prepared_text].join(" ")
end
def prepare_query_text(text, vdef, asymetric: false)
- qtext = asymetric ? "#{vdef.asymmetric_query_prefix} #{text}" : text
+ qtext = asymetric ? "#{vdef.search_prompt} #{text}" : text
max_length = vdef.max_sequence_length - 2
- vdef.tokenizer.truncate(text, max_length)
+ vdef.tokenizer.truncate(qtext, max_length)
end
private
diff --git a/spec/lib/modules/embeddings/strategies/truncation_spec.rb b/spec/lib/modules/embeddings/strategies/truncation_spec.rb
index 0792e5d4..9f22506c 100644
--- a/spec/lib/modules/embeddings/strategies/truncation_spec.rb
+++ b/spec/lib/modules/embeddings/strategies/truncation_spec.rb
@@ -3,29 +3,51 @@
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
subject(:truncation) { described_class.new }
+ fab!(:open_ai_embedding_def)
+ let(:prefix) { "I come first:" }
+
+ describe "#prepare_target_text" do
+ before { SiteSetting.max_post_length = 100_000 }
+
+ fab!(:topic)
+ fab!(:post) do
+ Fabricate(:post, topic: topic, raw: "Baby, bird, bird, bird\nBird is the word\n" * 500)
+ end
+ fab!(:post) do
+ Fabricate(
+ :post,
+ topic: topic,
+ raw: "Don't you know about the bird?\nEverybody knows that the bird is a word\n" * 400,
+ )
+ end
+ fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
+ fab!(:open_ai_embedding_def)
+
+ it "truncates a topic" do
+ prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def)
+
+ expect(open_ai_embedding_def.tokenizer.size(prepared_text)).to be <=
+ open_ai_embedding_def.max_sequence_length
+ end
+
+ it "includes embed prefix" do
+ open_ai_embedding_def.update!(embed_prompt: prefix)
+
+ prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def)
+
+ expect(prepared_text.starts_with?(prefix)).to eq(true)
+ end
+ end
+
describe "#prepare_query_text" do
- context "when using vector def from OpenAI" do
- before { SiteSetting.max_post_length = 100_000 }
+ context "when search is asymetric" do
+ it "includes search prefix" do
+ open_ai_embedding_def.update!(search_prompt: prefix)
- fab!(:topic)
- fab!(:post) do
- Fabricate(:post, topic: topic, raw: "Baby, bird, bird, bird\nBird is the word\n" * 500)
- end
- fab!(:post) do
- Fabricate(
- :post,
- topic: topic,
- raw: "Don't you know about the bird?\nEverybody knows that the bird is a word\n" * 400,
- )
- end
- fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
- fab!(:open_ai_embedding_def)
+ prepared_query_text =
+ truncation.prepare_query_text("searching", open_ai_embedding_def, asymetric: true)
- it "truncates a topic" do
- prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def)
-
- expect(open_ai_embedding_def.tokenizer.size(prepared_text)).to be <=
- open_ai_embedding_def.max_sequence_length
+ expect(prepared_query_text.starts_with?(prefix)).to eq(true)
end
end
end
diff --git a/spec/requests/admin/ai_embeddings_controller_spec.rb b/spec/requests/admin/ai_embeddings_controller_spec.rb
index 97f38ac6..ac3f0c23 100644
--- a/spec/requests/admin/ai_embeddings_controller_spec.rb
+++ b/spec/requests/admin/ai_embeddings_controller_spec.rb
@@ -15,6 +15,8 @@ RSpec.describe DiscourseAi::Admin::AiEmbeddingsController do
url: "https://test.com/api/v1/embeddings",
api_key: "test",
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
+ embed_prompt: "I come first:",
+ search_prompt: "prefix for search",
}
end
@@ -27,6 +29,8 @@ RSpec.describe DiscourseAi::Admin::AiEmbeddingsController do
expect(response.status).to eq(201)
expect(created_def.display_name).to eq(valid_attrs[:display_name])
+ expect(created_def.embed_prompt).to eq(valid_attrs[:embed_prompt])
+ expect(created_def.search_prompt).to eq(valid_attrs[:search_prompt])
end
it "stores provider-specific config params" do
diff --git a/spec/system/embeddings/ai_embedding_definition_spec.rb b/spec/system/embeddings/ai_embedding_definition_spec.rb
index 2c7347f5..d7ea406f 100644
--- a/spec/system/embeddings/ai_embedding_definition_spec.rb
+++ b/spec/system/embeddings/ai_embedding_definition_spec.rb
@@ -61,6 +61,11 @@ RSpec.describe "Managing Embeddings configurations", type: :system, js: true do
select_kit.expand
select_kit.select_row_by_value("DiscourseAi::Tokenizer::OpenAiTokenizer")
+ embed_prefix = "On creation:"
+ search_prefix = "On search:"
+ find("input.ai-embedding-editor__embed_prompt").fill_in(with: embed_prefix)
+ find("input.ai-embedding-editor__search_prompt").fill_in(with: search_prefix)
+
find("input.ai-embedding-editor__dimensions").fill_in(with: 1536)
find("input.ai-embedding-editor__max_sequence_length").fill_in(with: 8191)
@@ -83,5 +88,7 @@ RSpec.describe "Managing Embeddings configurations", type: :system, js: true do
expect(embedding_def.max_sequence_length).to eq(preset[:max_sequence_length])
expect(embedding_def.pg_function).to eq(preset[:pg_function])
expect(embedding_def.provider).to eq(preset[:provider])
+ expect(embedding_def.embed_prompt).to eq(embed_prefix)
+ expect(embedding_def.search_prompt).to eq(search_prefix)
end
end