FEATURE: allows forced LLM tool use (#818)
* FEATURE: allows forced LLM tool use Sometimes we need to force LLMs to use tools, for example in RAG like use cases we may want to force an unconditional search. The new framework allows you backend to force tool usage. Front end commit to follow * UI for forcing tools now works, but it does not react right * fix bugs * fix tests, this is now ready for review
This commit is contained in:
parent
c294b6d394
commit
545500b329
|
@ -120,15 +120,12 @@ module DiscourseAi
|
|||
def permit_tools(tools)
|
||||
return [] if !tools.is_a?(Array)
|
||||
|
||||
tools.filter_map do |tool, options|
|
||||
tools.filter_map do |tool, options, force_tool|
|
||||
break nil if !tool.is_a?(String)
|
||||
options&.permit! if options && options.is_a?(ActionController::Parameters)
|
||||
|
||||
if options
|
||||
[tool, options]
|
||||
else
|
||||
tool
|
||||
end
|
||||
# this is simpler from a storage perspective, 1 way to store tools
|
||||
[tool, options, !!force_tool]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -136,17 +136,23 @@ class AiPersona < ActiveRecord::Base
|
|||
end
|
||||
|
||||
options = {}
|
||||
force_tool_use = []
|
||||
|
||||
tools =
|
||||
self.tools.filter_map do |element|
|
||||
klass = nil
|
||||
|
||||
if element.is_a?(String) && element.start_with?("custom-")
|
||||
custom_tool_id = element.split("-", 2).last.to_i
|
||||
element = [element] if element.is_a?(String)
|
||||
|
||||
inner_name, current_options, should_force_tool_use =
|
||||
element.is_a?(Array) ? element : [element, nil]
|
||||
|
||||
if inner_name.start_with?("custom-")
|
||||
custom_tool_id = inner_name.split("-", 2).last.to_i
|
||||
if AiTool.exists?(id: custom_tool_id, enabled: true)
|
||||
klass = DiscourseAi::AiBot::Tools::Custom.class_instance(custom_tool_id)
|
||||
end
|
||||
else
|
||||
inner_name, current_options = element.is_a?(Array) ? element : [element, nil]
|
||||
inner_name = inner_name.gsub("Tool", "")
|
||||
inner_name = "List#{inner_name}" if %w[Categories Tags].include?(inner_name)
|
||||
|
||||
|
@ -155,9 +161,10 @@ class AiPersona < ActiveRecord::Base
|
|||
options[klass] = current_options if current_options
|
||||
rescue StandardError
|
||||
end
|
||||
|
||||
klass
|
||||
end
|
||||
|
||||
force_tool_use << klass if should_force_tool_use
|
||||
klass
|
||||
end
|
||||
|
||||
ai_persona_id = self.id
|
||||
|
@ -177,6 +184,7 @@ class AiPersona < ActiveRecord::Base
|
|||
end
|
||||
|
||||
define_method(:tools) { tools }
|
||||
define_method(:force_tool_use) { force_tool_use }
|
||||
define_method(:options) { options }
|
||||
define_method(:temperature) { @ai_persona&.temperature }
|
||||
define_method(:top_p) { @ai_persona&.top_p }
|
||||
|
|
|
@ -59,21 +59,25 @@ class ToolOption {
|
|||
export default class AiPersona extends RestModel {
|
||||
// this code is here to convert the wire schema to easier to work with object
|
||||
// on the wire we pass in/out tools as an Array.
|
||||
// [[ToolName, {option1: value, option2: value}], ToolName2, ToolName3]
|
||||
// [[ToolName, {option1: value, option2: value}, force], ToolName2, ToolName3]
|
||||
// So we rework this into a "tools" property and nested toolOptions
|
||||
init(properties) {
|
||||
this.forcedTools = [];
|
||||
if (properties.tools) {
|
||||
properties.tools = properties.tools.map((tool) => {
|
||||
if (typeof tool === "string") {
|
||||
return tool;
|
||||
} else {
|
||||
let [toolId, options] = tool;
|
||||
let [toolId, options, force] = tool;
|
||||
for (let optionId in options) {
|
||||
if (!options.hasOwnProperty(optionId)) {
|
||||
continue;
|
||||
}
|
||||
this.getToolOption(toolId, optionId).value = options[optionId];
|
||||
}
|
||||
if (force) {
|
||||
this.forcedTools.push(toolId);
|
||||
}
|
||||
return toolId;
|
||||
}
|
||||
});
|
||||
|
@ -109,6 +113,8 @@ export default class AiPersona extends RestModel {
|
|||
if (typeof toolId !== "string") {
|
||||
toolId = toolId[0];
|
||||
}
|
||||
|
||||
let force = this.forcedTools.includes(toolId);
|
||||
if (this.toolOptions && this.toolOptions[toolId]) {
|
||||
let options = this.toolOptions[toolId];
|
||||
let optionsWithValues = {};
|
||||
|
@ -119,9 +125,9 @@ export default class AiPersona extends RestModel {
|
|||
let option = options[optionId];
|
||||
optionsWithValues[optionId] = option.value;
|
||||
}
|
||||
toolsWithOptions.push([toolId, optionsWithValues]);
|
||||
toolsWithOptions.push([toolId, optionsWithValues, force]);
|
||||
} else {
|
||||
toolsWithOptions.push(toolId);
|
||||
toolsWithOptions.push([toolId, {}, force]);
|
||||
}
|
||||
});
|
||||
attrs.tools = toolsWithOptions;
|
||||
|
@ -133,7 +139,6 @@ export default class AiPersona extends RestModel {
|
|||
: this.getProperties(CREATE_ATTRIBUTES);
|
||||
attrs.id = this.id;
|
||||
this.populateToolOptions(attrs);
|
||||
|
||||
return attrs;
|
||||
}
|
||||
|
||||
|
@ -146,6 +151,9 @@ export default class AiPersona extends RestModel {
|
|||
workingCopy() {
|
||||
let attrs = this.getProperties(CREATE_ATTRIBUTES);
|
||||
this.populateToolOptions(attrs);
|
||||
return AiPersona.create(attrs);
|
||||
|
||||
const persona = AiPersona.create(attrs);
|
||||
persona.forcedTools = (this.forcedTools || []).slice();
|
||||
return persona;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,10 +40,39 @@ export default class PersonaEditor extends Component {
|
|||
@tracked maxPixelsValue = null;
|
||||
@tracked ragIndexingStatuses = null;
|
||||
|
||||
@tracked selectedTools = [];
|
||||
@tracked selectedToolNames = [];
|
||||
@tracked forcedToolNames = [];
|
||||
|
||||
get chatPluginEnabled() {
|
||||
return this.siteSettings.chat_enabled;
|
||||
}
|
||||
|
||||
get allowForceTools() {
|
||||
return !this.editingModel?.system && this.editingModel?.tools?.length > 0;
|
||||
}
|
||||
|
||||
@action
|
||||
forcedToolsChanged(tools) {
|
||||
this.forcedToolNames = tools;
|
||||
this.editingModel.forcedTools = this.forcedToolNames;
|
||||
}
|
||||
|
||||
@action
|
||||
toolsChanged(tools) {
|
||||
this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) =>
|
||||
tools.includes(tool.id)
|
||||
);
|
||||
this.selectedToolNames = tools.slice();
|
||||
|
||||
this.forcedToolNames = this.forcedToolNames.filter(
|
||||
(tool) => this.editingModel.tools.indexOf(tool) !== -1
|
||||
);
|
||||
|
||||
this.editingModel.tools = this.selectedToolNames;
|
||||
this.editingModel.forcedTools = this.forcedToolNames;
|
||||
}
|
||||
|
||||
@action
|
||||
updateModel() {
|
||||
this.editingModel = this.args.model.workingCopy();
|
||||
|
@ -51,6 +80,12 @@ export default class PersonaEditor extends Component {
|
|||
this.maxPixelsValue = this.findClosestPixelValue(
|
||||
this.editingModel.vision_max_pixels
|
||||
);
|
||||
|
||||
this.selectedToolNames = this.editingModel.tools || [];
|
||||
this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) =>
|
||||
this.selectedToolNames.includes(tool.id)
|
||||
);
|
||||
this.forcedToolNames = this.editingModel.forcedTools || [];
|
||||
}
|
||||
|
||||
findClosestPixelValue(pixels) {
|
||||
|
@ -336,15 +371,27 @@ export default class PersonaEditor extends Component {
|
|||
<label>{{I18n.t "discourse_ai.ai_persona.tools"}}</label>
|
||||
<AiToolSelector
|
||||
class="ai-persona-editor__tools"
|
||||
@value={{this.editingModel.tools}}
|
||||
@value={{this.selectedToolNames}}
|
||||
@disabled={{this.editingModel.system}}
|
||||
@tools={{@personas.resultSetMeta.tools}}
|
||||
@onChange={{this.toolsChanged}}
|
||||
/>
|
||||
</div>
|
||||
{{#if this.allowForceTools}}
|
||||
<div class="control-group">
|
||||
<label>{{I18n.t "discourse_ai.ai_persona.forced_tools"}}</label>
|
||||
<AiToolSelector
|
||||
class="ai-persona-editor__tools"
|
||||
@value={{this.forcedToolNames}}
|
||||
@tools={{this.selectedTools}}
|
||||
@onChange={{this.forcedToolsChanged}}
|
||||
/>
|
||||
</div>
|
||||
{{/if}}
|
||||
{{#unless this.editingModel.system}}
|
||||
<AiPersonaToolOptions
|
||||
@persona={{this.editingModel}}
|
||||
@tools={{this.editingModel.tools}}
|
||||
@tools={{this.selectedToolNames}}
|
||||
@allTools={{@personas.resultSetMeta.tools}}
|
||||
/>
|
||||
{{/unless}}
|
||||
|
|
|
@ -6,7 +6,7 @@ export default MultiSelectComponent.extend({
|
|||
this.selectKit.options.set("disabled", this.get("attrs.disabled.value"));
|
||||
}),
|
||||
|
||||
content: computed(function () {
|
||||
content: computed("tools", function () {
|
||||
return this.tools;
|
||||
}),
|
||||
|
||||
|
|
|
@ -148,6 +148,7 @@ en:
|
|||
saved: AI Persona Saved
|
||||
enabled: "Enabled?"
|
||||
tools: Enabled Tools
|
||||
forced_tools: Forced Tools
|
||||
allowed_groups: Allowed Groups
|
||||
confirm_delete: Are you sure you want to delete this persona?
|
||||
new: "New Persona"
|
||||
|
|
|
@ -67,6 +67,19 @@ module DiscourseAi
|
|||
.last
|
||||
end
|
||||
|
||||
def force_tool_if_needed(prompt, context)
|
||||
context[:chosen_tools] ||= []
|
||||
forced_tools = persona.force_tool_use.map { |tool| tool.name }
|
||||
force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) }
|
||||
|
||||
if force_tool
|
||||
context[:chosen_tools] << force_tool
|
||||
prompt.tool_choice = force_tool
|
||||
else
|
||||
prompt.tool_choice = nil
|
||||
end
|
||||
end
|
||||
|
||||
def reply(context, &update_blk)
|
||||
llm = DiscourseAi::Completions::Llm.proxy(model)
|
||||
prompt = persona.craft_prompt(context, llm: llm)
|
||||
|
@ -85,6 +98,7 @@ module DiscourseAi
|
|||
|
||||
while total_completions <= MAX_COMPLETIONS && ongoing_chain
|
||||
tool_found = false
|
||||
force_tool_if_needed(prompt, context)
|
||||
|
||||
result =
|
||||
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
|
||||
|
|
|
@ -113,6 +113,10 @@ module DiscourseAi
|
|||
[]
|
||||
end
|
||||
|
||||
def force_tool_use
|
||||
[]
|
||||
end
|
||||
|
||||
def required_tools
|
||||
[]
|
||||
end
|
||||
|
|
|
@ -60,6 +60,10 @@ module DiscourseAi
|
|||
@tools ||= tools_dialect.translated_tools
|
||||
end
|
||||
|
||||
def tool_choice
|
||||
prompt.tool_choice
|
||||
end
|
||||
|
||||
def translate
|
||||
messages = prompt.messages
|
||||
|
||||
|
|
|
@ -54,8 +54,12 @@ module DiscourseAi
|
|||
# We'll fallback to guess this using the tokenizer.
|
||||
payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai"
|
||||
end
|
||||
|
||||
payload[:tools] = dialect.tools if dialect.tools.present?
|
||||
if dialect.tools.present?
|
||||
payload[:tools] = dialect.tools
|
||||
if dialect.tool_choice.present?
|
||||
payload[:tool_choice] = { type: "function", function: { name: dialect.tool_choice } }
|
||||
end
|
||||
end
|
||||
payload
|
||||
end
|
||||
|
||||
|
|
|
@ -123,7 +123,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def record_prompt(prompt)
|
||||
@prompts << prompt if @prompts
|
||||
@prompts << prompt.dup if @prompts
|
||||
end
|
||||
|
||||
def proxy(model)
|
||||
|
|
|
@ -6,7 +6,7 @@ module DiscourseAi
|
|||
INVALID_TURN = Class.new(StandardError)
|
||||
|
||||
attr_reader :messages
|
||||
attr_accessor :tools, :topic_id, :post_id, :max_pixels
|
||||
attr_accessor :tools, :topic_id, :post_id, :max_pixels, :tool_choice
|
||||
|
||||
def initialize(
|
||||
system_message_text = nil,
|
||||
|
@ -14,7 +14,8 @@ module DiscourseAi
|
|||
tools: [],
|
||||
topic_id: nil,
|
||||
post_id: nil,
|
||||
max_pixels: nil
|
||||
max_pixels: nil,
|
||||
tool_choice: nil
|
||||
)
|
||||
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
|
||||
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
|
||||
|
@ -37,6 +38,7 @@ module DiscourseAi
|
|||
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }
|
||||
|
||||
@tools = tools
|
||||
@tool_choice = tool_choice
|
||||
end
|
||||
|
||||
def push(type:, content:, id: nil, name: nil, upload_ids: nil)
|
||||
|
|
|
@ -255,6 +255,95 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
|||
end
|
||||
end
|
||||
|
||||
describe "forced tool use" do
|
||||
it "can properly force tool use" do
|
||||
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
|
||||
tools = [
|
||||
{
|
||||
name: "echo",
|
||||
description: "echo something",
|
||||
parameters: [
|
||||
{ name: "text", type: "string", description: "text to echo", required: true },
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are a bot",
|
||||
messages: [type: :user, id: "user1", content: "echo hello"],
|
||||
tools: tools,
|
||||
tool_choice: "echo",
|
||||
)
|
||||
|
||||
response = {
|
||||
id: "chatcmpl-9JxkAzzaeO4DSV3omWvok9TKhCjBH",
|
||||
object: "chat.completion",
|
||||
created: 1_714_544_914,
|
||||
model: "gpt-4-turbo-2024-04-09",
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: nil,
|
||||
tool_calls: [
|
||||
{
|
||||
id: "call_I8LKnoijVuhKOM85nnEQgWwd",
|
||||
type: "function",
|
||||
function: {
|
||||
name: "echo",
|
||||
arguments: "{\"text\":\"hello\"}",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
logprobs: nil,
|
||||
finish_reason: "tool_calls",
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 55,
|
||||
completion_tokens: 13,
|
||||
total_tokens: 68,
|
||||
},
|
||||
system_fingerprint: "fp_ea6eb70039",
|
||||
}.to_json
|
||||
|
||||
body_json = nil
|
||||
stub_request(:post, "https://api.openai.com/v1/chat/completions").with(
|
||||
body: proc { |body| body_json = JSON.parse(body, symbolize_names: true) },
|
||||
).to_return(body: response)
|
||||
|
||||
result = llm.generate(prompt, user: user)
|
||||
|
||||
expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } })
|
||||
|
||||
expected = (<<~TXT).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>echo</tool_name>
|
||||
<parameters>
|
||||
<text>hello</text>
|
||||
</parameters>
|
||||
<tool_id>call_I8LKnoijVuhKOM85nnEQgWwd</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TXT
|
||||
|
||||
expect(result.strip).to eq(expected)
|
||||
|
||||
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||
body: { choices: [message: { content: "OK" }] }.to_json,
|
||||
)
|
||||
|
||||
result = llm.generate(prompt, user: user)
|
||||
|
||||
expect(result).to eq("OK")
|
||||
end
|
||||
end
|
||||
|
||||
describe "image support" do
|
||||
it "can handle images" do
|
||||
model = Fabricate(:llm_model, vision_enabled: true)
|
||||
|
|
|
@ -78,13 +78,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
end
|
||||
|
||||
let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) }
|
||||
|
||||
it "uses custom tool in conversation" do
|
||||
persona_klass = AiPersona.all_personas.find { |p| p.name == ai_persona.name }
|
||||
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
|
||||
playground = DiscourseAi::AiBot::Playground.new(bot)
|
||||
|
||||
function_call = (<<~XML).strip
|
||||
let(:function_call) { (<<~XML).strip }
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
|
@ -96,6 +90,32 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
</function_calls>",
|
||||
XML
|
||||
|
||||
let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) }
|
||||
|
||||
let(:playground) { DiscourseAi::AiBot::Playground.new(bot) }
|
||||
|
||||
it "can force usage of a tool" do
|
||||
tool_name = "custom-#{custom_tool.id}"
|
||||
ai_persona.update!(tools: [[tool_name, nil, "force"]])
|
||||
responses = [function_call, "custom tool did stuff (maybe)"]
|
||||
|
||||
prompt = nil
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt|
|
||||
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
|
||||
_reply_post = playground.reply_to(new_post)
|
||||
prompt = _prompt
|
||||
end
|
||||
|
||||
expect(prompt.length).to eq(2)
|
||||
expect(prompt[0].tool_choice).to eq("search")
|
||||
expect(prompt[1].tool_choice).to eq(nil)
|
||||
end
|
||||
|
||||
it "uses custom tool in conversation" do
|
||||
persona_klass = AiPersona.all_personas.find { |p| p.name == ai_persona.name }
|
||||
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
|
||||
playground = DiscourseAi::AiBot::Playground.new(bot)
|
||||
|
||||
responses = [function_call, "custom tool did stuff (maybe)"]
|
||||
|
||||
reply_post = nil
|
||||
|
|
|
@ -160,7 +160,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
name: "superbot",
|
||||
description: "Assists with tasks",
|
||||
system_prompt: "you are a helpful bot",
|
||||
tools: [["search", { "base_query" => "test" }]],
|
||||
tools: [["search", { "base_query" => "test" }, true]],
|
||||
top_p: 0.1,
|
||||
temperature: 0.5,
|
||||
mentionable: true,
|
||||
|
@ -186,7 +186,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
|
||||
persona = AiPersona.find(persona_json["id"])
|
||||
|
||||
expect(persona.tools).to eq([["search", { "base_query" => "test" }]])
|
||||
expect(persona.tools).to eq([["search", { "base_query" => "test" }, true]])
|
||||
expect(persona.top_p).to eq(0.1)
|
||||
expect(persona.temperature).to eq(0.5)
|
||||
}.to change(AiPersona, :count).by(1)
|
||||
|
@ -296,7 +296,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
ai_persona.reload
|
||||
expect(ai_persona.name).to eq("SuperBot")
|
||||
expect(ai_persona.enabled).to eq(false)
|
||||
expect(ai_persona.tools).to eq(["search"])
|
||||
expect(ai_persona.tools).to eq([["search", nil, false]])
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do
|
|||
expect(persona.name).to eq("Test Persona")
|
||||
expect(persona.description).to eq("I am a test persona")
|
||||
expect(persona.system_prompt).to eq("You are a helpful bot")
|
||||
expect(persona.tools).to eq([["Read", { "read_private" => nil }]])
|
||||
expect(persona.tools).to eq([["Read", { "read_private" => nil }, false]])
|
||||
end
|
||||
|
||||
it "will not allow deletion or editing of system personas" do
|
||||
|
|
|
@ -60,7 +60,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
|
|||
const updatedProperties = aiPersona.updateProperties();
|
||||
|
||||
// perform remapping for save
|
||||
properties.tools = [["ToolName", { option1: "value1" }]];
|
||||
properties.tools = [["ToolName", { option1: "value1" }, false]];
|
||||
|
||||
assert.deepEqual(updatedProperties, properties);
|
||||
});
|
||||
|
@ -100,7 +100,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
|
|||
|
||||
const createdProperties = aiPersona.createProperties();
|
||||
|
||||
properties.tools = [["ToolName", { option1: "value1" }]];
|
||||
properties.tools = [["ToolName", { option1: "value1" }, false]];
|
||||
|
||||
assert.deepEqual(createdProperties, properties);
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue