FEATURE: allow persona to only force tool calls on limited replies (#827)
This introduces another configuration that allows operators to limit the amount of interactions with forced tool usage. Forced tools are very handy in initial llm interactions, but as conversation progresses they can hinder by slowing down stuff and adding confusion.
This commit is contained in:
parent
52d90cf1bc
commit
6c4c96e83c
|
@ -106,6 +106,7 @@ module DiscourseAi
|
|||
:question_consolidator_llm,
|
||||
:allow_chat,
|
||||
:tool_details,
|
||||
:forced_tool_count,
|
||||
allowed_group_ids: [],
|
||||
rag_uploads: [:id],
|
||||
)
|
||||
|
|
|
@ -20,6 +20,7 @@ class AiPersona < ActiveRecord::Base
|
|||
validates :rag_chunk_tokens, numericality: { greater_than: 0, maximum: 50_000 }
|
||||
validates :rag_chunk_overlap_tokens, numericality: { greater_than: -1, maximum: 200 }
|
||||
validates :rag_conversation_chunks, numericality: { greater_than: 0, maximum: 1000 }
|
||||
validates :forced_tool_count, numericality: { greater_than: -2, maximum: 100_000 }
|
||||
has_many :rag_document_fragments, dependent: :destroy, as: :target
|
||||
|
||||
belongs_to :created_by, class_name: "User"
|
||||
|
@ -185,6 +186,7 @@ class AiPersona < ActiveRecord::Base
|
|||
|
||||
define_method(:tools) { tools }
|
||||
define_method(:force_tool_use) { force_tool_use }
|
||||
define_method(:forced_tool_count) { @ai_persona&.forced_tool_count }
|
||||
define_method(:options) { options }
|
||||
define_method(:temperature) { @ai_persona&.temperature }
|
||||
define_method(:top_p) { @ai_persona&.top_p }
|
||||
|
@ -265,32 +267,40 @@ end
|
|||
#
|
||||
# Table name: ai_personas
|
||||
#
|
||||
# id :bigint not null, primary key
|
||||
# name :string(100) not null
|
||||
# description :string(2000) not null
|
||||
# system_prompt :string(10000000) not null
|
||||
# allowed_group_ids :integer default([]), not null, is an Array
|
||||
# created_by_id :integer
|
||||
# enabled :boolean default(TRUE), not null
|
||||
# created_at :datetime not null
|
||||
# updated_at :datetime not null
|
||||
# system :boolean default(FALSE), not null
|
||||
# priority :boolean default(FALSE), not null
|
||||
# temperature :float
|
||||
# top_p :float
|
||||
# user_id :integer
|
||||
# mentionable :boolean default(FALSE), not null
|
||||
# default_llm :text
|
||||
# max_context_posts :integer
|
||||
# vision_enabled :boolean default(FALSE), not null
|
||||
# vision_max_pixels :integer default(1048576), not null
|
||||
# rag_chunk_tokens :integer default(374), not null
|
||||
# rag_chunk_overlap_tokens :integer default(10), not null
|
||||
# rag_conversation_chunks :integer default(10), not null
|
||||
# question_consolidator_llm :text
|
||||
# allow_chat :boolean default(FALSE), not null
|
||||
# tool_details :boolean default(TRUE), not null
|
||||
# tools :json not null
|
||||
# id :bigint not null, primary key
|
||||
# name :string(100) not null
|
||||
# description :string(2000) not null
|
||||
# system_prompt :string(10000000) not null
|
||||
# allowed_group_ids :integer default([]), not null, is an Array
|
||||
# created_by_id :integer
|
||||
# enabled :boolean default(TRUE), not null
|
||||
# created_at :datetime not null
|
||||
# updated_at :datetime not null
|
||||
# system :boolean default(FALSE), not null
|
||||
# priority :boolean default(FALSE), not null
|
||||
# temperature :float
|
||||
# top_p :float
|
||||
# user_id :integer
|
||||
# mentionable :boolean default(FALSE), not null
|
||||
# default_llm :text
|
||||
# max_context_posts :integer
|
||||
# max_post_context_tokens :integer
|
||||
# max_context_tokens :integer
|
||||
# vision_enabled :boolean default(FALSE), not null
|
||||
# vision_max_pixels :integer default(1048576), not null
|
||||
# rag_chunk_tokens :integer default(374), not null
|
||||
# rag_chunk_overlap_tokens :integer default(10), not null
|
||||
# rag_conversation_chunks :integer default(10), not null
|
||||
# role :enum default("bot"), not null
|
||||
# role_category_ids :integer default([]), not null, is an Array
|
||||
# role_tags :string default([]), not null, is an Array
|
||||
# role_group_ids :integer default([]), not null, is an Array
|
||||
# role_whispers :boolean default(FALSE), not null
|
||||
# role_max_responses_per_hour :integer default(50), not null
|
||||
# question_consolidator_llm :text
|
||||
# allow_chat :boolean default(FALSE), not null
|
||||
# tool_details :boolean default(TRUE), not null
|
||||
# tools :json not null
|
||||
#
|
||||
# Indexes
|
||||
#
|
||||
|
|
|
@ -25,7 +25,8 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer
|
|||
:rag_conversation_chunks,
|
||||
:question_consolidator_llm,
|
||||
:allow_chat,
|
||||
:tool_details
|
||||
:tool_details,
|
||||
:forced_tool_count
|
||||
|
||||
has_one :user, serializer: BasicUserSerializer, embed: :object
|
||||
has_many :rag_uploads, serializer: UploadSerializer, embed: :object
|
||||
|
|
|
@ -28,6 +28,7 @@ const CREATE_ATTRIBUTES = [
|
|||
"question_consolidator_llm",
|
||||
"allow_chat",
|
||||
"tool_details",
|
||||
"forced_tool_count",
|
||||
];
|
||||
|
||||
const SYSTEM_ATTRIBUTES = [
|
||||
|
@ -154,6 +155,7 @@ export default class AiPersona extends RestModel {
|
|||
|
||||
const persona = AiPersona.create(attrs);
|
||||
persona.forcedTools = (this.forcedTools || []).slice();
|
||||
persona.forced_tool_count = this.forced_tool_count || -1;
|
||||
return persona;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
import { computed } from "@ember/object";
|
||||
import I18n from "discourse-i18n";
|
||||
import ComboBox from "select-kit/components/combo-box";
|
||||
|
||||
export default ComboBox.extend({
|
||||
content: computed(function () {
|
||||
const content = [
|
||||
{
|
||||
id: -1,
|
||||
name: I18n.t("discourse_ai.ai_persona.tool_strategies.all"),
|
||||
},
|
||||
];
|
||||
|
||||
[1, 2, 5].forEach((i) => {
|
||||
content.push({
|
||||
id: i,
|
||||
name: I18n.t("discourse_ai.ai_persona.tool_strategies.replies", {
|
||||
count: i,
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
return content;
|
||||
}),
|
||||
|
||||
selectKitOptions: {
|
||||
filterable: false,
|
||||
},
|
||||
});
|
|
@ -20,6 +20,7 @@ import AdminUser from "admin/models/admin-user";
|
|||
import ComboBox from "select-kit/components/combo-box";
|
||||
import GroupChooser from "select-kit/components/group-chooser";
|
||||
import DTooltip from "float-kit/components/d-tooltip";
|
||||
import AiForcedToolStrategySelector from "./ai-forced-tool-strategy-selector";
|
||||
import AiLlmSelector from "./ai-llm-selector";
|
||||
import AiPersonaToolOptions from "./ai-persona-tool-options";
|
||||
import AiToolSelector from "./ai-tool-selector";
|
||||
|
@ -49,7 +50,11 @@ export default class PersonaEditor extends Component {
|
|||
}
|
||||
|
||||
get allowForceTools() {
|
||||
return !this.editingModel?.system && this.editingModel?.tools?.length > 0;
|
||||
return !this.editingModel?.system && this.selectedToolNames.length > 0;
|
||||
}
|
||||
|
||||
get hasForcedTools() {
|
||||
return this.forcedToolNames.length > 0;
|
||||
}
|
||||
|
||||
@action
|
||||
|
@ -381,12 +386,23 @@ export default class PersonaEditor extends Component {
|
|||
<div class="control-group">
|
||||
<label>{{I18n.t "discourse_ai.ai_persona.forced_tools"}}</label>
|
||||
<AiToolSelector
|
||||
class="ai-persona-editor__tools"
|
||||
class="ai-persona-editor__forced_tools"
|
||||
@value={{this.forcedToolNames}}
|
||||
@tools={{this.selectedTools}}
|
||||
@onChange={{this.forcedToolsChanged}}
|
||||
/>
|
||||
</div>
|
||||
{{#if this.hasForcedTools}}
|
||||
<div class="control-group">
|
||||
<label>{{I18n.t
|
||||
"discourse_ai.ai_persona.forced_tool_strategy"
|
||||
}}</label>
|
||||
<AiForcedToolStrategySelector
|
||||
class="ai-persona-editor__forced_tool_strategy"
|
||||
@value={{this.editingModel.forced_tool_count}}
|
||||
/>
|
||||
</div>
|
||||
{{/if}}
|
||||
{{/if}}
|
||||
{{#unless this.editingModel.system}}
|
||||
<AiPersonaToolOptions
|
||||
|
|
|
@ -116,6 +116,11 @@ en:
|
|||
select_option: "Select an option..."
|
||||
|
||||
ai_persona:
|
||||
tool_strategies:
|
||||
all: "Apply to all replies"
|
||||
replies:
|
||||
one: "Apply to first reply only"
|
||||
other: "Apply to first %{count} replies"
|
||||
back: Back
|
||||
name: Name
|
||||
edit: Edit
|
||||
|
@ -142,6 +147,7 @@ en:
|
|||
question_consolidator_llm: Language Model for Question Consolidator
|
||||
question_consolidator_llm_help: The language model to use for the question consolidator, you may choose a less powerful model to save costs.
|
||||
system_prompt: System Prompt
|
||||
forced_tool_strategy: Forced Tool Strategy
|
||||
allow_chat: "Allow Chat"
|
||||
allow_chat_help: "If enabled, users in allowed groups can DM this persona"
|
||||
save: Save
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class AddForcedToolCountToAiPersonas < ActiveRecord::Migration[7.1]
|
||||
def change
|
||||
add_column :ai_personas, :forced_tool_count, :integer, default: -1, null: false
|
||||
end
|
||||
end
|
|
@ -72,6 +72,11 @@ module DiscourseAi
|
|||
forced_tools = persona.force_tool_use.map { |tool| tool.name }
|
||||
force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) }
|
||||
|
||||
if force_tool && persona.forced_tool_count > 0
|
||||
user_turns = prompt.messages.select { |m| m[:type] == :user }.length
|
||||
force_tool = false if user_turns > persona.forced_tool_count
|
||||
end
|
||||
|
||||
if force_tool
|
||||
context[:chosen_tools] << force_tool
|
||||
prompt.tool_choice = force_tool
|
||||
|
|
|
@ -117,6 +117,10 @@ module DiscourseAi
|
|||
[]
|
||||
end
|
||||
|
||||
def forced_tool_count
|
||||
-1
|
||||
end
|
||||
|
||||
def required_tools
|
||||
[]
|
||||
end
|
||||
|
|
|
@ -127,19 +127,33 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
|
||||
it "can force usage of a tool" do
|
||||
tool_name = "custom-#{custom_tool.id}"
|
||||
ai_persona.update!(tools: [[tool_name, nil, true]])
|
||||
ai_persona.update!(tools: [[tool_name, nil, true]], forced_tool_count: 1)
|
||||
responses = [function_call, "custom tool did stuff (maybe)"]
|
||||
|
||||
prompts = nil
|
||||
reply_post = nil
|
||||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
|
||||
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
|
||||
_reply_post = playground.reply_to(new_post)
|
||||
reply_post = playground.reply_to(new_post)
|
||||
prompts = _prompts
|
||||
end
|
||||
|
||||
expect(prompts.length).to eq(2)
|
||||
expect(prompts[0].tool_choice).to eq("search")
|
||||
expect(prompts[1].tool_choice).to eq(nil)
|
||||
|
||||
ai_persona.update!(forced_tool_count: 1)
|
||||
responses = ["no tool call here"]
|
||||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
|
||||
new_post = Fabricate(:post, raw: "Will you use the custom tool?", topic: reply_post.topic)
|
||||
_reply_post = playground.reply_to(new_post)
|
||||
prompts = _prompts
|
||||
end
|
||||
|
||||
expect(prompts.length).to eq(1)
|
||||
expect(prompts[0].tool_choice).to eq(nil)
|
||||
end
|
||||
|
||||
it "uses custom tool in conversation" do
|
||||
|
|
|
@ -39,9 +39,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
Fabricate(
|
||||
:ai_persona,
|
||||
name: "search2",
|
||||
tools: [["SearchCommand", { base_query: "test" }]],
|
||||
tools: [["SearchCommand", { base_query: "test" }, true]],
|
||||
mentionable: true,
|
||||
default_llm: "anthropic:claude-2",
|
||||
forced_tool_count: 2,
|
||||
)
|
||||
persona2.create_user!
|
||||
|
||||
|
@ -55,6 +56,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
expect(serializer_persona2["default_llm"]).to eq("anthropic:claude-2")
|
||||
expect(serializer_persona2["user_id"]).to eq(persona2.user_id)
|
||||
expect(serializer_persona2["user"]["id"]).to eq(persona2.user_id)
|
||||
expect(serializer_persona2["forced_tool_count"]).to eq(2)
|
||||
|
||||
tools = response.parsed_body["meta"]["tools"]
|
||||
search_tool = tools.find { |c| c["id"] == "Search" }
|
||||
|
@ -85,7 +87,9 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
)
|
||||
|
||||
expect(serializer_persona1["tools"]).to eq(["SearchCommand"])
|
||||
expect(serializer_persona2["tools"]).to eq([["SearchCommand", { "base_query" => "test" }]])
|
||||
expect(serializer_persona2["tools"]).to eq(
|
||||
[["SearchCommand", { "base_query" => "test" }, true]],
|
||||
)
|
||||
end
|
||||
|
||||
context "with translations" do
|
||||
|
@ -165,6 +169,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
temperature: 0.5,
|
||||
mentionable: true,
|
||||
default_llm: "anthropic:claude-2",
|
||||
forced_tool_count: 2,
|
||||
}
|
||||
end
|
||||
|
||||
|
@ -183,6 +188,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
expect(persona_json["temperature"]).to eq(0.5)
|
||||
expect(persona_json["mentionable"]).to eq(true)
|
||||
expect(persona_json["default_llm"]).to eq("anthropic:claude-2")
|
||||
expect(persona_json["forced_tool_count"]).to eq(2)
|
||||
|
||||
persona = AiPersona.find(persona_json["id"])
|
||||
|
||||
|
|
|
@ -19,6 +19,17 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do
|
|||
tool_selector = PageObjects::Components::SelectKit.new(".ai-persona-editor__tools")
|
||||
tool_selector.expand
|
||||
tool_selector.select_row_by_value("Read")
|
||||
tool_selector.collapse
|
||||
|
||||
tool_selector = PageObjects::Components::SelectKit.new(".ai-persona-editor__forced_tools")
|
||||
tool_selector.expand
|
||||
tool_selector.select_row_by_value("Read")
|
||||
tool_selector.collapse
|
||||
|
||||
strategy_selector =
|
||||
PageObjects::Components::SelectKit.new(".ai-persona-editor__forced_tool_strategy")
|
||||
strategy_selector.expand
|
||||
strategy_selector.select_row_by_value(1)
|
||||
|
||||
find(".ai-persona-editor__save").click()
|
||||
|
||||
|
@ -30,7 +41,8 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do
|
|||
expect(persona.name).to eq("Test Persona")
|
||||
expect(persona.description).to eq("I am a test persona")
|
||||
expect(persona.system_prompt).to eq("You are a helpful bot")
|
||||
expect(persona.tools).to eq([["Read", { "read_private" => nil }, false]])
|
||||
expect(persona.tools).to eq([["Read", { "read_private" => nil }, true]])
|
||||
expect(persona.forced_tool_count).to eq(1)
|
||||
end
|
||||
|
||||
it "will not allow deletion or editing of system personas" do
|
||||
|
|
|
@ -51,6 +51,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
|
|||
question_consolidator_llm: "Question Consolidator LLM",
|
||||
allow_chat: false,
|
||||
tool_details: true,
|
||||
forced_tool_count: -1,
|
||||
};
|
||||
|
||||
const aiPersona = AiPersona.create({ ...properties });
|
||||
|
@ -92,6 +93,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
|
|||
question_consolidator_llm: "Question Consolidator LLM",
|
||||
allow_chat: false,
|
||||
tool_details: true,
|
||||
forced_tool_count: -1,
|
||||
};
|
||||
|
||||
const aiPersona = AiPersona.create({ ...properties });
|
||||
|
|
Loading…
Reference in New Issue