FEATURE: allow persona to only force tool calls on limited replies (#827)

This introduces another configuration that allows operators to
limit the amount of interactions with forced tool usage.

Forced tools are very handy in initial llm interactions, but as
conversation progresses they can hinder by slowing down stuff
and adding confusion.
This commit is contained in:
Sam 2024-10-11 07:23:42 +11:00 committed by GitHub
parent 52d90cf1bc
commit 6c4c96e83c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 149 additions and 34 deletions

View File

@ -106,6 +106,7 @@ module DiscourseAi
:question_consolidator_llm, :question_consolidator_llm,
:allow_chat, :allow_chat,
:tool_details, :tool_details,
:forced_tool_count,
allowed_group_ids: [], allowed_group_ids: [],
rag_uploads: [:id], rag_uploads: [:id],
) )

View File

@ -20,6 +20,7 @@ class AiPersona < ActiveRecord::Base
validates :rag_chunk_tokens, numericality: { greater_than: 0, maximum: 50_000 } validates :rag_chunk_tokens, numericality: { greater_than: 0, maximum: 50_000 }
validates :rag_chunk_overlap_tokens, numericality: { greater_than: -1, maximum: 200 } validates :rag_chunk_overlap_tokens, numericality: { greater_than: -1, maximum: 200 }
validates :rag_conversation_chunks, numericality: { greater_than: 0, maximum: 1000 } validates :rag_conversation_chunks, numericality: { greater_than: 0, maximum: 1000 }
validates :forced_tool_count, numericality: { greater_than: -2, maximum: 100_000 }
has_many :rag_document_fragments, dependent: :destroy, as: :target has_many :rag_document_fragments, dependent: :destroy, as: :target
belongs_to :created_by, class_name: "User" belongs_to :created_by, class_name: "User"
@ -185,6 +186,7 @@ class AiPersona < ActiveRecord::Base
define_method(:tools) { tools } define_method(:tools) { tools }
define_method(:force_tool_use) { force_tool_use } define_method(:force_tool_use) { force_tool_use }
define_method(:forced_tool_count) { @ai_persona&.forced_tool_count }
define_method(:options) { options } define_method(:options) { options }
define_method(:temperature) { @ai_persona&.temperature } define_method(:temperature) { @ai_persona&.temperature }
define_method(:top_p) { @ai_persona&.top_p } define_method(:top_p) { @ai_persona&.top_p }
@ -282,11 +284,19 @@ end
# mentionable :boolean default(FALSE), not null # mentionable :boolean default(FALSE), not null
# default_llm :text # default_llm :text
# max_context_posts :integer # max_context_posts :integer
# max_post_context_tokens :integer
# max_context_tokens :integer
# vision_enabled :boolean default(FALSE), not null # vision_enabled :boolean default(FALSE), not null
# vision_max_pixels :integer default(1048576), not null # vision_max_pixels :integer default(1048576), not null
# rag_chunk_tokens :integer default(374), not null # rag_chunk_tokens :integer default(374), not null
# rag_chunk_overlap_tokens :integer default(10), not null # rag_chunk_overlap_tokens :integer default(10), not null
# rag_conversation_chunks :integer default(10), not null # rag_conversation_chunks :integer default(10), not null
# role :enum default("bot"), not null
# role_category_ids :integer default([]), not null, is an Array
# role_tags :string default([]), not null, is an Array
# role_group_ids :integer default([]), not null, is an Array
# role_whispers :boolean default(FALSE), not null
# role_max_responses_per_hour :integer default(50), not null
# question_consolidator_llm :text # question_consolidator_llm :text
# allow_chat :boolean default(FALSE), not null # allow_chat :boolean default(FALSE), not null
# tool_details :boolean default(TRUE), not null # tool_details :boolean default(TRUE), not null

View File

@ -25,7 +25,8 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer
:rag_conversation_chunks, :rag_conversation_chunks,
:question_consolidator_llm, :question_consolidator_llm,
:allow_chat, :allow_chat,
:tool_details :tool_details,
:forced_tool_count
has_one :user, serializer: BasicUserSerializer, embed: :object has_one :user, serializer: BasicUserSerializer, embed: :object
has_many :rag_uploads, serializer: UploadSerializer, embed: :object has_many :rag_uploads, serializer: UploadSerializer, embed: :object

