FEATURE: allow persona to only force tool calls on limited replies (#827)

This introduces another configuration that allows operators to
limit the amount of interactions with forced tool usage.

Forced tools are very handy in initial llm interactions, but as
conversation progresses they can hinder by slowing down stuff
and adding confusion.
This commit is contained in:
Sam 2024-10-11 07:23:42 +11:00 committed by GitHub
parent 52d90cf1bc
commit 6c4c96e83c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 149 additions and 34 deletions

View File

@ -106,6 +106,7 @@ module DiscourseAi
:question_consolidator_llm,
:allow_chat,
:tool_details,
:forced_tool_count,
allowed_group_ids: [],
rag_uploads: [:id],
)

View File

@ -20,6 +20,7 @@ class AiPersona < ActiveRecord::Base
validates :rag_chunk_tokens, numericality: { greater_than: 0, maximum: 50_000 }
validates :rag_chunk_overlap_tokens, numericality: { greater_than: -1, maximum: 200 }
validates :rag_conversation_chunks, numericality: { greater_than: 0, maximum: 1000 }
validates :forced_tool_count, numericality: { greater_than: -2, maximum: 100_000 }
has_many :rag_document_fragments, dependent: :destroy, as: :target
belongs_to :created_by, class_name: "User"
@ -185,6 +186,7 @@ class AiPersona < ActiveRecord::Base
define_method(:tools) { tools }
define_method(:force_tool_use) { force_tool_use }
define_method(:forced_tool_count) { @ai_persona&.forced_tool_count }
define_method(:options) { options }
define_method(:temperature) { @ai_persona&.temperature }
define_method(:top_p) { @ai_persona&.top_p }
@ -265,32 +267,40 @@ end
#
# Table name: ai_personas
#
# id :bigint not null, primary key
# name :string(100) not null
# description :string(2000) not null
# system_prompt :string(10000000) not null
# allowed_group_ids :integer default([]), not null, is an Array
# created_by_id :integer
# enabled :boolean default(TRUE), not null
# created_at :datetime not null
# updated_at :datetime not null
# system :boolean default(FALSE), not null
# priority :boolean default(FALSE), not null
# temperature :float
# top_p :float
# user_id :integer
# mentionable :boolean default(FALSE), not null
# default_llm :text
# max_context_posts :integer
# vision_enabled :boolean default(FALSE), not null
# vision_max_pixels :integer default(1048576), not null
# rag_chunk_tokens :integer default(374), not null
# rag_chunk_overlap_tokens :integer default(10), not null
# rag_conversation_chunks :integer default(10), not null
# question_consolidator_llm :text
# allow_chat :boolean default(FALSE), not null
# tool_details :boolean default(TRUE), not null
# tools :json not null
# id :bigint not null, primary key
# name :string(100) not null
# description :string(2000) not null
# system_prompt :string(10000000) not null
# allowed_group_ids :integer default([]), not null, is an Array
# created_by_id :integer
# enabled :boolean default(TRUE), not null
# created_at :datetime not null
# updated_at :datetime not null
# system :boolean default(FALSE), not null
# priority :boolean default(FALSE), not null
# temperature :float
# top_p :float
# user_id :integer
# mentionable :boolean default(FALSE), not null
# default_llm :text
# max_context_posts :integer
# max_post_context_tokens :integer
# max_context_tokens :integer
# vision_enabled :boolean default(FALSE), not null
# vision_max_pixels :integer default(1048576), not null
# rag_chunk_tokens :integer default(374), not null
# rag_chunk_overlap_tokens :integer default(10), not null
# rag_conversation_chunks :integer default(10), not null
# role :enum default("bot"), not null
# role_category_ids :integer default([]), not null, is an Array
# role_tags :string default([]), not null, is an Array
# role_group_ids :integer default([]), not null, is an Array
# role_whispers :boolean default(FALSE), not null
# role_max_responses_per_hour :integer default(50), not null
# question_consolidator_llm :text
# allow_chat :boolean default(FALSE), not null
# tool_details :boolean default(TRUE), not null
# tools :json not null
#
# Indexes
#

View File

@ -25,7 +25,8 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer
:rag_conversation_chunks,
:question_consolidator_llm,
:allow_chat,
:tool_details
:tool_details,
:forced_tool_count
has_one :user, serializer: BasicUserSerializer, embed: :object
has_many :rag_uploads, serializer: UploadSerializer, embed: :object

