FEATURE: allows forced LLM tool use (#818)

* FEATURE: allows forced LLM tool use

Sometimes we need to force LLMs to use tools, for example in RAG
like use cases we may want to force an unconditional search.

The new framework allows you backend to force tool usage.

Front end commit to follow

* UI for forcing tools now works, but it does not react right

* fix bugs

* fix tests, this is now ready for review
This commit is contained in:
Sam 2024-10-05 08:46:57 +09:00 committed by GitHub
parent c294b6d394
commit 545500b329
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 236 additions and 38 deletions

View File

@ -120,15 +120,12 @@ module DiscourseAi
def permit_tools(tools) def permit_tools(tools)
return [] if !tools.is_a?(Array) return [] if !tools.is_a?(Array)
tools.filter_map do |tool, options| tools.filter_map do |tool, options, force_tool|
break nil if !tool.is_a?(String) break nil if !tool.is_a?(String)
options&.permit! if options && options.is_a?(ActionController::Parameters) options&.permit! if options && options.is_a?(ActionController::Parameters)
if options # this is simpler from a storage perspective, 1 way to store tools
[tool, options] [tool, options, !!force_tool]
else
tool
end
end end
end end
end end

View File

@ -136,17 +136,23 @@ class AiPersona < ActiveRecord::Base
end end
options = {} options = {}
force_tool_use = []
tools = tools =
self.tools.filter_map do |element| self.tools.filter_map do |element|
klass = nil klass = nil
if element.is_a?(String) && element.start_with?("custom-") element = [element] if element.is_a?(String)
custom_tool_id = element.split("-", 2).last.to_i
inner_name, current_options, should_force_tool_use =
element.is_a?(Array) ? element : [element, nil]
if inner_name.start_with?("custom-")
custom_tool_id = inner_name.split("-", 2).last.to_i
if AiTool.exists?(id: custom_tool_id, enabled: true) if AiTool.exists?(id: custom_tool_id, enabled: true)
klass = DiscourseAi::AiBot::Tools::Custom.class_instance(custom_tool_id) klass = DiscourseAi::AiBot::Tools::Custom.class_instance(custom_tool_id)
end end
else else
inner_name, current_options = element.is_a?(Array) ? element : [element, nil]
inner_name = inner_name.gsub("Tool", "") inner_name = inner_name.gsub("Tool", "")
inner_name = "List#{inner_name}" if %w[Categories Tags].include?(inner_name) inner_name = "List#{inner_name}" if %w[Categories Tags].include?(inner_name)
@ -155,9 +161,10 @@ class AiPersona < ActiveRecord::Base
options[klass] = current_options if current_options options[klass] = current_options if current_options
rescue StandardError rescue StandardError
end end
klass
end end
force_tool_use << klass if should_force_tool_use
klass
end end
ai_persona_id = self.id ai_persona_id = self.id
@ -177,6 +184,7 @@ class AiPersona < ActiveRecord::Base
end end
define_method(:tools) { tools } define_method(:tools) { tools }
define_method(:force_tool_use) { force_tool_use }
define_method(:options) { options } define_method(:options) { options }
define_method(:temperature) { @ai_persona&.temperature } define_method(:temperature) { @ai_persona&.temperature }
define_method(:top_p) { @ai_persona&.top_p } define_method(:top_p) { @ai_persona&.top_p }

View File

@ -59,21 +59,25 @@ class ToolOption {
export default class AiPersona extends RestModel { export default class AiPersona extends RestModel {
// this code is here to convert the wire schema to easier to work with object // this code is here to convert the wire schema to easier to work with object
// on the wire we pass in/out tools as an Array. // on the wire we pass in/out tools as an Array.
// [[ToolName, {option1: value, option2: value}], ToolName2, ToolName3] // [[ToolName, {option1: value, option2: value}, force], ToolName2, ToolName3]
// So we rework this into a "tools" property and nested toolOptions // So we rework this into a "tools" property and nested toolOptions
init(properties) { init(properties) {
this.forcedTools = [];
if (properties.tools) { if (properties.tools) {
properties.tools = properties.tools.map((tool) => { properties.tools = properties.tools.map((tool) => {
if (typeof tool === "string") { if (typeof tool === "string") {
return tool; return tool;
} else { } else {
let [toolId, options] = tool; let [toolId, options, force] = tool;
for (let optionId in options) { for (let optionId in options) {
if (!options.hasOwnProperty(optionId)) { if (!options.hasOwnProperty(optionId)) {
continue; continue;
} }
this.getToolOption(toolId, optionId).value = options[optionId]; this.getToolOption(toolId, optionId).value = options[optionId];
} }
if (force) {
this.forcedTools.push(toolId);
}
return toolId; return toolId;
} }
}); });
@ -109,6 +113,8 @@ export default class AiPersona extends RestModel {
if (typeof toolId !== "string") { if (typeof toolId !== "string") {
toolId = toolId[0]; toolId = toolId[0];
} }
let force = this.forcedTools.includes(toolId);
if (this.toolOptions && this.toolOptions[toolId]) { if (this.toolOptions && this.toolOptions[toolId]) {
let options = this.toolOptions[toolId]; let options = this.toolOptions[toolId];
let optionsWithValues = {}; let optionsWithValues = {};
@ -119,9 +125,9 @@ export default class AiPersona extends RestModel {
let option = options[optionId]; let option = options[optionId];
optionsWithValues[optionId] = option.value; optionsWithValues[optionId] = option.value;
} }
toolsWithOptions.push([toolId, optionsWithValues]); toolsWithOptions.push([toolId, optionsWithValues, force]);
} else { } else {
toolsWithOptions.push(toolId); toolsWithOptions.push([toolId, {}, force]);
} }
}); });
attrs.tools = toolsWithOptions; attrs.tools = toolsWithOptions;
@ -133,7 +139,6 @@ export default class AiPersona extends RestModel {
: this.getProperties(CREATE_ATTRIBUTES); : this.getProperties(CREATE_ATTRIBUTES);
attrs.id = this.id; attrs.id = this.id;
this.populateToolOptions(attrs); this.populateToolOptions(attrs);
return attrs; return attrs;
} }
@ -146,6 +151,9 @@ export default class AiPersona extends RestModel {
workingCopy() { workingCopy() {
let attrs = this.getProperties(CREATE_ATTRIBUTES); let attrs = this.getProperties(CREATE_ATTRIBUTES);
this.populateToolOptions(attrs); this.populateToolOptions(attrs);
return AiPersona.create(attrs);
const persona = AiPersona.create(attrs);
persona.forcedTools = (this.forcedTools || []).slice();
return persona;
} }
} }

View File

@ -40,10 +40,39 @@ export default class PersonaEditor extends Component {
@tracked maxPixelsValue = null; @tracked maxPixelsValue = null;
@tracked ragIndexingStatuses = null; @tracked ragIndexingStatuses = null;
@tracked selectedTools = [];
@tracked selectedToolNames = [];
@tracked forcedToolNames = [];
get chatPluginEnabled() { get chatPluginEnabled() {
return this.siteSettings.chat_enabled; return this.siteSettings.chat_enabled;
} }
get allowForceTools() {
return !this.editingModel?.system && this.editingModel?.tools?.length > 0;
}
@action
forcedToolsChanged(tools) {
this.forcedToolNames = tools;
this.editingModel.forcedTools = this.forcedToolNames;
}
@action
toolsChanged(tools) {
this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) =>
tools.includes(tool.id)
);
this.selectedToolNames = tools.slice();
this.forcedToolNames = this.forcedToolNames.filter(
(tool) => this.editingModel.tools.indexOf(tool) !== -1
);
this.editingModel.tools = this.selectedToolNames;
this.editingModel.forcedTools = this.forcedToolNames;
}
@action @action
updateModel() { updateModel() {
this.editingModel = this.args.model.workingCopy(); this.editingModel = this.args.model.workingCopy();
@ -51,6 +80,12 @@ export default class PersonaEditor extends Component {
this.maxPixelsValue = this.findClosestPixelValue( this.maxPixelsValue = this.findClosestPixelValue(
this.editingModel.vision_max_pixels this.editingModel.vision_max_pixels
); );
this.selectedToolNames = this.editingModel.tools || [];
this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) =>
this.selectedToolNames.includes(tool.id)
);
this.forcedToolNames = this.editingModel.forcedTools || [];
} }
findClosestPixelValue(pixels) { findClosestPixelValue(pixels) {
@ -336,15 +371,27 @@ export default class PersonaEditor extends Component {
<label>{{I18n.t "discourse_ai.ai_persona.tools"}}</label> <label>{{I18n.t "discourse_ai.ai_persona.tools"}}</label>
<AiToolSelector <AiToolSelector
class="ai-persona-editor__tools" class="ai-persona-editor__tools"
@value={{this.editingModel.tools}} @value={{this.selectedToolNames}}
@disabled={{this.editingModel.system}} @disabled={{this.editingModel.system}}
@tools={{@personas.resultSetMeta.tools}} @tools={{@personas.resultSetMeta.tools}}
@onChange={{this.toolsChanged}}
/> />
</div> </div>
{{#if this.allowForceTools}}
<div class="control-group">
<label>{{I18n.t "discourse_ai.ai_persona.forced_tools"}}</label>
<AiToolSelector
class="ai-persona-editor__tools"
@value={{this.forcedToolNames}}
@tools={{this.selectedTools}}
@onChange={{this.forcedToolsChanged}}
/>
</div>
{{/if}}
{{#unless this.editingModel.system}} {{#unless this.editingModel.system}}
<AiPersonaToolOptions <AiPersonaToolOptions
@persona={{this.editingModel}} @persona={{this.editingModel}}
@tools={{this.editingModel.tools}} @tools={{this.selectedToolNames}}
@allTools={{@personas.resultSetMeta.tools}} @allTools={{@personas.resultSetMeta.tools}}
/> />
{{/unless}} {{/unless}}

View File

@ -6,7 +6,7 @@ export default MultiSelectComponent.extend({
this.selectKit.options.set("disabled", this.get("attrs.disabled.value")); this.selectKit.options.set("disabled", this.get("attrs.disabled.value"));
}), }),
content: computed(function () { content: computed("tools", function () {
return this.tools; return this.tools;
}), }),

View File

@ -148,6 +148,7 @@ en:
saved: AI Persona Saved saved: AI Persona Saved
enabled: "Enabled?" enabled: "Enabled?"
tools: Enabled Tools tools: Enabled Tools
forced_tools: Forced Tools
allowed_groups: Allowed Groups allowed_groups: Allowed Groups
confirm_delete: Are you sure you want to delete this persona? confirm_delete: Are you sure you want to delete this persona?
new: "New Persona" new: "New Persona"

View File

@ -67,6 +67,19 @@ module DiscourseAi
.last .last
end end
def force_tool_if_needed(prompt, context)
context[:chosen_tools] ||= []
forced_tools = persona.force_tool_use.map { |tool| tool.name }
force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) }
if force_tool
context[:chosen_tools] << force_tool
prompt.tool_choice = force_tool
else
prompt.tool_choice = nil
end
end
def reply(context, &update_blk) def reply(context, &update_blk)
llm = DiscourseAi::Completions::Llm.proxy(model) llm = DiscourseAi::Completions::Llm.proxy(model)
prompt = persona.craft_prompt(context, llm: llm) prompt = persona.craft_prompt(context, llm: llm)
@ -85,6 +98,7 @@ module DiscourseAi
while total_completions <= MAX_COMPLETIONS && ongoing_chain while total_completions <= MAX_COMPLETIONS && ongoing_chain
tool_found = false tool_found = false
force_tool_if_needed(prompt, context)
result = result =
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel| llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|

View File

@ -113,6 +113,10 @@ module DiscourseAi
[] []
end end
def force_tool_use
[]
end
def required_tools def required_tools
[] []
end end

View File

@ -60,6 +60,10 @@ module DiscourseAi
@tools ||= tools_dialect.translated_tools @tools ||= tools_dialect.translated_tools
end end
def tool_choice
prompt.tool_choice
end
def translate def translate
messages = prompt.messages messages = prompt.messages

View File

@ -54,8 +54,12 @@ module DiscourseAi
# We'll fallback to guess this using the tokenizer. # We'll fallback to guess this using the tokenizer.
payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai" payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai"
end end
if dialect.tools.present?
payload[:tools] = dialect.tools if dialect.tools.present? payload[:tools] = dialect.tools
if dialect.tool_choice.present?
payload[:tool_choice] = { type: "function", function: { name: dialect.tool_choice } }
end
end
payload payload
end end

View File

@ -123,7 +123,7 @@ module DiscourseAi
end end
def record_prompt(prompt) def record_prompt(prompt)
@prompts << prompt if @prompts @prompts << prompt.dup if @prompts
end end
def proxy(model) def proxy(model)

View File

@ -6,7 +6,7 @@ module DiscourseAi
INVALID_TURN = Class.new(StandardError) INVALID_TURN = Class.new(StandardError)
attr_reader :messages attr_reader :messages
attr_accessor :tools, :topic_id, :post_id, :max_pixels attr_accessor :tools, :topic_id, :post_id, :max_pixels, :tool_choice
def initialize( def initialize(
system_message_text = nil, system_message_text = nil,
@ -14,7 +14,8 @@ module DiscourseAi
tools: [], tools: [],
topic_id: nil, topic_id: nil,
post_id: nil, post_id: nil,
max_pixels: nil max_pixels: nil,
tool_choice: nil
) )
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array) raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array) raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
@ -37,6 +38,7 @@ module DiscourseAi
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) } @messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }
@tools = tools @tools = tools
@tool_choice = tool_choice
end end
def push(type:, content:, id: nil, name: nil, upload_ids: nil) def push(type:, content:, id: nil, name: nil, upload_ids: nil)

