FEATURE: allows forced LLM tool use (#818)

* FEATURE: allows forced LLM tool use

Sometimes we need to force LLMs to use tools, for example in RAG
like use cases we may want to force an unconditional search.

The new framework allows you backend to force tool usage.

Front end commit to follow

* UI for forcing tools now works, but it does not react right

* fix bugs

* fix tests, this is now ready for review
This commit is contained in:
Sam 2024-10-05 08:46:57 +09:00 committed by GitHub
parent c294b6d394
commit 545500b329
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 236 additions and 38 deletions

View File

@ -120,15 +120,12 @@ module DiscourseAi
def permit_tools(tools)
return [] if !tools.is_a?(Array)
tools.filter_map do |tool, options|
tools.filter_map do |tool, options, force_tool|
break nil if !tool.is_a?(String)
options&.permit! if options && options.is_a?(ActionController::Parameters)
if options
[tool, options]
else
tool
end
# this is simpler from a storage perspective, 1 way to store tools
[tool, options, !!force_tool]
end
end
end

View File

@ -136,17 +136,23 @@ class AiPersona < ActiveRecord::Base
end
options = {}
force_tool_use = []
tools =
self.tools.filter_map do |element|
klass = nil
if element.is_a?(String) && element.start_with?("custom-")
custom_tool_id = element.split("-", 2).last.to_i
element = [element] if element.is_a?(String)
inner_name, current_options, should_force_tool_use =
element.is_a?(Array) ? element : [element, nil]
if inner_name.start_with?("custom-")
custom_tool_id = inner_name.split("-", 2).last.to_i
if AiTool.exists?(id: custom_tool_id, enabled: true)
klass = DiscourseAi::AiBot::Tools::Custom.class_instance(custom_tool_id)
end
else
inner_name, current_options = element.is_a?(Array) ? element : [element, nil]
inner_name = inner_name.gsub("Tool", "")
inner_name = "List#{inner_name}" if %w[Categories Tags].include?(inner_name)
@ -155,9 +161,10 @@ class AiPersona < ActiveRecord::Base
options[klass] = current_options if current_options
rescue StandardError
end
klass
end
force_tool_use << klass if should_force_tool_use
klass
end
ai_persona_id = self.id
@ -177,6 +184,7 @@ class AiPersona < ActiveRecord::Base
end
define_method(:tools) { tools }
define_method(:force_tool_use) { force_tool_use }
define_method(:options) { options }
define_method(:temperature) { @ai_persona&.temperature }
define_method(:top_p) { @ai_persona&.top_p }

View File

@ -59,21 +59,25 @@ class ToolOption {
export default class AiPersona extends RestModel {
// this code is here to convert the wire schema to easier to work with object
// on the wire we pass in/out tools as an Array.
// [[ToolName, {option1: value, option2: value}], ToolName2, ToolName3]
// [[ToolName, {option1: value, option2: value}, force], ToolName2, ToolName3]
// So we rework this into a "tools" property and nested toolOptions
init(properties) {
this.forcedTools = [];
if (properties.tools) {
properties.tools = properties.tools.map((tool) => {
if (typeof tool === "string") {
return tool;
} else {
let [toolId, options] = tool;
let [toolId, options, force] = tool;
for (let optionId in options) {
if (!options.hasOwnProperty(optionId)) {
continue;
}
this.getToolOption(toolId, optionId).value = options[optionId];
}
if (force) {
this.forcedTools.push(toolId);
}
return toolId;
}
});
@ -109,6 +113,8 @@ export default class AiPersona extends RestModel {
if (typeof toolId !== "string") {
toolId = toolId[0];
}
let force = this.forcedTools.includes(toolId);
if (this.toolOptions && this.toolOptions[toolId]) {
let options = this.toolOptions[toolId];
let optionsWithValues = {};
@ -119,9 +125,9 @@ export default class AiPersona extends RestModel {
let option = options[optionId];
optionsWithValues[optionId] = option.value;
}
toolsWithOptions.push([toolId, optionsWithValues]);
toolsWithOptions.push([toolId, optionsWithValues, force]);
} else {
toolsWithOptions.push(toolId);
toolsWithOptions.push([toolId, {}, force]);
}
});
attrs.tools = toolsWithOptions;
@ -133,7 +139,6 @@ export default class AiPersona extends RestModel {
: this.getProperties(CREATE_ATTRIBUTES);
attrs.id = this.id;
this.populateToolOptions(attrs);
return attrs;
}
@ -146,6 +151,9 @@ export default class AiPersona extends RestModel {
workingCopy() {
let attrs = this.getProperties(CREATE_ATTRIBUTES);
this.populateToolOptions(attrs);
return AiPersona.create(attrs);
const persona = AiPersona.create(attrs);
persona.forcedTools = (this.forcedTools || []).slice();
return persona;
}
}

View File

@ -40,10 +40,39 @@ export default class PersonaEditor extends Component {
@tracked maxPixelsValue = null;
@tracked ragIndexingStatuses = null;
@tracked selectedTools = [];
@tracked selectedToolNames = [];
@tracked forcedToolNames = [];
get chatPluginEnabled() {
return this.siteSettings.chat_enabled;
}
get allowForceTools() {
return !this.editingModel?.system && this.editingModel?.tools?.length > 0;
}
@action
forcedToolsChanged(tools) {
this.forcedToolNames = tools;
this.editingModel.forcedTools = this.forcedToolNames;
}
@action
toolsChanged(tools) {
this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) =>
tools.includes(tool.id)
);
this.selectedToolNames = tools.slice();
this.forcedToolNames = this.forcedToolNames.filter(
(tool) => this.editingModel.tools.indexOf(tool) !== -1
);
this.editingModel.tools = this.selectedToolNames;
this.editingModel.forcedTools = this.forcedToolNames;
}
@action
updateModel() {
this.editingModel = this.args.model.workingCopy();
@ -51,6 +80,12 @@ export default class PersonaEditor extends Component {
this.maxPixelsValue = this.findClosestPixelValue(
this.editingModel.vision_max_pixels
);
this.selectedToolNames = this.editingModel.tools || [];
this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) =>
this.selectedToolNames.includes(tool.id)
);
this.forcedToolNames = this.editingModel.forcedTools || [];
}
findClosestPixelValue(pixels) {
@ -336,15 +371,27 @@ export default class PersonaEditor extends Component {
<label>{{I18n.t "discourse_ai.ai_persona.tools"}}</label>
<AiToolSelector
class="ai-persona-editor__tools"
@value={{this.editingModel.tools}}
@value={{this.selectedToolNames}}
@disabled={{this.editingModel.system}}
@tools={{@personas.resultSetMeta.tools}}
@onChange={{this.toolsChanged}}
/>
</div>
{{#if this.allowForceTools}}
<div class="control-group">
<label>{{I18n.t "discourse_ai.ai_persona.forced_tools"}}</label>
<AiToolSelector
class="ai-persona-editor__tools"
@value={{this.forcedToolNames}}
@tools={{this.selectedTools}}
@onChange={{this.forcedToolsChanged}}
/>
</div>
{{/if}}
{{#unless this.editingModel.system}}
<AiPersonaToolOptions
@persona={{this.editingModel}}
@tools={{this.editingModel.tools}}
@tools={{this.selectedToolNames}}
@allTools={{@personas.resultSetMeta.tools}}
/>
{{/unless}}

View File

@ -6,7 +6,7 @@ export default MultiSelectComponent.extend({
this.selectKit.options.set("disabled", this.get("attrs.disabled.value"));
}),
content: computed(function () {
content: computed("tools", function () {
return this.tools;
}),

View File

@ -148,6 +148,7 @@ en:
saved: AI Persona Saved
enabled: "Enabled?"
tools: Enabled Tools
forced_tools: Forced Tools
allowed_groups: Allowed Groups
confirm_delete: Are you sure you want to delete this persona?
new: "New Persona"

View File

@ -67,6 +67,19 @@ module DiscourseAi
.last
end
def force_tool_if_needed(prompt, context)
context[:chosen_tools] ||= []
forced_tools = persona.force_tool_use.map { |tool| tool.name }
force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) }
if force_tool
context[:chosen_tools] << force_tool
prompt.tool_choice = force_tool
else
prompt.tool_choice = nil
end
end
def reply(context, &update_blk)
llm = DiscourseAi::Completions::Llm.proxy(model)
prompt = persona.craft_prompt(context, llm: llm)
@ -85,6 +98,7 @@ module DiscourseAi
while total_completions <= MAX_COMPLETIONS && ongoing_chain
tool_found = false
force_tool_if_needed(prompt, context)
result =
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|

View File

@ -113,6 +113,10 @@ module DiscourseAi
[]
end
def force_tool_use
[]
end
def required_tools
[]
end

View File

@ -60,6 +60,10 @@ module DiscourseAi
@tools ||= tools_dialect.translated_tools
end
def tool_choice
prompt.tool_choice
end
def translate
messages = prompt.messages

View File

@ -54,8 +54,12 @@ module DiscourseAi
# We'll fallback to guess this using the tokenizer.
payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai"
end
payload[:tools] = dialect.tools if dialect.tools.present?
if dialect.tools.present?
payload[:tools] = dialect.tools
if dialect.tool_choice.present?
payload[:tool_choice] = { type: "function", function: { name: dialect.tool_choice } }
end
end
payload
end

View File

@ -123,7 +123,7 @@ module DiscourseAi
end
def record_prompt(prompt)
@prompts << prompt if @prompts
@prompts << prompt.dup if @prompts
end
def proxy(model)

View File

@ -6,7 +6,7 @@ module DiscourseAi
INVALID_TURN = Class.new(StandardError)
attr_reader :messages
attr_accessor :tools, :topic_id, :post_id, :max_pixels
attr_accessor :tools, :topic_id, :post_id, :max_pixels, :tool_choice
def initialize(
system_message_text = nil,
@ -14,7 +14,8 @@ module DiscourseAi
tools: [],
topic_id: nil,
post_id: nil,
max_pixels: nil
max_pixels: nil,
tool_choice: nil
)
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
@ -37,6 +38,7 @@ module DiscourseAi
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }
@tools = tools
@tool_choice = tool_choice
end
def push(type:, content:, id: nil, name: nil, upload_ids: nil)

View File

@ -255,6 +255,95 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
end
end
describe "forced tool use" do
it "can properly force tool use" do
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
tools = [
{
name: "echo",
description: "echo something",
parameters: [
{ name: "text", type: "string", description: "text to echo", required: true },
],
},
]
prompt =
DiscourseAi::Completions::Prompt.new(
"You are a bot",
messages: [type: :user, id: "user1", content: "echo hello"],
tools: tools,
tool_choice: "echo",
)
response = {
id: "chatcmpl-9JxkAzzaeO4DSV3omWvok9TKhCjBH",
object: "chat.completion",
created: 1_714_544_914,
model: "gpt-4-turbo-2024-04-09",
choices: [
{
index: 0,
message: {
role: "assistant",
content: nil,
tool_calls: [
{
id: "call_I8LKnoijVuhKOM85nnEQgWwd",
type: "function",
function: {
name: "echo",
arguments: "{\"text\":\"hello\"}",
},
},
],
},
logprobs: nil,
finish_reason: "tool_calls",
},
],
usage: {
prompt_tokens: 55,
completion_tokens: 13,
total_tokens: 68,
},
system_fingerprint: "fp_ea6eb70039",
}.to_json
body_json = nil
stub_request(:post, "https://api.openai.com/v1/chat/completions").with(
body: proc { |body| body_json = JSON.parse(body, symbolize_names: true) },
).to_return(body: response)
result = llm.generate(prompt, user: user)
expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } })
expected = (<<~TXT).strip
<function_calls>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>hello</text>
</parameters>
<tool_id>call_I8LKnoijVuhKOM85nnEQgWwd</tool_id>
</invoke>
</function_calls>
TXT
expect(result.strip).to eq(expected)
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
body: { choices: [message: { content: "OK" }] }.to_json,
)
result = llm.generate(prompt, user: user)
expect(result).to eq("OK")
end
end
describe "image support" do
it "can handle images" do
model = Fabricate(:llm_model, vision_enabled: true)

View File

@ -78,13 +78,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end
let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) }
it "uses custom tool in conversation" do
persona_klass = AiPersona.all_personas.find { |p| p.name == ai_persona.name }
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
playground = DiscourseAi::AiBot::Playground.new(bot)
function_call = (<<~XML).strip
let(:function_call) { (<<~XML).strip }
<function_calls>
<invoke>
<tool_name>search</tool_name>
@ -96,6 +90,32 @@ RSpec.describe DiscourseAi::AiBot::Playground do
</function_calls>",
XML
let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) }
let(:playground) { DiscourseAi::AiBot::Playground.new(bot) }
it "can force usage of a tool" do
tool_name = "custom-#{custom_tool.id}"
ai_persona.update!(tools: [[tool_name, nil, "force"]])
responses = [function_call, "custom tool did stuff (maybe)"]
prompt = nil
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt|
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
_reply_post = playground.reply_to(new_post)
prompt = _prompt
end
expect(prompt.length).to eq(2)
expect(prompt[0].tool_choice).to eq("search")
expect(prompt[1].tool_choice).to eq(nil)
end
it "uses custom tool in conversation" do
persona_klass = AiPersona.all_personas.find { |p| p.name == ai_persona.name }
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
playground = DiscourseAi::AiBot::Playground.new(bot)
responses = [function_call, "custom tool did stuff (maybe)"]
reply_post = nil

