FEATURE: allows forced LLM tool use (#818)
* FEATURE: allows forced LLM tool use Sometimes we need to force LLMs to use tools, for example in RAG like use cases we may want to force an unconditional search. The new framework allows you backend to force tool usage. Front end commit to follow * UI for forcing tools now works, but it does not react right * fix bugs * fix tests, this is now ready for review
This commit is contained in:
parent
c294b6d394
commit
545500b329
|
@ -120,15 +120,12 @@ module DiscourseAi
|
||||||
def permit_tools(tools)
|
def permit_tools(tools)
|
||||||
return [] if !tools.is_a?(Array)
|
return [] if !tools.is_a?(Array)
|
||||||
|
|
||||||
tools.filter_map do |tool, options|
|
tools.filter_map do |tool, options, force_tool|
|
||||||
break nil if !tool.is_a?(String)
|
break nil if !tool.is_a?(String)
|
||||||
options&.permit! if options && options.is_a?(ActionController::Parameters)
|
options&.permit! if options && options.is_a?(ActionController::Parameters)
|
||||||
|
|
||||||
if options
|
# this is simpler from a storage perspective, 1 way to store tools
|
||||||
[tool, options]
|
[tool, options, !!force_tool]
|
||||||
else
|
|
||||||
tool
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -136,17 +136,23 @@ class AiPersona < ActiveRecord::Base
|
||||||
end
|
end
|
||||||
|
|
||||||
options = {}
|
options = {}
|
||||||
|
force_tool_use = []
|
||||||
|
|
||||||
tools =
|
tools =
|
||||||
self.tools.filter_map do |element|
|
self.tools.filter_map do |element|
|
||||||
klass = nil
|
klass = nil
|
||||||
|
|
||||||
if element.is_a?(String) && element.start_with?("custom-")
|
element = [element] if element.is_a?(String)
|
||||||
custom_tool_id = element.split("-", 2).last.to_i
|
|
||||||
|
inner_name, current_options, should_force_tool_use =
|
||||||
|
element.is_a?(Array) ? element : [element, nil]
|
||||||
|
|
||||||
|
if inner_name.start_with?("custom-")
|
||||||
|
custom_tool_id = inner_name.split("-", 2).last.to_i
|
||||||
if AiTool.exists?(id: custom_tool_id, enabled: true)
|
if AiTool.exists?(id: custom_tool_id, enabled: true)
|
||||||
klass = DiscourseAi::AiBot::Tools::Custom.class_instance(custom_tool_id)
|
klass = DiscourseAi::AiBot::Tools::Custom.class_instance(custom_tool_id)
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
inner_name, current_options = element.is_a?(Array) ? element : [element, nil]
|
|
||||||
inner_name = inner_name.gsub("Tool", "")
|
inner_name = inner_name.gsub("Tool", "")
|
||||||
inner_name = "List#{inner_name}" if %w[Categories Tags].include?(inner_name)
|
inner_name = "List#{inner_name}" if %w[Categories Tags].include?(inner_name)
|
||||||
|
|
||||||
|
@ -155,9 +161,10 @@ class AiPersona < ActiveRecord::Base
|
||||||
options[klass] = current_options if current_options
|
options[klass] = current_options if current_options
|
||||||
rescue StandardError
|
rescue StandardError
|
||||||
end
|
end
|
||||||
|
|
||||||
klass
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
force_tool_use << klass if should_force_tool_use
|
||||||
|
klass
|
||||||
end
|
end
|
||||||
|
|
||||||
ai_persona_id = self.id
|
ai_persona_id = self.id
|
||||||
|
@ -177,6 +184,7 @@ class AiPersona < ActiveRecord::Base
|
||||||
end
|
end
|
||||||
|
|
||||||
define_method(:tools) { tools }
|
define_method(:tools) { tools }
|
||||||
|
define_method(:force_tool_use) { force_tool_use }
|
||||||
define_method(:options) { options }
|
define_method(:options) { options }
|
||||||
define_method(:temperature) { @ai_persona&.temperature }
|
define_method(:temperature) { @ai_persona&.temperature }
|
||||||
define_method(:top_p) { @ai_persona&.top_p }
|
define_method(:top_p) { @ai_persona&.top_p }
|
||||||
|
|
|
@ -59,21 +59,25 @@ class ToolOption {
|
||||||
export default class AiPersona extends RestModel {
|
export default class AiPersona extends RestModel {
|
||||||
// this code is here to convert the wire schema to easier to work with object
|
// this code is here to convert the wire schema to easier to work with object
|
||||||
// on the wire we pass in/out tools as an Array.
|
// on the wire we pass in/out tools as an Array.
|
||||||
// [[ToolName, {option1: value, option2: value}], ToolName2, ToolName3]
|
// [[ToolName, {option1: value, option2: value}, force], ToolName2, ToolName3]
|
||||||
// So we rework this into a "tools" property and nested toolOptions
|
// So we rework this into a "tools" property and nested toolOptions
|
||||||
init(properties) {
|
init(properties) {
|
||||||
|
this.forcedTools = [];
|
||||||
if (properties.tools) {
|
if (properties.tools) {
|
||||||
properties.tools = properties.tools.map((tool) => {
|
properties.tools = properties.tools.map((tool) => {
|
||||||
if (typeof tool === "string") {
|
if (typeof tool === "string") {
|
||||||
return tool;
|
return tool;
|
||||||
} else {
|
} else {
|
||||||
let [toolId, options] = tool;
|
let [toolId, options, force] = tool;
|
||||||
for (let optionId in options) {
|
for (let optionId in options) {
|
||||||
if (!options.hasOwnProperty(optionId)) {
|
if (!options.hasOwnProperty(optionId)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
this.getToolOption(toolId, optionId).value = options[optionId];
|
this.getToolOption(toolId, optionId).value = options[optionId];
|
||||||
}
|
}
|
||||||
|
if (force) {
|
||||||
|
this.forcedTools.push(toolId);
|
||||||
|
}
|
||||||
return toolId;
|
return toolId;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -109,6 +113,8 @@ export default class AiPersona extends RestModel {
|
||||||
if (typeof toolId !== "string") {
|
if (typeof toolId !== "string") {
|
||||||
toolId = toolId[0];
|
toolId = toolId[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let force = this.forcedTools.includes(toolId);
|
||||||
if (this.toolOptions && this.toolOptions[toolId]) {
|
if (this.toolOptions && this.toolOptions[toolId]) {
|
||||||
let options = this.toolOptions[toolId];
|
let options = this.toolOptions[toolId];
|
||||||
let optionsWithValues = {};
|
let optionsWithValues = {};
|
||||||
|
@ -119,9 +125,9 @@ export default class AiPersona extends RestModel {
|
||||||
let option = options[optionId];
|
let option = options[optionId];
|
||||||
optionsWithValues[optionId] = option.value;
|
optionsWithValues[optionId] = option.value;
|
||||||
}
|
}
|
||||||
toolsWithOptions.push([toolId, optionsWithValues]);
|
toolsWithOptions.push([toolId, optionsWithValues, force]);
|
||||||
} else {
|
} else {
|
||||||
toolsWithOptions.push(toolId);
|
toolsWithOptions.push([toolId, {}, force]);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
attrs.tools = toolsWithOptions;
|
attrs.tools = toolsWithOptions;
|
||||||
|
@ -133,7 +139,6 @@ export default class AiPersona extends RestModel {
|
||||||
: this.getProperties(CREATE_ATTRIBUTES);
|
: this.getProperties(CREATE_ATTRIBUTES);
|
||||||
attrs.id = this.id;
|
attrs.id = this.id;
|
||||||
this.populateToolOptions(attrs);
|
this.populateToolOptions(attrs);
|
||||||
|
|
||||||
return attrs;
|
return attrs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,6 +151,9 @@ export default class AiPersona extends RestModel {
|
||||||
workingCopy() {
|
workingCopy() {
|
||||||
let attrs = this.getProperties(CREATE_ATTRIBUTES);
|
let attrs = this.getProperties(CREATE_ATTRIBUTES);
|
||||||
this.populateToolOptions(attrs);
|
this.populateToolOptions(attrs);
|
||||||
return AiPersona.create(attrs);
|
|
||||||
|
const persona = AiPersona.create(attrs);
|
||||||
|
persona.forcedTools = (this.forcedTools || []).slice();
|
||||||
|
return persona;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,10 +40,39 @@ export default class PersonaEditor extends Component {
|
||||||
@tracked maxPixelsValue = null;
|
@tracked maxPixelsValue = null;
|
||||||
@tracked ragIndexingStatuses = null;
|
@tracked ragIndexingStatuses = null;
|
||||||
|
|
||||||
|
@tracked selectedTools = [];
|
||||||
|
@tracked selectedToolNames = [];
|
||||||
|
@tracked forcedToolNames = [];
|
||||||
|
|
||||||
get chatPluginEnabled() {
|
get chatPluginEnabled() {
|
||||||
return this.siteSettings.chat_enabled;
|
return this.siteSettings.chat_enabled;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
get allowForceTools() {
|
||||||
|
return !this.editingModel?.system && this.editingModel?.tools?.length > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@action
|
||||||
|
forcedToolsChanged(tools) {
|
||||||
|
this.forcedToolNames = tools;
|
||||||
|
this.editingModel.forcedTools = this.forcedToolNames;
|
||||||
|
}
|
||||||
|
|
||||||
|
@action
|
||||||
|
toolsChanged(tools) {
|
||||||
|
this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) =>
|
||||||
|
tools.includes(tool.id)
|
||||||
|
);
|
||||||
|
this.selectedToolNames = tools.slice();
|
||||||
|
|
||||||
|
this.forcedToolNames = this.forcedToolNames.filter(
|
||||||
|
(tool) => this.editingModel.tools.indexOf(tool) !== -1
|
||||||
|
);
|
||||||
|
|
||||||
|
this.editingModel.tools = this.selectedToolNames;
|
||||||
|
this.editingModel.forcedTools = this.forcedToolNames;
|
||||||
|
}
|
||||||
|
|
||||||
@action
|
@action
|
||||||
updateModel() {
|
updateModel() {
|
||||||
this.editingModel = this.args.model.workingCopy();
|
this.editingModel = this.args.model.workingCopy();
|
||||||
|
@ -51,6 +80,12 @@ export default class PersonaEditor extends Component {
|
||||||
this.maxPixelsValue = this.findClosestPixelValue(
|
this.maxPixelsValue = this.findClosestPixelValue(
|
||||||
this.editingModel.vision_max_pixels
|
this.editingModel.vision_max_pixels
|
||||||
);
|
);
|
||||||
|
|
||||||
|
this.selectedToolNames = this.editingModel.tools || [];
|
||||||
|
this.selectedTools = this.args.personas.resultSetMeta.tools.filter((tool) =>
|
||||||
|
this.selectedToolNames.includes(tool.id)
|
||||||
|
);
|
||||||
|
this.forcedToolNames = this.editingModel.forcedTools || [];
|
||||||
}
|
}
|
||||||
|
|
||||||
findClosestPixelValue(pixels) {
|
findClosestPixelValue(pixels) {
|
||||||
|
@ -336,15 +371,27 @@ export default class PersonaEditor extends Component {
|
||||||
<label>{{I18n.t "discourse_ai.ai_persona.tools"}}</label>
|
<label>{{I18n.t "discourse_ai.ai_persona.tools"}}</label>
|
||||||
<AiToolSelector
|
<AiToolSelector
|
||||||
class="ai-persona-editor__tools"
|
class="ai-persona-editor__tools"
|
||||||
@value={{this.editingModel.tools}}
|
@value={{this.selectedToolNames}}
|
||||||
@disabled={{this.editingModel.system}}
|
@disabled={{this.editingModel.system}}
|
||||||
@tools={{@personas.resultSetMeta.tools}}
|
@tools={{@personas.resultSetMeta.tools}}
|
||||||
|
@onChange={{this.toolsChanged}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
{{#if this.allowForceTools}}
|
||||||
|
<div class="control-group">
|
||||||
|
<label>{{I18n.t "discourse_ai.ai_persona.forced_tools"}}</label>
|
||||||
|
<AiToolSelector
|
||||||
|
class="ai-persona-editor__tools"
|
||||||
|
@value={{this.forcedToolNames}}
|
||||||
|
@tools={{this.selectedTools}}
|
||||||
|
@onChange={{this.forcedToolsChanged}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{{/if}}
|
||||||
{{#unless this.editingModel.system}}
|
{{#unless this.editingModel.system}}
|
||||||
<AiPersonaToolOptions
|
<AiPersonaToolOptions
|
||||||
@persona={{this.editingModel}}
|
@persona={{this.editingModel}}
|
||||||
@tools={{this.editingModel.tools}}
|
@tools={{this.selectedToolNames}}
|
||||||
@allTools={{@personas.resultSetMeta.tools}}
|
@allTools={{@personas.resultSetMeta.tools}}
|
||||||
/>
|
/>
|
||||||
{{/unless}}
|
{{/unless}}
|
||||||
|
|
|
@ -6,7 +6,7 @@ export default MultiSelectComponent.extend({
|
||||||
this.selectKit.options.set("disabled", this.get("attrs.disabled.value"));
|
this.selectKit.options.set("disabled", this.get("attrs.disabled.value"));
|
||||||
}),
|
}),
|
||||||
|
|
||||||
content: computed(function () {
|
content: computed("tools", function () {
|
||||||
return this.tools;
|
return this.tools;
|
||||||
}),
|
}),
|
||||||
|
|
||||||
|
|
|
@ -148,6 +148,7 @@ en:
|
||||||
saved: AI Persona Saved
|
saved: AI Persona Saved
|
||||||
enabled: "Enabled?"
|
enabled: "Enabled?"
|
||||||
tools: Enabled Tools
|
tools: Enabled Tools
|
||||||
|
forced_tools: Forced Tools
|
||||||
allowed_groups: Allowed Groups
|
allowed_groups: Allowed Groups
|
||||||
confirm_delete: Are you sure you want to delete this persona?
|
confirm_delete: Are you sure you want to delete this persona?
|
||||||
new: "New Persona"
|
new: "New Persona"
|
||||||
|
|
|
@ -67,6 +67,19 @@ module DiscourseAi
|
||||||
.last
|
.last
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def force_tool_if_needed(prompt, context)
|
||||||
|
context[:chosen_tools] ||= []
|
||||||
|
forced_tools = persona.force_tool_use.map { |tool| tool.name }
|
||||||
|
force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) }
|
||||||
|
|
||||||
|
if force_tool
|
||||||
|
context[:chosen_tools] << force_tool
|
||||||
|
prompt.tool_choice = force_tool
|
||||||
|
else
|
||||||
|
prompt.tool_choice = nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
def reply(context, &update_blk)
|
def reply(context, &update_blk)
|
||||||
llm = DiscourseAi::Completions::Llm.proxy(model)
|
llm = DiscourseAi::Completions::Llm.proxy(model)
|
||||||
prompt = persona.craft_prompt(context, llm: llm)
|
prompt = persona.craft_prompt(context, llm: llm)
|
||||||
|
@ -85,6 +98,7 @@ module DiscourseAi
|
||||||
|
|
||||||
while total_completions <= MAX_COMPLETIONS && ongoing_chain
|
while total_completions <= MAX_COMPLETIONS && ongoing_chain
|
||||||
tool_found = false
|
tool_found = false
|
||||||
|
force_tool_if_needed(prompt, context)
|
||||||
|
|
||||||
result =
|
result =
|
||||||
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
|
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
|
||||||
|
|
|
@ -113,6 +113,10 @@ module DiscourseAi
|
||||||
[]
|
[]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def force_tool_use
|
||||||
|
[]
|
||||||
|
end
|
||||||
|
|
||||||
def required_tools
|
def required_tools
|
||||||
[]
|
[]
|
||||||
end
|
end
|
||||||
|
|
|
@ -60,6 +60,10 @@ module DiscourseAi
|
||||||
@tools ||= tools_dialect.translated_tools
|
@tools ||= tools_dialect.translated_tools
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def tool_choice
|
||||||
|
prompt.tool_choice
|
||||||
|
end
|
||||||
|
|
||||||
def translate
|
def translate
|
||||||
messages = prompt.messages
|
messages = prompt.messages
|
||||||
|
|
||||||
|
|
|
@ -54,8 +54,12 @@ module DiscourseAi
|
||||||
# We'll fallback to guess this using the tokenizer.
|
# We'll fallback to guess this using the tokenizer.
|
||||||
payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai"
|
payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai"
|
||||||
end
|
end
|
||||||
|
if dialect.tools.present?
|
||||||
payload[:tools] = dialect.tools if dialect.tools.present?
|
payload[:tools] = dialect.tools
|
||||||
|
if dialect.tool_choice.present?
|
||||||
|
payload[:tool_choice] = { type: "function", function: { name: dialect.tool_choice } }
|
||||||
|
end
|
||||||
|
end
|
||||||
payload
|
payload
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -123,7 +123,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def record_prompt(prompt)
|
def record_prompt(prompt)
|
||||||
@prompts << prompt if @prompts
|
@prompts << prompt.dup if @prompts
|
||||||
end
|
end
|
||||||
|
|
||||||
def proxy(model)
|
def proxy(model)
|
||||||
|
|
|
@ -6,7 +6,7 @@ module DiscourseAi
|
||||||
INVALID_TURN = Class.new(StandardError)
|
INVALID_TURN = Class.new(StandardError)
|
||||||
|
|
||||||
attr_reader :messages
|
attr_reader :messages
|
||||||
attr_accessor :tools, :topic_id, :post_id, :max_pixels
|
attr_accessor :tools, :topic_id, :post_id, :max_pixels, :tool_choice
|
||||||
|
|
||||||
def initialize(
|
def initialize(
|
||||||
system_message_text = nil,
|
system_message_text = nil,
|
||||||
|
@ -14,7 +14,8 @@ module DiscourseAi
|
||||||
tools: [],
|
tools: [],
|
||||||
topic_id: nil,
|
topic_id: nil,
|
||||||
post_id: nil,
|
post_id: nil,
|
||||||
max_pixels: nil
|
max_pixels: nil,
|
||||||
|
tool_choice: nil
|
||||||
)
|
)
|
||||||
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
|
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
|
||||||
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
|
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
|
||||||
|
@ -37,6 +38,7 @@ module DiscourseAi
|
||||||
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }
|
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }
|
||||||
|
|
||||||
@tools = tools
|
@tools = tools
|
||||||
|
@tool_choice = tool_choice
|
||||||
end
|
end
|
||||||
|
|
||||||
def push(type:, content:, id: nil, name: nil, upload_ids: nil)
|
def push(type:, content:, id: nil, name: nil, upload_ids: nil)
|
||||||
|
|
|
@ -255,6 +255,95 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "forced tool use" do
|
||||||
|
it "can properly force tool use" do
|
||||||
|
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
name: "echo",
|
||||||
|
description: "echo something",
|
||||||
|
parameters: [
|
||||||
|
{ name: "text", type: "string", description: "text to echo", required: true },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt =
|
||||||
|
DiscourseAi::Completions::Prompt.new(
|
||||||
|
"You are a bot",
|
||||||
|
messages: [type: :user, id: "user1", content: "echo hello"],
|
||||||
|
tools: tools,
|
||||||
|
tool_choice: "echo",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = {
|
||||||
|
id: "chatcmpl-9JxkAzzaeO4DSV3omWvok9TKhCjBH",
|
||||||
|
object: "chat.completion",
|
||||||
|
created: 1_714_544_914,
|
||||||
|
model: "gpt-4-turbo-2024-04-09",
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
message: {
|
||||||
|
role: "assistant",
|
||||||
|
content: nil,
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
id: "call_I8LKnoijVuhKOM85nnEQgWwd",
|
||||||
|
type: "function",
|
||||||
|
function: {
|
||||||
|
name: "echo",
|
||||||
|
arguments: "{\"text\":\"hello\"}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
logprobs: nil,
|
||||||
|
finish_reason: "tool_calls",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 55,
|
||||||
|
completion_tokens: 13,
|
||||||
|
total_tokens: 68,
|
||||||
|
},
|
||||||
|
system_fingerprint: "fp_ea6eb70039",
|
||||||
|
}.to_json
|
||||||
|
|
||||||
|
body_json = nil
|
||||||
|
stub_request(:post, "https://api.openai.com/v1/chat/completions").with(
|
||||||
|
body: proc { |body| body_json = JSON.parse(body, symbolize_names: true) },
|
||||||
|
).to_return(body: response)
|
||||||
|
|
||||||
|
result = llm.generate(prompt, user: user)
|
||||||
|
|
||||||
|
expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } })
|
||||||
|
|
||||||
|
expected = (<<~TXT).strip
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>echo</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<text>hello</text>
|
||||||
|
</parameters>
|
||||||
|
<tool_id>call_I8LKnoijVuhKOM85nnEQgWwd</tool_id>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
TXT
|
||||||
|
|
||||||
|
expect(result.strip).to eq(expected)
|
||||||
|
|
||||||
|
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||||
|
body: { choices: [message: { content: "OK" }] }.to_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = llm.generate(prompt, user: user)
|
||||||
|
|
||||||
|
expect(result).to eq("OK")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
describe "image support" do
|
describe "image support" do
|
||||||
it "can handle images" do
|
it "can handle images" do
|
||||||
model = Fabricate(:llm_model, vision_enabled: true)
|
model = Fabricate(:llm_model, vision_enabled: true)
|
||||||
|
|
|
@ -78,13 +78,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
end
|
end
|
||||||
|
|
||||||
let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) }
|
let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) }
|
||||||
|
let(:function_call) { (<<~XML).strip }
|
||||||
it "uses custom tool in conversation" do
|
|
||||||
persona_klass = AiPersona.all_personas.find { |p| p.name == ai_persona.name }
|
|
||||||
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
|
|
||||||
playground = DiscourseAi::AiBot::Playground.new(bot)
|
|
||||||
|
|
||||||
function_call = (<<~XML).strip
|
|
||||||
<function_calls>
|
<function_calls>
|
||||||
<invoke>
|
<invoke>
|
||||||
<tool_name>search</tool_name>
|
<tool_name>search</tool_name>
|
||||||
|
@ -96,6 +90,32 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
</function_calls>",
|
</function_calls>",
|
||||||
XML
|
XML
|
||||||
|
|
||||||
|
let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) }
|
||||||
|
|
||||||
|
let(:playground) { DiscourseAi::AiBot::Playground.new(bot) }
|
||||||
|
|
||||||
|
it "can force usage of a tool" do
|
||||||
|
tool_name = "custom-#{custom_tool.id}"
|
||||||
|
ai_persona.update!(tools: [[tool_name, nil, "force"]])
|
||||||
|
responses = [function_call, "custom tool did stuff (maybe)"]
|
||||||
|
|
||||||
|
prompt = nil
|
||||||
|
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt|
|
||||||
|
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
|
||||||
|
_reply_post = playground.reply_to(new_post)
|
||||||
|
prompt = _prompt
|
||||||
|
end
|
||||||
|
|
||||||
|
expect(prompt.length).to eq(2)
|
||||||
|
expect(prompt[0].tool_choice).to eq("search")
|
||||||
|
expect(prompt[1].tool_choice).to eq(nil)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "uses custom tool in conversation" do
|
||||||
|
persona_klass = AiPersona.all_personas.find { |p| p.name == ai_persona.name }
|
||||||
|
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
|
||||||
|
playground = DiscourseAi::AiBot::Playground.new(bot)
|
||||||
|
|
||||||
responses = [function_call, "custom tool did stuff (maybe)"]
|
responses = [function_call, "custom tool did stuff (maybe)"]
|
||||||
|
|
||||||
reply_post = nil
|
reply_post = nil
|
||||||
|
|
|
@ -160,7 +160,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
||||||
name: "superbot",
|
name: "superbot",
|
||||||
description: "Assists with tasks",
|
description: "Assists with tasks",
|
||||||
system_prompt: "you are a helpful bot",
|
system_prompt: "you are a helpful bot",
|
||||||
tools: [["search", { "base_query" => "test" }]],
|
tools: [["search", { "base_query" => "test" }, true]],
|
||||||
top_p: 0.1,
|
top_p: 0.1,
|
||||||
temperature: 0.5,
|
temperature: 0.5,
|
||||||
mentionable: true,
|
mentionable: true,
|
||||||
|
@ -186,7 +186,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
||||||
|
|
||||||
persona = AiPersona.find(persona_json["id"])
|
persona = AiPersona.find(persona_json["id"])
|
||||||
|
|
||||||
expect(persona.tools).to eq([["search", { "base_query" => "test" }]])
|
expect(persona.tools).to eq([["search", { "base_query" => "test" }, true]])
|
||||||
expect(persona.top_p).to eq(0.1)
|
expect(persona.top_p).to eq(0.1)
|
||||||
expect(persona.temperature).to eq(0.5)
|
expect(persona.temperature).to eq(0.5)
|
||||||
}.to change(AiPersona, :count).by(1)
|
}.to change(AiPersona, :count).by(1)
|
||||||
|
@ -296,7 +296,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
||||||
ai_persona.reload
|
ai_persona.reload
|
||||||
expect(ai_persona.name).to eq("SuperBot")
|
expect(ai_persona.name).to eq("SuperBot")
|
||||||
expect(ai_persona.enabled).to eq(false)
|
expect(ai_persona.enabled).to eq(false)
|
||||||
expect(ai_persona.tools).to eq(["search"])
|
expect(ai_persona.tools).to eq([["search", nil, false]])
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do
|
||||||
expect(persona.name).to eq("Test Persona")
|
expect(persona.name).to eq("Test Persona")
|
||||||
expect(persona.description).to eq("I am a test persona")
|
expect(persona.description).to eq("I am a test persona")
|
||||||
expect(persona.system_prompt).to eq("You are a helpful bot")
|
expect(persona.system_prompt).to eq("You are a helpful bot")
|
||||||
expect(persona.tools).to eq([["Read", { "read_private" => nil }]])
|
expect(persona.tools).to eq([["Read", { "read_private" => nil }, false]])
|
||||||
end
|
end
|
||||||
|
|
||||||
it "will not allow deletion or editing of system personas" do
|
it "will not allow deletion or editing of system personas" do
|
||||||
|
|
|
@ -60,7 +60,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
|
||||||
const updatedProperties = aiPersona.updateProperties();
|
const updatedProperties = aiPersona.updateProperties();
|
||||||
|
|
||||||
// perform remapping for save
|
// perform remapping for save
|
||||||
properties.tools = [["ToolName", { option1: "value1" }]];
|
properties.tools = [["ToolName", { option1: "value1" }, false]];
|
||||||
|
|
||||||
assert.deepEqual(updatedProperties, properties);
|
assert.deepEqual(updatedProperties, properties);
|
||||||
});
|
});
|
||||||
|
@ -100,7 +100,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
|
||||||
|
|
||||||
const createdProperties = aiPersona.createProperties();
|
const createdProperties = aiPersona.createProperties();
|
||||||
|
|
||||||
properties.tools = [["ToolName", { option1: "value1" }]];
|
properties.tools = [["ToolName", { option1: "value1" }, false]];
|
||||||
|
|
||||||
assert.deepEqual(createdProperties, properties);
|
assert.deepEqual(createdProperties, properties);
|
||||||
});
|
});
|
||||||
|
|
Loading…
Reference in New Issue