From aabff875014e546a46ef2d67c8f03398c99241e8 Mon Sep 17 00:00:00 2001
From: Sam <sam.saffron@gmail.com>
Date: Tue, 27 Feb 2024 18:24:30 +1100
Subject: [PATCH] FIX: image generation in gemini was broken (#490)

We need to inject blank model answers after tool calls if absent
otherwise model will reject it.
---
 lib/completions/dialects/gemini.rb            | 10 ++++----
 .../completions/dialects/dialect_context.rb   | 18 ++++++++++++++
 spec/lib/completions/dialects/gemini_spec.rb  | 24 +++++++++++++++++++
 3 files changed, 48 insertions(+), 4 deletions(-)

diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb
index cfbae79e..4987df54 100644
--- a/lib/completions/dialects/gemini.rb
+++ b/lib/completions/dialects/gemini.rb
@@ -23,7 +23,9 @@ module DiscourseAi
           # Gemini doesn't use an assistant msg to improve long-context responses.
           messages.pop if messages.last[:type] == :model
 
-          trim_messages(messages).reduce([]) do |memo, msg|
+          memo = []
+
+          trim_messages(messages).each do |msg|
             if msg[:type] == :system
               memo << { role: "user", parts: { text: msg[:content] } }
               memo << noop_model_response.dup
@@ -56,15 +58,15 @@ module DiscourseAi
             else
               # Gemini quirk. Doesn't accept tool -> user or user -> user msgs.
               previous_msg_role = memo.last&.dig(:role)
-              if previous_msg_role == "user" || previous_msg_role == "tool"
+              if previous_msg_role == "user" || previous_msg_role == "function"
                 memo << noop_model_response.dup
               end
 
               memo << { role: "user", parts: { text: msg[:content] } }
             end
-
-            memo
           end
+
+          memo
         end
 
         def tools
diff --git a/spec/lib/completions/dialects/dialect_context.rb b/spec/lib/completions/dialects/dialect_context.rb
index ddbb6172..f97eddb0 100644
--- a/spec/lib/completions/dialects/dialect_context.rb
+++ b/spec/lib/completions/dialects/dialect_context.rb
@@ -25,6 +25,24 @@ class DialectContext
     dialect(a_prompt).translate
   end
 
+  def image_generation_scenario
+    context_and_multi_turn = [
+      { type: :user, id: "user1", content: "draw a cat" },
+      {
+        type: :tool_call,
+        id: "tool_id",
+        content: { name: "draw", arguments: { picture: "Cat" } }.to_json,
+      },
+      { type: :tool, id: "tool_id", content: "I'm a tool result".to_json },
+      { type: :user, id: "user1", content: "draw another cat" },
+    ]
+
+    a_prompt = prompt
+    context_and_multi_turn.each { |msg| a_prompt.push(**msg) }
+
+    dialect(a_prompt).translate
+  end
+
   def multi_turn_scenario
     context_and_multi_turn = [
       { type: :user, id: "user1", content: "This is a message by a user" },
diff --git a/spec/lib/completions/dialects/gemini_spec.rb b/spec/lib/completions/dialects/gemini_spec.rb
index 248f51cb..534de8f4 100644
--- a/spec/lib/completions/dialects/gemini_spec.rb
+++ b/spec/lib/completions/dialects/gemini_spec.rb
@@ -19,6 +19,30 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
       expect(translated).to eq(gemini_version)
     end
 
+    it "injects model after tool call" do
+      expect(context.image_generation_scenario).to eq(
+        [
+          { role: "user", parts: { text: context.system_insts } },
+          { parts: { text: "Ok." }, role: "model" },
+          { parts: { text: "draw a cat" }, role: "user" },
+          { parts: { functionCall: { args: { picture: "Cat" }, name: "draw" } }, role: "model" },
+          {
+            parts: {
+              functionResponse: {
+                name: "tool_id",
+                response: {
+                  content: "\"I'm a tool result\"",
+                },
+              },
+            },
+            role: "function",
+          },
+          { parts: { text: "Ok." }, role: "model" },
+          { parts: { text: "draw another cat" }, role: "user" },
+        ],
+      )
+    end
+
     it "translates tool_call and tool messages" do
       expect(context.multi_turn_scenario).to eq(
         [