View File

@ -160,7 +160,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
name: "superbot",
description: "Assists with tasks",
system_prompt: "you are a helpful bot",
tools: [["search", { "base_query" => "test" }]],
tools: [["search", { "base_query" => "test" }, true]],
top_p: 0.1,
temperature: 0.5,
mentionable: true,
@ -186,7 +186,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
persona = AiPersona.find(persona_json["id"])
expect(persona.tools).to eq([["search", { "base_query" => "test" }]])
expect(persona.tools).to eq([["search", { "base_query" => "test" }, true]])
expect(persona.top_p).to eq(0.1)
expect(persona.temperature).to eq(0.5)
}.to change(AiPersona, :count).by(1)
@ -296,7 +296,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
ai_persona.reload
expect(ai_persona.name).to eq("SuperBot")
expect(ai_persona.enabled).to eq(false)
expect(ai_persona.tools).to eq(["search"])
expect(ai_persona.tools).to eq([["search", nil, false]])
end
end

View File

@ -30,7 +30,7 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do
expect(persona.name).to eq("Test Persona")
expect(persona.description).to eq("I am a test persona")
expect(persona.system_prompt).to eq("You are a helpful bot")
expect(persona.tools).to eq([["Read", { "read_private" => nil }]])
expect(persona.tools).to eq([["Read", { "read_private" => nil }, false]])
end
it "will not allow deletion or editing of system personas" do

View File

@ -60,7 +60,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
const updatedProperties = aiPersona.updateProperties();
// perform remapping for save
properties.tools = [["ToolName", { option1: "value1" }]];
properties.tools = [["ToolName", { option1: "value1" }, false]];
assert.deepEqual(updatedProperties, properties);
});
@ -100,7 +100,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
const createdProperties = aiPersona.createProperties();
properties.tools = [["ToolName", { option1: "value1" }]];
properties.tools = [["ToolName", { option1: "value1" }, false]];
assert.deepEqual(createdProperties, properties);
});