FIX: Restore the accidentally deleted query prefix. (#1079)

Additionally, we add a prefix for embedding generation.
Both are stored in the definitions table.
This commit is contained in:
Roman Rizzi 2025-01-21 14:10:31 -03:00 committed by GitHub
parent f5cf1019fb
commit 3b66fb3e87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 119 additions and 34 deletions

View File

@ -111,6 +111,8 @@ module DiscourseAi
:url,
:api_key,
:tokenizer_class,
:embed_prompt,
:search_prompt,
)
extra_field_names = EmbeddingDefinition.provider_params.dig(permitted[:provider]&.to_sym)

View File

@ -42,6 +42,7 @@ class EmbeddingDefinition < ActiveRecord::Base
pg_function: "<#>",
tokenizer_class: "DiscourseAi::Tokenizer::BgeLargeEnTokenizer",
provider: HUGGING_FACE,
search_prompt: "Represent this sentence for searching relevant passages:",
},
{
preset_id: "bge-m3",
@ -228,4 +229,6 @@ end
# provider_params :jsonb
# created_at :datetime not null
# updated_at :datetime not null
# embed_prompt :string default(""), not null
# search_prompt :string default(""), not null
#

View File

@ -13,6 +13,8 @@ class AiEmbeddingDefinitionSerializer < ApplicationSerializer
:api_key,
:seeded,
:tokenizer_class,
:embed_prompt,
:search_prompt,
:provider_params
def api_key

View File

@ -14,7 +14,9 @@ export default class AiEmbedding extends RestModel {
"api_key",
"max_sequence_length",
"provider_params",
"pg_function"
"pg_function",
"embed_prompt",
"search_prompt"
);
}

View File

@ -290,6 +290,24 @@ export default class AiEmbeddingEditor extends Component {
{{/if}}
</div>
<div class="control-group">
<label>{{i18n "discourse_ai.embeddings.embed_prompt"}}</label>
<Input
@type="text"
class="ai-embedding-editor-input ai-embedding-editor__embed_prompt"
@value={{this.editingModel.embed_prompt}}
/>
</div>
<div class="control-group">
<label>{{i18n "discourse_ai.embeddings.search_prompt"}}</label>
<Input
@type="text"
class="ai-embedding-editor-input ai-embedding-editor__search_prompt"
@value={{this.editingModel.search_prompt}}
/>
</div>
<div class="control-group">
<label>{{i18n "discourse_ai.embeddings.max_sequence_length"}}</label>
<Input

View File

@ -530,7 +530,9 @@ en:
tokenizer: "Tokenizer"
dimensions: "Embedding dimensions"
max_sequence_length: "Sequence length"
embed_prompt: "Embed prompt"
search_prompt: "Search prompt"
distance_function: "Distance function"
distance_functions:
<#>: "Negative inner product (<#>)"

View File

@ -0,0 +1,18 @@
# frozen_string_literal: true
class ConfigurableEmbeddingsPrefixes < ActiveRecord::Migration[7.2]
def up
add_column :embedding_definitions, :embed_prompt, :string, null: false, default: ""
add_column :embedding_definitions, :search_prompt, :string, null: false, default: ""
# 4 is bge-large-en. Default model and the only one using this so far.
execute <<~SQL
UPDATE embedding_definitions
SET search_prompt='Represent this sentence for searching relevant passages:'
WHERE id = 4
SQL
end
def down
raise ActiveRecord::IrreversibleMigration
end
end

View File

@ -15,23 +15,28 @@ module DiscourseAi
def prepare_target_text(target, vdef)
max_length = vdef.max_sequence_length - 2
case target
when Topic
topic_truncation(target, vdef.tokenizer, max_length)
when Post
post_truncation(target, vdef.tokenizer, max_length)
when RagDocumentFragment
vdef.tokenizer.truncate(target.fragment, max_length)
else
raise ArgumentError, "Invalid target type"
end
prepared_text =
case target
when Topic
topic_truncation(target, vdef.tokenizer, max_length)
when Post
post_truncation(target, vdef.tokenizer, max_length)
when RagDocumentFragment
vdef.tokenizer.truncate(target.fragment, max_length)
else
raise ArgumentError, "Invalid target type"
end
return prepared_text if vdef.embed_prompt.blank?
[vdef.embed_prompt, prepared_text].join(" ")
end
def prepare_query_text(text, vdef, asymetric: false)
qtext = asymetric ? "#{vdef.asymmetric_query_prefix} #{text}" : text
qtext = asymetric ? "#{vdef.search_prompt} #{text}" : text
max_length = vdef.max_sequence_length - 2
vdef.tokenizer.truncate(text, max_length)
vdef.tokenizer.truncate(qtext, max_length)
end
private

View File

@ -3,29 +3,51 @@
RSpec.describe DiscourseAi::Embeddings::Strategies::Truncation do
subject(:truncation) { described_class.new }
fab!(:open_ai_embedding_def)
let(:prefix) { "I come first:" }
describe "#prepare_target_text" do
before { SiteSetting.max_post_length = 100_000 }
fab!(:topic)
fab!(:post) do
Fabricate(:post, topic: topic, raw: "Baby, bird, bird, bird\nBird is the word\n" * 500)
end
fab!(:post) do
Fabricate(
:post,
topic: topic,
raw: "Don't you know about the bird?\nEverybody knows that the bird is a word\n" * 400,
)
end
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
fab!(:open_ai_embedding_def)
it "truncates a topic" do
prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def)
expect(open_ai_embedding_def.tokenizer.size(prepared_text)).to be <=
open_ai_embedding_def.max_sequence_length
end
it "includes embed prefix" do
open_ai_embedding_def.update!(embed_prompt: prefix)
prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def)
expect(prepared_text.starts_with?(prefix)).to eq(true)
end
end
describe "#prepare_query_text" do
context "when using vector def from OpenAI" do
before { SiteSetting.max_post_length = 100_000 }
context "when search is asymetric" do
it "includes search prefix" do
open_ai_embedding_def.update!(search_prompt: prefix)
fab!(:topic)
fab!(:post) do
Fabricate(:post, topic: topic, raw: "Baby, bird, bird, bird\nBird is the word\n" * 500)
end
fab!(:post) do
Fabricate(
:post,
topic: topic,
raw: "Don't you know about the bird?\nEverybody knows that the bird is a word\n" * 400,
)
end
fab!(:post) { Fabricate(:post, topic: topic, raw: "Surfin' bird\n" * 800) }
fab!(:open_ai_embedding_def)
prepared_query_text =
truncation.prepare_query_text("searching", open_ai_embedding_def, asymetric: true)
it "truncates a topic" do
prepared_text = truncation.prepare_target_text(topic, open_ai_embedding_def)
expect(open_ai_embedding_def.tokenizer.size(prepared_text)).to be <=
open_ai_embedding_def.max_sequence_length
expect(prepared_query_text.starts_with?(prefix)).to eq(true)
end
end
end