View File

@ -28,6 +28,7 @@ const CREATE_ATTRIBUTES = [
"question_consolidator_llm", "question_consolidator_llm",
"allow_chat", "allow_chat",
"tool_details", "tool_details",
"forced_tool_count",
]; ];
const SYSTEM_ATTRIBUTES = [ const SYSTEM_ATTRIBUTES = [
@ -154,6 +155,7 @@ export default class AiPersona extends RestModel {
const persona = AiPersona.create(attrs); const persona = AiPersona.create(attrs);
persona.forcedTools = (this.forcedTools || []).slice(); persona.forcedTools = (this.forcedTools || []).slice();
persona.forced_tool_count = this.forced_tool_count || -1;
return persona; return persona;
} }
} }

View File

@ -0,0 +1,29 @@
import { computed } from "@ember/object";
import I18n from "discourse-i18n";
import ComboBox from "select-kit/components/combo-box";
export default ComboBox.extend({
content: computed(function () {
const content = [
{
id: -1,
name: I18n.t("discourse_ai.ai_persona.tool_strategies.all"),
},
];
[1, 2, 5].forEach((i) => {
content.push({
id: i,
name: I18n.t("discourse_ai.ai_persona.tool_strategies.replies", {
count: i,
}),
});
});
return content;
}),
selectKitOptions: {
filterable: false,
},
});

View File

