FIX: image generation in gemini was broken (#490)

We need to inject blank model answers after tool calls if absent
otherwise model will reject it.
This commit is contained in:
Sam 2024-02-27 18:24:30 +11:00 committed by GitHub
parent 2c7d34ff1f
commit aabff87501
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 48 additions and 4 deletions

View File

@ -23,7 +23,9 @@ module DiscourseAi
# Gemini doesn't use an assistant msg to improve long-context responses.
messages.pop if messages.last[:type] == :model
trim_messages(messages).reduce([]) do |memo, msg|
memo = []
trim_messages(messages).each do |msg|
if msg[:type] == :system
memo << { role: "user", parts: { text: msg[:content] } }
memo << noop_model_response.dup
@ -56,15 +58,15 @@ module DiscourseAi
else
# Gemini quirk. Doesn't accept tool -> user or user -> user msgs.
previous_msg_role = memo.last&.dig(:role)
if previous_msg_role == "user" || previous_msg_role == "tool"
if previous_msg_role == "user" || previous_msg_role == "function"
memo << noop_model_response.dup
end
memo << { role: "user", parts: { text: msg[:content] } }
end
memo
end
memo
end
def tools

View File

@ -25,6 +25,24 @@ class DialectContext
dialect(a_prompt).translate
end
def image_generation_scenario
context_and_multi_turn = [
{ type: :user, id: "user1", content: "draw a cat" },
{
type: :tool_call,
id: "tool_id",
content: { name: "draw", arguments: { picture: "Cat" } }.to_json,
},
{ type: :tool, id: "tool_id", content: "I'm a tool result".to_json },
{ type: :user, id: "user1", content: "draw another cat" },
]
a_prompt = prompt
context_and_multi_turn.each { |msg| a_prompt.push(**msg) }
dialect(a_prompt).translate
end
def multi_turn_scenario
context_and_multi_turn = [
{ type: :user, id: "user1", content: "This is a message by a user" },

View File

@ -19,6 +19,30 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
expect(translated).to eq(gemini_version)
end
it "injects model after tool call" do
expect(context.image_generation_scenario).to eq(
[
{ role: "user", parts: { text: context.system_insts } },
{ parts: { text: "Ok." }, role: "model" },
{ parts: { text: "draw a cat" }, role: "user" },
{ parts: { functionCall: { args: { picture: "Cat" }, name: "draw" } }, role: "model" },
{
parts: {
functionResponse: {
name: "tool_id",
response: {
content: "\"I'm a tool result\"",
},
},
},
role: "function",
},
{ parts: { text: "Ok." }, role: "model" },
{ parts: { text: "draw another cat" }, role: "user" },
],
)
end
it "translates tool_call and tool messages" do
expect(context.multi_turn_scenario).to eq(
[