diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb
index 6ae77e4b..c67880e5 100644
--- a/lib/ai_bot/bot.rb
+++ b/lib/ai_bot/bot.rb
@@ -48,42 +48,46 @@ module DiscourseAi
llm = DiscourseAi::Completions::Llm.proxy(current_model)
tool_found = false
- llm.generate(prompt, user: context[:user]) do |partial, cancel|
- if (tool = persona.find_tool(partial))
- tool_found = true
- ongoing_chain = tool.chain_next_response?
- low_cost = tool.low_cost?
- tool_call_id = tool.tool_call_id
- invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json
+ result =
+ llm.generate(prompt, user: context[:user]) do |partial, cancel|
+ if (tool = persona.find_tool(partial))
+ tool_found = true
+ ongoing_chain = tool.chain_next_response?
+ low_cost = tool.low_cost?
+ tool_call_id = tool.tool_call_id
+ invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json
- invocation_context = {
- type: "tool",
- name: tool_call_id,
- content: invocation_result_json,
- }
- tool_context = {
- type: "tool_call",
- name: tool_call_id,
- content: { name: tool.name, arguments: tool.parameters }.to_json,
- }
+ invocation_context = {
+ type: "tool",
+ name: tool_call_id,
+ content: invocation_result_json,
+ }
+ tool_context = {
+ type: "tool_call",
+ name: tool_call_id,
+ content: { name: tool.name, arguments: tool.parameters }.to_json,
+ }
- prompt[:conversation_context] ||= []
+ prompt[:conversation_context] ||= []
- if tool.standalone?
- prompt[:conversation_context] = [invocation_context, tool_context]
+ if tool.standalone?
+ prompt[:conversation_context] = [invocation_context, tool_context]
+ else
+ prompt[:conversation_context] = [invocation_context, tool_context] +
+ prompt[:conversation_context]
+ end
+
+ raw_context << [tool_context[:content], tool_call_id, "tool_call"]
+ raw_context << [invocation_result_json, tool_call_id, "tool"]
else
- prompt[:conversation_context] = [invocation_context, tool_context] +
- prompt[:conversation_context]
+ update_blk.call(partial, cancel, nil)
end
-
- raw_context << [tool_context[:content], tool_call_id, "tool_call"]
- raw_context << [invocation_result_json, tool_call_id, "tool"]
- else
- update_blk.call(partial, cancel, nil)
end
- end
- ongoing_chain = false if !tool_found
+ if !tool_found
+ ongoing_chain = false
+ raw_context << [result, bot_user.username]
+ end
total_completions += 1
# do not allow tools when we are at the end of a chain (total_completions == MAX_COMPLETIONS)
@@ -93,10 +97,10 @@ module DiscourseAi
raw_context
end
- private
-
attr_reader :persona
+ private
+
def invoke_tool(tool, llm, cancel, &update_blk)
update_blk.call("", cancel, build_placeholder(tool.summary, ""))
diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb
index 61b88d07..293c0fbe 100644
--- a/lib/ai_bot/playground.rb
+++ b/lib/ai_bot/playground.rb
@@ -139,22 +139,19 @@ module DiscourseAi
return if reply.blank?
- reply_post.tap do |bot_reply|
- publish_update(bot_reply, done: true)
+ publish_update(reply_post, done: true)
- bot_reply.revise(
- bot.bot_user,
- { raw: reply },
- skip_validations: true,
- skip_revision: true,
- )
+ reply_post.revise(bot.bot_user, { raw: reply }, skip_validations: true, skip_revision: true)
- bot_reply.post_custom_prompt ||= bot_reply.build_post_custom_prompt(custom_prompt: [])
- prompt = bot_reply.post_custom_prompt.custom_prompt || []
+ # not need to add a custom prompt for a single reply
+ if new_custom_prompts.length > 1
+ reply_post.post_custom_prompt ||= reply_post.build_post_custom_prompt(custom_prompt: [])
+ prompt = reply_post.post_custom_prompt.custom_prompt || []
prompt.concat(new_custom_prompts)
- prompt << [reply, bot.bot_user.username]
- bot_reply.post_custom_prompt.update!(custom_prompt: prompt)
+ reply_post.post_custom_prompt.update!(custom_prompt: prompt)
end
+
+ reply_post
end
private
diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb
index 06d1f7df..a127d4e7 100644
--- a/lib/completions/dialects/gemini.rb
+++ b/lib/completions/dialects/gemini.rb
@@ -130,8 +130,13 @@ module DiscourseAi
def flatten_context(context)
context.map do |a_context|
if a_context[:type] == "multi_turn"
- # Drop old tool calls and only keep bot response.
- a_context[:content].find { |c| c[:type] == "assistant" }
+ # Some multi-turn, like the ones that generate images, doesn't chain a next
+ # response. We don't have an assistant call for those, so we use the tool_call instead.
+ # We cannot use tool since it confuses the model, making it stop calling tools in next responses,
+ # and replying with a JSON.
+
+ a_context[:content].find { |c| c[:type] == "assistant" } ||
+ a_context[:content].find { |c| c[:type] == "tool_call" }
else
a_context
end
diff --git a/spec/lib/completions/dialects/gemini_spec.rb b/spec/lib/completions/dialects/gemini_spec.rb
index cc80611b..528b329a 100644
--- a/spec/lib/completions/dialects/gemini_spec.rb
+++ b/spec/lib/completions/dialects/gemini_spec.rb
@@ -124,6 +124,69 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
expect(translated_context.last.dig(:parts, :text).length).to be <
context.last[:content].length
end
+
+ context "when working with multi-turn contexts" do
+ context "when the multi-turn is for turn that doesn't chain" do
+ it "uses the tool_call context" do
+ prompt[:conversation_context] = [
+ {
+ type: "multi_turn",
+ content: [
+ {
+ type: "tool_call",
+ name: "get_weather",
+ content: {
+ name: "get_weather",
+ arguments: {
+ location: "Sydney",
+ unit: "c",
+ },
+ }.to_json,
+ },
+ { type: "tool", name: "get_weather", content: "I'm a tool result" },
+ ],
+ },
+ ]
+
+ translated_context = dialect.conversation_context
+
+ expect(translated_context.size).to eq(1)
+ expect(translated_context.last[:role]).to eq("model")
+ expect(translated_context.last.dig(:parts, :functionCall)).to be_present
+ end
+ end
+
+ context "when the multi-turn is from a chainable tool" do
+ it "uses the assistand context" do
+ prompt[:conversation_context] = [
+ {
+ type: "multi_turn",
+ content: [
+ {
+ type: "tool_call",
+ name: "get_weather",
+ content: {
+ name: "get_weather",
+ arguments: {
+ location: "Sydney",
+ unit: "c",
+ },
+ }.to_json,
+ },
+ { type: "tool", name: "get_weather", content: "I'm a tool result" },
+ { type: "assistant", content: "I'm a bot reply!" },
+ ],
+ },
+ ]
+
+ translated_context = dialect.conversation_context
+
+ expect(translated_context.size).to eq(1)
+ expect(translated_context.last[:role]).to eq("model")
+ expect(translated_context.last.dig(:parts, :text)).to be_present
+ end
+ end
+ end
end
describe "#tools" do
diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb
index 255d21af..b9a05297 100644
--- a/spec/lib/modules/ai_bot/playground_spec.rb
+++ b/spec/lib/modules/ai_bot/playground_spec.rb
@@ -4,11 +4,11 @@ RSpec.describe DiscourseAi::AiBot::Playground do
subject(:playground) { described_class.new(bot) }
before do
- SiteSetting.ai_bot_enabled_chat_bots = "gpt-4"
+ SiteSetting.ai_bot_enabled_chat_bots = "claude-2"
SiteSetting.ai_bot_enabled = true
end
- let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID) }
+ let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID) }
let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user) }
fab!(:user) { Fabricate(:user) }
@@ -74,6 +74,77 @@ RSpec.describe DiscourseAi::AiBot::Playground do
expect(pm.reload.posts.last.cooked).to eq(PrettyText.cook(expected_bot_response))
end
end
+
+ it "does not include placeholders in conversation context but includes all completions" do
+ response1 = (<<~TXT).strip
+
+
+ search
+ search
+
+ testing various things
+
+
+
+ TXT
+
+ response2 = "I found some really amazing stuff!"
+
+ DiscourseAi::Completions::Llm.with_prepared_responses([response1, response2]) do
+ playground.reply_to(third_post)
+ end
+
+ last_post = third_post.topic.reload.posts.order(:post_number).last
+ custom_prompt = PostCustomPrompt.where(post_id: last_post.id).first.custom_prompt
+
+ expect(custom_prompt.length).to eq(3)
+ expect(custom_prompt.to_s).not_to include("")
+ expect(custom_prompt.last.first).to eq(response2)
+ expect(custom_prompt.last.last).to eq(bot_user.username)
+ end
+
+ context "with Dall E bot" do
+ let(:bot) do
+ DiscourseAi::AiBot::Bot.as(bot_user, persona: DiscourseAi::AiBot::Personas::DallE3.new)
+ end
+
+ it "does not include placeholders in conversation context (simulate DALL-E)" do
+ SiteSetting.ai_openai_api_key = "123"
+
+ response = (<<~TXT).strip
+
+
+ dall_e
+ dall_e
+
+ ["a pink cow"]
+
+
+
+ TXT
+
+ image =
+ "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
+
+ data = [{ b64_json: image, revised_prompt: "a pink cow 1" }]
+
+ WebMock.stub_request(:post, SiteSetting.ai_openai_dall_e_3_url).to_return(
+ status: 200,
+ body: { data: data }.to_json,
+ )
+
+ DiscourseAi::Completions::Llm.with_prepared_responses([response]) do
+ playground.reply_to(third_post)
+ end
+
+ last_post = third_post.topic.reload.posts.order(:post_number).last
+ custom_prompt = PostCustomPrompt.where(post_id: last_post.id).first.custom_prompt
+
+ # DALL E has custom_raw, we do not want to inject this into the prompt stream
+ expect(custom_prompt.length).to eq(2)
+ expect(custom_prompt.to_s).not_to include("")
+ end
+ end
end
describe "#conversation_context" do
diff --git a/spec/lib/modules/ai_bot/tools/dall_e_spec.rb b/spec/lib/modules/ai_bot/tools/dall_e_spec.rb
index 058692e9..46200fe6 100644
--- a/spec/lib/modules/ai_bot/tools/dall_e_spec.rb
+++ b/spec/lib/modules/ai_bot/tools/dall_e_spec.rb
@@ -13,7 +13,7 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
describe "#process" do
it "can generate correct info with azure" do
- post = Fabricate(:post)
+ _post = Fabricate(:post)
SiteSetting.ai_openai_api_key = "abc"
SiteSetting.ai_openai_dall_e_3_url = "https://test.azure.com/some_url"
@@ -43,8 +43,6 @@ RSpec.describe DiscourseAi::AiBot::Tools::DallE do
end
it "can generate correct info" do
- post = Fabricate(:post)
-
SiteSetting.ai_openai_api_key = "abc"
image =