FEATURE: basic infrastructure for custom personas (#288)

- New AiPersona model which can store custom personas
- Persona are restricted via group security
- They can contain custom system messages
- They can support a list of commands optionally

To avoid expensive DB calls in the serializer a Multisite friendly Hash was introduced (which can be expired on transaction commit)
This commit is contained in:
Sam 2023-11-10 11:39:49 +11:00 committed by GitHub
parent d0198c5c5b
commit a4f419f54f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 224 additions and 4 deletions

121
app/models/ai_persona.rb Normal file
View File

@ -0,0 +1,121 @@
# frozen_string_literal: true
class AiPersona < ActiveRecord::Base
# places a hard limit, so per site we cache a maximum of 500 classes
MAX_PERSONAS_PER_SITE = 500
class MultisiteHash
def initialize(id)
@hash = Hash.new { |h, k| h[k] = {} }
@id = id
MessageBus.subscribe(channel_name) { |message| @hash[message.data] = {} }
end
def channel_name
"/multisite-hash-#{@id}"
end
def current_db
RailsMultisite::ConnectionManagement.current_db
end
def [](key)
@hash.dig(current_db, key)
end
def []=(key, val)
@hash[current_db][key] = val
end
def flush!
@hash[current_db] = {}
MessageBus.publish(channel_name, current_db)
end
end
def self.persona_cache
@persona_cache ||= MultisiteHash.new("persona_cache")
end
def self.all_personas
persona_cache[:value] ||= AiPersona
.order(:name)
.where(enabled: true)
.all
.limit(MAX_PERSONAS_PER_SITE)
.map do |ai_persona|
name = ai_persona.name
description = ai_persona.description
ai_persona_id = ai_persona.id
allowed_group_ids = ai_persona.allowed_group_ids
commands =
ai_persona.commands.filter_map do |inner_name|
begin
("DiscourseAi::AiBot::Commands::#{inner_name}").constantize
rescue StandardError
nil
end
end
Class.new(DiscourseAi::AiBot::Personas::Persona) do
define_singleton_method :name do
name
end
define_singleton_method :description do
description
end
define_singleton_method :allowed_group_ids do
allowed_group_ids
end
define_singleton_method :to_s do
"#<DiscourseAi::AiBot::Personas::Persona::Custom @name=#{self.name} @allowed_group_ids=#{self.allowed_group_ids.join(",")}>"
end
define_singleton_method :inspect do
"#<DiscourseAi::AiBot::Personas::Persona::Custom @name=#{self.name} @allowed_group_ids=#{self.allowed_group_ids.join(",")}>"
end
define_method :initialize do |*args, **kwargs|
@ai_persona = AiPersona.find_by(id: ai_persona_id)
super(*args, **kwargs)
end
define_method :commands do
commands
end
define_method :system_prompt do
@ai_persona&.system_prompt || "You are a helpful bot."
end
end
end
end
after_commit :bump_cache
def bump_cache
self.class.persona_cache.flush!
end
end
# == Schema Information
#
# Table name: ai_personas
#
# id :bigint not null, primary key
# name :string(100) not null
# description :string(2000) not null
# commands :string default([]), not null, is an Array
# system_prompt :string not null
# allowed_group_ids :integer default([]), not null, is an Array
# created_at :datetime not null
# updated_at :datetime not null
#
# Indexes
#
# index_ai_personas_on_name (name) UNIQUE
#

View File

@ -0,0 +1,18 @@
# frozen_string_literal: true
#
class CreateAiPersonas < ActiveRecord::Migration[7.0]
def change
create_table :ai_personas do |t|
t.string :name, null: false, unique: true, limit: 100
t.string :description, null: false, limit: 2000
t.string :commands, array: true, default: [], null: false
t.string :system_prompt, null: false, limit: 10_000_000
t.integer :allowed_group_ids, array: true, default: [], null: false
t.integer :created_by_id
t.boolean :enabled, default: true, null: false
t.timestamps
end
add_index :ai_personas, :name, unique: true
end
end

View File

@ -116,7 +116,9 @@ module DiscourseAi
@persona = DiscourseAi::AiBot::Personas::General.new(allow_commands: allow_commands) @persona = DiscourseAi::AiBot::Personas::General.new(allow_commands: allow_commands)
if persona_name = post.topic.custom_fields["ai_persona"] if persona_name = post.topic.custom_fields["ai_persona"]
persona_class = persona_class =
DiscourseAi::AiBot::Personas.all.find { |current| current.name == persona_name } DiscourseAi::AiBot::Personas
.all(user: post.user)
.find { |current| current.name == persona_name }
@persona = persona_class.new(allow_commands: allow_commands) if persona_class @persona = persona_class.new(allow_commands: allow_commands) if persona_class
end end

View File

@ -74,7 +74,9 @@ module DiscourseAi
scope.user.in_any_groups?(SiteSetting.ai_bot_allowed_groups_map) scope.user.in_any_groups?(SiteSetting.ai_bot_allowed_groups_map)
end, end,
) do ) do
Personas.all.map { |persona| { name: persona.name, description: persona.description } } Personas
.all(user: scope.user)
.map { |persona| { name: persona.name, description: persona.description } }
end end
plugin.add_to_serializer( plugin.add_to_serializer(

View File

@ -3,7 +3,7 @@
module DiscourseAi module DiscourseAi
module AiBot module AiBot
module Personas module Personas
def self.all def self.all(user: nil)
personas = [Personas::General, Personas::SqlHelper] personas = [Personas::General, Personas::SqlHelper]
personas << Personas::Artist if SiteSetting.ai_stability_api_key.present? personas << Personas::Artist if SiteSetting.ai_stability_api_key.present?
personas << Personas::SettingsExplorer personas << Personas::SettingsExplorer
@ -11,7 +11,20 @@ module DiscourseAi
personas << Personas::Creative personas << Personas::Creative
personas_allowed = SiteSetting.ai_bot_enabled_personas.split("|") personas_allowed = SiteSetting.ai_bot_enabled_personas.split("|")
personas.filter { |persona| personas_allowed.include?(persona.to_s.demodulize.underscore) } personas =
personas.filter do |persona|
personas_allowed.include?(persona.to_s.demodulize.underscore)
end
if user
personas.concat(
AiPersona.all_personas.filter do |persona|
user.in_any_groups?(persona.allowed_group_ids)
end,
)
end
personas
end end
class Persona class Persona

View File

@ -34,6 +34,8 @@ module DiscourseAi::AiBot::Personas
topic topic
end end
fab!(:user) { Fabricate(:user) }
it "can disable commands via constructor" do it "can disable commands via constructor" do
persona = TestPersona.new(allow_commands: false) persona = TestPersona.new(allow_commands: false)
@ -72,6 +74,49 @@ module DiscourseAi::AiBot::Personas
expect(rendered).not_to include("!tags") expect(rendered).not_to include("!tags")
end end
describe "custom personas" do
it "is able to find custom personas" do
Group.refresh_automatic_groups!
# define an ai persona everyone can see
persona =
AiPersona.create!(
name: "pun_bot",
description: "you write puns",
system_prompt: "you are pun bot",
commands: ["ImageCommand"],
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
)
custom_persona = DiscourseAi::AiBot::Personas.all(user: user).last
expect(custom_persona.name).to eq("pun_bot")
expect(custom_persona.description).to eq("you write puns")
instance = custom_persona.new
expect(instance.commands).to eq([DiscourseAi::AiBot::Commands::ImageCommand])
expect(instance.render_system_prompt(render_function_instructions: true)).to eq(
"you are pun bot",
)
# should update
persona.update!(name: "pun_bot2")
custom_persona = DiscourseAi::AiBot::Personas.all(user: user).last
expect(custom_persona.name).to eq("pun_bot2")
# can be disabled
persona.update!(enabled: false)
last_persona = DiscourseAi::AiBot::Personas.all(user: user).last
expect(last_persona.name).not_to eq("pun_bot2")
persona.update!(enabled: true)
# no groups have access
persona.update!(allowed_group_ids: [])
last_persona = DiscourseAi::AiBot::Personas.all(user: user).last
expect(last_persona.name).not_to eq("pun_bot2")
end
end
describe "available personas" do describe "available personas" do
it "includes all personas by default" do it "includes all personas by default" do
# must be enabled to see it # must be enabled to see it

View File

@ -0,0 +1,19 @@
# frozen_string_literal: true
RSpec.describe AiPersona do
it "does not leak caches between sites" do
AiPersona.create!(
name: "pun_bot",
description: "you write puns",
system_prompt: "you are pun bot",
commands: ["ImageCommand"],
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
)
AiPersona.all_personas
expect(AiPersona.persona_cache[:value].length).to eq(1)
RailsMultisite::ConnectionManagement.stubs(:current_db) { "abc" }
expect(AiPersona.persona_cache[:value]).to eq(nil)
end
end