FEATURE: allow personas to supply top_p and temperature params (#459)

* FEATURE: allow personas to supply top_p and temperature params

Code assistance generally are more focused at a lower temperature
This amends it so SQL Helper runs at 0.2 temperature vs the more
common default across LLMs of 1.0.

Reduced temperature leads to more focused, concise and predictable
answers for the SQL Helper

* fix tests

* This is not perfect, but far better than what we do today

Instead of fishing for

1. Draft sequence
2. Draft body

We skip (2), this means the composer "only" needs 1 http request to
open, we also want to eliminate (1) but it is a bit of a trickier
core change, may figure out how to pull it off (defer it to first draft save)

Value of bot drafts < value of opening bot conversations really fast
This commit is contained in:
Sam 2024-02-03 07:09:34 +11:00 committed by GitHub
parent 944fd6569c
commit a3c827efcc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 186 additions and 8 deletions

View File

@ -62,6 +62,8 @@ module DiscourseAi
:enabled, :enabled,
:system_prompt, :system_prompt,
:priority, :priority,
:top_p,
:temperature,
allowed_group_ids: [], allowed_group_ids: [],
) )

View File

@ -157,6 +157,14 @@ class AiPersona < ActiveRecord::Base
options options
end end
define_method :temperature do
@ai_persona&.temperature
end
define_method :top_p do
@ai_persona&.top_p
end
define_method :system_prompt do define_method :system_prompt do
@ai_persona&.system_prompt || "You are a helpful bot." @ai_persona&.system_prompt || "You are a helpful bot."
end end
@ -166,7 +174,8 @@ class AiPersona < ActiveRecord::Base
private private
def system_persona_unchangeable def system_persona_unchangeable
if system_prompt_changed? || commands_changed? || name_changed? || description_changed? if top_p_changed? || temperature_changed? || system_prompt_changed? || commands_changed? ||
name_changed? || description_changed?
errors.add(:base, I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona")) errors.add(:base, I18n.t("discourse_ai.ai_bot.personas.cannot_edit_system_persona"))
end end
end end
@ -186,7 +195,7 @@ end
# id :bigint not null, primary key # id :bigint not null, primary key
# name :string(100) not null # name :string(100) not null
# description :string(2000) not null # description :string(2000) not null
# commands :string default([]), not null, is an Array # commands :json not null
# system_prompt :string(10000000) not null # system_prompt :string(10000000) not null
# allowed_group_ids :integer default([]), not null, is an Array # allowed_group_ids :integer default([]), not null, is an Array
# created_by_id :integer # created_by_id :integer
@ -194,7 +203,9 @@ end
# created_at :datetime not null # created_at :datetime not null
# updated_at :datetime not null # updated_at :datetime not null
# system :boolean default(FALSE), not null # system :boolean default(FALSE), not null
# priority :integer default(0), not null # priority :boolean default(FALSE), not null
# temperature :float
# top_p :float
# #
# Indexes # Indexes
# #

View File

@ -11,7 +11,9 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer
:priority, :priority,
:commands, :commands,
:system_prompt, :system_prompt,
:allowed_group_ids :allowed_group_ids,
:temperature,
:top_p
def name def name
object.class_instance.name object.class_instance.name

View File

@ -10,6 +10,8 @@ const ATTRIBUTES = [
"enabled", "enabled",
"system", "system",
"priority", "priority",
"top_p",
"temperature",
]; ];
class CommandOption { class CommandOption {

View File

@ -73,6 +73,14 @@ export default class PersonaEditor extends Component {
} }
} }
get showTemperature() {
return this.editingModel?.temperature || !this.editingModel?.system;
}
get showTopP() {
return this.editingModel?.top_p || !this.editingModel?.system;
}
@action @action
delete() { delete() {
return this.dialog.confirm({ return this.dialog.confirm({
@ -213,6 +221,36 @@ export default class PersonaEditor extends Component {
disabled={{this.editingModel.system}} disabled={{this.editingModel.system}}
/> />
</div> </div>
{{#if this.showTemperature}}
<div class="control-group">
<label>{{I18n.t "discourse_ai.ai_persona.temperature"}}</label>
<Input
@type="number"
class="ai-persona-editor__temperature"
@value={{this.editingModel.temperature}}
disabled={{this.editingModel.system}}
/>
<DTooltip
@icon="question-circle"
@content={{I18n.t "discourse_ai.ai_persona.temperature_help"}}
/>
</div>
{{/if}}
{{#if this.showTopP}}
<div class="control-group">
<label>{{I18n.t "discourse_ai.ai_persona.top_p"}}</label>
<Input
@type="number"
class="ai-persona-editor__top_p"
@value={{this.editingModel.top_p}}
disabled={{this.editingModel.system}}
/>
<DTooltip
@icon="question-circle"
@content={{I18n.t "discourse_ai.ai_persona.top_p_help"}}
/>
</div>
{{/if}}
<div class="control-group ai-persona-editor__action_panel"> <div class="control-group ai-persona-editor__action_panel">
<DButton <DButton
class="btn-primary ai-persona-editor__save" class="btn-primary ai-persona-editor__save"

View File

@ -15,9 +15,10 @@ export function composeAiBotMessage(targetBot, composer) {
recipients: botUsername, recipients: botUsername,
topicTitle: I18n.t("discourse_ai.ai_bot.default_pm_prefix"), topicTitle: I18n.t("discourse_ai.ai_bot.default_pm_prefix"),
archetypeId: "private_message", archetypeId: "private_message",
draftKey: Composer.NEW_PRIVATE_MESSAGE_KEY, draftKey: "private_message_ai",
hasGroups: false, hasGroups: false,
warningsDisabled: true, warningsDisabled: true,
skipDraftCheck: true,
}, },
}); });
} }

View File

@ -118,6 +118,10 @@ en:
new: New new: New
title: "AI Personas" title: "AI Personas"
delete: Delete delete: Delete
temperature: Temperature
temperature_help: Temperature to use for the LLM, increase to increase creativity (leave empty to use model default, generally a value from 0.0 to 2.0)
top_p: Top P
top_p_help: Top P to use for the LLM, increase to increase randomness (leave empty to use model default, generally a value from 0.0 to 1.0)
priority: Priority priority: Priority
priority_help: Priority personas are displayed to users at the top of the persona list. If multiple personas have priority, they will be sorted alphabetically. priority_help: Priority personas are displayed to users at the top of the persona list. If multiple personas have priority, they will be sorted alphabetically.
command_options: "Command Options" command_options: "Command Options"

View File

@ -34,5 +34,7 @@ DiscourseAi::AiBot::Personas::Persona.system_personas.each do |persona_class, id
instance = persona_class.new instance = persona_class.new
persona.commands = instance.tools.map { |tool| tool.to_s.split("::").last } persona.commands = instance.tools.map { |tool| tool.to_s.split("::").last }
persona.system_prompt = instance.system_prompt persona.system_prompt = instance.system_prompt
persona.top_p = instance.top_p
persona.temperature = instance.temperature
persona.save!(validate: false) persona.save!(validate: false)
end end

View File

@ -0,0 +1,8 @@
# frozen_string_literal: true
class AddTemperatureTopPToAiPersonas < ActiveRecord::Migration[7.0]
def change
add_column :ai_personas, :temperature, :float, null: true
add_column :ai_personas, :top_p, :float, null: true
end
end

View File

@ -48,13 +48,19 @@ module DiscourseAi
low_cost = false low_cost = false
raw_context = [] raw_context = []
user = context[:user]
llm_kwargs = { user: user }
llm_kwargs[:temperature] = persona.temperature if persona.temperature
llm_kwargs[:top_p] = persona.top_p if persona.top_p
while total_completions <= MAX_COMPLETIONS && ongoing_chain while total_completions <= MAX_COMPLETIONS && ongoing_chain
current_model = model(prefer_low_cost: low_cost) current_model = model(prefer_low_cost: low_cost)
llm = DiscourseAi::Completions::Llm.proxy(current_model) llm = DiscourseAi::Completions::Llm.proxy(current_model)
tool_found = false tool_found = false
result = result =
llm.generate(prompt, user: context[:user]) do |partial, cancel| llm.generate(prompt, **llm_kwargs) do |partial, cancel|
if (tool = persona.find_tool(partial)) if (tool = persona.find_tool(partial))
tool_found = true tool_found = true
ongoing_chain = tool.chain_next_response? ongoing_chain = tool.chain_next_response?

View File

@ -23,7 +23,6 @@ module DiscourseAi
def all(user:) def all(user:)
# listing tools has to be dynamic cause site settings may change # listing tools has to be dynamic cause site settings may change
AiPersona.all_personas.filter do |persona| AiPersona.all_personas.filter do |persona|
next false if !user.in_any_groups?(persona.allowed_group_ids) next false if !user.in_any_groups?(persona.allowed_group_ids)
@ -85,6 +84,14 @@ module DiscourseAi
[] []
end end
def temperature
nil
end
def top_p
nil
end
def options def options
{} {}
end end

View File

@ -31,6 +31,10 @@ module DiscourseAi
[Tools::DbSchema] [Tools::DbSchema]
end end
def temperature
0.2
end
def system_prompt def system_prompt
<<~PROMPT <<~PROMPT
You are a PostgreSQL expert. You are a PostgreSQL expert.

View File

@ -95,7 +95,17 @@ module DiscourseAi
@chunk_count = chunk_count @chunk_count = chunk_count
end end
def self.last_call
@last_call
end
def self.last_call=(params)
@last_call = params
end
def perform_completion!(dialect, user, model_params = {}) def perform_completion!(dialect, user, model_params = {})
self.class.last_call = { dialect: dialect, user: user, model_params: model_params }
content = self.class.fake_content content = self.class.fake_content
if block_given? if block_given?

View File

@ -3,6 +3,8 @@
RSpec.describe DiscourseAi::AiBot::Bot do RSpec.describe DiscourseAi::AiBot::Bot do
subject(:bot) { described_class.as(bot_user) } subject(:bot) { described_class.as(bot_user) }
fab!(:admin) { Fabricate(:admin) }
before do before do
SiteSetting.ai_bot_enabled_chat_bots = "gpt-4" SiteSetting.ai_bot_enabled_chat_bots = "gpt-4"
SiteSetting.ai_bot_enabled = true SiteSetting.ai_bot_enabled = true
@ -25,6 +27,40 @@ RSpec.describe DiscourseAi::AiBot::Bot do
let(:llm_responses) { [function_call, response] } let(:llm_responses) { [function_call, response] }
describe "#reply" do describe "#reply" do
it "sets top_p and temperature params" do
# full integration test so we have certainty it is passed through
DiscourseAi::Completions::Endpoints::Fake.delays = []
DiscourseAi::Completions::Endpoints::Fake.last_call = nil
SiteSetting.ai_bot_enabled_chat_bots = "fake"
SiteSetting.ai_bot_enabled = true
Group.refresh_automatic_groups!
bot_user = User.find(DiscourseAi::AiBot::EntryPoint::FAKE_ID)
AiPersona.create!(
name: "TestPersona",
top_p: 0.5,
temperature: 0.4,
system_prompt: "test",
description: "test",
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
)
personaClass = DiscourseAi::AiBot::Personas::Persona.find_by(user: admin, name: "TestPersona")
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: personaClass.new)
bot.reply(
{ conversation_context: [{ type: :user, content: "test" }] },
) do |_partial, _cancel, _placeholder|
# we just need the block so bot has something to call with results
end
last_call = DiscourseAi::Completions::Endpoints::Fake.last_call
expect(last_call[:model_params][:top_p]).to eq(0.5)
expect(last_call[:model_params][:temperature]).to eq(0.4)
end
context "when using function chaining" do context "when using function chaining" do
it "yields a loading placeholder while proceeds to invoke the command" do it "yields a loading placeholder while proceeds to invoke the command" do
tool = DiscourseAi::AiBot::Tools::ListCategories.new({}) tool = DiscourseAi::AiBot::Tools::ListCategories.new({})

View File

@ -116,6 +116,8 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
description: "Assists with tasks", description: "Assists with tasks",
system_prompt: "you are a helpful bot", system_prompt: "you are a helpful bot",
commands: [["search", { "base_query" => "test" }]], commands: [["search", { "base_query" => "test" }]],
top_p: 0.1,
temperature: 0.5,
} }
end end
@ -127,8 +129,17 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
"CONTENT_TYPE" => "application/json", "CONTENT_TYPE" => "application/json",
} }
expect(response).to be_successful expect(response).to be_successful
persona = AiPersona.find(response.parsed_body["ai_persona"]["id"]) persona_json = response.parsed_body["ai_persona"]
expect(persona_json["name"]).to eq("superbot")
expect(persona_json["top_p"]).to eq(0.1)
expect(persona_json["temperature"]).to eq(0.5)
persona = AiPersona.find(persona_json["id"])
expect(persona.commands).to eq([["search", { "base_query" => "test" }]]) expect(persona.commands).to eq([["search", { "base_query" => "test" }]])
expect(persona.top_p).to eq(0.1)
expect(persona.temperature).to eq(0.5)
}.to change(AiPersona, :count).by(1) }.to change(AiPersona, :count).by(1)
end end
end end
@ -143,6 +154,36 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
end end
describe "PUT #update" do describe "PUT #update" do
it "allows us to trivially clear top_p and temperature" do
persona = Fabricate(:ai_persona, name: "test_bot2", top_p: 0.5, temperature: 0.1)
put "/admin/plugins/discourse-ai/ai_personas/#{persona.id}.json",
params: {
ai_persona: {
top_p: "",
temperature: "",
},
}
expect(response).to have_http_status(:ok)
persona.reload
expect(persona.top_p).to eq(nil)
expect(persona.temperature).to eq(nil)
end
it "does not allow temperature and top p changes on stock personas" do
put "/admin/plugins/discourse-ai/ai_personas/#{DiscourseAi::AiBot::Personas::Persona.system_personas.values.first}.json",
params: {
ai_persona: {
top_p: 0.5,
temperature: 0.1,
},
}
expect(response).to have_http_status(:unprocessable_entity)
end
context "with valid params" do context "with valid params" do
it "updates the requested ai_persona" do it "updates the requested ai_persona" do
put "/admin/plugins/discourse-ai/ai_personas/#{ai_persona.id}.json", put "/admin/plugins/discourse-ai/ai_personas/#{ai_persona.id}.json",

View File

@ -39,6 +39,8 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
system_prompt: "System Prompt", system_prompt: "System Prompt",
priority: false, priority: false,
description: "Description", description: "Description",
top_p: 0.8,
temperature: 0.7,
}; };
const aiPersona = AiPersona.create({ ...properties }); const aiPersona = AiPersona.create({ ...properties });
@ -63,6 +65,8 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
system_prompt: "System Prompt", system_prompt: "System Prompt",
priority: false, priority: false,
description: "Description", description: "Description",
top_p: 0.8,
temperature: 0.7,
}; };
const aiPersona = AiPersona.create({ ...properties }); const aiPersona = AiPersona.create({ ...properties });