From 17cc09ec9c4559c1123e46d9da92b686268b8b8b Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 6 Jan 2024 05:21:14 +1100 Subject: [PATCH] FIX: don't include
in context (#406) * FIX: don't include
in context We need to be careful adding
into context of conversations it can cause LLMs to hallucinate results * Fix Gemini multi-turn ctx flattening --------- Co-authored-by: Roman Rizzi --- lib/ai_bot/bot.rb | 66 +++++++++-------- lib/ai_bot/playground.rb | 21 +++--- lib/completions/dialects/gemini.rb | 9 ++- spec/lib/completions/dialects/gemini_spec.rb | 63 ++++++++++++++++ spec/lib/modules/ai_bot/playground_spec.rb | 75 +++++++++++++++++++- spec/lib/modules/ai_bot/tools/dall_e_spec.rb | 4 +- 6 files changed, 188 insertions(+), 50 deletions(-) diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index 6ae77e4b..c67880e5 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -48,42 +48,46 @@ module DiscourseAi llm = DiscourseAi::Completions::Llm.proxy(current_model) tool_found = false - llm.generate(prompt, user: context[:user]) do |partial, cancel| - if (tool = persona.find_tool(partial)) - tool_found = true - ongoing_chain = tool.chain_next_response? - low_cost = tool.low_cost? - tool_call_id = tool.tool_call_id - invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json + result = + llm.generate(prompt, user: context[:user]) do |partial, cancel| + if (tool = persona.find_tool(partial)) + tool_found = true + ongoing_chain = tool.chain_next_response? + low_cost = tool.low_cost? + tool_call_id = tool.tool_call_id + invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json - invocation_context = { - type: "tool", - name: tool_call_id, - content: invocation_result_json, - } - tool_context = { - type: "tool_call", - name: tool_call_id, - content: { name: tool.name, arguments: tool.parameters }.to_json, - } + invocation_context = { + type: "tool", + name: tool_call_id, + content: invocation_result_json, + } + tool_context = { + type: "tool_call", + name: tool_call_id, + content: { name: tool.name, arguments: tool.parameters }.to_json, + } - prompt[:conversation_context] ||= [] + prompt[:conversation_context] ||= [] - if tool.standalone? - prompt[:conversation_context] = [invocation_context, tool_context] + if tool.standalone? + prompt[:conversation_context] = [invocation_context, tool_context] + else + prompt[:conversation_context] = [invocation_context, tool_context] + + prompt[:conversation_context] + end + + raw_context << [tool_context[:content], tool_call_id, "tool_call"] + raw_context << [invocation_result_json, tool_call_id, "tool"] else - prompt[:conversation_context] = [invocation_context, tool_context] + - prompt[:conversation_context] + update_blk.call(partial, cancel, nil) end - - raw_context << [tool_context[:content], tool_call_id, "tool_call"] - raw_context << [invocation_result_json, tool_call_id, "tool"] - else - update_blk.call(partial, cancel, nil) end - end - ongoing_chain = false if !tool_found + if !tool_found + ongoing_chain = false + raw_context << [result, bot_user.username] + end total_completions += 1 # do not allow tools when we are at the end of a chain (total_completions == MAX_COMPLETIONS) @@ -93,10 +97,10 @@ module DiscourseAi raw_context end - private - attr_reader :persona + private + def invoke_tool(tool, llm, cancel, &update_blk) update_blk.call("", cancel, build_placeholder(tool.summary, "")) diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 61b88d07..293c0fbe 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -139,22 +139,19 @@ module DiscourseAi return if reply.blank? - reply_post.tap do |bot_reply| - publish_update(bot_reply, done: true) + publish_update(reply_post, done: true) - bot_reply.revise( - bot.bot_user, - { raw: reply }, - skip_validations: true, - skip_revision: true, - ) + reply_post.revise(bot.bot_user, { raw: reply }, skip_validations: true, skip_revision: true) - bot_reply.post_custom_prompt ||= bot_reply.build_post_custom_prompt(custom_prompt: []) - prompt = bot_reply.post_custom_prompt.custom_prompt || [] + # not need to add a custom prompt for a single reply + if new_custom_prompts.length > 1 + reply_post.post_custom_prompt ||= reply_post.build_post_custom_prompt(custom_prompt: []) + prompt = reply_post.post_custom_prompt.custom_prompt || [] prompt.concat(new_custom_prompts) - prompt << [reply, bot.bot_user.username] - bot_reply.post_custom_prompt.update!(custom_prompt: prompt) + reply_post.post_custom_prompt.update!(custom_prompt: prompt) end + + reply_post end private diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb index 06d1f7df..a127d4e7 100644 --- a/lib/completions/dialects/gemini.rb +++ b/lib/completions/dialects/gemini.rb @@ -130,8 +130,13 @@ module DiscourseAi def flatten_context(context) context.map do |a_context| if a_context[:type] == "multi_turn" - # Drop old tool calls and only keep bot response. - a_context[:content].find { |c| c[:type] == "assistant" } + # Some multi-turn, like the ones that generate images, doesn't chain a next + # response. We don't have an assistant call for those, so we use the tool_call instead. + # We cannot use tool since it confuses the model, making it stop calling tools in next responses, + # and replying with a JSON. + + a_context[:content].find { |c| c[:type] == "assistant" } || + a_context[:content].find { |c| c[:type] == "tool_call" } else a_context end diff --git a/spec/lib/completions/dialects/gemini_spec.rb b/spec/lib/completions/dialects/gemini_spec.rb index cc80611b..528b329a 100644 --- a/spec/lib/completions/dialects/gemini_spec.rb +++ b/spec/lib/completions/dialects/gemini_spec.rb @@ -124,6 +124,69 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do expect(translated_context.last.dig(:parts, :text).length).to be < context.last[:content].length end + + context "when working with multi-turn contexts" do + context "when the multi-turn is for turn that doesn't chain" do + it "uses the tool_call context" do + prompt[:conversation_context] = [ + { + type: "multi_turn", + content: [ + { + type: "tool_call", + name: "get_weather", + content: { + name: "get_weather", + arguments: { + location: "Sydney", + unit: "c", + }, + }.to_json, + }, + { type: "tool", name: "get_weather", content: "I'm a tool result" }, + ], + }, + ] + + translated_context = dialect.conversation_context + + expect(translated_context.size).to eq(1) + expect(translated_context.last[:role]).to eq("model") + expect(translated_context.last.dig(:parts, :functionCall)).to be_present + end + end + + context "when the multi-turn is from a chainable tool" do + it "uses the assistand context" do + prompt[:conversation_context] = [ + { + type: "multi_turn", + content: [ + { + type: "tool_call", + name: "get_weather", + content: { + name: "get_weather", + arguments: { + location: "Sydney", + unit: "c", + }, + }.to_json, + }, + { type: "tool", name: "get_weather", content: "I'm a tool result" }, + { type: "assistant", content: "I'm a bot reply!" }, + ], + }, + ] + + translated_context = dialect.conversation_context + + expect(translated_context.size).to eq(1) + expect(translated_context.last[:role]).to eq("model") + expect(translated_context.last.dig(:parts, :text)).to be_present + end + end + end end describe "#tools" do diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index 255d21af..b9a05297 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -4,11 +4,11 @@ RSpec.describe DiscourseAi::AiBot::Playground do subject(:playground) { described_class.new(bot) } before do - SiteSetting.ai_bot_enabled_chat_bots = "gpt-4" + SiteSetting.ai_bot_enabled_chat_bots = "claude-2" SiteSetting.ai_bot_enabled = true end - let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID) } + let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID) } let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user) } fab!(:user) { Fabricate(:user) } @@ -74,6 +74,77 @@ RSpec.describe DiscourseAi::AiBot::Playground do expect(pm.reload.posts.last.cooked).to eq(PrettyText.cook(expected_bot_response)) end end + + it "does not include placeholders in conversation context but includes all completions" do + response1 = (<<~TXT).strip + + + search + search + + testing various things + + + + TXT + + response2 = "I found some really amazing stuff!" + + DiscourseAi::Completions::Llm.with_prepared_responses([response1, response2]) do + playground.reply_to(third_post) + end + + last_post = third_post.topic.reload.posts.order(:post_number).last + custom_prompt = PostCustomPrompt.where(post_id: last_post.id).first.custom_prompt + + expect(custom_prompt.length).to eq(3) + expect(custom_prompt.to_s).not_to include("
") + expect(custom_prompt.last.first).to eq(response2) + expect(custom_prompt.last.last).to eq(bot_user.username) + end + + context "with Dall E bot" do + let(:bot) do + DiscourseAi::AiBot::Bot.as(bot_user, persona: DiscourseAi::AiBot::Personas::DallE3.new) + end + + it "does not include placeholders in conversation context (simulate DALL-E)" do + SiteSetting.ai_openai_api_key = "123" + + response = (<<~TXT).strip + + + dall_e + dall_e + + ["a pink cow"] + + + + TXT + + image = + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" + + data = [{ b64_json: image, revised_prompt: "a pink cow 1" }] + + WebMock.stub_request(:post, SiteSetting.ai_openai_dall_e_3_url).to_return( + status: 200, + body: { data: data }.to_json, + ) + + DiscourseAi::Completions::Llm.with_prepared_responses([response]) do + playground.reply_to(third_post) + end + + last_post = third_post.topic.reload.posts.order(:post_number).last + custom_prompt = PostCustomPrompt.where(post_id: last_post.id).first.custom_prompt + + # DALL E has custom_raw, we do not want to inject this into the prompt stream + expect(custom_prompt.length).to eq(2) + expect(custom_prompt.to_s).not_to include("
") + end + end end describe "#conversation_context" do diff --git a/spec/lib/modules/ai_bot/tools/dall_e_spec.rb b/spec/lib/modules/ai_bot/tools/dall_e_spec.rb index 058692e9..46200fe6 100644 --- a/spec/lib/modules/ai_bot/tools/dall_e_spec.rb +++ b/spec/lib/modules/ai_bot/tools/dall_e_spec.rb @@ -13,7 +13,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do describe "#process" do it "can generate correct info with azure" do - post = Fabricate(:post) + _post = Fabricate(:post) SiteSetting.ai_openai_api_key = "abc" SiteSetting.ai_openai_dall_e_3_url = "https://test.azure.com/some_url" @@ -43,8 +43,6 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do end it "can generate correct info" do - post = Fabricate(:post) - SiteSetting.ai_openai_api_key = "abc" image =