From aabff875014e546a46ef2d67c8f03398c99241e8 Mon Sep 17 00:00:00 2001 From: Sam <sam.saffron@gmail.com> Date: Tue, 27 Feb 2024 18:24:30 +1100 Subject: [PATCH] FIX: image generation in gemini was broken (#490) We need to inject blank model answers after tool calls if absent otherwise model will reject it. --- lib/completions/dialects/gemini.rb | 10 ++++---- .../completions/dialects/dialect_context.rb | 18 ++++++++++++++ spec/lib/completions/dialects/gemini_spec.rb | 24 +++++++++++++++++++ 3 files changed, 48 insertions(+), 4 deletions(-) diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb index cfbae79e..4987df54 100644 --- a/lib/completions/dialects/gemini.rb +++ b/lib/completions/dialects/gemini.rb @@ -23,7 +23,9 @@ module DiscourseAi # Gemini doesn't use an assistant msg to improve long-context responses. messages.pop if messages.last[:type] == :model - trim_messages(messages).reduce([]) do |memo, msg| + memo = [] + + trim_messages(messages).each do |msg| if msg[:type] == :system memo << { role: "user", parts: { text: msg[:content] } } memo << noop_model_response.dup @@ -56,15 +58,15 @@ module DiscourseAi else # Gemini quirk. Doesn't accept tool -> user or user -> user msgs. previous_msg_role = memo.last&.dig(:role) - if previous_msg_role == "user" || previous_msg_role == "tool" + if previous_msg_role == "user" || previous_msg_role == "function" memo << noop_model_response.dup end memo << { role: "user", parts: { text: msg[:content] } } end - - memo end + + memo end def tools diff --git a/spec/lib/completions/dialects/dialect_context.rb b/spec/lib/completions/dialects/dialect_context.rb index ddbb6172..f97eddb0 100644 --- a/spec/lib/completions/dialects/dialect_context.rb +++ b/spec/lib/completions/dialects/dialect_context.rb @@ -25,6 +25,24 @@ class DialectContext dialect(a_prompt).translate end + def image_generation_scenario + context_and_multi_turn = [ + { type: :user, id: "user1", content: "draw a cat" }, + { + type: :tool_call, + id: "tool_id", + content: { name: "draw", arguments: { picture: "Cat" } }.to_json, + }, + { type: :tool, id: "tool_id", content: "I'm a tool result".to_json }, + { type: :user, id: "user1", content: "draw another cat" }, + ] + + a_prompt = prompt + context_and_multi_turn.each { |msg| a_prompt.push(**msg) } + + dialect(a_prompt).translate + end + def multi_turn_scenario context_and_multi_turn = [ { type: :user, id: "user1", content: "This is a message by a user" }, diff --git a/spec/lib/completions/dialects/gemini_spec.rb b/spec/lib/completions/dialects/gemini_spec.rb index 248f51cb..534de8f4 100644 --- a/spec/lib/completions/dialects/gemini_spec.rb +++ b/spec/lib/completions/dialects/gemini_spec.rb @@ -19,6 +19,30 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do expect(translated).to eq(gemini_version) end + it "injects model after tool call" do + expect(context.image_generation_scenario).to eq( + [ + { role: "user", parts: { text: context.system_insts } }, + { parts: { text: "Ok." }, role: "model" }, + { parts: { text: "draw a cat" }, role: "user" }, + { parts: { functionCall: { args: { picture: "Cat" }, name: "draw" } }, role: "model" }, + { + parts: { + functionResponse: { + name: "tool_id", + response: { + content: "\"I'm a tool result\"", + }, + }, + }, + role: "function", + }, + { parts: { text: "Ok." }, role: "model" }, + { parts: { text: "draw another cat" }, role: "user" }, + ], + ) + end + it "translates tool_call and tool messages" do expect(context.multi_turn_scenario).to eq( [