@ -20,6 +20,7 @@ import AdminUser from "admin/models/admin-user";
import ComboBox from "select-kit/components/combo-box"; import ComboBox from "select-kit/components/combo-box";
import GroupChooser from "select-kit/components/group-chooser"; import GroupChooser from "select-kit/components/group-chooser";
import DTooltip from "float-kit/components/d-tooltip"; import DTooltip from "float-kit/components/d-tooltip";
import AiForcedToolStrategySelector from "./ai-forced-tool-strategy-selector";
import AiLlmSelector from "./ai-llm-selector"; import AiLlmSelector from "./ai-llm-selector";
import AiPersonaToolOptions from "./ai-persona-tool-options"; import AiPersonaToolOptions from "./ai-persona-tool-options";
import AiToolSelector from "./ai-tool-selector"; import AiToolSelector from "./ai-tool-selector";
@ -49,7 +50,11 @@ export default class PersonaEditor extends Component {
} }
get allowForceTools() { get allowForceTools() {
return !this.editingModel?.system && this.editingModel?.tools?.length > 0; return !this.editingModel?.system && this.selectedToolNames.length > 0;
}
get hasForcedTools() {
return this.forcedToolNames.length > 0;
} }
@action @action
@ -381,12 +386,23 @@ export default class PersonaEditor extends Component {
<div class="control-group"> <div class="control-group">
<label>{{I18n.t "discourse_ai.ai_persona.forced_tools"}}</label> <label>{{I18n.t "discourse_ai.ai_persona.forced_tools"}}</label>
<AiToolSelector <AiToolSelector
class="ai-persona-editor__tools" class="ai-persona-editor__forced_tools"
@value={{this.forcedToolNames}} @value={{this.forcedToolNames}}
@tools={{this.selectedTools}} @tools={{this.selectedTools}}
@onChange={{this.forcedToolsChanged}} @onChange={{this.forcedToolsChanged}}
/> />
</div> </div>
{{#if this.hasForcedTools}}
<div class="control-group">
<label>{{I18n.t
"discourse_ai.ai_persona.forced_tool_strategy"
}}</label>
<AiForcedToolStrategySelector
class="ai-persona-editor__forced_tool_strategy"
@value={{this.editingModel.forced_tool_count}}
/>
</div>
{{/if}}
{{/if}} {{/if}}
{{#unless this.editingModel.system}} {{#unless this.editingModel.system}}
<AiPersonaToolOptions <AiPersonaToolOptions

View File

@ -116,6 +116,11 @@ en:
select_option: "Select an option..." select_option: "Select an option..."
ai_persona: ai_persona:
tool_strategies:
all: "Apply to all replies"
replies:
one: "Apply to first reply only"
other: "Apply to first %{count} replies"
back: Back back: Back
name: Name name: Name
edit: Edit edit: Edit
@ -142,6 +147,7 @@ en:
question_consolidator_llm: Language Model for Question Consolidator question_consolidator_llm: Language Model for Question Consolidator
question_consolidator_llm_help: The language model to use for the question consolidator, you may choose a less powerful model to save costs. question_consolidator_llm_help: The language model to use for the question consolidator, you may choose a less powerful model to save costs.
system_prompt: System Prompt system_prompt: System Prompt
forced_tool_strategy: Forced Tool Strategy
allow_chat: "Allow Chat" allow_chat: "Allow Chat"
allow_chat_help: "If enabled, users in allowed groups can DM this persona" allow_chat_help: "If enabled, users in allowed groups can DM this persona"
save: Save save: Save

View File

@ -0,0 +1,7 @@
# frozen_string_literal: true
class AddForcedToolCountToAiPersonas < ActiveRecord::Migration[7.1]
def change
add_column :ai_personas, :forced_tool_count, :integer, default: -1, null: false
end
end

View File

@ -72,6 +72,11 @@ module DiscourseAi
forced_tools = persona.force_tool_use.map { |tool| tool.name } forced_tools = persona.force_tool_use.map { |tool| tool.name }
force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) } force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) }
if force_tool && persona.forced_tool_count > 0
user_turns = prompt.messages.select { |m| m[:type] == :user }.length
force_tool = false if user_turns > persona.forced_tool_count
end
if force_tool if force_tool
context[:chosen_tools] << force_tool context[:chosen_tools] << force_tool
prompt.tool_choice = force_tool prompt.tool_choice = force_tool

View File

@ -117,6 +117,10 @@ module DiscourseAi
[] []
end end
def forced_tool_count
-1
end
def required_tools def required_tools
[] []
end end

View File

@ -127,19 +127,33 @@ RSpec.describe DiscourseAi::AiBot::Playground do
it "can force usage of a tool" do it "can force usage of a tool" do
tool_name = "custom-#{custom_tool.id}" tool_name = "custom-#{custom_tool.id}"
ai_persona.update!(tools: [[tool_name, nil, true]]) ai_persona.update!(tools: [[tool_name, nil, true]], forced_tool_count: 1)
responses = [function_call, "custom tool did stuff (maybe)"] responses = [function_call, "custom tool did stuff (maybe)"]
prompts = nil prompts = nil
reply_post = nil
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts| DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
new_post = Fabricate(:post, raw: "Can you use the custom tool?") new_post = Fabricate(:post, raw: "Can you use the custom tool?")
_reply_post = playground.reply_to(new_post) reply_post = playground.reply_to(new_post)
prompts = _prompts prompts = _prompts
end end
expect(prompts.length).to eq(2) expect(prompts.length).to eq(2)
expect(prompts[0].tool_choice).to eq("search") expect(prompts[0].tool_choice).to eq("search")
expect(prompts[1].tool_choice).to eq(nil) expect(prompts[1].tool_choice).to eq(nil)
ai_persona.update!(forced_tool_count: 1)
responses = ["no tool call here"]
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
new_post = Fabricate(:post, raw: "Will you use the custom tool?", topic: reply_post.topic)
_reply_post = playground.reply_to(new_post)
prompts = _prompts
end
expect(prompts.length).to eq(1)
expect(prompts[0].tool_choice).to eq(nil)
end end
it "uses custom tool in conversation" do it "uses custom tool in conversation" do

View File

@ -39,9 +39,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
Fabricate( Fabricate(
:ai_persona, :ai_persona,
name: "search2", name: "search2",
tools: [["SearchCommand", { base_query: "test" }]], tools: [["SearchCommand", { base_query: "test" }, true]],
mentionable: true, mentionable: true,
default_llm: "anthropic:claude-2", default_llm: "anthropic:claude-2",
forced_tool_count: 2,
) )
persona2.create_user! persona2.create_user!
@ -55,6 +56,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
expect(serializer_persona2["default_llm"]).to eq("anthropic:claude-2") expect(serializer_persona2["default_llm"]).to eq("anthropic:claude-2")
expect(serializer_persona2["user_id"]).to eq(persona2.user_id) expect(serializer_persona2["user_id"]).to eq(persona2.user_id)
expect(serializer_persona2["user"]["id"]).to eq(persona2.user_id) expect(serializer_persona2["user"]["id"]).to eq(persona2.user_id)
expect(serializer_persona2["forced_tool_count"]).to eq(2)
tools = response.parsed_body["meta"]["tools"] tools = response.parsed_body["meta"]["tools"]
search_tool = tools.find { |c| c["id"] == "Search" } search_tool = tools.find { |c| c["id"] == "Search" }
@ -85,7 +87,9 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
) )
expect(serializer_persona1["tools"]).to eq(["SearchCommand"]) expect(serializer_persona1["tools"]).to eq(["SearchCommand"])
expect(serializer_persona2["tools"]).to eq([["SearchCommand", { "base_query" => "test" }]]) expect(serializer_persona2["tools"]).to eq(
[["SearchCommand", { "base_query" => "test" }, true]],
)
end end
context "with translations" do context "with translations" do
@ -165,6 +169,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
temperature: 0.5, temperature: 0.5,
mentionable: true, mentionable: true,
default_llm: "anthropic:claude-2", default_llm: "anthropic:claude-2",
forced_tool_count: 2,
} }
end end
@ -183,6 +188,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
expect(persona_json["temperature"]).to eq(0.5) expect(persona_json["temperature"]).to eq(0.5)
expect(persona_json["mentionable"]).to eq(true) expect(persona_json["mentionable"]).to eq(true)
expect(persona_json["default_llm"]).to eq("anthropic:claude-2") expect(persona_json["default_llm"]).to eq("anthropic:claude-2")
expect(persona_json["forced_tool_count"]).to eq(2)
persona = AiPersona.find(persona_json["id"]) persona = AiPersona.find(persona_json["id"])

