FIX: scrub invalid prompts when truncating (#426)
When you trim a prompt we never want to have a state where there is a "tool" reply without a corresponding tool call, it makes no sense Also - GPT-4-Turbo is 128k, fix that - Claude was not preserving username in prompt - We were throwing away unicode usernames instead of adding to message
This commit is contained in:
parent
ff4da6ace8
commit
05d8b021f1
|
@ -21,6 +21,8 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
||||
|
||||
def translate
|
||||
messages = prompt.messages
|
||||
|
||||
|
@ -32,6 +34,11 @@ module DiscourseAi
|
|||
|
||||
trimmed_messages = trim_messages(messages)
|
||||
|
||||
embed_user_ids =
|
||||
trimmed_messages.any? do |m|
|
||||
m[:id] && m[:type] == :user && !m[:id].to_s.match?(VALID_ID_REGEX)
|
||||
end
|
||||
|
||||
trimmed_messages.map do |msg|
|
||||
if msg[:type] == :system
|
||||
{ role: "system", content: msg[:content] }
|
||||
|
@ -49,9 +56,15 @@ module DiscourseAi
|
|||
elsif msg[:type] == :tool
|
||||
{ role: "tool", tool_call_id: msg[:id], content: msg[:content] }
|
||||
else
|
||||
{ role: "user", content: msg[:content] }.tap do |user_msg|
|
||||
user_msg[:name] = msg[:id] if msg[:id]
|
||||
user_message = { role: "user", content: msg[:content] }
|
||||
if msg[:id]
|
||||
if embed_user_ids
|
||||
user_message[:content] = "#{msg[:id]}: #{msg[:content]}"
|
||||
else
|
||||
user_message[:name] = msg[:id]
|
||||
end
|
||||
end
|
||||
user_message
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -109,6 +122,10 @@ module DiscourseAi
|
|||
8192
|
||||
when "gpt-4-32k"
|
||||
32_768
|
||||
when "gpt-4-1106-preview"
|
||||
131_072
|
||||
when "gpt-4-turbo"
|
||||
131_072
|
||||
else
|
||||
8192
|
||||
end
|
||||
|
|
|
@ -29,7 +29,7 @@ module DiscourseAi
|
|||
if msg[:type] == :system
|
||||
memo << "Human: " unless uses_system_message?
|
||||
memo << msg[:content]
|
||||
if prompt.tools
|
||||
if prompt.tools.present?
|
||||
memo << "\n"
|
||||
memo << build_tools_prompt
|
||||
end
|
||||
|
@ -49,7 +49,9 @@ module DiscourseAi
|
|||
</function_results>
|
||||
TEXT
|
||||
else
|
||||
memo << "\n\nHuman: #{msg[:content]}"
|
||||
memo << "\n\nHuman: "
|
||||
memo << "#{msg[:id]}: " if msg[:id]
|
||||
memo << msg[:content]
|
||||
end
|
||||
|
||||
memo
|
||||
|
|
|
@ -157,6 +157,8 @@ module DiscourseAi
|
|||
reversed_trimmed_msgs << dupped_msg
|
||||
end
|
||||
|
||||
reversed_trimmed_msgs.pop if reversed_trimmed_msgs.last&.dig(:type) == :tool
|
||||
|
||||
trimmed_messages.concat(reversed_trimmed_msgs.reverse)
|
||||
end
|
||||
|
||||
|
|
|
@ -22,10 +22,6 @@ module DiscourseAi
|
|||
@messages.concat(messages)
|
||||
|
||||
@messages.each { |message| validate_message(message) }
|
||||
@messages.each do |message|
|
||||
message[:id] = clean_username(message[:id]) if message[:type] == :user &&
|
||||
message[:id].present?
|
||||
end
|
||||
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }
|
||||
|
||||
@tools = tools
|
||||
|
@ -34,8 +30,7 @@ module DiscourseAi
|
|||
def push(type:, content:, id: nil)
|
||||
return if type == :system
|
||||
new_message = { type: type, content: content }
|
||||
|
||||
new_message[:id] = type == :user ? clean_username(id) : id if id && type != :model
|
||||
new_message[:id] = id.to_s if id
|
||||
|
||||
validate_message(new_message)
|
||||
validate_turn(messages.last, new_message)
|
||||
|
@ -45,16 +40,6 @@ module DiscourseAi
|
|||
|
||||
private
|
||||
|
||||
def clean_username(username)
|
||||
if username.match?(/\0[a-zA-Z0-9_-]{1,64}\z/)
|
||||
username
|
||||
else
|
||||
# not the best in the world, but this is what we have to work with
|
||||
# if sites enable unicode usernames this can get messy
|
||||
username.gsub(/[^a-zA-Z0-9_-]/, "_")[0..63]
|
||||
end
|
||||
end
|
||||
|
||||
def validate_message(message)
|
||||
valid_types = %i[system user model tool tool_call]
|
||||
if !valid_types.include?(message[:type])
|
||||
|
|
|
@ -18,6 +18,29 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
|
|||
expect(translated).to contain_exactly(*open_ai_version)
|
||||
end
|
||||
|
||||
it "will retain usernames for unicode usernames, correctly in mixed mode" do
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are a bot",
|
||||
messages: [
|
||||
{ id: "👻", type: :user, content: "Message1" },
|
||||
{ type: :model, content: "Ok" },
|
||||
{ id: "joe", type: :user, content: "Message2" },
|
||||
],
|
||||
)
|
||||
|
||||
translated = context.dialect(prompt).translate
|
||||
|
||||
expect(translated).to eq(
|
||||
[
|
||||
{ role: "system", content: "You are a bot" },
|
||||
{ role: "user", content: "👻: Message1" },
|
||||
{ role: "assistant", content: "Ok" },
|
||||
{ role: "user", content: "joe: Message2" },
|
||||
],
|
||||
)
|
||||
end
|
||||
|
||||
it "translates tool_call and tool messages" do
|
||||
expect(context.multi_turn_scenario).to eq(
|
||||
[
|
||||
|
|
|
@ -31,11 +31,11 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
|||
<tools>
|
||||
#{context.dialect_tools}</tools>
|
||||
|
||||
Human: This is a message by a user
|
||||
Human: user1: This is a message by a user
|
||||
|
||||
Assistant: I'm a previous bot reply, that's why there's no user
|
||||
|
||||
Human: This is a new message by a user
|
||||
Human: user1: This is a new message by a user
|
||||
|
||||
Assistant:
|
||||
<function_results>
|
||||
|
@ -60,5 +60,31 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
|||
|
||||
expect(translated.length).to be < context.long_message_text(length: length).length
|
||||
end
|
||||
|
||||
it "retains usernames in generated prompt" do
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are a bot",
|
||||
messages: [
|
||||
{ id: "👻", type: :user, content: "Message1" },
|
||||
{ type: :model, content: "Ok" },
|
||||
{ id: "joe", type: :user, content: "Message2" },
|
||||
],
|
||||
)
|
||||
|
||||
translated = context.dialect(prompt).translate
|
||||
|
||||
expect(translated).to eq(<<~TEXT.strip)
|
||||
You are a bot
|
||||
|
||||
Human: 👻: Message1
|
||||
|
||||
Assistant: Ok
|
||||
|
||||
Human: joe: Message2
|
||||
|
||||
Assistant:
|
||||
TEXT
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class TestDialect < DiscourseAi::Completions::Dialects::Dialect
|
||||
attr_accessor :max_prompt_tokens
|
||||
|
||||
def trim(messages)
|
||||
trim_messages(messages)
|
||||
end
|
||||
|
||||
def self.tokenizer
|
||||
Class.new do
|
||||
def self.size(str)
|
||||
str.length
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
|
||||
describe "#trim_messages" do
|
||||
it "should trim tool messages if tool_calls are trimmed" do
|
||||
prompt = DiscourseAi::Completions::Prompt.new("12345")
|
||||
prompt.push(type: :user, content: "12345")
|
||||
prompt.push(type: :tool_call, content: "12345", id: 1)
|
||||
prompt.push(type: :tool, content: "12345", id: 1)
|
||||
prompt.push(type: :user, content: "12345")
|
||||
|
||||
dialect = TestDialect.new(prompt, "test")
|
||||
dialect.max_prompt_tokens = 15 # fits the user messages and the tool_call message
|
||||
|
||||
trimmed = dialect.trim(prompt.messages)
|
||||
|
||||
expect(trimmed).to eq(
|
||||
[{ type: :system, content: "12345" }, { type: :user, content: "12345" }],
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -21,18 +21,6 @@ RSpec.describe DiscourseAi::Completions::Prompt do
|
|||
bad_messages3 = [{ content: "some content associated to no one" }]
|
||||
expect { described_class.new("a bot", messages: bad_messages3) }.to raise_error(ArgumentError)
|
||||
end
|
||||
|
||||
it "cleans unicode usernames" do
|
||||
unicode_username = "罗马罗马"
|
||||
|
||||
prompt =
|
||||
described_class.new(
|
||||
"a bot",
|
||||
messages: [{ type: :user, content: user_msg, id: unicode_username }],
|
||||
)
|
||||
|
||||
expect(prompt.messages.last[:id]).to eq("____")
|
||||
end
|
||||
end
|
||||
|
||||
describe "#push" do
|
||||
|
@ -74,14 +62,5 @@ RSpec.describe DiscourseAi::Completions::Prompt do
|
|||
expect(system_message[:content]).to eq(user_msg)
|
||||
expect(system_message[:id]).to eq(username)
|
||||
end
|
||||
|
||||
it "cleans unicode usernames" do
|
||||
unicode_username = "罗马罗马"
|
||||
prompt.push(type: :user, content: user_msg, id: unicode_username)
|
||||
|
||||
user_message = prompt.messages.last
|
||||
|
||||
expect(user_message[:id]).to eq("____")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue