FEATURE: basic infrastructure for custom personas (#288)
- New AiPersona model which can store custom personas - Persona are restricted via group security - They can contain custom system messages - They can support a list of commands optionally To avoid expensive DB calls in the serializer a Multisite friendly Hash was introduced (which can be expired on transaction commit)
This commit is contained in:
parent
d0198c5c5b
commit
a4f419f54f
|
@ -0,0 +1,121 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class AiPersona < ActiveRecord::Base
|
||||
# places a hard limit, so per site we cache a maximum of 500 classes
|
||||
MAX_PERSONAS_PER_SITE = 500
|
||||
|
||||
class MultisiteHash
|
||||
def initialize(id)
|
||||
@hash = Hash.new { |h, k| h[k] = {} }
|
||||
@id = id
|
||||
|
||||
MessageBus.subscribe(channel_name) { |message| @hash[message.data] = {} }
|
||||
end
|
||||
|
||||
def channel_name
|
||||
"/multisite-hash-#{@id}"
|
||||
end
|
||||
|
||||
def current_db
|
||||
RailsMultisite::ConnectionManagement.current_db
|
||||
end
|
||||
|
||||
def [](key)
|
||||
@hash.dig(current_db, key)
|
||||
end
|
||||
|
||||
def []=(key, val)
|
||||
@hash[current_db][key] = val
|
||||
end
|
||||
|
||||
def flush!
|
||||
@hash[current_db] = {}
|
||||
MessageBus.publish(channel_name, current_db)
|
||||
end
|
||||
end
|
||||
|
||||
def self.persona_cache
|
||||
@persona_cache ||= MultisiteHash.new("persona_cache")
|
||||
end
|
||||
|
||||
def self.all_personas
|
||||
persona_cache[:value] ||= AiPersona
|
||||
.order(:name)
|
||||
.where(enabled: true)
|
||||
.all
|
||||
.limit(MAX_PERSONAS_PER_SITE)
|
||||
.map do |ai_persona|
|
||||
name = ai_persona.name
|
||||
description = ai_persona.description
|
||||
ai_persona_id = ai_persona.id
|
||||
allowed_group_ids = ai_persona.allowed_group_ids
|
||||
commands =
|
||||
ai_persona.commands.filter_map do |inner_name|
|
||||
begin
|
||||
("DiscourseAi::AiBot::Commands::#{inner_name}").constantize
|
||||
rescue StandardError
|
||||
nil
|
||||
end
|
||||
end
|
||||
|
||||
Class.new(DiscourseAi::AiBot::Personas::Persona) do
|
||||
define_singleton_method :name do
|
||||
name
|
||||
end
|
||||
|
||||
define_singleton_method :description do
|
||||
description
|
||||
end
|
||||
|
||||
define_singleton_method :allowed_group_ids do
|
||||
allowed_group_ids
|
||||
end
|
||||
|
||||
define_singleton_method :to_s do
|
||||
"#<DiscourseAi::AiBot::Personas::Persona::Custom @name=#{self.name} @allowed_group_ids=#{self.allowed_group_ids.join(",")}>"
|
||||
end
|
||||
|
||||
define_singleton_method :inspect do
|
||||
"#<DiscourseAi::AiBot::Personas::Persona::Custom @name=#{self.name} @allowed_group_ids=#{self.allowed_group_ids.join(",")}>"
|
||||
end
|
||||
|
||||
define_method :initialize do |*args, **kwargs|
|
||||
@ai_persona = AiPersona.find_by(id: ai_persona_id)
|
||||
super(*args, **kwargs)
|
||||
end
|
||||
|
||||
define_method :commands do
|
||||
commands
|
||||
end
|
||||
|
||||
define_method :system_prompt do
|
||||
@ai_persona&.system_prompt || "You are a helpful bot."
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
after_commit :bump_cache
|
||||
|
||||
def bump_cache
|
||||
self.class.persona_cache.flush!
|
||||
end
|
||||
end
|
||||
|
||||
# == Schema Information
|
||||
#
|
||||
# Table name: ai_personas
|
||||
#
|
||||
# id :bigint not null, primary key
|
||||
# name :string(100) not null
|
||||
# description :string(2000) not null
|
||||
# commands :string default([]), not null, is an Array
|
||||
# system_prompt :string not null
|
||||
# allowed_group_ids :integer default([]), not null, is an Array
|
||||
# created_at :datetime not null
|
||||
# updated_at :datetime not null
|
||||
#
|
||||
# Indexes
|
||||
#
|
||||
# index_ai_personas_on_name (name) UNIQUE
|
||||
#
|
|
@ -0,0 +1,18 @@
|
|||
# frozen_string_literal: true
|
||||
#
|
||||
class CreateAiPersonas < ActiveRecord::Migration[7.0]
|
||||
def change
|
||||
create_table :ai_personas do |t|
|
||||
t.string :name, null: false, unique: true, limit: 100
|
||||
t.string :description, null: false, limit: 2000
|
||||
t.string :commands, array: true, default: [], null: false
|
||||
t.string :system_prompt, null: false, limit: 10_000_000
|
||||
t.integer :allowed_group_ids, array: true, default: [], null: false
|
||||
t.integer :created_by_id
|
||||
t.boolean :enabled, default: true, null: false
|
||||
t.timestamps
|
||||
end
|
||||
|
||||
add_index :ai_personas, :name, unique: true
|
||||
end
|
||||
end
|
|
@ -116,7 +116,9 @@ module DiscourseAi
|
|||
@persona = DiscourseAi::AiBot::Personas::General.new(allow_commands: allow_commands)
|
||||
if persona_name = post.topic.custom_fields["ai_persona"]
|
||||
persona_class =
|
||||
DiscourseAi::AiBot::Personas.all.find { |current| current.name == persona_name }
|
||||
DiscourseAi::AiBot::Personas
|
||||
.all(user: post.user)
|
||||
.find { |current| current.name == persona_name }
|
||||
@persona = persona_class.new(allow_commands: allow_commands) if persona_class
|
||||
end
|
||||
|
||||
|
|
|
@ -74,7 +74,9 @@ module DiscourseAi
|
|||
scope.user.in_any_groups?(SiteSetting.ai_bot_allowed_groups_map)
|
||||
end,
|
||||
) do
|
||||
Personas.all.map { |persona| { name: persona.name, description: persona.description } }
|
||||
Personas
|
||||
.all(user: scope.user)
|
||||
.map { |persona| { name: persona.name, description: persona.description } }
|
||||
end
|
||||
|
||||
plugin.add_to_serializer(
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
module DiscourseAi
|
||||
module AiBot
|
||||
module Personas
|
||||
def self.all
|
||||
def self.all(user: nil)
|
||||
personas = [Personas::General, Personas::SqlHelper]
|
||||
personas << Personas::Artist if SiteSetting.ai_stability_api_key.present?
|
||||
personas << Personas::SettingsExplorer
|
||||
|
@ -11,7 +11,20 @@ module DiscourseAi
|
|||
personas << Personas::Creative
|
||||
|
||||
personas_allowed = SiteSetting.ai_bot_enabled_personas.split("|")
|
||||
personas.filter { |persona| personas_allowed.include?(persona.to_s.demodulize.underscore) }
|
||||
personas =
|
||||
personas.filter do |persona|
|
||||
personas_allowed.include?(persona.to_s.demodulize.underscore)
|
||||
end
|
||||
|
||||
if user
|
||||
personas.concat(
|
||||
AiPersona.all_personas.filter do |persona|
|
||||
user.in_any_groups?(persona.allowed_group_ids)
|
||||
end,
|
||||
)
|
||||
end
|
||||
|
||||
personas
|
||||
end
|
||||
|
||||
class Persona
|
||||
|
|
|
@ -34,6 +34,8 @@ module DiscourseAi::AiBot::Personas
|
|||
topic
|
||||
end
|
||||
|
||||
fab!(:user) { Fabricate(:user) }
|
||||
|
||||
it "can disable commands via constructor" do
|
||||
persona = TestPersona.new(allow_commands: false)
|
||||
|
||||
|
@ -72,6 +74,49 @@ module DiscourseAi::AiBot::Personas
|
|||
expect(rendered).not_to include("!tags")
|
||||
end
|
||||
|
||||
describe "custom personas" do
|
||||
it "is able to find custom personas" do
|
||||
Group.refresh_automatic_groups!
|
||||
|
||||
# define an ai persona everyone can see
|
||||
persona =
|
||||
AiPersona.create!(
|
||||
name: "pun_bot",
|
||||
description: "you write puns",
|
||||
system_prompt: "you are pun bot",
|
||||
commands: ["ImageCommand"],
|
||||
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
|
||||
)
|
||||
|
||||
custom_persona = DiscourseAi::AiBot::Personas.all(user: user).last
|
||||
expect(custom_persona.name).to eq("pun_bot")
|
||||
expect(custom_persona.description).to eq("you write puns")
|
||||
|
||||
instance = custom_persona.new
|
||||
expect(instance.commands).to eq([DiscourseAi::AiBot::Commands::ImageCommand])
|
||||
expect(instance.render_system_prompt(render_function_instructions: true)).to eq(
|
||||
"you are pun bot",
|
||||
)
|
||||
|
||||
# should update
|
||||
persona.update!(name: "pun_bot2")
|
||||
custom_persona = DiscourseAi::AiBot::Personas.all(user: user).last
|
||||
expect(custom_persona.name).to eq("pun_bot2")
|
||||
|
||||
# can be disabled
|
||||
persona.update!(enabled: false)
|
||||
last_persona = DiscourseAi::AiBot::Personas.all(user: user).last
|
||||
expect(last_persona.name).not_to eq("pun_bot2")
|
||||
|
||||
persona.update!(enabled: true)
|
||||
# no groups have access
|
||||
persona.update!(allowed_group_ids: [])
|
||||
|
||||
last_persona = DiscourseAi::AiBot::Personas.all(user: user).last
|
||||
expect(last_persona.name).not_to eq("pun_bot2")
|
||||
end
|
||||
end
|
||||
|
||||
describe "available personas" do
|
||||
it "includes all personas by default" do
|
||||
# must be enabled to see it
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe AiPersona do
|
||||
it "does not leak caches between sites" do
|
||||
AiPersona.create!(
|
||||
name: "pun_bot",
|
||||
description: "you write puns",
|
||||
system_prompt: "you are pun bot",
|
||||
commands: ["ImageCommand"],
|
||||
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
|
||||
)
|
||||
|
||||
AiPersona.all_personas
|
||||
|
||||
expect(AiPersona.persona_cache[:value].length).to eq(1)
|
||||
RailsMultisite::ConnectionManagement.stubs(:current_db) { "abc" }
|
||||
expect(AiPersona.persona_cache[:value]).to eq(nil)
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue