From 0a8195242b811883aa21e46e646f2be683f54941 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Fri, 12 Jul 2024 15:09:01 -0300 Subject: [PATCH] FIX: Limit system message size to 60% of available tokens. (#714) Using RAG fragments can lead to considerably big system messages, which becomes problematic when models have a smaller context window. Before this change, we only look at the rest of the conversation to make sure we don't surpass the limit, which could lead to two unwanted scenarios when having large system messages: All other messages are excluded due to size. The system message already exceeds the limit. As a result, I'm putting a hard-limit of 60% of available tokens. We don't want to aggresively truncate because if rag fragments are included, the system message contains a lot of context to improve the model response, but we also want to make room for the recent messages in the conversation. --- lib/completions/dialects/dialect.rb | 10 +++++ spec/lib/completions/dialects/dialect_spec.rb | 37 +++++++++++++------ 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index d36c8f1d..7d49e396 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -95,7 +95,17 @@ module DiscourseAi range = (0..-1) if messages.dig(0, :type) == :system + max_system_tokens = prompt_limit * 0.6 system_message = messages[0] + system_size = calculate_message_token(system_message) + + if system_size > max_system_tokens + system_message[:content] = tokenizer.truncate( + system_message[:content], + max_system_tokens, + ) + end + trimmed_messages << system_message current_token_count += calculate_message_token(system_message) range = (1..-1) diff --git a/spec/lib/completions/dialects/dialect_spec.rb b/spec/lib/completions/dialects/dialect_spec.rb index e5511bbf..a7667a7c 100644 --- a/spec/lib/completions/dialects/dialect_spec.rb +++ b/spec/lib/completions/dialects/dialect_spec.rb @@ -8,22 +8,20 @@ class TestDialect < DiscourseAi::Completions::Dialects::Dialect end def tokenizer - Class.new do - def self.size(str) - str.length - end - end + DiscourseAi::Tokenizer::OpenAiTokenizer end end RSpec.describe DiscourseAi::Completions::Dialects::Dialect do describe "#trim_messages" do + let(:five_token_msg) { "This represents five tokens." } + it "should trim tool messages if tool_calls are trimmed" do - prompt = DiscourseAi::Completions::Prompt.new("12345") - prompt.push(type: :user, content: "12345") - prompt.push(type: :tool_call, content: "12345", id: 1) - prompt.push(type: :tool, content: "12345", id: 1) - prompt.push(type: :user, content: "12345") + prompt = DiscourseAi::Completions::Prompt.new(five_token_msg) + prompt.push(type: :user, content: five_token_msg) + prompt.push(type: :tool_call, content: five_token_msg, id: 1) + prompt.push(type: :tool, content: five_token_msg, id: 1) + prompt.push(type: :user, content: five_token_msg) dialect = TestDialect.new(prompt, "test") dialect.max_prompt_tokens = 15 # fits the user messages and the tool_call message @@ -31,7 +29,24 @@ RSpec.describe DiscourseAi::Completions::Dialects::Dialect do trimmed = dialect.trim(prompt.messages) expect(trimmed).to eq( - [{ type: :system, content: "12345" }, { type: :user, content: "12345" }], + [{ type: :system, content: five_token_msg }, { type: :user, content: five_token_msg }], + ) + end + + it "limits the system message to 60% of available tokens" do + prompt = DiscourseAi::Completions::Prompt.new("I'm a system message consisting of 10 tokens") + prompt.push(type: :user, content: five_token_msg) + + dialect = TestDialect.new(prompt, "test") + dialect.max_prompt_tokens = 15 + + trimmed = dialect.trim(prompt.messages) + + expect(trimmed).to eq( + [ + { type: :system, content: "I'm a system message consisting of 10" }, + { type: :user, content: five_token_msg }, + ], ) end end