FEATURE: allow persona to only force tool calls on limited replies (#827)
This introduces another configuration that allows operators to limit the amount of interactions with forced tool usage. Forced tools are very handy in initial llm interactions, but as conversation progresses they can hinder by slowing down stuff and adding confusion.
This commit is contained in:
parent
52d90cf1bc
commit
6c4c96e83c
|
@ -106,6 +106,7 @@ module DiscourseAi
|
||||||
:question_consolidator_llm,
|
:question_consolidator_llm,
|
||||||
:allow_chat,
|
:allow_chat,
|
||||||
:tool_details,
|
:tool_details,
|
||||||
|
:forced_tool_count,
|
||||||
allowed_group_ids: [],
|
allowed_group_ids: [],
|
||||||
rag_uploads: [:id],
|
rag_uploads: [:id],
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,6 +20,7 @@ class AiPersona < ActiveRecord::Base
|
||||||
validates :rag_chunk_tokens, numericality: { greater_than: 0, maximum: 50_000 }
|
validates :rag_chunk_tokens, numericality: { greater_than: 0, maximum: 50_000 }
|
||||||
validates :rag_chunk_overlap_tokens, numericality: { greater_than: -1, maximum: 200 }
|
validates :rag_chunk_overlap_tokens, numericality: { greater_than: -1, maximum: 200 }
|
||||||
validates :rag_conversation_chunks, numericality: { greater_than: 0, maximum: 1000 }
|
validates :rag_conversation_chunks, numericality: { greater_than: 0, maximum: 1000 }
|
||||||
|
validates :forced_tool_count, numericality: { greater_than: -2, maximum: 100_000 }
|
||||||
has_many :rag_document_fragments, dependent: :destroy, as: :target
|
has_many :rag_document_fragments, dependent: :destroy, as: :target
|
||||||
|
|
||||||
belongs_to :created_by, class_name: "User"
|
belongs_to :created_by, class_name: "User"
|
||||||
|
@ -185,6 +186,7 @@ class AiPersona < ActiveRecord::Base
|
||||||
|
|
||||||
define_method(:tools) { tools }
|
define_method(:tools) { tools }
|
||||||
define_method(:force_tool_use) { force_tool_use }
|
define_method(:force_tool_use) { force_tool_use }
|
||||||
|
define_method(:forced_tool_count) { @ai_persona&.forced_tool_count }
|
||||||
define_method(:options) { options }
|
define_method(:options) { options }
|
||||||
define_method(:temperature) { @ai_persona&.temperature }
|
define_method(:temperature) { @ai_persona&.temperature }
|
||||||
define_method(:top_p) { @ai_persona&.top_p }
|
define_method(:top_p) { @ai_persona&.top_p }
|
||||||
|
@ -282,11 +284,19 @@ end
|
||||||
# mentionable :boolean default(FALSE), not null
|
# mentionable :boolean default(FALSE), not null
|
||||||
# default_llm :text
|
# default_llm :text
|
||||||
# max_context_posts :integer
|
# max_context_posts :integer
|
||||||
|
# max_post_context_tokens :integer
|
||||||
|
# max_context_tokens :integer
|
||||||
# vision_enabled :boolean default(FALSE), not null
|
# vision_enabled :boolean default(FALSE), not null
|
||||||
# vision_max_pixels :integer default(1048576), not null
|
# vision_max_pixels :integer default(1048576), not null
|
||||||
# rag_chunk_tokens :integer default(374), not null
|
# rag_chunk_tokens :integer default(374), not null
|
||||||
# rag_chunk_overlap_tokens :integer default(10), not null
|
# rag_chunk_overlap_tokens :integer default(10), not null
|
||||||
# rag_conversation_chunks :integer default(10), not null
|
# rag_conversation_chunks :integer default(10), not null
|
||||||
|
# role :enum default("bot"), not null
|
||||||
|
# role_category_ids :integer default([]), not null, is an Array
|
||||||
|
# role_tags :string default([]), not null, is an Array
|
||||||
|
# role_group_ids :integer default([]), not null, is an Array
|
||||||
|
# role_whispers :boolean default(FALSE), not null
|
||||||
|
# role_max_responses_per_hour :integer default(50), not null
|
||||||
# question_consolidator_llm :text
|
# question_consolidator_llm :text
|
||||||
# allow_chat :boolean default(FALSE), not null
|
# allow_chat :boolean default(FALSE), not null
|
||||||
# tool_details :boolean default(TRUE), not null
|
# tool_details :boolean default(TRUE), not null
|
||||||
|
|
|
@ -25,7 +25,8 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer
|
||||||
:rag_conversation_chunks,
|
:rag_conversation_chunks,
|
||||||
:question_consolidator_llm,
|
:question_consolidator_llm,
|
||||||
:allow_chat,
|
:allow_chat,
|
||||||
:tool_details
|
:tool_details,
|
||||||
|
:forced_tool_count
|
||||||
|
|
||||||
has_one :user, serializer: BasicUserSerializer, embed: :object
|
has_one :user, serializer: BasicUserSerializer, embed: :object
|
||||||
has_many :rag_uploads, serializer: UploadSerializer, embed: :object
|
has_many :rag_uploads, serializer: UploadSerializer, embed: :object
|
||||||
|
|
|
@ -28,6 +28,7 @@ const CREATE_ATTRIBUTES = [
|
||||||
"question_consolidator_llm",
|
"question_consolidator_llm",
|
||||||
"allow_chat",
|
"allow_chat",
|
||||||
"tool_details",
|
"tool_details",
|
||||||
|
"forced_tool_count",
|
||||||
];
|
];
|
||||||
|
|
||||||
const SYSTEM_ATTRIBUTES = [
|
const SYSTEM_ATTRIBUTES = [
|
||||||
|
@ -154,6 +155,7 @@ export default class AiPersona extends RestModel {
|
||||||
|
|
||||||
const persona = AiPersona.create(attrs);
|
const persona = AiPersona.create(attrs);
|
||||||
persona.forcedTools = (this.forcedTools || []).slice();
|
persona.forcedTools = (this.forcedTools || []).slice();
|
||||||
|
persona.forced_tool_count = this.forced_tool_count || -1;
|
||||||
return persona;
|
return persona;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
import { computed } from "@ember/object";
|
||||||
|
import I18n from "discourse-i18n";
|
||||||
|
import ComboBox from "select-kit/components/combo-box";
|
||||||
|
|
||||||
|
export default ComboBox.extend({
|
||||||
|
content: computed(function () {
|
||||||
|
const content = [
|
||||||
|
{
|
||||||
|
id: -1,
|
||||||
|
name: I18n.t("discourse_ai.ai_persona.tool_strategies.all"),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
[1, 2, 5].forEach((i) => {
|
||||||
|
content.push({
|
||||||
|
id: i,
|
||||||
|
name: I18n.t("discourse_ai.ai_persona.tool_strategies.replies", {
|
||||||
|
count: i,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return content;
|
||||||
|
}),
|
||||||
|
|
||||||
|
selectKitOptions: {
|
||||||
|
filterable: false,
|
||||||
|
},
|
||||||
|
});
|
|
@ -20,6 +20,7 @@ import AdminUser from "admin/models/admin-user";
|
||||||
import ComboBox from "select-kit/components/combo-box";
|
import ComboBox from "select-kit/components/combo-box";
|
||||||
import GroupChooser from "select-kit/components/group-chooser";
|
import GroupChooser from "select-kit/components/group-chooser";
|
||||||
import DTooltip from "float-kit/components/d-tooltip";
|
import DTooltip from "float-kit/components/d-tooltip";
|
||||||
|
import AiForcedToolStrategySelector from "./ai-forced-tool-strategy-selector";
|
||||||
import AiLlmSelector from "./ai-llm-selector";
|
import AiLlmSelector from "./ai-llm-selector";
|
||||||
import AiPersonaToolOptions from "./ai-persona-tool-options";
|
import AiPersonaToolOptions from "./ai-persona-tool-options";
|
||||||
import AiToolSelector from "./ai-tool-selector";
|
import AiToolSelector from "./ai-tool-selector";
|
||||||
|
@ -49,7 +50,11 @@ export default class PersonaEditor extends Component {
|
||||||
}
|
}
|
||||||
|
|
||||||
get allowForceTools() {
|
get allowForceTools() {
|
||||||
return !this.editingModel?.system && this.editingModel?.tools?.length > 0;
|
return !this.editingModel?.system && this.selectedToolNames.length > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
get hasForcedTools() {
|
||||||
|
return this.forcedToolNames.length > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@action
|
@action
|
||||||
|
@ -381,12 +386,23 @@ export default class PersonaEditor extends Component {
|
||||||
<div class="control-group">
|
<div class="control-group">
|
||||||
<label>{{I18n.t "discourse_ai.ai_persona.forced_tools"}}</label>
|
<label>{{I18n.t "discourse_ai.ai_persona.forced_tools"}}</label>
|
||||||
<AiToolSelector
|
<AiToolSelector
|
||||||
class="ai-persona-editor__tools"
|
class="ai-persona-editor__forced_tools"
|
||||||
@value={{this.forcedToolNames}}
|
@value={{this.forcedToolNames}}
|
||||||
@tools={{this.selectedTools}}
|
@tools={{this.selectedTools}}
|
||||||
@onChange={{this.forcedToolsChanged}}
|
@onChange={{this.forcedToolsChanged}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
{{#if this.hasForcedTools}}
|
||||||
|
<div class="control-group">
|
||||||
|
<label>{{I18n.t
|
||||||
|
"discourse_ai.ai_persona.forced_tool_strategy"
|
||||||
|
}}</label>
|
||||||
|
<AiForcedToolStrategySelector
|
||||||
|
class="ai-persona-editor__forced_tool_strategy"
|
||||||
|
@value={{this.editingModel.forced_tool_count}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{{/if}}
|
||||||
{{/if}}
|
{{/if}}
|
||||||
{{#unless this.editingModel.system}}
|
{{#unless this.editingModel.system}}
|
||||||
<AiPersonaToolOptions
|
<AiPersonaToolOptions
|
||||||
|
|
|
@ -116,6 +116,11 @@ en:
|
||||||
select_option: "Select an option..."
|
select_option: "Select an option..."
|
||||||
|
|
||||||
ai_persona:
|
ai_persona:
|
||||||
|
tool_strategies:
|
||||||
|
all: "Apply to all replies"
|
||||||
|
replies:
|
||||||
|
one: "Apply to first reply only"
|
||||||
|
other: "Apply to first %{count} replies"
|
||||||
back: Back
|
back: Back
|
||||||
name: Name
|
name: Name
|
||||||
edit: Edit
|
edit: Edit
|
||||||
|
@ -142,6 +147,7 @@ en:
|
||||||
question_consolidator_llm: Language Model for Question Consolidator
|
question_consolidator_llm: Language Model for Question Consolidator
|
||||||
question_consolidator_llm_help: The language model to use for the question consolidator, you may choose a less powerful model to save costs.
|
question_consolidator_llm_help: The language model to use for the question consolidator, you may choose a less powerful model to save costs.
|
||||||
system_prompt: System Prompt
|
system_prompt: System Prompt
|
||||||
|
forced_tool_strategy: Forced Tool Strategy
|
||||||
allow_chat: "Allow Chat"
|
allow_chat: "Allow Chat"
|
||||||
allow_chat_help: "If enabled, users in allowed groups can DM this persona"
|
allow_chat_help: "If enabled, users in allowed groups can DM this persona"
|
||||||
save: Save
|
save: Save
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class AddForcedToolCountToAiPersonas < ActiveRecord::Migration[7.1]
|
||||||
|
def change
|
||||||
|
add_column :ai_personas, :forced_tool_count, :integer, default: -1, null: false
|
||||||
|
end
|
||||||
|
end
|
|
@ -72,6 +72,11 @@ module DiscourseAi
|
||||||
forced_tools = persona.force_tool_use.map { |tool| tool.name }
|
forced_tools = persona.force_tool_use.map { |tool| tool.name }
|
||||||
force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) }
|
force_tool = forced_tools.find { |name| !context[:chosen_tools].include?(name) }
|
||||||
|
|
||||||
|
if force_tool && persona.forced_tool_count > 0
|
||||||
|
user_turns = prompt.messages.select { |m| m[:type] == :user }.length
|
||||||
|
force_tool = false if user_turns > persona.forced_tool_count
|
||||||
|
end
|
||||||
|
|
||||||
if force_tool
|
if force_tool
|
||||||
context[:chosen_tools] << force_tool
|
context[:chosen_tools] << force_tool
|
||||||
prompt.tool_choice = force_tool
|
prompt.tool_choice = force_tool
|
||||||
|
|
|
@ -117,6 +117,10 @@ module DiscourseAi
|
||||||
[]
|
[]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def forced_tool_count
|
||||||
|
-1
|
||||||
|
end
|
||||||
|
|
||||||
def required_tools
|
def required_tools
|
||||||
[]
|
[]
|
||||||
end
|
end
|
||||||
|
|
|
@ -127,19 +127,33 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
|
|
||||||
it "can force usage of a tool" do
|
it "can force usage of a tool" do
|
||||||
tool_name = "custom-#{custom_tool.id}"
|
tool_name = "custom-#{custom_tool.id}"
|
||||||
ai_persona.update!(tools: [[tool_name, nil, true]])
|
ai_persona.update!(tools: [[tool_name, nil, true]], forced_tool_count: 1)
|
||||||
responses = [function_call, "custom tool did stuff (maybe)"]
|
responses = [function_call, "custom tool did stuff (maybe)"]
|
||||||
|
|
||||||
prompts = nil
|
prompts = nil
|
||||||
|
reply_post = nil
|
||||||
|
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
|
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
|
||||||
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
|
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
|
||||||
_reply_post = playground.reply_to(new_post)
|
reply_post = playground.reply_to(new_post)
|
||||||
prompts = _prompts
|
prompts = _prompts
|
||||||
end
|
end
|
||||||
|
|
||||||
expect(prompts.length).to eq(2)
|
expect(prompts.length).to eq(2)
|
||||||
expect(prompts[0].tool_choice).to eq("search")
|
expect(prompts[0].tool_choice).to eq("search")
|
||||||
expect(prompts[1].tool_choice).to eq(nil)
|
expect(prompts[1].tool_choice).to eq(nil)
|
||||||
|
|
||||||
|
ai_persona.update!(forced_tool_count: 1)
|
||||||
|
responses = ["no tool call here"]
|
||||||
|
|
||||||
|
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
|
||||||
|
new_post = Fabricate(:post, raw: "Will you use the custom tool?", topic: reply_post.topic)
|
||||||
|
_reply_post = playground.reply_to(new_post)
|
||||||
|
prompts = _prompts
|
||||||
|
end
|
||||||
|
|
||||||
|
expect(prompts.length).to eq(1)
|
||||||
|
expect(prompts[0].tool_choice).to eq(nil)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "uses custom tool in conversation" do
|
it "uses custom tool in conversation" do
|
||||||
|
|
|
@ -39,9 +39,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
||||||
Fabricate(
|
Fabricate(
|
||||||
:ai_persona,
|
:ai_persona,
|
||||||
name: "search2",
|
name: "search2",
|
||||||
tools: [["SearchCommand", { base_query: "test" }]],
|
tools: [["SearchCommand", { base_query: "test" }, true]],
|
||||||
mentionable: true,
|
mentionable: true,
|
||||||
default_llm: "anthropic:claude-2",
|
default_llm: "anthropic:claude-2",
|
||||||
|
forced_tool_count: 2,
|
||||||
)
|
)
|
||||||
persona2.create_user!
|
persona2.create_user!
|
||||||
|
|
||||||
|
@ -55,6 +56,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
||||||
expect(serializer_persona2["default_llm"]).to eq("anthropic:claude-2")
|
expect(serializer_persona2["default_llm"]).to eq("anthropic:claude-2")
|
||||||
expect(serializer_persona2["user_id"]).to eq(persona2.user_id)
|
expect(serializer_persona2["user_id"]).to eq(persona2.user_id)
|
||||||
expect(serializer_persona2["user"]["id"]).to eq(persona2.user_id)
|
expect(serializer_persona2["user"]["id"]).to eq(persona2.user_id)
|
||||||
|
expect(serializer_persona2["forced_tool_count"]).to eq(2)
|
||||||
|
|
||||||
tools = response.parsed_body["meta"]["tools"]
|
tools = response.parsed_body["meta"]["tools"]
|
||||||
search_tool = tools.find { |c| c["id"] == "Search" }
|
search_tool = tools.find { |c| c["id"] == "Search" }
|
||||||
|
@ -85,7 +87,9 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
||||||
)
|
)
|
||||||
|
|
||||||
expect(serializer_persona1["tools"]).to eq(["SearchCommand"])
|
expect(serializer_persona1["tools"]).to eq(["SearchCommand"])
|
||||||
expect(serializer_persona2["tools"]).to eq([["SearchCommand", { "base_query" => "test" }]])
|
expect(serializer_persona2["tools"]).to eq(
|
||||||
|
[["SearchCommand", { "base_query" => "test" }, true]],
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
context "with translations" do
|
context "with translations" do
|
||||||
|
@ -165,6 +169,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
||||||
temperature: 0.5,
|
temperature: 0.5,
|
||||||
mentionable: true,
|
mentionable: true,
|
||||||
default_llm: "anthropic:claude-2",
|
default_llm: "anthropic:claude-2",
|
||||||
|
forced_tool_count: 2,
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -183,6 +188,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
||||||
expect(persona_json["temperature"]).to eq(0.5)
|
expect(persona_json["temperature"]).to eq(0.5)
|
||||||
expect(persona_json["mentionable"]).to eq(true)
|
expect(persona_json["mentionable"]).to eq(true)
|
||||||
expect(persona_json["default_llm"]).to eq("anthropic:claude-2")
|
expect(persona_json["default_llm"]).to eq("anthropic:claude-2")
|
||||||
|
expect(persona_json["forced_tool_count"]).to eq(2)
|
||||||
|
|
||||||
persona = AiPersona.find(persona_json["id"])
|
persona = AiPersona.find(persona_json["id"])
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,17 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do
|
||||||
tool_selector = PageObjects::Components::SelectKit.new(".ai-persona-editor__tools")
|
tool_selector = PageObjects::Components::SelectKit.new(".ai-persona-editor__tools")
|
||||||
tool_selector.expand
|
tool_selector.expand
|
||||||
tool_selector.select_row_by_value("Read")
|
tool_selector.select_row_by_value("Read")
|
||||||
|
tool_selector.collapse
|
||||||
|
|
||||||
|
tool_selector = PageObjects::Components::SelectKit.new(".ai-persona-editor__forced_tools")
|
||||||
|
tool_selector.expand
|
||||||
|
tool_selector.select_row_by_value("Read")
|
||||||
|
tool_selector.collapse
|
||||||
|
|
||||||
|
strategy_selector =
|
||||||
|
PageObjects::Components::SelectKit.new(".ai-persona-editor__forced_tool_strategy")
|
||||||
|
strategy_selector.expand
|
||||||
|
strategy_selector.select_row_by_value(1)
|
||||||
|
|
||||||
find(".ai-persona-editor__save").click()
|
find(".ai-persona-editor__save").click()
|
||||||
|
|
||||||
|
@ -30,7 +41,8 @@ RSpec.describe "Admin AI persona configuration", type: :system, js: true do
|
||||||
expect(persona.name).to eq("Test Persona")
|
expect(persona.name).to eq("Test Persona")
|
||||||
expect(persona.description).to eq("I am a test persona")
|
expect(persona.description).to eq("I am a test persona")
|
||||||
expect(persona.system_prompt).to eq("You are a helpful bot")
|
expect(persona.system_prompt).to eq("You are a helpful bot")
|
||||||
expect(persona.tools).to eq([["Read", { "read_private" => nil }, false]])
|
expect(persona.tools).to eq([["Read", { "read_private" => nil }, true]])
|
||||||
|
expect(persona.forced_tool_count).to eq(1)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "will not allow deletion or editing of system personas" do
|
it "will not allow deletion or editing of system personas" do
|
||||||
|
|
|
@ -51,6 +51,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
|
||||||
question_consolidator_llm: "Question Consolidator LLM",
|
question_consolidator_llm: "Question Consolidator LLM",
|
||||||
allow_chat: false,
|
allow_chat: false,
|
||||||
tool_details: true,
|
tool_details: true,
|
||||||
|
forced_tool_count: -1,
|
||||||
};
|
};
|
||||||
|
|
||||||
const aiPersona = AiPersona.create({ ...properties });
|
const aiPersona = AiPersona.create({ ...properties });
|
||||||
|
@ -92,6 +93,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
|
||||||
question_consolidator_llm: "Question Consolidator LLM",
|
question_consolidator_llm: "Question Consolidator LLM",
|
||||||
allow_chat: false,
|
allow_chat: false,
|
||||||
tool_details: true,
|
tool_details: true,
|
||||||
|
forced_tool_count: -1,
|
||||||
};
|
};
|
||||||
|
|
||||||
const aiPersona = AiPersona.create({ ...properties });
|
const aiPersona = AiPersona.create({ ...properties });
|
||||||
|
|
Loading…
Reference in New Issue