View File

@ -19,6 +19,17 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do
tool_selector = PageObjects::Components::SelectKit.new(".ai-persona-editor__tools") tool_selector = PageObjects::Components::SelectKit.new(".ai-persona-editor__tools")
tool_selector.expand tool_selector.expand
tool_selector.select_row_by_value("Read") tool_selector.select_row_by_value("Read")
tool_selector.collapse
tool_selector = PageObjects::Components::SelectKit.new(".ai-persona-editor__forced_tools")
tool_selector.expand
tool_selector.select_row_by_value("Read")
tool_selector.collapse
strategy_selector =
PageObjects::Components::SelectKit.new(".ai-persona-editor__forced_tool_strategy")
strategy_selector.expand
strategy_selector.select_row_by_value(1)
find(".ai-persona-editor__save").click() find(".ai-persona-editor__save").click()
@ -30,7 +41,8 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do
expect(persona.name).to eq("Test Persona") expect(persona.name).to eq("Test Persona")
expect(persona.description).to eq("I am a test persona") expect(persona.description).to eq("I am a test persona")
expect(persona.system_prompt).to eq("You are a helpful bot") expect(persona.system_prompt).to eq("You are a helpful bot")
expect(persona.tools).to eq([["Read", { "read_private" => nil }, false]]) expect(persona.tools).to eq([["Read", { "read_private" => nil }, true]])
expect(persona.forced_tool_count).to eq(1)
end end
it "will not allow deletion or editing of system personas" do it "will not allow deletion or editing of system personas" do

View File

@ -51,6 +51,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
question_consolidator_llm: "Question Consolidator LLM", question_consolidator_llm: "Question Consolidator LLM",
allow_chat: false, allow_chat: false,
tool_details: true, tool_details: true,
forced_tool_count: -1,
}; };
const aiPersona = AiPersona.create({ ...properties }); const aiPersona = AiPersona.create({ ...properties });
@ -92,6 +93,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
question_consolidator_llm: "Question Consolidator LLM", question_consolidator_llm: "Question Consolidator LLM",
allow_chat: false, allow_chat: false,
tool_details: true, tool_details: true,
forced_tool_count: -1,
}; };
const aiPersona = AiPersona.create({ ...properties }); const aiPersona = AiPersona.create({ ...properties });