diff --git a/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb b/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb index 7000aea4..c3f21caf 100644 --- a/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_embeddings_controller.rb @@ -111,6 +111,8 @@ module DiscourseAi :url, :api_key, :tokenizer_class, + :embed_prompt, + :search_prompt, ) extra_field_names = EmbeddingDefinition.provider_params.dig(permitted[:provider]&.to_sym) diff --git a/app/models/embedding_definition.rb b/app/models/embedding_definition.rb index 48fdf3a5..8f0bd5c2 100644 --- a/app/models/embedding_definition.rb +++ b/app/models/embedding_definition.rb @@ -42,6 +42,7 @@ class EmbeddingDefinition < ActiveRecord::Base pg_function: "<#>", tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer", provider: HUGGING_FACE, + search_prompt: "Represent this sentence for searching relevant passages:", }, { preset_id: "bge-m3", @@ -228,4 +229,6 @@ end # provider_params :jsonb # created_at :datetime not null # updated_at :datetime not null +# embed_prompt :string default(""), not null +# search_prompt :string default(""), not null # diff --git a/app/serializers/ai_embedding_definition_serializer.rb b/app/serializers/ai_embedding_definition_serializer.rb index 53a75d47..a15adcf2 100644 --- a/app/serializers/ai_embedding_definition_serializer.rb +++ b/app/serializers/ai_embedding_definition_serializer.rb @@ -13,6 +13,8 @@ class AiEmbeddingDefinitionSerializer < ApplicationSerializer :api_key, :seeded, :tokenizer_class, + :embed_prompt, + :search_prompt, :provider_params def api_key diff --git a/assets/javascripts/discourse/admin/models/ai-embedding.js b/assets/javascripts/discourse/admin/models/ai-embedding.js index b1896afa..ea312f25 100644 --- a/assets/javascripts/discourse/admin/models/ai-embedding.js +++ b/assets/javascripts/discourse/admin/models/ai-embedding.js @@ -14,7 +14,9 @@ export default class AiEmbedding extends RestModel { "api_key", "max_sequence_length", "provider_params", - "pg_function" + "pg_function", + "embed_prompt", + "search_prompt" ); } diff --git a/assets/javascripts/discourse/components/ai-embedding-editor.gjs b/assets/javascripts/discourse/components/ai-embedding-editor.gjs index f6a98b83..22a6a4dc 100644 --- a/assets/javascripts/discourse/components/ai-embedding-editor.gjs +++ b/assets/javascripts/discourse/components/ai-embedding-editor.gjs @@ -290,6 +290,24 @@ export default class AiEmbeddingEditor extends Component { {{/if}} +
+ + +
+ +
+ + +
+
: "Negative inner product (<#>)" diff --git a/db/migrate/20250121162520_configurable_embeddings_prefixes.rb b/db/migrate/20250121162520_configurable_embeddings_prefixes.rb new file mode 100644 index 00000000..2064ed85 --- /dev/null +++ b/db/migrate/20250121162520_configurable_embeddings_prefixes.rb @@ -0,0 +1,18 @@ +# frozen_string_literal: true +class ConfigurableEmbeddingsPrefixes < ActiveRecord::Migration[7.2] + def up + add_column :embedding_definitions, :embed_prompt, :string, null: false, default: "" + add_column :embedding_definitions, :search_prompt, :string, null: false, default: "" + + # 4 is bge-large-en. Default model and the only one using this so far. + execute <<~SQL + UPDATE embedding_definitions + SET search_prompt='Represent this sentence for searching relevant passages:' + WHERE id = 4 + SQL + end + + def down + raise ActiveRecord::IrreversibleMigration + end +end diff --git a/lib/embeddings/strategies/truncation.rb b/lib/embeddings/strategies/truncation.rb index cbe463c0..f35fd92c 100644 --- a/lib/embeddings/strategies/truncation.rb +++ b/lib/embeddings/strategies/truncation.rb @@ -15,23 +15,28 @@ module DiscourseAi def prepare_target_text(target, vdef) max_length = vdef.max_sequence_length - 2 - case target - when Topic - topic_truncation(target, vdef.tokenizer, max_length) - when Post - post_truncation(target, vdef.tokenizer, max_length) - when RagDocumentFragment - vdef.tokenizer.truncate(target.fragment, max_length) - else - raise ArgumentError, "Invalid target type" - end + prepared_text = + case target + when Topic + topic_truncation(target, vdef.tokenizer, max_length) + when Post + post_truncation(target, vdef.tokenizer, max_length) + when RagDocumentFragment + vdef.tokenizer.truncate(target.fragment, max_length) + else + raise ArgumentError, "Invalid target type" + end + + return prepared_text if vdef.embed_prompt.blank? + + [vdef.embed_prompt, prepared_text].join(" ") end def prepare_query_text(text, vdef, asymetric: false) - qtext = asymetric ? "#{vdef.asymmetric_query_prefix} #{text}" : text + qtext = asymetric ? "#{vdef.search_prompt} #{text}" : text max_length = vdef.max_sequence_length - 2 - vdef.tokenizer.truncate(text, max_length) + vdef.tokenizer.truncate(qtext, max_length) end private diff --git a/spec/lib/modules/embeddings/strategies/truncation_spec.rb b/spec/lib/modules/embeddings/strategies/truncation_spec.rb index 0792e5d4..9f22506c 100644 --- a/spec/lib/modules/embeddings/strategies/truncation_spec.rb +++ b/spec/lib/modules/embeddings/strategies/truncation_spec.rb @@ -3,29 +3,51 @@ RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do subject(:truncation) { described_class.new } + fab!(:open_ai_embedding_def) + let(:prefix) { "I come first:" } + + describe "#prepare_target_text" do + before { SiteSetting.max_post_length = 100_000 } + + fab!(:topic) + fab!(:post) do + Fabricate(:post, topic: topic, raw: "Baby, bird, bird, bird\nBird is the word\n" * 500) + end + fab!(:post) do + Fabricate( + :post, + topic: topic, + raw: "Don't you know about the bird?\nEverybody knows that the bird is a word\n" * 400, + ) + end + fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) } + fab!(:open_ai_embedding_def) + + it "truncates a topic" do + prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def) + + expect(open_ai_embedding_def.tokenizer.size(prepared_text)).to be <= + open_ai_embedding_def.max_sequence_length + end + + it "includes embed prefix" do + open_ai_embedding_def.update!(embed_prompt: prefix) + + prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def) + + expect(prepared_text.starts_with?(prefix)).to eq(true) + end + end + describe "#prepare_query_text" do - context "when using vector def from OpenAI" do - before { SiteSetting.max_post_length = 100_000 } + context "when search is asymetric" do + it "includes search prefix" do + open_ai_embedding_def.update!(search_prompt: prefix) - fab!(:topic) - fab!(:post) do - Fabricate(:post, topic: topic, raw: "Baby, bird, bird, bird\nBird is the word\n" * 500) - end - fab!(:post) do - Fabricate( - :post, - topic: topic, - raw: "Don't you know about the bird?\nEverybody knows that the bird is a word\n" * 400, - ) - end - fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) } - fab!(:open_ai_embedding_def) + prepared_query_text = + truncation.prepare_query_text("searching", open_ai_embedding_def, asymetric: true) - it "truncates a topic" do - prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def) - - expect(open_ai_embedding_def.tokenizer.size(prepared_text)).to be <= - open_ai_embedding_def.max_sequence_length + expect(prepared_query_text.starts_with?(prefix)).to eq(true) end end end diff --git a/spec/requests/admin/ai_embeddings_controller_spec.rb b/spec/requests/admin/ai_embeddings_controller_spec.rb index 97f38ac6..ac3f0c23 100644 --- a/spec/requests/admin/ai_embeddings_controller_spec.rb +++ b/spec/requests/admin/ai_embeddings_controller_spec.rb @@ -15,6 +15,8 @@ RSpec.describe DiscourseAi::Admin::AiEmbeddingsController do url: "https://test.com/api/v1/embeddings", api_key: "test", tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer", + embed_prompt: "I come first:", + search_prompt: "prefix for search", } end @@ -27,6 +29,8 @@ RSpec.describe DiscourseAi::Admin::AiEmbeddingsController do expect(response.status).to eq(201) expect(created_def.display_name).to eq(valid_attrs[:display_name]) + expect(created_def.embed_prompt).to eq(valid_attrs[:embed_prompt]) + expect(created_def.search_prompt).to eq(valid_attrs[:search_prompt]) end it "stores provider-specific config params" do diff --git a/spec/system/embeddings/ai_embedding_definition_spec.rb b/spec/system/embeddings/ai_embedding_definition_spec.rb index 2c7347f5..d7ea406f 100644 --- a/spec/system/embeddings/ai_embedding_definition_spec.rb +++ b/spec/system/embeddings/ai_embedding_definition_spec.rb @@ -61,6 +61,11 @@ RSpec.describe "Managing Embeddings configurations", type: :system, js: true do select_kit.expand select_kit.select_row_by_value("DiscourseAi::Tokenizer::OpenAiTokenizer") + embed_prefix = "On creation:" + search_prefix = "On search:" + find("input.ai-embedding-editor__embed_prompt").fill_in(with: embed_prefix) + find("input.ai-embedding-editor__search_prompt").fill_in(with: search_prefix) + find("input.ai-embedding-editor__dimensions").fill_in(with: 1536) find("input.ai-embedding-editor__max_sequence_length").fill_in(with: 8191) @@ -83,5 +88,7 @@ RSpec.describe "Managing Embeddings configurations", type: :system, js: true do expect(embedding_def.max_sequence_length).to eq(preset[:max_sequence_length]) expect(embedding_def.pg_function).to eq(preset[:pg_function]) expect(embedding_def.provider).to eq(preset[:provider]) + expect(embedding_def.embed_prompt).to eq(embed_prefix) + expect(embedding_def.search_prompt).to eq(search_prefix) end end