From 545500b32967d9806b15899d3ed3d125a5545f30 Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 5 Oct 2024 08:46:57 +0900 Subject: [PATCH] FEATURE: allows forced LLM tool use (#818) * FEATURE: allows forced LLM tool use Sometimes we need to force LLMs to use tools, for example in RAG like use cases we may want to force an unconditional search. The new framework allows you backend to force tool usage. Front end commit to follow * UI for forcing tools now works, but it does not react right * fix bugs * fix tests, this is now ready for review --- .../admin/ai_personas_controller.rb | 9 +- app/models/ai_persona.rb | 18 ++-- .../discourse/admin/models/ai-persona.js | 20 +++-- .../components/ai-persona-editor.gjs | 51 ++++++++++- .../discourse/components/ai-tool-selector.js | 2 +- config/locales/client.en.yml | 1 + lib/ai_bot/bot.rb | 14 +++ lib/ai_bot/personas/persona.rb | 4 + lib/completions/dialects/dialect.rb | 4 + lib/completions/endpoints/open_ai.rb | 8 +- lib/completions/llm.rb | 2 +- lib/completions/prompt.rb | 6 +- .../lib/completions/endpoints/open_ai_spec.rb | 89 +++++++++++++++++++ spec/lib/modules/ai_bot/playground_spec.rb | 34 +++++-- .../admin/ai_personas_controller_spec.rb | 6 +- spec/system/admin_ai_persona_spec.rb | 2 +- .../unit/models/ai-persona-test.js | 4 +- 17 files changed, 236 insertions(+), 38 deletions(-) diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb index ee653d5b..2f3563ac 100644 --- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb @@ -120,15 +120,12 @@ module DiscourseAi def permit_tools(tools) return [] if !tools.is_a?(Array) - tools.filter_map do |tool, options| + tools.filter_map do |tool, options, force_tool| break nil if !tool.is_a?(String) options&.permit! if options && options.is_a?(ActionController::Parameters) - if options - [tool, options] - else - tool - end + # this is simpler from a storage perspective, 1 way to store tools + [tool, options, !!force_tool] end end end diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb index 29d31e54..54eb6310 100644 --- a/app/models/ai_persona.rb +++ b/app/models/ai_persona.rb @@ -136,17 +136,23 @@ class AiPersona < ActiveRecord::Base end options = {} + force_tool_use = [] + tools = self.tools.filter_map do |element| klass = nil - if element.is_a?(String) && element.start_with?("custom-") - custom_tool_id = element.split("-", 2).last.to_i + element = [element] if element.is_a?(String) + + inner_name, current_options, should_force_tool_use = + element.is_a?(Array) ? element : [element, nil] + + if inner_name.start_with?("custom-") + custom_tool_id = inner_name.split("-", 2).last.to_i if AiTool.exists?(id: custom_tool_id, enabled: true) klass = DiscourseAi::AiBot::Tools::Custom.class_instance(custom_tool_id) end else - inner_name, current_options = element.is_a?(Array) ? element : [element, nil] inner_name = inner_name.gsub("Tool", "") inner_name = "List#{inner_name}" if %w[Categories Tags].include?(inner_name) @@ -155,9 +161,10 @@ class AiPersona < ActiveRecord::Base options[klass] = current_options if current_options rescue StandardError end - - klass end + + force_tool_use << klass if should_force_tool_use + klass end ai_persona_id = self.id @@ -177,6 +184,7 @@ class AiPersona < ActiveRecord::Base end define_method(:tools) { tools } + define_method(:force_tool_use) { force_tool_use } define_method(:options) { options } define_method(:temperature) { @ai_persona&.temperature } define_method(:top_p) { @ai_persona&.top_p } diff --git a/assets/javascripts/discourse/admin/models/ai-persona.js b/assets/javascripts/discourse/admin/models/ai-persona.js index c12c461c..9344f579 100644 --- a/assets/javascripts/discourse/admin/models/ai-persona.js +++ b/assets/javascripts/discourse/admin/models/ai-persona.js @@ -59,21 +59,25 @@ class ToolOption { export default class AiPersona extends RestModel { // this code is here to convert the wire schema to easier to work with object // on the wire we pass in/out tools as an Array. - // [[ToolName, {option1: value, option2: value}], ToolName2, ToolName3] + // [[ToolName, {option1: value, option2: value}, force], ToolName2, ToolName3] // So we rework this into a "tools" property and nested toolOptions init(properties) { + this.forcedTools = []; if (properties.tools) { properties.tools = properties.tools.map((tool) => { if (typeof tool === "string") { return tool; } else { - let [toolId, options] = tool; + let [toolId, options, force] = tool; for (let optionId in options) { if (!options.hasOwnProperty(optionId)) { continue; } this.getToolOption(toolId, optionId).value = options[optionId]; } + if (force) { + this.forcedTools.push(toolId); + } return toolId; } }); @@ -109,6 +113,8 @@ export default class AiPersona extends RestModel { if (typeof toolId !== "string") { toolId = toolId[0]; } + + let force = this.forcedTools.includes(toolId); if (this.toolOptions && this.toolOptions[toolId]) { let options = this.toolOptions[toolId]; let optionsWithValues = {}; @@ -119,9 +125,9 @@ export default class AiPersona extends RestModel { let option = options[optionId]; optionsWithValues[optionId] = option.value; } - toolsWithOptions.push([toolId, optionsWithValues]); + toolsWithOptions.push([toolId, optionsWithValues, force]); } else { - toolsWithOptions.push(toolId); + toolsWithOptions.push([toolId, {}, force]); } }); attrs.tools = toolsWithOptions; @@ -133,7 +139,6 @@ export default class AiPersona extends RestModel { : this.getProperties(CREATE_ATTRIBUTES); attrs.id = this.id; this.populateToolOptions(attrs); - return attrs; } @@ -146,6 +151,9 @@ export default class AiPersona extends RestModel { workingCopy() { let attrs = this.getProperties(CREATE_ATTRIBUTES); this.populateToolOptions(attrs); - return AiPersona.create(attrs); + + const persona = AiPersona.create(attrs); + persona.forcedTools = (this.forcedTools || []).slice(); + return persona; } } diff --git a/assets/javascripts/discourse/components/ai-persona-editor.gjs b/assets/javascripts/discourse/components/ai-persona-editor.gjs index 139768c1..574911ec 100644 --- a/assets/javascripts/discourse/components/ai-persona-editor.gjs +++ b/assets/javascripts/discourse/components/ai-persona-editor.gjs @@ -40,10 +40,39 @@ export default class PersonaEditor extends Component { @tracked maxPixelsValue = null; @tracked ragIndexingStatuses = null; + @tracked selectedTools = []; + @tracked selectedToolNames = []; + @tracked forcedToolNames = []; + get chatPluginEnabled() { return this.siteSettings.chat_enabled; } + get allowForceTools() { + return !this.editingModel?.system && this.editingModel?.tools?.length > 0; + } + + @action + forcedToolsChanged(tools) { + this.forcedToolNames = tools; + this.editingModel.forcedTools = this.forcedToolNames; + } + + @action + toolsChanged(tools) { + this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) => + tools.includes(tool.id) + ); + this.selectedToolNames = tools.slice(); + + this.forcedToolNames = this.forcedToolNames.filter( + (tool) => this.editingModel.tools.indexOf(tool) !== -1 + ); + + this.editingModel.tools = this.selectedToolNames; + this.editingModel.forcedTools = this.forcedToolNames; + } + @action updateModel() { this.editingModel = this.args.model.workingCopy(); @@ -51,6 +80,12 @@ export default class PersonaEditor extends Component { this.maxPixelsValue = this.findClosestPixelValue( this.editingModel.vision_max_pixels ); + + this.selectedToolNames = this.editingModel.tools || []; + this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) => + this.selectedToolNames.includes(tool.id) + ); + this.forcedToolNames = this.editingModel.forcedTools || []; } findClosestPixelValue(pixels) { @@ -336,15 +371,27 @@ export default class PersonaEditor extends Component { + {{#if this.allowForceTools}} +
+ + +
+ {{/if}} {{#unless this.editingModel.system}} {{/unless}} diff --git a/assets/javascripts/discourse/components/ai-tool-selector.js b/assets/javascripts/discourse/components/ai-tool-selector.js index 0060e06f..c3959eff 100644 --- a/assets/javascripts/discourse/components/ai-tool-selector.js +++ b/assets/javascripts/discourse/components/ai-tool-selector.js @@ -6,7 +6,7 @@ export default MultiSelectComponent.extend({ this.selectKit.options.set("disabled", this.get("attrs.disabled.value")); }), - content: computed(function () { + content: computed("tools", function () { return this.tools; }), diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 7b6d7f35..f4719967 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -148,6 +148,7 @@ en: saved: AI Persona Saved enabled: "Enabled?" tools: Enabled Tools + forced_tools: Forced Tools allowed_groups: Allowed Groups confirm_delete: Are you sure you want to delete this persona? new: "New Persona" diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index e3941e0e..391ee16b 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -67,6 +67,19 @@ module DiscourseAi .last end + def force_tool_if_needed(prompt, context) + context[:chosen_tools] ||= [] + forced_tools = persona.force_tool_use.map { |tool| tool.name } + force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) } + + if force_tool + context[:chosen_tools] << force_tool + prompt.tool_choice = force_tool + else + prompt.tool_choice = nil + end + end + def reply(context, &update_blk) llm = DiscourseAi::Completions::Llm.proxy(model) prompt = persona.craft_prompt(context, llm: llm) @@ -85,6 +98,7 @@ module DiscourseAi while total_completions <= MAX_COMPLETIONS && ongoing_chain tool_found = false + force_tool_if_needed(prompt, context) result = llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel| diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index b46abebe..0a31598c 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -113,6 +113,10 @@ module DiscourseAi [] end + def force_tool_use + [] + end + def required_tools [] end diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 5420e643..fa3a9ca4 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -60,6 +60,10 @@ module DiscourseAi @tools ||= tools_dialect.translated_tools end + def tool_choice + prompt.tool_choice + end + def translate messages = prompt.messages diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index 43f19053..35b3e724 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -54,8 +54,12 @@ module DiscourseAi # We'll fallback to guess this using the tokenizer. payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai" end - - payload[:tools] = dialect.tools if dialect.tools.present? + if dialect.tools.present? + payload[:tools] = dialect.tools + if dialect.tool_choice.present? + payload[:tool_choice] = { type: "function", function: { name: dialect.tool_choice } } + end + end payload end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 527ec87f..445bfc19 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -123,7 +123,7 @@ module DiscourseAi end def record_prompt(prompt) - @prompts << prompt if @prompts + @prompts << prompt.dup if @prompts end def proxy(model) diff --git a/lib/completions/prompt.rb b/lib/completions/prompt.rb index 051818d2..9a6d4d61 100644 --- a/lib/completions/prompt.rb +++ b/lib/completions/prompt.rb @@ -6,7 +6,7 @@ module DiscourseAi INVALID_TURN = Class.new(StandardError) attr_reader :messages - attr_accessor :tools, :topic_id, :post_id, :max_pixels + attr_accessor :tools, :topic_id, :post_id, :max_pixels, :tool_choice def initialize( system_message_text = nil, @@ -14,7 +14,8 @@ module DiscourseAi tools: [], topic_id: nil, post_id: nil, - max_pixels: nil + max_pixels: nil, + tool_choice: nil ) raise ArgumentError, "messages must be an array" if !messages.is_a?(Array) raise ArgumentError, "tools must be an array" if !tools.is_a?(Array) @@ -37,6 +38,7 @@ module DiscourseAi @messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) } @tools = tools + @tool_choice = tool_choice end def push(type:, content:, id: nil, name: nil, upload_ids: nil) diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb index 8c665639..44ff136c 100644 --- a/spec/lib/completions/endpoints/open_ai_spec.rb +++ b/spec/lib/completions/endpoints/open_ai_spec.rb @@ -255,6 +255,95 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do end end + describe "forced tool use" do + it "can properly force tool use" do + llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") + + tools = [ + { + name: "echo", + description: "echo something", + parameters: [ + { name: "text", type: "string", description: "text to echo", required: true }, + ], + }, + ] + + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a bot", + messages: [type: :user, id: "user1", content: "echo hello"], + tools: tools, + tool_choice: "echo", + ) + + response = { + id: "chatcmpl-9JxkAzzaeO4DSV3omWvok9TKhCjBH", + object: "chat.completion", + created: 1_714_544_914, + model: "gpt-4-turbo-2024-04-09", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: nil, + tool_calls: [ + { + id: "call_I8LKnoijVuhKOM85nnEQgWwd", + type: "function", + function: { + name: "echo", + arguments: "{\"text\":\"hello\"}", + }, + }, + ], + }, + logprobs: nil, + finish_reason: "tool_calls", + }, + ], + usage: { + prompt_tokens: 55, + completion_tokens: 13, + total_tokens: 68, + }, + system_fingerprint: "fp_ea6eb70039", + }.to_json + + body_json = nil + stub_request(:post, "https://api.openai.com/v1/chat/completions").with( + body: proc { |body| body_json = JSON.parse(body, symbolize_names: true) }, + ).to_return(body: response) + + result = llm.generate(prompt, user: user) + + expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } }) + + expected = (<<~TXT).strip + + + echo + + hello + + call_I8LKnoijVuhKOM85nnEQgWwd + + + TXT + + expect(result.strip).to eq(expected) + + stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return( + body: { choices: [message: { content: "OK" }] }.to_json, + ) + + result = llm.generate(prompt, user: user) + + expect(result).to eq("OK") + end + end + describe "image support" do it "can handle images" do model = Fabricate(:llm_model, vision_enabled: true) diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index f087414b..c3f054df 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -78,13 +78,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do end let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) } - - it "uses custom tool in conversation" do - persona_klass = AiPersona.all_personas.find { |p| p.name == ai_persona.name } - bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new) - playground = DiscourseAi::AiBot::Playground.new(bot) - - function_call = (<<~XML).strip + let(:function_call) { (<<~XML).strip } search @@ -96,6 +90,32 @@ RSpec.describe DiscourseAi::AiBot::Playground do ", XML + let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) } + + let(:playground) { DiscourseAi::AiBot::Playground.new(bot) } + + it "can force usage of a tool" do + tool_name = "custom-#{custom_tool.id}" + ai_persona.update!(tools: [[tool_name, nil, "force"]]) + responses = [function_call, "custom tool did stuff (maybe)"] + + prompt = nil + DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt| + new_post = Fabricate(:post, raw: "Can you use the custom tool?") + _reply_post = playground.reply_to(new_post) + prompt = _prompt + end + + expect(prompt.length).to eq(2) + expect(prompt[0].tool_choice).to eq("search") + expect(prompt[1].tool_choice).to eq(nil) + end + + it "uses custom tool in conversation" do + persona_klass = AiPersona.all_personas.find { |p| p.name == ai_persona.name } + bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new) + playground = DiscourseAi::AiBot::Playground.new(bot) + responses = [function_call, "custom tool did stuff (maybe)"] reply_post = nil diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb index 203abe33..1b1da5fc 100644 --- a/spec/requests/admin/ai_personas_controller_spec.rb +++ b/spec/requests/admin/ai_personas_controller_spec.rb @@ -160,7 +160,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do name: "superbot", description: "Assists with tasks", system_prompt: "you are a helpful bot", - tools: [["search", { "base_query" => "test" }]], + tools: [["search", { "base_query" => "test" }, true]], top_p: 0.1, temperature: 0.5, mentionable: true, @@ -186,7 +186,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do persona = AiPersona.find(persona_json["id"]) - expect(persona.tools).to eq([["search", { "base_query" => "test" }]]) + expect(persona.tools).to eq([["search", { "base_query" => "test" }, true]]) expect(persona.top_p).to eq(0.1) expect(persona.temperature).to eq(0.5) }.to change(AiPersona, :count).by(1) @@ -296,7 +296,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do ai_persona.reload expect(ai_persona.name).to eq("SuperBot") expect(ai_persona.enabled).to eq(false) - expect(ai_persona.tools).to eq(["search"]) + expect(ai_persona.tools).to eq([["search", nil, false]]) end end diff --git a/spec/system/admin_ai_persona_spec.rb b/spec/system/admin_ai_persona_spec.rb index 7ecd57dd..171cfcc8 100644 --- a/spec/system/admin_ai_persona_spec.rb +++ b/spec/system/admin_ai_persona_spec.rb @@ -30,7 +30,7 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do expect(persona.name).to eq("Test Persona") expect(persona.description).to eq("I am a test persona") expect(persona.system_prompt).to eq("You are a helpful bot") - expect(persona.tools).to eq([["Read", { "read_private" => nil }]]) + expect(persona.tools).to eq([["Read", { "read_private" => nil }, false]]) end it "will not allow deletion or editing of system personas" do diff --git a/test/javascripts/unit/models/ai-persona-test.js b/test/javascripts/unit/models/ai-persona-test.js index f785ffba..c1f25aeb 100644 --- a/test/javascripts/unit/models/ai-persona-test.js +++ b/test/javascripts/unit/models/ai-persona-test.js @@ -60,7 +60,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () { const updatedProperties = aiPersona.updateProperties(); // perform remapping for save - properties.tools = [["ToolName", { option1: "value1" }]]; + properties.tools = [["ToolName", { option1: "value1" }, false]]; assert.deepEqual(updatedProperties, properties); }); @@ -100,7 +100,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () { const createdProperties = aiPersona.createProperties(); - properties.tools = [["ToolName", { option1: "value1" }]]; + properties.tools = [["ToolName", { option1: "value1" }, false]]; assert.deepEqual(createdProperties, properties); });