diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb index ee653d5b..2f3563ac 100644 --- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb @@ -120,15 +120,12 @@ module DiscourseAi def permit_tools(tools) return [] if !tools.is_a?(Array) - tools.filter_map do |tool, options| + tools.filter_map do |tool, options, force_tool| break nil if !tool.is_a?(String) options&.permit! if options && options.is_a?(ActionController::Parameters) - if options - [tool, options] - else - tool - end + # this is simpler from a storage perspective, 1 way to store tools + [tool, options, !!force_tool] end end end diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb index 29d31e54..54eb6310 100644 --- a/app/models/ai_persona.rb +++ b/app/models/ai_persona.rb @@ -136,17 +136,23 @@ class AiPersona < ActiveRecord::Base end options = {} + force_tool_use = [] + tools = self.tools.filter_map do |element| klass = nil - if element.is_a?(String) && element.start_with?("custom-") - custom_tool_id = element.split("-", 2).last.to_i + element = [element] if element.is_a?(String) + + inner_name, current_options, should_force_tool_use = + element.is_a?(Array) ? element : [element, nil] + + if inner_name.start_with?("custom-") + custom_tool_id = inner_name.split("-", 2).last.to_i if AiTool.exists?(id: custom_tool_id, enabled: true) klass = DiscourseAi::AiBot::Tools::Custom.class_instance(custom_tool_id) end else - inner_name, current_options = element.is_a?(Array) ? element : [element, nil] inner_name = inner_name.gsub("Tool", "") inner_name = "List#{inner_name}" if %w[Categories Tags].include?(inner_name) @@ -155,9 +161,10 @@ class AiPersona < ActiveRecord::Base options[klass] = current_options if current_options rescue StandardError end - - klass end + + force_tool_use << klass if should_force_tool_use + klass end ai_persona_id = self.id @@ -177,6 +184,7 @@ class AiPersona < ActiveRecord::Base end define_method(:tools) { tools } + define_method(:force_tool_use) { force_tool_use } define_method(:options) { options } define_method(:temperature) { @ai_persona&.temperature } define_method(:top_p) { @ai_persona&.top_p } diff --git a/assets/javascripts/discourse/admin/models/ai-persona.js b/assets/javascripts/discourse/admin/models/ai-persona.js index c12c461c..9344f579 100644 --- a/assets/javascripts/discourse/admin/models/ai-persona.js +++ b/assets/javascripts/discourse/admin/models/ai-persona.js @@ -59,21 +59,25 @@ class ToolOption { export default class AiPersona extends RestModel { // this code is here to convert the wire schema to easier to work with object // on the wire we pass in/out tools as an Array. - // [[ToolName, {option1: value, option2: value}], ToolName2, ToolName3] + // [[ToolName, {option1: value, option2: value}, force], ToolName2, ToolName3] // So we rework this into a "tools" property and nested toolOptions init(properties) { + this.forcedTools = []; if (properties.tools) { properties.tools = properties.tools.map((tool) => { if (typeof tool === "string") { return tool; } else { - let [toolId, options] = tool; + let [toolId, options, force] = tool; for (let optionId in options) { if (!options.hasOwnProperty(optionId)) { continue; } this.getToolOption(toolId, optionId).value = options[optionId]; } + if (force) { + this.forcedTools.push(toolId); + } return toolId; } }); @@ -109,6 +113,8 @@ export default class AiPersona extends RestModel { if (typeof toolId !== "string") { toolId = toolId[0]; } + + let force = this.forcedTools.includes(toolId); if (this.toolOptions && this.toolOptions[toolId]) { let options = this.toolOptions[toolId]; let optionsWithValues = {}; @@ -119,9 +125,9 @@ export default class AiPersona extends RestModel { let option = options[optionId]; optionsWithValues[optionId] = option.value; } - toolsWithOptions.push([toolId, optionsWithValues]); + toolsWithOptions.push([toolId, optionsWithValues, force]); } else { - toolsWithOptions.push(toolId); + toolsWithOptions.push([toolId, {}, force]); } }); attrs.tools = toolsWithOptions; @@ -133,7 +139,6 @@ export default class AiPersona extends RestModel { : this.getProperties(CREATE_ATTRIBUTES); attrs.id = this.id; this.populateToolOptions(attrs); - return attrs; } @@ -146,6 +151,9 @@ export default class AiPersona extends RestModel { workingCopy() { let attrs = this.getProperties(CREATE_ATTRIBUTES); this.populateToolOptions(attrs); - return AiPersona.create(attrs); + + const persona = AiPersona.create(attrs); + persona.forcedTools = (this.forcedTools || []).slice(); + return persona; } } diff --git a/assets/javascripts/discourse/components/ai-persona-editor.gjs b/assets/javascripts/discourse/components/ai-persona-editor.gjs index 139768c1..574911ec 100644 --- a/assets/javascripts/discourse/components/ai-persona-editor.gjs +++ b/assets/javascripts/discourse/components/ai-persona-editor.gjs @@ -40,10 +40,39 @@ export default class PersonaEditor extends Component { @tracked maxPixelsValue = null; @tracked ragIndexingStatuses = null; + @tracked selectedTools = []; + @tracked selectedToolNames = []; + @tracked forcedToolNames = []; + get chatPluginEnabled() { return this.siteSettings.chat_enabled; } + get allowForceTools() { + return !this.editingModel?.system && this.editingModel?.tools?.length > 0; + } + + @action + forcedToolsChanged(tools) { + this.forcedToolNames = tools; + this.editingModel.forcedTools = this.forcedToolNames; + } + + @action + toolsChanged(tools) { + this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) => + tools.includes(tool.id) + ); + this.selectedToolNames = tools.slice(); + + this.forcedToolNames = this.forcedToolNames.filter( + (tool) => this.editingModel.tools.indexOf(tool) !== -1 + ); + + this.editingModel.tools = this.selectedToolNames; + this.editingModel.forcedTools = this.forcedToolNames; + } + @action updateModel() { this.editingModel = this.args.model.workingCopy(); @@ -51,6 +80,12 @@ export default class PersonaEditor extends Component { this.maxPixelsValue = this.findClosestPixelValue( this.editingModel.vision_max_pixels ); + + this.selectedToolNames = this.editingModel.tools || []; + this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) => + this.selectedToolNames.includes(tool.id) + ); + this.forcedToolNames = this.editingModel.forcedTools || []; } findClosestPixelValue(pixels) { @@ -336,15 +371,27 @@ export default class PersonaEditor extends Component { + {{#if this.allowForceTools}} +
+ + +
+ {{/if}} {{#unless this.editingModel.system}} {{/unless}} diff --git a/assets/javascripts/discourse/components/ai-tool-selector.js b/assets/javascripts/discourse/components/ai-tool-selector.js index 0060e06f..c3959eff 100644 --- a/assets/javascripts/discourse/components/ai-tool-selector.js +++ b/assets/javascripts/discourse/components/ai-tool-selector.js @@ -6,7 +6,7 @@ export default MultiSelectComponent.extend({ this.selectKit.options.set("disabled", this.get("attrs.disabled.value")); }), - content: computed(function () { + content: computed("tools", function () { return this.tools; }), diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 7b6d7f35..f4719967 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -148,6 +148,7 @@ en: saved: AI Persona Saved enabled: "Enabled?" tools: Enabled Tools + forced_tools: Forced Tools allowed_groups: Allowed Groups confirm_delete: Are you sure you want to delete this persona? new: "New Persona" diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index e3941e0e..391ee16b 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -67,6 +67,19 @@ module DiscourseAi .last end + def force_tool_if_needed(prompt, context) + context[:chosen_tools] ||= [] + forced_tools = persona.force_tool_use.map { |tool| tool.name } + force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) } + + if force_tool + context[:chosen_tools] << force_tool + prompt.tool_choice = force_tool + else + prompt.tool_choice = nil + end + end + def reply(context, &update_blk) llm = DiscourseAi::Completions::Llm.proxy(model) prompt = persona.craft_prompt(context, llm: llm) @@ -85,6 +98,7 @@ module DiscourseAi while total_completions <= MAX_COMPLETIONS && ongoing_chain tool_found = false + force_tool_if_needed(prompt, context) result = llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel| diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index b46abebe..0a31598c 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -113,6 +113,10 @@ module DiscourseAi [] end + def force_tool_use + [] + end + def required_tools [] end diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 5420e643..fa3a9ca4 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -60,6 +60,10 @@ module DiscourseAi @tools ||= tools_dialect.translated_tools end + def tool_choice + prompt.tool_choice + end + def translate messages = prompt.messages diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index 43f19053..35b3e724 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -54,8 +54,12 @@ module DiscourseAi # We'll fallback to guess this using the tokenizer. payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai" end - - payload[:tools] = dialect.tools if dialect.tools.present? + if dialect.tools.present? + payload[:tools] = dialect.tools + if dialect.tool_choice.present? + payload[:tool_choice] = { type: "function", function: { name: dialect.tool_choice } } + end + end payload end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 527ec87f..445bfc19 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -123,7 +123,7 @@ module DiscourseAi end def record_prompt(prompt) - @prompts << prompt if @prompts + @prompts << prompt.dup if @prompts end def proxy(model) diff --git a/lib/completions/prompt.rb b/lib/completions/prompt.rb index 051818d2..9a6d4d61 100644 --- a/lib/completions/prompt.rb +++ b/lib/completions/prompt.rb @@ -6,7 +6,7 @@ module DiscourseAi INVALID_TURN = Class.new(StandardError) attr_reader :messages - attr_accessor :tools, :topic_id, :post_id, :max_pixels + attr_accessor :tools, :topic_id, :post_id, :max_pixels, :tool_choice def initialize( system_message_text = nil, @@ -14,7 +14,8 @@ module DiscourseAi tools: [], topic_id: nil, post_id: nil, - max_pixels: nil + max_pixels: nil, + tool_choice: nil ) raise ArgumentError, "messages must be an array" if !messages.is_a?(Array) raise ArgumentError, "tools must be an array" if !tools.is_a?(Array) @@ -37,6 +38,7 @@ module DiscourseAi @messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) } @tools = tools + @tool_choice = tool_choice end def push(type:, content:, id: nil, name: nil, upload_ids: nil) diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb index 8c665639..44ff136c 100644 --- a/spec/lib/completions/endpoints/open_ai_spec.rb +++ b/spec/lib/completions/endpoints/open_ai_spec.rb @@ -255,6 +255,95 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do end end + describe "forced tool use" do + it "can properly force tool use" do + llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") + + tools = [ + { + name: "echo", + description: "echo something", + parameters: [ + { name: "text", type: "string", description: "text to echo", required: true }, + ], + }, + ] + + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a bot", + messages: [type: :user, id: "user1", content: "echo hello"], + tools: tools, + tool_choice: "echo", + ) + + response = { + id: "chatcmpl-9JxkAzzaeO4DSV3omWvok9TKhCjBH", + object: "chat.completion", + created: 1_714_544_914, + model: "gpt-4-turbo-2024-04-09", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: nil, + tool_calls: [ + { + id: "call_I8LKnoijVuhKOM85nnEQgWwd", + type: "function", + function: { + name: "echo", + arguments: "{\"text\":\"hello\"}", + }, + }, + ], + }, + logprobs: nil, + finish_reason: "tool_calls", + }, + ], + usage: { + prompt_tokens: 55, + completion_tokens: 13, + total_tokens: 68, + }, + system_fingerprint: "fp_ea6eb70039", + }.to_json + + body_json = nil + stub_request(:post, "https://api.openai.com/v1/chat/completions").with( + body: proc { |body| body_json = JSON.parse(body, symbolize_names: true) }, + ).to_return(body: response) + + result = llm.generate(prompt, user: user) + + expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } }) + + expected = (<<~TXT).strip + + + echo + + hello + + call_I8LKnoijVuhKOM85nnEQgWwd + + + TXT + + expect(result.strip).to eq(expected) + + stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return( + body: { choices: [message: { content: "OK" }] }.to_json, + ) + + result = llm.generate(prompt, user: user) + + expect(result).to eq("OK") + end + end + describe "image support" do it "can handle images" do model = Fabricate(:llm_model, vision_enabled: true) diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index f087414b..c3f054df 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -78,13 +78,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do end let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) } - - it "uses custom tool in conversation" do - persona_klass = AiPersona.all_personas.find { |p| p.name == ai_persona.name } - bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new) - playground = DiscourseAi::AiBot::Playground.new(bot) - - function_call = (<<~XML).strip + let(:function_call) { (<<~XML).strip } search @@ -96,6 +90,32 @@ RSpec.describe DiscourseAi::AiBot::Playground do ", XML + let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) } + + let(:playground) { DiscourseAi::AiBot::Playground.new(bot) } + + it "can force usage of a tool" do + tool_name = "custom-#{custom_tool.id}" + ai_persona.update!(tools: [[tool_name, nil, "force"]]) + responses = [function_call, "custom tool did stuff (maybe)"] + + prompt = nil + DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt| + new_post = Fabricate(:post, raw: "Can you use the custom tool?") + _reply_post = playground.reply_to(new_post) + prompt = _prompt + end + + expect(prompt.length).to eq(2) + expect(prompt[0].tool_choice).to eq("search") + expect(prompt[1].tool_choice).to eq(nil) + end + + it "uses custom tool in conversation" do + persona_klass = AiPersona.all_personas.find { |p| p.name == ai_persona.name } + bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new) + playground = DiscourseAi::AiBot::Playground.new(bot) + responses = [function_call, "custom tool did stuff (maybe)"] reply_post = nil diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb index 203abe33..1b1da5fc 100644 --- a/spec/requests/admin/ai_personas_controller_spec.rb +++ b/spec/requests/admin/ai_personas_controller_spec.rb @@ -160,7 +160,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do name: "superbot", description: "Assists with tasks", system_prompt: "you are a helpful bot", - tools: [["search", { "base_query" => "test" }]], + tools: [["search", { "base_query" => "test" }, true]], top_p: 0.1, temperature: 0.5, mentionable: true, @@ -186,7 +186,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do persona = AiPersona.find(persona_json["id"]) - expect(persona.tools).to eq([["search", { "base_query" => "test" }]]) + expect(persona.tools).to eq([["search", { "base_query" => "test" }, true]]) expect(persona.top_p).to eq(0.1) expect(persona.temperature).to eq(0.5) }.to change(AiPersona, :count).by(1) @@ -296,7 +296,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do ai_persona.reload expect(ai_persona.name).to eq("SuperBot") expect(ai_persona.enabled).to eq(false) - expect(ai_persona.tools).to eq(["search"]) + expect(ai_persona.tools).to eq([["search", nil, false]]) end end diff --git a/spec/system/admin_ai_persona_spec.rb b/spec/system/admin_ai_persona_spec.rb index 7ecd57dd..171cfcc8 100644 --- a/spec/system/admin_ai_persona_spec.rb +++ b/spec/system/admin_ai_persona_spec.rb @@ -30,7 +30,7 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do expect(persona.name).to eq("Test Persona") expect(persona.description).to eq("I am a test persona") expect(persona.system_prompt).to eq("You are a helpful bot") - expect(persona.tools).to eq([["Read", { "read_private" => nil }]]) + expect(persona.tools).to eq([["Read", { "read_private" => nil }, false]]) end it "will not allow deletion or editing of system personas" do diff --git a/test/javascripts/unit/models/ai-persona-test.js b/test/javascripts/unit/models/ai-persona-test.js index f785ffba..c1f25aeb 100644 --- a/test/javascripts/unit/models/ai-persona-test.js +++ b/test/javascripts/unit/models/ai-persona-test.js @@ -60,7 +60,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () { const updatedProperties = aiPersona.updateProperties(); // perform remapping for save - properties.tools = [["ToolName", { option1: "value1" }]]; + properties.tools = [["ToolName", { option1: "value1" }, false]]; assert.deepEqual(updatedProperties, properties); }); @@ -100,7 +100,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () { const createdProperties = aiPersona.createProperties(); - properties.tools = [["ToolName", { option1: "value1" }]]; + properties.tools = [["ToolName", { option1: "value1" }, false]]; assert.deepEqual(createdProperties, properties); });