diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index 78eb8b98..73d00690 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -75,6 +75,7 @@ module DiscourseAi { role: "assistant", + content: nil, tool_calls: [{ type: "function", function: function, id: context[:name] }], } else diff --git a/spec/lib/completions/dialects/chat_gpt_spec.rb b/spec/lib/completions/dialects/chat_gpt_spec.rb index 84e348bc..9fe23942 100644 --- a/spec/lib/completions/dialects/chat_gpt_spec.rb +++ b/spec/lib/completions/dialects/chat_gpt_spec.rb @@ -85,6 +85,11 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do { type: "user", name: "user1", content: "This is a new message by a user" }, { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, { type: "tool", name: "tool_id", content: "I'm a tool result" }, + { + type: "tool_call", + name: "tool_id", + content: { name: "get_weather", arguments: { location: "Sydney", unit: "c" } }.to_json, + }, ] end @@ -95,7 +100,21 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do expect(translated_context).to eq( [ - { role: "tool", content: context.last[:content], tool_call_id: context.last[:name] }, + { + role: "assistant", + content: nil, + tool_calls: [ + { + type: "function", + function: { + name: "get_weather", + arguments: { location: "Sydney", unit: "c" }.to_json, + }, + id: "tool_id", + }, + ], + }, + { role: "tool", content: context.third[:content], tool_call_id: context.third[:name] }, { role: "assistant", content: context.second[:content] }, { role: "user", content: context.first[:content], name: context.first[:name] }, ], @@ -103,13 +122,13 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do end it "trims content if it's getting too long" do - context.last[:content] = context.last[:content] * 1000 + context.third[:content] = context.third[:content] * 1000 prompt[:conversation_context] = context translated_context = dialect.conversation_context - expect(translated_context.last[:content].length).to be < context.last[:content].length + expect(translated_context.third[:content].length).to be < context.third[:content].length end end