FIX: don't include <details> in context (#406)

* FIX: don't include <details> in context

We need to be careful adding <details> into context of conversations
it can cause LLMs to hallucinate results

* Fix Gemini multi-turn ctx flattening

---------

Co-authored-by: Roman Rizzi <rizziromanalejandro@gmail.com>
This commit is contained in:
Sam 2024-01-06 05:21:14 +11:00 committed by GitHub
parent 7201d482d5
commit 17cc09ec9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 188 additions and 50 deletions

View File

@ -48,42 +48,46 @@ module DiscourseAi
llm = DiscourseAi::Completions::Llm.proxy(current_model)
tool_found = false
llm.generate(prompt, user: context[:user]) do |partial, cancel|
if (tool = persona.find_tool(partial))
tool_found = true
ongoing_chain = tool.chain_next_response?
low_cost = tool.low_cost?
tool_call_id = tool.tool_call_id
invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json
result =
llm.generate(prompt, user: context[:user]) do |partial, cancel|
if (tool = persona.find_tool(partial))
tool_found = true
ongoing_chain = tool.chain_next_response?
low_cost = tool.low_cost?
tool_call_id = tool.tool_call_id
invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json
invocation_context = {
type: "tool",
name: tool_call_id,
content: invocation_result_json,
}
tool_context = {
type: "tool_call",
name: tool_call_id,
content: { name: tool.name, arguments: tool.parameters }.to_json,
}
invocation_context = {
type: "tool",
name: tool_call_id,
content: invocation_result_json,
}
tool_context = {
type: "tool_call",
name: tool_call_id,
content: { name: tool.name, arguments: tool.parameters }.to_json,
}
prompt[:conversation_context] ||= []
prompt[:conversation_context] ||= []
if tool.standalone?
prompt[:conversation_context] = [invocation_context, tool_context]
if tool.standalone?
prompt[:conversation_context] = [invocation_context, tool_context]
else
prompt[:conversation_context] = [invocation_context, tool_context] +
prompt[:conversation_context]
end
raw_context << [tool_context[:content], tool_call_id, "tool_call"]
raw_context << [invocation_result_json, tool_call_id, "tool"]
else
prompt[:conversation_context] = [invocation_context, tool_context] +
prompt[:conversation_context]
update_blk.call(partial, cancel, nil)
end
raw_context << [tool_context[:content], tool_call_id, "tool_call"]
raw_context << [invocation_result_json, tool_call_id, "tool"]
else
update_blk.call(partial, cancel, nil)
end
end
ongoing_chain = false if !tool_found
if !tool_found
ongoing_chain = false
raw_context << [result, bot_user.username]
end
total_completions += 1
# do not allow tools when we are at the end of a chain (total_completions == MAX_COMPLETIONS)
@ -93,10 +97,10 @@ module DiscourseAi
raw_context
end
private
attr_reader :persona
private
def invoke_tool(tool, llm, cancel, &update_blk)
update_blk.call("", cancel, build_placeholder(tool.summary, ""))

View File

@ -139,22 +139,19 @@ module DiscourseAi
return if reply.blank?
reply_post.tap do |bot_reply|
publish_update(bot_reply, done: true)
publish_update(reply_post, done: true)
bot_reply.revise(
bot.bot_user,
{ raw: reply },
skip_validations: true,
skip_revision: true,
)
reply_post.revise(bot.bot_user, { raw: reply }, skip_validations: true, skip_revision: true)
bot_reply.post_custom_prompt ||= bot_reply.build_post_custom_prompt(custom_prompt: [])
prompt = bot_reply.post_custom_prompt.custom_prompt || []
# not need to add a custom prompt for a single reply
if new_custom_prompts.length > 1
reply_post.post_custom_prompt ||= reply_post.build_post_custom_prompt(custom_prompt: [])
prompt = reply_post.post_custom_prompt.custom_prompt || []
prompt.concat(new_custom_prompts)
prompt << [reply, bot.bot_user.username]
bot_reply.post_custom_prompt.update!(custom_prompt: prompt)
reply_post.post_custom_prompt.update!(custom_prompt: prompt)
end
reply_post
end
private

View File

@ -130,8 +130,13 @@ module DiscourseAi
def flatten_context(context)
context.map do |a_context|
if a_context[:type] == "multi_turn"
# Drop old tool calls and only keep bot response.
a_context[:content].find { |c| c[:type] == "assistant" }
# Some multi-turn, like the ones that generate images, doesn't chain a next
# response. We don't have an assistant call for those, so we use the tool_call instead.
# We cannot use tool since it confuses the model, making it stop calling tools in next responses,
# and replying with a JSON.
a_context[:content].find { |c| c[:type] == "assistant" } ||
a_context[:content].find { |c| c[:type] == "tool_call" }
else
a_context
end

View File