View File

@ -255,6 +255,95 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
end end
end end
describe "forced tool use" do
it "can properly force tool use" do
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
tools = [
{
name: "echo",
description: "echo something",
parameters: [
{ name: "text", type: "string", description: "text to echo", required: true },
],
},
]
prompt =
DiscourseAi::Completions::Prompt.new(
"You are a bot",
messages: [type: :user, id: "user1", content: "echo hello"],
tools: tools,
tool_choice: "echo",
)
response = {
id: "chatcmpl-9JxkAzzaeO4DSV3omWvok9TKhCjBH",
object: "chat.completion",
created: 1_714_544_914,
model: "gpt-4-turbo-2024-04-09",
choices: [
{
index: 0,
message: {
role: "assistant",
content: nil,
tool_calls: [
{
id: "call_I8LKnoijVuhKOM85nnEQgWwd",
type: "function",
function: {
name: "echo",
arguments: "{\"text\":\"hello\"}",
},
},
],
},
logprobs: nil,
finish_reason: "tool_calls",
},
],
usage: {
prompt_tokens: 55,
completion_tokens: 13,
total_tokens: 68,
},
system_fingerprint: "fp_ea6eb70039",
}.to_json
body_json = nil
stub_request(:post, "https://api.openai.com/v1/chat/completions").with(
body: proc { |body| body_json = JSON.parse(body, symbolize_names: true) },
).to_return(body: response)
result = llm.generate(prompt, user: user)
expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } })
expected = (<<~TXT).strip
<function_calls>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>hello</text>
</parameters>
<tool_id>call_I8LKnoijVuhKOM85nnEQgWwd</tool_id>
</invoke>
</function_calls>
TXT
expect(result.strip).to eq(expected)
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
body: { choices: [message: { content: "OK" }] }.to_json,
)
result = llm.generate(prompt, user: user)
expect(result).to eq("OK")
end
end
describe "image support" do describe "image support" do
it "can handle images" do it "can handle images" do
model = Fabricate(:llm_model, vision_enabled: true) model = Fabricate(:llm_model, vision_enabled: true)

