diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index d3d1030b..4d511a0b 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -21,6 +21,8 @@ module DiscourseAi end end + VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ + def translate messages = prompt.messages @@ -32,6 +34,11 @@ module DiscourseAi trimmed_messages = trim_messages(messages) + embed_user_ids = + trimmed_messages.any? do |m| + m[:id] && m[:type] == :user && !m[:id].to_s.match?(VALID_ID_REGEX) + end + trimmed_messages.map do |msg| if msg[:type] == :system { role: "system", content: msg[:content] } @@ -49,9 +56,15 @@ module DiscourseAi elsif msg[:type] == :tool { role: "tool", tool_call_id: msg[:id], content: msg[:content] } else - { role: "user", content: msg[:content] }.tap do |user_msg| - user_msg[:name] = msg[:id] if msg[:id] + user_message = { role: "user", content: msg[:content] } + if msg[:id] + if embed_user_ids + user_message[:content] = "#{msg[:id]}: #{msg[:content]}" + else + user_message[:name] = msg[:id] + end end + user_message end end end @@ -109,6 +122,10 @@ module DiscourseAi 8192 when "gpt-4-32k" 32_768 + when "gpt-4-1106-preview" + 131_072 + when "gpt-4-turbo" + 131_072 else 8192 end diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index acb0f538..f05739c0 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -29,7 +29,7 @@ module DiscourseAi if msg[:type] == :system memo << "Human: " unless uses_system_message? memo << msg[:content] - if prompt.tools + if prompt.tools.present? memo << "\n" memo << build_tools_prompt end @@ -49,7 +49,9 @@ module DiscourseAi TEXT else - memo << "\n\nHuman: #{msg[:content]}" + memo << "\n\nHuman: " + memo << "#{msg[:id]}: " if msg[:id] + memo << msg[:content] end memo diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 1d968746..8b6acd59 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -157,6 +157,8 @@ module DiscourseAi reversed_trimmed_msgs << dupped_msg end + reversed_trimmed_msgs.pop if reversed_trimmed_msgs.last&.dig(:type) == :tool + trimmed_messages.concat(reversed_trimmed_msgs.reverse) end diff --git a/lib/completions/prompt.rb b/lib/completions/prompt.rb index e5a2e2e8..5819055f 100644 --- a/lib/completions/prompt.rb +++ b/lib/completions/prompt.rb @@ -22,10 +22,6 @@ module DiscourseAi @messages.concat(messages) @messages.each { |message| validate_message(message) } - @messages.each do |message| - message[:id] = clean_username(message[:id]) if message[:type] == :user && - message[:id].present? - end @messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) } @tools = tools @@ -34,8 +30,7 @@ module DiscourseAi def push(type:, content:, id: nil) return if type == :system new_message = { type: type, content: content } - - new_message[:id] = type == :user ? clean_username(id) : id if id && type != :model + new_message[:id] = id.to_s if id validate_message(new_message) validate_turn(messages.last, new_message) @@ -45,16 +40,6 @@ module DiscourseAi private - def clean_username(username) - if username.match?(/\0[a-zA-Z0-9_-]{1,64}\z/) - username - else - # not the best in the world, but this is what we have to work with - # if sites enable unicode usernames this can get messy - username.gsub(/[^a-zA-Z0-9_-]/, "_")[0..63] - end - end - def validate_message(message) valid_types = %i[system user model tool tool_call] if !valid_types.include?(message[:type]) diff --git a/spec/lib/completions/dialects/chat_gpt_spec.rb b/spec/lib/completions/dialects/chat_gpt_spec.rb index d26c6fe1..c1714d75 100644 --- a/spec/lib/completions/dialects/chat_gpt_spec.rb +++ b/spec/lib/completions/dialects/chat_gpt_spec.rb @@ -18,6 +18,29 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do expect(translated).to contain_exactly(*open_ai_version) end + it "will retain usernames for unicode usernames, correctly in mixed mode" do + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a bot", + messages: [ + { id: "👻", type: :user, content: "Message1" }, + { type: :model, content: "Ok" }, + { id: "joe", type: :user, content: "Message2" }, + ], + ) + + translated = context.dialect(prompt).translate + + expect(translated).to eq( + [ + { role: "system", content: "You are a bot" }, + { role: "user", content: "👻: Message1" }, + { role: "assistant", content: "Ok" }, + { role: "user", content: "joe: Message2" }, + ], + ) + end + it "translates tool_call and tool messages" do expect(context.multi_turn_scenario).to eq( [ diff --git a/spec/lib/completions/dialects/claude_spec.rb b/spec/lib/completions/dialects/claude_spec.rb index 4e8b877a..5c4619f0 100644 --- a/spec/lib/completions/dialects/claude_spec.rb +++ b/spec/lib/completions/dialects/claude_spec.rb @@ -15,7 +15,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do #{context.dialect_tools} Human: #{context.simple_user_input} - + Assistant: TEXT @@ -31,11 +31,11 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do #{context.dialect_tools} - Human: This is a message by a user + Human: user1: This is a message by a user Assistant: I'm a previous bot reply, that's why there's no user - Human: This is a new message by a user + Human: user1: This is a new message by a user Assistant: @@ -46,7 +46,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do - + Assistant: TEXT @@ -60,5 +60,31 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do expect(translated.length).to be < context.long_message_text(length: length).length end + + it "retains usernames in generated prompt" do + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a bot", + messages: [ + { id: "👻", type: :user, content: "Message1" }, + { type: :model, content: "Ok" }, + { id: "joe", type: :user, content: "Message2" }, + ], + ) + + translated = context.dialect(prompt).translate + + expect(translated).to eq(<<~TEXT.strip) + You are a bot + + Human: 👻: Message1 + + Assistant: Ok + + Human: joe: Message2 + + Assistant: + TEXT + end end end diff --git a/spec/lib/completions/dialects/dialect_spec.rb b/spec/lib/completions/dialects/dialect_spec.rb new file mode 100644 index 00000000..c54d1838 --- /dev/null +++ b/spec/lib/completions/dialects/dialect_spec.rb @@ -0,0 +1,38 @@ +# frozen_string_literal: true + +class TestDialect < DiscourseAi::Completions::Dialects::Dialect + attr_accessor :max_prompt_tokens + + def trim(messages) + trim_messages(messages) + end + + def self.tokenizer + Class.new do + def self.size(str) + str.length + end + end + end +end + +RSpec.describe DiscourseAi::Completions::Dialects::Dialect do + describe "#trim_messages" do + it "should trim tool messages if tool_calls are trimmed" do + prompt = DiscourseAi::Completions::Prompt.new("12345") + prompt.push(type: :user, content: "12345") + prompt.push(type: :tool_call, content: "12345", id: 1) + prompt.push(type: :tool, content: "12345", id: 1) + prompt.push(type: :user, content: "12345") + + dialect = TestDialect.new(prompt, "test") + dialect.max_prompt_tokens = 15 # fits the user messages and the tool_call message + + trimmed = dialect.trim(prompt.messages) + + expect(trimmed).to eq( + [{ type: :system, content: "12345" }, { type: :user, content: "12345" }], + ) + end + end +end diff --git a/spec/lib/completions/prompt_spec.rb b/spec/lib/completions/prompt_spec.rb index 7ef4620e..2fa150e5 100644 --- a/spec/lib/completions/prompt_spec.rb +++ b/spec/lib/completions/prompt_spec.rb @@ -21,18 +21,6 @@ RSpec.describe DiscourseAi::Completions::Prompt do bad_messages3 = [{ content: "some content associated to no one" }] expect { described_class.new("a bot", messages: bad_messages3) }.to raise_error(ArgumentError) end - - it "cleans unicode usernames" do - unicode_username = "罗马罗马" - - prompt = - described_class.new( - "a bot", - messages: [{ type: :user, content: user_msg, id: unicode_username }], - ) - - expect(prompt.messages.last[:id]).to eq("____") - end end describe "#push" do @@ -74,14 +62,5 @@ RSpec.describe DiscourseAi::Completions::Prompt do expect(system_message[:content]).to eq(user_msg) expect(system_message[:id]).to eq(username) end - - it "cleans unicode usernames" do - unicode_username = "罗马罗马" - prompt.push(type: :user, content: user_msg, id: unicode_username) - - user_message = prompt.messages.last - - expect(user_message[:id]).to eq("____") - end end end