FIX: scrub invalid prompts when truncating (#426)

When you trim a prompt we never want to have a state where there
is a "tool" reply without a corresponding tool call, it makes no
sense

Also

- GPT-4-Turbo is 128k, fix that
- Claude was not preserving username in prompt
- We were throwing away unicode usernames instead of adding to
message
This commit is contained in:
Sam 2024-01-16 13:48:00 +11:00 committed by GitHub
parent ff4da6ace8
commit 05d8b021f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 117 additions and 45 deletions

View File

@ -21,6 +21,8 @@ module DiscourseAi
end
end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def translate
messages = prompt.messages
@ -32,6 +34,11 @@ module DiscourseAi
trimmed_messages = trim_messages(messages)
embed_user_ids =
trimmed_messages.any? do |m|
m[:id] && m[:type] == :user && !m[:id].to_s.match?(VALID_ID_REGEX)
end
trimmed_messages.map do |msg|
if msg[:type] == :system
{ role: "system", content: msg[:content] }
@ -49,9 +56,15 @@ module DiscourseAi
elsif msg[:type] == :tool
{ role: "tool", tool_call_id: msg[:id], content: msg[:content] }
else
{ role: "user", content: msg[:content] }.tap do |user_msg|
user_msg[:name] = msg[:id] if msg[:id]
user_message = { role: "user", content: msg[:content] }
if msg[:id]
if embed_user_ids
user_message[:content] = "#{msg[:id]}: #{msg[:content]}"
else
user_message[:name] = msg[:id]
end
end
user_message
end
end
end
@ -109,6 +122,10 @@ module DiscourseAi
8192
when "gpt-4-32k"
32_768
when "gpt-4-1106-preview"
131_072
when "gpt-4-turbo"
131_072
else
8192
end

View File

@ -29,7 +29,7 @@ module DiscourseAi
if msg[:type] == :system
memo << "Human: " unless uses_system_message?
memo << msg[:content]
if prompt.tools
if prompt.tools.present?
memo << "\n"
memo << build_tools_prompt
end
@ -49,7 +49,9 @@ module DiscourseAi
</function_results>
TEXT
else
memo << "\n\nHuman: #{msg[:content]}"
memo << "\n\nHuman: "
memo << "#{msg[:id]}: " if msg[:id]
memo << msg[:content]
end
memo

View File

@ -157,6 +157,8 @@ module DiscourseAi
reversed_trimmed_msgs << dupped_msg
end
reversed_trimmed_msgs.pop if reversed_trimmed_msgs.last&.dig(:type) == :tool
trimmed_messages.concat(reversed_trimmed_msgs.reverse)
end

View File

@ -22,10 +22,6 @@ module DiscourseAi
@messages.concat(messages)
@messages.each { |message| validate_message(message) }
@messages.each do |message|
message[:id] = clean_username(message[:id]) if message[:type] == :user &&
message[:id].present?
end
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }
@tools = tools
@ -34,8 +30,7 @@ module DiscourseAi
def push(type:, content:, id: nil)
return if type == :system
new_message = { type: type, content: content }
new_message[:id] = type == :user ? clean_username(id) : id if id && type != :model
new_message[:id] = id.to_s if id
validate_message(new_message)
validate_turn(messages.last, new_message)
@ -45,16 +40,6 @@ module DiscourseAi
private
def clean_username(username)
if username.match?(/\0[a-zA-Z0-9_-]{1,64}\z/)
username
else
# not the best in the world, but this is what we have to work with
# if sites enable unicode usernames this can get messy
username.gsub(/[^a-zA-Z0-9_-]/, "_")[0..63]
end
end
def validate_message(message)
valid_types = %i[system user model tool tool_call]
if !valid_types.include?(message[:type])

View File

@ -18,6 +18,29 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
expect(translated).to contain_exactly(*open_ai_version)
end
it "will retain usernames for unicode usernames, correctly in mixed mode" do
prompt =
DiscourseAi::Completions::Prompt.new(
"You are a bot",
messages: [
{ id: "👻", type: :user, content: "Message1" },
{ type: :model, content: "Ok" },
{ id: "joe", type: :user, content: "Message2" },
],
)
translated = context.dialect(prompt).translate
expect(translated).to eq(
[
{ role: "system", content: "You are a bot" },
{ role: "user", content: "👻: Message1" },
{ role: "assistant", content: "Ok" },
{ role: "user", content: "joe: Message2" },
],
)
end
it "translates tool_call and tool messages" do
expect(context.multi_turn_scenario).to eq(
[

View File

@ -31,11 +31,11 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
<tools>
#{context.dialect_tools}</tools>
Human: This is a message by a user
Human: user1: This is a message by a user
Assistant: I'm a previous bot reply, that's why there's no user
Human: This is a new message by a user
Human: user1: This is a new message by a user
Assistant:
<function_results>
@ -60,5 +60,31 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
expect(translated.length).to be < context.long_message_text(length: length).length
end
it "retains usernames in generated prompt" do
prompt =
DiscourseAi::Completions::Prompt.new(
"You are a bot",
messages: [
{ id: "👻", type: :user, content: "Message1" },
{ type: :model, content: "Ok" },
{ id: "joe", type: :user, content: "Message2" },
],
)
translated = context.dialect(prompt).translate
expect(translated).to eq(<<~TEXT.strip)
You are a bot
Human: 👻: Message1
Assistant: Ok
Human: joe: Message2
Assistant:
TEXT
end
end
end

View File

@ -0,0 +1,38 @@
# frozen_string_literal: true
class TestDialect < DiscourseAi::Completions::Dialects::Dialect
attr_accessor :max_prompt_tokens
def trim(messages)
trim_messages(messages)
end
def self.tokenizer
Class.new do
def self.size(str)
str.length
end
end
end
end
RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
describe "#trim_messages" do
it "should trim tool messages if tool_calls are trimmed" do
prompt = DiscourseAi::Completions::Prompt.new("12345")
prompt.push(type: :user, content: "12345")
prompt.push(type: :tool_call, content: "12345", id: 1)
prompt.push(type: :tool, content: "12345", id: 1)
prompt.push(type: :user, content: "12345")
dialect = TestDialect.new(prompt, "test")
dialect.max_prompt_tokens = 15 # fits the user messages and the tool_call message
trimmed = dialect.trim(prompt.messages)
expect(trimmed).to eq(
[{ type: :system, content: "12345" }, { type: :user, content: "12345" }],
)
end
end
end

View File

@ -21,18 +21,6 @@ RSpec.describe DiscourseAi::Completions::Prompt do
bad_messages3 = [{ content: "some content associated to no one" }]
expect { described_class.new("a bot", messages: bad_messages3) }.to raise_error(ArgumentError)
end
it "cleans unicode usernames" do
unicode_username = "罗马罗马"
prompt =
described_class.new(
"a bot",
messages: [{ type: :user, content: user_msg, id: unicode_username }],
)
expect(prompt.messages.last[:id]).to eq("____")
end
end
describe "#push" do
@ -74,14 +62,5 @@ RSpec.describe DiscourseAi::Completions::Prompt do
expect(system_message[:content]).to eq(user_msg)
expect(system_message[:id]).to eq(username)
end
it "cleans unicode usernames" do
unicode_username = "罗马罗马"
prompt.push(type: :user, content: user_msg, id: unicode_username)
user_message = prompt.messages.last
expect(user_message[:id]).to eq("____")
end
end
end