View File

@ -78,13 +78,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end end
let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) } let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) }
let(:function_call) { (<<~XML).strip }
it "uses custom tool in conversation" do
persona_klass = AiPersona.all_personas.find { |p| p.name == ai_persona.name }
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
playground = DiscourseAi::AiBot::Playground.new(bot)
function_call = (<<~XML).strip
<function_calls> <function_calls>
<invoke> <invoke>
<tool_name>search</tool_name> <tool_name>search</tool_name>
@ -96,6 +90,32 @@ RSpec.describe DiscourseAi::AiBot::Playground do
</function_calls>", </function_calls>",
XML XML
let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) }
let(:playground) { DiscourseAi::AiBot::Playground.new(bot) }
it "can force usage of a tool" do
tool_name = "custom-#{custom_tool.id}"
ai_persona.update!(tools: [[tool_name, nil, "force"]])
responses = [function_call, "custom tool did stuff (maybe)"]
prompt = nil
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt|
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
_reply_post = playground.reply_to(new_post)
prompt = _prompt
end
expect(prompt.length).to eq(2)
expect(prompt[0].tool_choice).to eq("search")
expect(prompt[1].tool_choice).to eq(nil)
end
it "uses custom tool in conversation" do
persona_klass = AiPersona.all_personas.find { |p| p.name == ai_persona.name }
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
playground = DiscourseAi::AiBot::Playground.new(bot)
responses = [function_call, "custom tool did stuff (maybe)"] responses = [function_call, "custom tool did stuff (maybe)"]
reply_post = nil reply_post = nil

