diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb
index ee653d5b..2f3563ac 100644
--- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb
+++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb
@@ -120,15 +120,12 @@ module DiscourseAi
def permit_tools(tools)
return [] if !tools.is_a?(Array)
- tools.filter_map do |tool, options|
+ tools.filter_map do |tool, options, force_tool|
break nil if !tool.is_a?(String)
options&.permit! if options && options.is_a?(ActionController::Parameters)
- if options
- [tool, options]
- else
- tool
- end
+ # this is simpler from a storage perspective, 1 way to store tools
+ [tool, options, !!force_tool]
end
end
end
diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb
index 29d31e54..54eb6310 100644
--- a/app/models/ai_persona.rb
+++ b/app/models/ai_persona.rb
@@ -136,17 +136,23 @@ class AiPersona < ActiveRecord::Base
end
options = {}
+ force_tool_use = []
+
tools =
self.tools.filter_map do |element|
klass = nil
- if element.is_a?(String) && element.start_with?("custom-")
- custom_tool_id = element.split("-", 2).last.to_i
+ element = [element] if element.is_a?(String)
+
+ inner_name, current_options, should_force_tool_use =
+ element.is_a?(Array) ? element : [element, nil]
+
+ if inner_name.start_with?("custom-")
+ custom_tool_id = inner_name.split("-", 2).last.to_i
if AiTool.exists?(id: custom_tool_id, enabled: true)
klass = DiscourseAi::AiBot::Tools::Custom.class_instance(custom_tool_id)
end
else
- inner_name, current_options = element.is_a?(Array) ? element : [element, nil]
inner_name = inner_name.gsub("Tool", "")
inner_name = "List#{inner_name}" if %w[Categories Tags].include?(inner_name)
@@ -155,9 +161,10 @@ class AiPersona < ActiveRecord::Base
options[klass] = current_options if current_options
rescue StandardError
end
-
- klass
end
+
+ force_tool_use << klass if should_force_tool_use
+ klass
end
ai_persona_id = self.id
@@ -177,6 +184,7 @@ class AiPersona < ActiveRecord::Base
end
define_method(:tools) { tools }
+ define_method(:force_tool_use) { force_tool_use }
define_method(:options) { options }
define_method(:temperature) { @ai_persona&.temperature }
define_method(:top_p) { @ai_persona&.top_p }
diff --git a/assets/javascripts/discourse/admin/models/ai-persona.js b/assets/javascripts/discourse/admin/models/ai-persona.js
index c12c461c..9344f579 100644
--- a/assets/javascripts/discourse/admin/models/ai-persona.js
+++ b/assets/javascripts/discourse/admin/models/ai-persona.js
@@ -59,21 +59,25 @@ class ToolOption {
export default class AiPersona extends RestModel {
// this code is here to convert the wire schema to easier to work with object
// on the wire we pass in/out tools as an Array.
- // [[ToolName, {option1: value, option2: value}], ToolName2, ToolName3]
+ // [[ToolName, {option1: value, option2: value}, force], ToolName2, ToolName3]
// So we rework this into a "tools" property and nested toolOptions
init(properties) {
+ this.forcedTools = [];
if (properties.tools) {
properties.tools = properties.tools.map((tool) => {
if (typeof tool === "string") {
return tool;
} else {
- let [toolId, options] = tool;
+ let [toolId, options, force] = tool;
for (let optionId in options) {
if (!options.hasOwnProperty(optionId)) {
continue;
}
this.getToolOption(toolId, optionId).value = options[optionId];
}
+ if (force) {
+ this.forcedTools.push(toolId);
+ }
return toolId;
}
});
@@ -109,6 +113,8 @@ export default class AiPersona extends RestModel {
if (typeof toolId !== "string") {
toolId = toolId[0];
}
+
+ let force = this.forcedTools.includes(toolId);
if (this.toolOptions && this.toolOptions[toolId]) {
let options = this.toolOptions[toolId];
let optionsWithValues = {};
@@ -119,9 +125,9 @@ export default class AiPersona extends RestModel {
let option = options[optionId];
optionsWithValues[optionId] = option.value;
}
- toolsWithOptions.push([toolId, optionsWithValues]);
+ toolsWithOptions.push([toolId, optionsWithValues, force]);
} else {
- toolsWithOptions.push(toolId);
+ toolsWithOptions.push([toolId, {}, force]);
}
});
attrs.tools = toolsWithOptions;
@@ -133,7 +139,6 @@ export default class AiPersona extends RestModel {
: this.getProperties(CREATE_ATTRIBUTES);
attrs.id = this.id;
this.populateToolOptions(attrs);
-
return attrs;
}
@@ -146,6 +151,9 @@ export default class AiPersona extends RestModel {
workingCopy() {
let attrs = this.getProperties(CREATE_ATTRIBUTES);
this.populateToolOptions(attrs);
- return AiPersona.create(attrs);
+
+ const persona = AiPersona.create(attrs);
+ persona.forcedTools = (this.forcedTools || []).slice();
+ return persona;
}
}
diff --git a/assets/javascripts/discourse/components/ai-persona-editor.gjs b/assets/javascripts/discourse/components/ai-persona-editor.gjs
index 139768c1..574911ec 100644
--- a/assets/javascripts/discourse/components/ai-persona-editor.gjs
+++ b/assets/javascripts/discourse/components/ai-persona-editor.gjs
@@ -40,10 +40,39 @@ export default class PersonaEditor extends Component {
@tracked maxPixelsValue = null;
@tracked ragIndexingStatuses = null;
+ @tracked selectedTools = [];
+ @tracked selectedToolNames = [];
+ @tracked forcedToolNames = [];
+
get chatPluginEnabled() {
return this.siteSettings.chat_enabled;
}
+ get allowForceTools() {
+ return !this.editingModel?.system && this.editingModel?.tools?.length > 0;
+ }
+
+ @action
+ forcedToolsChanged(tools) {
+ this.forcedToolNames = tools;
+ this.editingModel.forcedTools = this.forcedToolNames;
+ }
+
+ @action
+ toolsChanged(tools) {
+ this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) =>
+ tools.includes(tool.id)
+ );
+ this.selectedToolNames = tools.slice();
+
+ this.forcedToolNames = this.forcedToolNames.filter(
+ (tool) => this.editingModel.tools.indexOf(tool) !== -1
+ );
+
+ this.editingModel.tools = this.selectedToolNames;
+ this.editingModel.forcedTools = this.forcedToolNames;
+ }
+
@action
updateModel() {
this.editingModel = this.args.model.workingCopy();
@@ -51,6 +80,12 @@ export default class PersonaEditor extends Component {
this.maxPixelsValue = this.findClosestPixelValue(
this.editingModel.vision_max_pixels
);
+
+ this.selectedToolNames = this.editingModel.tools || [];
+ this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) =>
+ this.selectedToolNames.includes(tool.id)
+ );
+ this.forcedToolNames = this.editingModel.forcedTools || [];
}
findClosestPixelValue(pixels) {
@@ -336,15 +371,27 @@ export default class PersonaEditor extends Component {
+ {{#if this.allowForceTools}}
+
+
+
+
+ {{/if}}
{{#unless this.editingModel.system}}
{{/unless}}
diff --git a/assets/javascripts/discourse/components/ai-tool-selector.js b/assets/javascripts/discourse/components/ai-tool-selector.js
index 0060e06f..c3959eff 100644
--- a/assets/javascripts/discourse/components/ai-tool-selector.js
+++ b/assets/javascripts/discourse/components/ai-tool-selector.js
@@ -6,7 +6,7 @@ export default MultiSelectComponent.extend({
this.selectKit.options.set("disabled", this.get("attrs.disabled.value"));
}),
- content: computed(function () {
+ content: computed("tools", function () {
return this.tools;
}),
diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml
index 7b6d7f35..f4719967 100644
--- a/config/locales/client.en.yml
+++ b/config/locales/client.en.yml
@@ -148,6 +148,7 @@ en:
saved: AI Persona Saved
enabled: "Enabled?"
tools: Enabled Tools
+ forced_tools: Forced Tools
allowed_groups: Allowed Groups
confirm_delete: Are you sure you want to delete this persona?
new: "New Persona"
diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb
index e3941e0e..391ee16b 100644
--- a/lib/ai_bot/bot.rb
+++ b/lib/ai_bot/bot.rb
@@ -67,6 +67,19 @@ module DiscourseAi
.last
end
+ def force_tool_if_needed(prompt, context)
+ context[:chosen_tools] ||= []
+ forced_tools = persona.force_tool_use.map { |tool| tool.name }
+ force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) }
+
+ if force_tool
+ context[:chosen_tools] << force_tool
+ prompt.tool_choice = force_tool
+ else
+ prompt.tool_choice = nil
+ end
+ end
+
def reply(context, &update_blk)
llm = DiscourseAi::Completions::Llm.proxy(model)
prompt = persona.craft_prompt(context, llm: llm)
@@ -85,6 +98,7 @@ module DiscourseAi
while total_completions <= MAX_COMPLETIONS && ongoing_chain
tool_found = false
+ force_tool_if_needed(prompt, context)
result =
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb
index b46abebe..0a31598c 100644
--- a/lib/ai_bot/personas/persona.rb
+++ b/lib/ai_bot/personas/persona.rb
@@ -113,6 +113,10 @@ module DiscourseAi
[]
end
+ def force_tool_use
+ []
+ end
+
def required_tools
[]
end
diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb
index 5420e643..fa3a9ca4 100644
--- a/lib/completions/dialects/dialect.rb
+++ b/lib/completions/dialects/dialect.rb
@@ -60,6 +60,10 @@ module DiscourseAi
@tools ||= tools_dialect.translated_tools
end
+ def tool_choice
+ prompt.tool_choice
+ end
+
def translate
messages = prompt.messages
diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb
index 43f19053..35b3e724 100644
--- a/lib/completions/endpoints/open_ai.rb
+++ b/lib/completions/endpoints/open_ai.rb
@@ -54,8 +54,12 @@ module DiscourseAi
# We'll fallback to guess this using the tokenizer.
payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai"
end
-
- payload[:tools] = dialect.tools if dialect.tools.present?
+ if dialect.tools.present?
+ payload[:tools] = dialect.tools
+ if dialect.tool_choice.present?
+ payload[:tool_choice] = { type: "function", function: { name: dialect.tool_choice } }
+ end
+ end
payload
end
diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb
index 527ec87f..445bfc19 100644
--- a/lib/completions/llm.rb
+++ b/lib/completions/llm.rb
@@ -123,7 +123,7 @@ module DiscourseAi
end
def record_prompt(prompt)
- @prompts << prompt if @prompts
+ @prompts << prompt.dup if @prompts
end
def proxy(model)
diff --git a/lib/completions/prompt.rb b/lib/completions/prompt.rb
index 051818d2..9a6d4d61 100644
--- a/lib/completions/prompt.rb
+++ b/lib/completions/prompt.rb
@@ -6,7 +6,7 @@ module DiscourseAi
INVALID_TURN = Class.new(StandardError)
attr_reader :messages
- attr_accessor :tools, :topic_id, :post_id, :max_pixels
+ attr_accessor :tools, :topic_id, :post_id, :max_pixels, :tool_choice
def initialize(
system_message_text = nil,
@@ -14,7 +14,8 @@ module DiscourseAi
tools: [],
topic_id: nil,
post_id: nil,
- max_pixels: nil
+ max_pixels: nil,
+ tool_choice: nil
)
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
@@ -37,6 +38,7 @@ module DiscourseAi
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }
@tools = tools
+ @tool_choice = tool_choice
end
def push(type:, content:, id: nil, name: nil, upload_ids: nil)
diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb
index 8c665639..44ff136c 100644
--- a/spec/lib/completions/endpoints/open_ai_spec.rb
+++ b/spec/lib/completions/endpoints/open_ai_spec.rb
@@ -255,6 +255,95 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
end
end
+ describe "forced tool use" do
+ it "can properly force tool use" do
+ llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
+
+ tools = [
+ {
+ name: "echo",
+ description: "echo something",
+ parameters: [
+ { name: "text", type: "string", description: "text to echo", required: true },
+ ],
+ },
+ ]
+
+ prompt =
+ DiscourseAi::Completions::Prompt.new(
+ "You are a bot",
+ messages: [type: :user, id: "user1", content: "echo hello"],
+ tools: tools,
+ tool_choice: "echo",
+ )
+
+ response = {
+ id: "chatcmpl-9JxkAzzaeO4DSV3omWvok9TKhCjBH",
+ object: "chat.completion",
+ created: 1_714_544_914,
+ model: "gpt-4-turbo-2024-04-09",
+ choices: [
+ {
+ index: 0,
+ message: {
+ role: "assistant",
+ content: nil,
+ tool_calls: [
+ {
+ id: "call_I8LKnoijVuhKOM85nnEQgWwd",
+ type: "function",
+ function: {
+ name: "echo",
+ arguments: "{\"text\":\"hello\"}",
+ },
+ },
+ ],
+ },
+ logprobs: nil,
+ finish_reason: "tool_calls",
+ },
+ ],
+ usage: {
+ prompt_tokens: 55,
+ completion_tokens: 13,
+ total_tokens: 68,
+ },
+ system_fingerprint: "fp_ea6eb70039",
+ }.to_json
+
+ body_json = nil
+ stub_request(:post, "https://api.openai.com/v1/chat/completions").with(
+ body: proc { |body| body_json = JSON.parse(body, symbolize_names: true) },
+ ).to_return(body: response)
+
+ result = llm.generate(prompt, user: user)
+
+ expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } })
+
+ expected = (<<~TXT).strip
+
+
+ echo
+
+ hello
+
+ call_I8LKnoijVuhKOM85nnEQgWwd
+
+
+ TXT
+
+ expect(result.strip).to eq(expected)
+
+ stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
+ body: { choices: [message: { content: "OK" }] }.to_json,
+ )
+
+ result = llm.generate(prompt, user: user)
+
+ expect(result).to eq("OK")
+ end
+ end
+
describe "image support" do
it "can handle images" do
model = Fabricate(:llm_model, vision_enabled: true)
diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb
index f087414b..c3f054df 100644
--- a/spec/lib/modules/ai_bot/playground_spec.rb
+++ b/spec/lib/modules/ai_bot/playground_spec.rb
@@ -78,13 +78,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end
let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) }
-
- it "uses custom tool in conversation" do
- persona_klass = AiPersona.all_personas.find { |p| p.name == ai_persona.name }
- bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
- playground = DiscourseAi::AiBot::Playground.new(bot)
-
- function_call = (<<~XML).strip
+ let(:function_call) { (<<~XML).strip }
search
@@ -96,6 +90,32 @@ RSpec.describe DiscourseAi::AiBot::Playground do
",
XML
+ let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) }
+
+ let(:playground) { DiscourseAi::AiBot::Playground.new(bot) }
+
+ it "can force usage of a tool" do
+ tool_name = "custom-#{custom_tool.id}"
+ ai_persona.update!(tools: [[tool_name, nil, "force"]])
+ responses = [function_call, "custom tool did stuff (maybe)"]
+
+ prompt = nil
+ DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt|
+ new_post = Fabricate(:post, raw: "Can you use the custom tool?")
+ _reply_post = playground.reply_to(new_post)
+ prompt = _prompt
+ end
+
+ expect(prompt.length).to eq(2)
+ expect(prompt[0].tool_choice).to eq("search")
+ expect(prompt[1].tool_choice).to eq(nil)
+ end
+
+ it "uses custom tool in conversation" do
+ persona_klass = AiPersona.all_personas.find { |p| p.name == ai_persona.name }
+ bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
+ playground = DiscourseAi::AiBot::Playground.new(bot)
+
responses = [function_call, "custom tool did stuff (maybe)"]
reply_post = nil
diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb
index 203abe33..1b1da5fc 100644
--- a/spec/requests/admin/ai_personas_controller_spec.rb
+++ b/spec/requests/admin/ai_personas_controller_spec.rb
@@ -160,7 +160,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
name: "superbot",
description: "Assists with tasks",
system_prompt: "you are a helpful bot",
- tools: [["search", { "base_query" => "test" }]],
+ tools: [["search", { "base_query" => "test" }, true]],
top_p: 0.1,
temperature: 0.5,
mentionable: true,
@@ -186,7 +186,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
persona = AiPersona.find(persona_json["id"])
- expect(persona.tools).to eq([["search", { "base_query" => "test" }]])
+ expect(persona.tools).to eq([["search", { "base_query" => "test" }, true]])
expect(persona.top_p).to eq(0.1)
expect(persona.temperature).to eq(0.5)
}.to change(AiPersona, :count).by(1)
@@ -296,7 +296,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
ai_persona.reload
expect(ai_persona.name).to eq("SuperBot")
expect(ai_persona.enabled).to eq(false)
- expect(ai_persona.tools).to eq(["search"])
+ expect(ai_persona.tools).to eq([["search", nil, false]])
end
end
diff --git a/spec/system/admin_ai_persona_spec.rb b/spec/system/admin_ai_persona_spec.rb
index 7ecd57dd..171cfcc8 100644
--- a/spec/system/admin_ai_persona_spec.rb
+++ b/spec/system/admin_ai_persona_spec.rb
@@ -30,7 +30,7 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do
expect(persona.name).to eq("Test Persona")
expect(persona.description).to eq("I am a test persona")
expect(persona.system_prompt).to eq("You are a helpful bot")
- expect(persona.tools).to eq([["Read", { "read_private" => nil }]])
+ expect(persona.tools).to eq([["Read", { "read_private" => nil }, false]])
end
it "will not allow deletion or editing of system personas" do
diff --git a/test/javascripts/unit/models/ai-persona-test.js b/test/javascripts/unit/models/ai-persona-test.js
index f785ffba..c1f25aeb 100644
--- a/test/javascripts/unit/models/ai-persona-test.js
+++ b/test/javascripts/unit/models/ai-persona-test.js
@@ -60,7 +60,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
const updatedProperties = aiPersona.updateProperties();
// perform remapping for save
- properties.tools = [["ToolName", { option1: "value1" }]];
+ properties.tools = [["ToolName", { option1: "value1" }, false]];
assert.deepEqual(updatedProperties, properties);
});
@@ -100,7 +100,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
const createdProperties = aiPersona.createProperties();
- properties.tools = [["ToolName", { option1: "value1" }]];
+ properties.tools = [["ToolName", { option1: "value1" }, false]];
assert.deepEqual(createdProperties, properties);
});