View File

@ -28,6 +28,7 @@ const CREATE_ATTRIBUTES = [
"question_consolidator_llm",
"allow_chat",
"tool_details",
"forced_tool_count",
];
const SYSTEM_ATTRIBUTES = [
@ -154,6 +155,7 @@ export default class AiPersona extends RestModel {
const persona = AiPersona.create(attrs);
persona.forcedTools = (this.forcedTools || []).slice();
persona.forced_tool_count = this.forced_tool_count || -1;
return persona;
}
}

View File

@ -0,0 +1,29 @@
import { computed } from "@ember/object";
import I18n from "discourse-i18n";
import ComboBox from "select-kit/components/combo-box";
export default ComboBox.extend({
content: computed(function () {
const content = [
{
id: -1,
name: I18n.t("discourse_ai.ai_persona.tool_strategies.all"),
},
];
[1, 2, 5].forEach((i) => {
content.push({
id: i,
name: I18n.t("discourse_ai.ai_persona.tool_strategies.replies", {
count: i,
}),
});
});
return content;
}),
selectKitOptions: {
filterable: false,
},
});

View File

@ -20,6 +20,7 @@ import AdminUser from "admin/models/admin-user";
import ComboBox from "select-kit/components/combo-box";
import GroupChooser from "select-kit/components/group-chooser";
import DTooltip from "float-kit/components/d-tooltip";
import AiForcedToolStrategySelector from "./ai-forced-tool-strategy-selector";
import AiLlmSelector from "./ai-llm-selector";
import AiPersonaToolOptions from "./ai-persona-tool-options";
import AiToolSelector from "./ai-tool-selector";
@ -49,7 +50,11 @@ export default class PersonaEditor extends Component {
}
get allowForceTools() {
return !this.editingModel?.system && this.editingModel?.tools?.length > 0;
return !this.editingModel?.system && this.selectedToolNames.length > 0;
}
get hasForcedTools() {
return this.forcedToolNames.length > 0;
}
@action
@ -381,12 +386,23 @@ export default class PersonaEditor extends Component {
<div class="control-group">
<label>{{I18n.t "discourse_ai.ai_persona.forced_tools"}}</label>
<AiToolSelector
class="ai-persona-editor__tools"
class="ai-persona-editor__forced_tools"
@value={{this.forcedToolNames}}
@tools={{this.selectedTools}}
@onChange={{this.forcedToolsChanged}}
/>
</div>
{{#if this.hasForcedTools}}
<div class="control-group">
<label>{{I18n.t
"discourse_ai.ai_persona.forced_tool_strategy"
}}</label>
<AiForcedToolStrategySelector
class="ai-persona-editor__forced_tool_strategy"
@value={{this.editingModel.forced_tool_count}}
/>
</div>
{{/if}}
{{/if}}
{{#unless this.editingModel.system}}
<AiPersonaToolOptions

View File

@ -116,6 +116,11 @@ en:
select_option: "Select an option..."
ai_persona:
tool_strategies:
all: "Apply to all replies"
replies:
one: "Apply to first reply only"
other: "Apply to first %{count} replies"
back: Back
name: Name
edit: Edit
@ -142,6 +147,7 @@ en:
question_consolidator_llm: Language Model for Question Consolidator
question_consolidator_llm_help: The language model to use for the question consolidator, you may choose a less powerful model to save costs.
system_prompt: System Prompt
forced_tool_strategy: Forced Tool Strategy
allow_chat: "Allow Chat"
allow_chat_help: "If enabled, users in allowed groups can DM this persona"
save: Save

View File

@ -0,0 +1,7 @@
# frozen_string_literal: true
class AddForcedToolCountToAiPersonas < ActiveRecord::Migration[7.1]
def change
add_column :ai_personas, :forced_tool_count, :integer, default: -1, null: false
end
end

View File

@ -72,6 +72,11 @@ module DiscourseAi
forced_tools = persona.force_tool_use.map { |tool| tool.name }
force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) }
if force_tool && persona.forced_tool_count > 0
user_turns = prompt.messages.select { |m| m[:type] == :user }.length
force_tool = false if user_turns > persona.forced_tool_count
end
if force_tool
context[:chosen_tools] << force_tool
prompt.tool_choice = force_tool

View File

@ -117,6 +117,10 @@ module DiscourseAi
[]
end
def forced_tool_count
-1
end
def required_tools
[]
end

View File

@ -127,19 +127,33 @@ RSpec.describe DiscourseAi::AiBot::Playground do
it "can force usage of a tool" do
tool_name = "custom-#{custom_tool.id}"
ai_persona.update!(tools: [[tool_name, nil, true]])
ai_persona.update!(tools: [[tool_name, nil, true]], forced_tool_count: 1)
responses = [function_call, "custom tool did stuff (maybe)"]
prompts = nil
reply_post = nil
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
_reply_post = playground.reply_to(new_post)
reply_post = playground.reply_to(new_post)
prompts = _prompts
end
expect(prompts.length).to eq(2)
expect(prompts[0].tool_choice).to eq("search")
expect(prompts[1].tool_choice).to eq(nil)
ai_persona.update!(forced_tool_count: 1)
responses = ["no tool call here"]
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
new_post = Fabricate(:post, raw: "Will you use the custom tool?", topic: reply_post.topic)
_reply_post = playground.reply_to(new_post)
prompts = _prompts
end
expect(prompts.length).to eq(1)
expect(prompts[0].tool_choice).to eq(nil)
end
it "uses custom tool in conversation" do

View File

@ -39,9 +39,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
Fabricate(
:ai_persona,
name: "search2",
tools: [["SearchCommand", { base_query: "test" }]],
tools: [["SearchCommand", { base_query: "test" }, true]],
mentionable: true,
default_llm: "anthropic:claude-2",
forced_tool_count: 2,
)
persona2.create_user!
@ -55,6 +56,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
expect(serializer_persona2["default_llm"]).to eq("anthropic:claude-2")
expect(serializer_persona2["user_id"]).to eq(persona2.user_id)
expect(serializer_persona2["user"]["id"]).to eq(persona2.user_id)
expect(serializer_persona2["forced_tool_count"]).to eq(2)
tools = response.parsed_body["meta"]["tools"]
search_tool = tools.find { |c| c["id"] == "Search" }
@ -85,7 +87,9 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
)
expect(serializer_persona1["tools"]).to eq(["SearchCommand"])
expect(serializer_persona2["tools"]).to eq([["SearchCommand", { "base_query" => "test" }]])
expect(serializer_persona2["tools"]).to eq(
[["SearchCommand", { "base_query" => "test" }, true]],
)
end
context "with translations" do
@ -165,6 +169,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
temperature: 0.5,
mentionable: true,
default_llm: "anthropic:claude-2",
forced_tool_count: 2,
}
end
@ -183,6 +188,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
expect(persona_json["temperature"]).to eq(0.5)
expect(persona_json["mentionable"]).to eq(true)
expect(persona_json["default_llm"]).to eq("anthropic:claude-2")
expect(persona_json["forced_tool_count"]).to eq(2)
persona = AiPersona.find(persona_json["id"])

View File

@ -19,6 +19,17 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do
tool_selector = PageObjects::Components::SelectKit.new(".ai-persona-editor__tools")
tool_selector.expand
tool_selector.select_row_by_value("Read")
tool_selector.collapse
tool_selector = PageObjects::Components::SelectKit.new(".ai-persona-editor__forced_tools")
tool_selector.expand
tool_selector.select_row_by_value("Read")
tool_selector.collapse
strategy_selector =
PageObjects::Components::SelectKit.new(".ai-persona-editor__forced_tool_strategy")
strategy_selector.expand
strategy_selector.select_row_by_value(1)
find(".ai-persona-editor__save").click()
@ -30,7 +41,8 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do
expect(persona.name).to eq("Test Persona")
expect(persona.description).to eq("I am a test persona")
expect(persona.system_prompt).to eq("You are a helpful bot")
expect(persona.tools).to eq([["Read", { "read_private" => nil }, false]])
expect(persona.tools).to eq([["Read", { "read_private" => nil }, true]])
expect(persona.forced_tool_count).to eq(1)
end
it "will not allow deletion or editing of system personas" do

View File

@ -51,6 +51,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
question_consolidator_llm: "Question Consolidator LLM",
allow_chat: false,
tool_details: true,
forced_tool_count: -1,
};
const aiPersona = AiPersona.create({ ...properties });
@ -92,6 +93,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
question_consolidator_llm: "Question Consolidator LLM",
allow_chat: false,
tool_details: true,
forced_tool_count: -1,
};
const aiPersona = AiPersona.create({ ...properties });