View File

@ -160,7 +160,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
name: "superbot", name: "superbot",
description: "Assists with tasks", description: "Assists with tasks",
system_prompt: "you are a helpful bot", system_prompt: "you are a helpful bot",
tools: [["search", { "base_query" => "test" }]], tools: [["search", { "base_query" => "test" }, true]],
top_p: 0.1, top_p: 0.1,
temperature: 0.5, temperature: 0.5,
mentionable: true, mentionable: true,
@ -186,7 +186,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
persona = AiPersona.find(persona_json["id"]) persona = AiPersona.find(persona_json["id"])
expect(persona.tools).to eq([["search", { "base_query" => "test" }]]) expect(persona.tools).to eq([["search", { "base_query" => "test" }, true]])
expect(persona.top_p).to eq(0.1) expect(persona.top_p).to eq(0.1)
expect(persona.temperature).to eq(0.5) expect(persona.temperature).to eq(0.5)
}.to change(AiPersona, :count).by(1) }.to change(AiPersona, :count).by(1)
@ -296,7 +296,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
ai_persona.reload ai_persona.reload
expect(ai_persona.name).to eq("SuperBot") expect(ai_persona.name).to eq("SuperBot")
expect(ai_persona.enabled).to eq(false) expect(ai_persona.enabled).to eq(false)
expect(ai_persona.tools).to eq(["search"]) expect(ai_persona.tools).to eq([["search", nil, false]])
end end
end end