@ -124,6 +124,69 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
expect(translated_context.last.dig(:parts, :text).length).to be <
context.last[:content].length
end
context "when working with multi-turn contexts" do
context "when the multi-turn is for turn that doesn't chain" do
it "uses the tool_call context" do
prompt[:conversation_context] = [
{
type: "multi_turn",
content: [
{
type: "tool_call",
name: "get_weather",
content: {
name: "get_weather",
arguments: {
location: "Sydney",
unit: "c",
},
}.to_json,
},
{ type: "tool", name: "get_weather", content: "I'm a tool result" },
],
},
]
translated_context = dialect.conversation_context
expect(translated_context.size).to eq(1)
expect(translated_context.last[:role]).to eq("model")
expect(translated_context.last.dig(:parts, :functionCall)).to be_present
end
end
context "when the multi-turn is from a chainable tool" do
it "uses the assistand context" do
prompt[:conversation_context] = [
{
type: "multi_turn",
content: [
{
type: "tool_call",
name: "get_weather",
content: {
name: "get_weather",
arguments: {
location: "Sydney",
unit: "c",
},
}.to_json,
},
{ type: "tool", name: "get_weather", content: "I'm a tool result" },
{ type: "assistant", content: "I'm a bot reply!" },
],
},
]
translated_context = dialect.conversation_context
expect(translated_context.size).to eq(1)
expect(translated_context.last[:role]).to eq("model")
expect(translated_context.last.dig(:parts, :text)).to be_present
end
end
end
end
describe "#tools" do

View File

@ -4,11 +4,11 @@ RSpec.describe DiscourseAi::AiBot::Playground do
subject(:playground) { described_class.new(bot) }
before do
SiteSetting.ai_bot_enabled_chat_bots = "gpt-4"
SiteSetting.ai_bot_enabled_chat_bots = "claude-2"
SiteSetting.ai_bot_enabled = true
end
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID) }
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID) }
let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user) }
fab!(:user) { Fabricate(:user) }
@ -74,6 +74,77 @@ RSpec.describe DiscourseAi::AiBot::Playground do
expect(pm.reload.posts.last.cooked).to eq(PrettyText.cook(expected_bot_response))
end
end
it "does not include placeholders in conversation context but includes all completions" do
response1 = (<<~TXT).strip
<function_calls>
<invoke>
<tool_name>search</tool_name>
<tool_id>search</tool_id>
<parameters>
<search_query>testing various things</search_query>
</parameters>
</invoke>
</function_calls>
TXT
response2 = "I found some really amazing stuff!"
DiscourseAi::Completions::Llm.with_prepared_responses([response1, response2]) do
playground.reply_to(third_post)
end
last_post = third_post.topic.reload.posts.order(:post_number).last
custom_prompt = PostCustomPrompt.where(post_id: last_post.id).first.custom_prompt
expect(custom_prompt.length).to eq(3)
expect(custom_prompt.to_s).not_to include("<details>")
expect(custom_prompt.last.first).to eq(response2)
expect(custom_prompt.last.last).to eq(bot_user.username)
end
context "with Dall E bot" do
let(:bot) do
DiscourseAi::AiBot::Bot.as(bot_user, persona: DiscourseAi::AiBot::Personas::DallE3.new)
end
it "does not include placeholders in conversation context (simulate DALL-E)" do
SiteSetting.ai_openai_api_key = "123"
response = (<<~TXT).strip
<function_calls>
<invoke>
<tool_name>dall_e</tool_name>
<tool_id>dall_e</tool_id>
<parameters>
<prompts>["a pink cow"]</prompts>
</parameters>
</invoke>
</function_calls>
TXT
image =
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
data = [{ b64_json: image, revised_prompt: "a pink cow 1" }]
WebMock.stub_request(:post, SiteSetting.ai_openai_dall_e_3_url).to_return(
status: 200,
body: { data: data }.to_json,
)
DiscourseAi::Completions::Llm.with_prepared_responses([response]) do
playground.reply_to(third_post)
end
last_post = third_post.topic.reload.posts.order(:post_number).last
custom_prompt = PostCustomPrompt.where(post_id: last_post.id).first.custom_prompt
# DALL E has custom_raw, we do not want to inject this into the prompt stream
expect(custom_prompt.length).to eq(2)
expect(custom_prompt.to_s).not_to include("<details>")
end
end
end
describe "#conversation_context" do

View File

@ -13,7 +13,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
describe "#process" do
it "can generate correct info with azure" do
post = Fabricate(:post)
_post = Fabricate(:post)
SiteSetting.ai_openai_api_key = "abc"
SiteSetting.ai_openai_dall_e_3_url = "https://test.azure.com/some_url"
@ -43,8 +43,6 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
end
it "can generate correct info" do
post = Fabricate(:post)
SiteSetting.ai_openai_api_key = "abc"
image =