View File

@ -15,6 +15,8 @@ RSpec.describe DiscourseAi::Admin::AiEmbeddingsController do
url: "https://test.com/api/v1/embeddings",
api_key: "test",
tokenizer_class: "DiscourseAi::Tokenizer::BgeM3Tokenizer",
embed_prompt: "I come first:",
search_prompt: "prefix for search",
}
end
@ -27,6 +29,8 @@ RSpec.describe DiscourseAi::Admin::AiEmbeddingsController do
expect(response.status).to eq(201)
expect(created_def.display_name).to eq(valid_attrs[:display_name])
expect(created_def.embed_prompt).to eq(valid_attrs[:embed_prompt])
expect(created_def.search_prompt).to eq(valid_attrs[:search_prompt])
end
it "stores provider-specific config params" do

View File

@ -61,6 +61,11 @@ RSpec.describe "Managing Embeddings configurations", type: :system, js: true do
select_kit.expand
select_kit.select_row_by_value("DiscourseAi::Tokenizer::OpenAiTokenizer")
embed_prefix = "On creation:"
search_prefix = "On search:"
find("input.ai-embedding-editor__embed_prompt").fill_in(with: embed_prefix)
find("input.ai-embedding-editor__search_prompt").fill_in(with: search_prefix)
find("input.ai-embedding-editor__dimensions").fill_in(with: 1536)
find("input.ai-embedding-editor__max_sequence_length").fill_in(with: 8191)
@ -83,5 +88,7 @@ RSpec.describe "Managing Embeddings configurations", type: :system, js: true do
expect(embedding_def.max_sequence_length).to eq(preset[:max_sequence_length])
expect(embedding_def.pg_function).to eq(preset[:pg_function])
expect(embedding_def.provider).to eq(preset[:provider])
expect(embedding_def.embed_prompt).to eq(embed_prefix)
expect(embedding_def.search_prompt).to eq(search_prefix)
end
end