View File

@ -30,7 +30,7 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do
expect(persona.name).to eq("Test Persona") expect(persona.name).to eq("Test Persona")
expect(persona.description).to eq("I am a test persona") expect(persona.description).to eq("I am a test persona")
expect(persona.system_prompt).to eq("You are a helpful bot") expect(persona.system_prompt).to eq("You are a helpful bot")
expect(persona.tools).to eq([["Read", { "read_private" => nil }]]) expect(persona.tools).to eq([["Read", { "read_private" => nil }, false]])
end end
it "will not allow deletion or editing of system personas" do it "will not allow deletion or editing of system personas" do

View File

@ -60,7 +60,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
const updatedProperties = aiPersona.updateProperties(); const updatedProperties = aiPersona.updateProperties();
// perform remapping for save // perform remapping for save
properties.tools = [["ToolName", { option1: "value1" }]]; properties.tools = [["ToolName", { option1: "value1" }, false]];
assert.deepEqual(updatedProperties, properties); assert.deepEqual(updatedProperties, properties);
}); });
@ -100,7 +100,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
const createdProperties = aiPersona.createProperties(); const createdProperties = aiPersona.createProperties();
properties.tools = [["ToolName", { option1: "value1" }]]; properties.tools = [["ToolName", { option1: "value1" }, false]];
assert.deepEqual(createdProperties, properties); assert.deepEqual(createdProperties, properties);
}); });