FEATURE: basic infrastructure for custom personas (#288)
- New AiPersona model which can store custom personas - Persona are restricted via group security - They can contain custom system messages - They can support a list of commands optionally To avoid expensive DB calls in the serializer a Multisite friendly Hash was introduced (which can be expired on transaction commit)
This commit is contained in:
parent
d0198c5c5b
commit
a4f419f54f
|
@ -0,0 +1,121 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class AiPersona < ActiveRecord::Base
|
||||||
|
# places a hard limit, so per site we cache a maximum of 500 classes
|
||||||
|
MAX_PERSONAS_PER_SITE = 500
|
||||||
|
|
||||||
|
class MultisiteHash
|
||||||
|
def initialize(id)
|
||||||
|
@hash = Hash.new { |h, k| h[k] = {} }
|
||||||
|
@id = id
|
||||||
|
|
||||||
|
MessageBus.subscribe(channel_name) { |message| @hash[message.data] = {} }
|
||||||
|
end
|
||||||
|
|
||||||
|
def channel_name
|
||||||
|
"/multisite-hash-#{@id}"
|
||||||
|
end
|
||||||
|
|
||||||
|
def current_db
|
||||||
|
RailsMultisite::ConnectionManagement.current_db
|
||||||
|
end
|
||||||
|
|
||||||
|
def [](key)
|
||||||
|
@hash.dig(current_db, key)
|
||||||
|
end
|
||||||
|
|
||||||
|
def []=(key, val)
|
||||||
|
@hash[current_db][key] = val
|
||||||
|
end
|
||||||
|
|
||||||
|
def flush!
|
||||||
|
@hash[current_db] = {}
|
||||||
|
MessageBus.publish(channel_name, current_db)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.persona_cache
|
||||||
|
@persona_cache ||= MultisiteHash.new("persona_cache")
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.all_personas
|
||||||
|
persona_cache[:value] ||= AiPersona
|
||||||
|
.order(:name)
|
||||||
|
.where(enabled: true)
|
||||||
|
.all
|
||||||
|
.limit(MAX_PERSONAS_PER_SITE)
|
||||||
|
.map do |ai_persona|
|
||||||
|
name = ai_persona.name
|
||||||
|
description = ai_persona.description
|
||||||
|
ai_persona_id = ai_persona.id
|
||||||
|
allowed_group_ids = ai_persona.allowed_group_ids
|
||||||
|
commands =
|
||||||
|
ai_persona.commands.filter_map do |inner_name|
|
||||||
|
begin
|
||||||
|
("DiscourseAi::AiBot::Commands::#{inner_name}").constantize
|
||||||
|
rescue StandardError
|
||||||
|
nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
Class.new(DiscourseAi::AiBot::Personas::Persona) do
|
||||||
|
define_singleton_method :name do
|
||||||
|
name
|
||||||
|
end
|
||||||
|
|
||||||
|
define_singleton_method :description do
|
||||||
|
description
|
||||||
|
end
|
||||||
|
|
||||||
|
define_singleton_method :allowed_group_ids do
|
||||||
|
allowed_group_ids
|
||||||
|
end
|
||||||
|
|
||||||
|
define_singleton_method :to_s do
|
||||||
|
"#<DiscourseAi::AiBot::Personas::Persona::Custom @name=#{self.name} @allowed_group_ids=#{self.allowed_group_ids.join(",")}>"
|
||||||
|
end
|
||||||
|
|
||||||
|
define_singleton_method :inspect do
|
||||||
|
"#<DiscourseAi::AiBot::Personas::Persona::Custom @name=#{self.name} @allowed_group_ids=#{self.allowed_group_ids.join(",")}>"
|
||||||
|
end
|
||||||
|
|
||||||
|
define_method :initialize do |*args, **kwargs|
|
||||||
|
@ai_persona = AiPersona.find_by(id: ai_persona_id)
|
||||||
|
super(*args, **kwargs)
|
||||||
|
end
|
||||||
|
|
||||||
|
define_method :commands do
|
||||||
|
commands
|
||||||
|
end
|
||||||
|
|
||||||
|
define_method :system_prompt do
|
||||||
|
@ai_persona&.system_prompt || "You are a helpful bot."
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
after_commit :bump_cache
|
||||||
|
|
||||||
|
def bump_cache
|
||||||
|
self.class.persona_cache.flush!
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# == Schema Information
|
||||||
|
#
|
||||||
|
# Table name: ai_personas
|
||||||
|
#
|
||||||
|
# id :bigint not null, primary key
|
||||||
|
# name :string(100) not null
|
||||||
|
# description :string(2000) not null
|
||||||
|
# commands :string default([]), not null, is an Array
|
||||||
|
# system_prompt :string not null
|
||||||
|
# allowed_group_ids :integer default([]), not null, is an Array
|
||||||
|
# created_at :datetime not null
|
||||||
|
# updated_at :datetime not null
|
||||||
|
#
|
||||||
|
# Indexes
|
||||||
|
#
|
||||||
|
# index_ai_personas_on_name (name) UNIQUE
|
||||||
|
#
|
|
@ -0,0 +1,18 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
#
|
||||||
|
class CreateAiPersonas < ActiveRecord::Migration[7.0]
|
||||||
|
def change
|
||||||
|
create_table :ai_personas do |t|
|
||||||
|
t.string :name, null: false, unique: true, limit: 100
|
||||||
|
t.string :description, null: false, limit: 2000
|
||||||
|
t.string :commands, array: true, default: [], null: false
|
||||||
|
t.string :system_prompt, null: false, limit: 10_000_000
|
||||||
|
t.integer :allowed_group_ids, array: true, default: [], null: false
|
||||||
|
t.integer :created_by_id
|
||||||
|
t.boolean :enabled, default: true, null: false
|
||||||
|
t.timestamps
|
||||||
|
end
|
||||||
|
|
||||||
|
add_index :ai_personas, :name, unique: true
|
||||||
|
end
|
||||||
|
end
|
|
@ -116,7 +116,9 @@ module DiscourseAi
|
||||||
@persona = DiscourseAi::AiBot::Personas::General.new(allow_commands: allow_commands)
|
@persona = DiscourseAi::AiBot::Personas::General.new(allow_commands: allow_commands)
|
||||||
if persona_name = post.topic.custom_fields["ai_persona"]
|
if persona_name = post.topic.custom_fields["ai_persona"]
|
||||||
persona_class =
|
persona_class =
|
||||||
DiscourseAi::AiBot::Personas.all.find { |current| current.name == persona_name }
|
DiscourseAi::AiBot::Personas
|
||||||
|
.all(user: post.user)
|
||||||
|
.find { |current| current.name == persona_name }
|
||||||
@persona = persona_class.new(allow_commands: allow_commands) if persona_class
|
@persona = persona_class.new(allow_commands: allow_commands) if persona_class
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -74,7 +74,9 @@ module DiscourseAi
|
||||||
scope.user.in_any_groups?(SiteSetting.ai_bot_allowed_groups_map)
|
scope.user.in_any_groups?(SiteSetting.ai_bot_allowed_groups_map)
|
||||||
end,
|
end,
|
||||||
) do
|
) do
|
||||||
Personas.all.map { |persona| { name: persona.name, description: persona.description } }
|
Personas
|
||||||
|
.all(user: scope.user)
|
||||||
|
.map { |persona| { name: persona.name, description: persona.description } }
|
||||||
end
|
end
|
||||||
|
|
||||||
plugin.add_to_serializer(
|
plugin.add_to_serializer(
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module AiBot
|
module AiBot
|
||||||
module Personas
|
module Personas
|
||||||
def self.all
|
def self.all(user: nil)
|
||||||
personas = [Personas::General, Personas::SqlHelper]
|
personas = [Personas::General, Personas::SqlHelper]
|
||||||
personas << Personas::Artist if SiteSetting.ai_stability_api_key.present?
|
personas << Personas::Artist if SiteSetting.ai_stability_api_key.present?
|
||||||
personas << Personas::SettingsExplorer
|
personas << Personas::SettingsExplorer
|
||||||
|
@ -11,7 +11,20 @@ module DiscourseAi
|
||||||
personas << Personas::Creative
|
personas << Personas::Creative
|
||||||
|
|
||||||
personas_allowed = SiteSetting.ai_bot_enabled_personas.split("|")
|
personas_allowed = SiteSetting.ai_bot_enabled_personas.split("|")
|
||||||
personas.filter { |persona| personas_allowed.include?(persona.to_s.demodulize.underscore) }
|
personas =
|
||||||
|
personas.filter do |persona|
|
||||||
|
personas_allowed.include?(persona.to_s.demodulize.underscore)
|
||||||
|
end
|
||||||
|
|
||||||
|
if user
|
||||||
|
personas.concat(
|
||||||
|
AiPersona.all_personas.filter do |persona|
|
||||||
|
user.in_any_groups?(persona.allowed_group_ids)
|
||||||
|
end,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
personas
|
||||||
end
|
end
|
||||||
|
|
||||||
class Persona
|
class Persona
|
||||||
|
|
|
@ -34,6 +34,8 @@ module DiscourseAi::AiBot::Personas
|
||||||
topic
|
topic
|
||||||
end
|
end
|
||||||
|
|
||||||
|
fab!(:user) { Fabricate(:user) }
|
||||||
|
|
||||||
it "can disable commands via constructor" do
|
it "can disable commands via constructor" do
|
||||||
persona = TestPersona.new(allow_commands: false)
|
persona = TestPersona.new(allow_commands: false)
|
||||||
|
|
||||||
|
@ -72,6 +74,49 @@ module DiscourseAi::AiBot::Personas
|
||||||
expect(rendered).not_to include("!tags")
|
expect(rendered).not_to include("!tags")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "custom personas" do
|
||||||
|
it "is able to find custom personas" do
|
||||||
|
Group.refresh_automatic_groups!
|
||||||
|
|
||||||
|
# define an ai persona everyone can see
|
||||||
|
persona =
|
||||||
|
AiPersona.create!(
|
||||||
|
name: "pun_bot",
|
||||||
|
description: "you write puns",
|
||||||
|
system_prompt: "you are pun bot",
|
||||||
|
commands: ["ImageCommand"],
|
||||||
|
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
|
||||||
|
)
|
||||||
|
|
||||||
|
custom_persona = DiscourseAi::AiBot::Personas.all(user: user).last
|
||||||
|
expect(custom_persona.name).to eq("pun_bot")
|
||||||
|
expect(custom_persona.description).to eq("you write puns")
|
||||||
|
|
||||||
|
instance = custom_persona.new
|
||||||
|
expect(instance.commands).to eq([DiscourseAi::AiBot::Commands::ImageCommand])
|
||||||
|
expect(instance.render_system_prompt(render_function_instructions: true)).to eq(
|
||||||
|
"you are pun bot",
|
||||||
|
)
|
||||||
|
|
||||||
|
# should update
|
||||||
|
persona.update!(name: "pun_bot2")
|
||||||
|
custom_persona = DiscourseAi::AiBot::Personas.all(user: user).last
|
||||||
|
expect(custom_persona.name).to eq("pun_bot2")
|
||||||
|
|
||||||
|
# can be disabled
|
||||||
|
persona.update!(enabled: false)
|
||||||
|
last_persona = DiscourseAi::AiBot::Personas.all(user: user).last
|
||||||
|
expect(last_persona.name).not_to eq("pun_bot2")
|
||||||
|
|
||||||
|
persona.update!(enabled: true)
|
||||||
|
# no groups have access
|
||||||
|
persona.update!(allowed_group_ids: [])
|
||||||
|
|
||||||
|
last_persona = DiscourseAi::AiBot::Personas.all(user: user).last
|
||||||
|
expect(last_persona.name).not_to eq("pun_bot2")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
describe "available personas" do
|
describe "available personas" do
|
||||||
it "includes all personas by default" do
|
it "includes all personas by default" do
|
||||||
# must be enabled to see it
|
# must be enabled to see it
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe AiPersona do
|
||||||
|
it "does not leak caches between sites" do
|
||||||
|
AiPersona.create!(
|
||||||
|
name: "pun_bot",
|
||||||
|
description: "you write puns",
|
||||||
|
system_prompt: "you are pun bot",
|
||||||
|
commands: ["ImageCommand"],
|
||||||
|
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
|
||||||
|
)
|
||||||
|
|
||||||
|
AiPersona.all_personas
|
||||||
|
|
||||||
|
expect(AiPersona.persona_cache[:value].length).to eq(1)
|
||||||
|
RailsMultisite::ConnectionManagement.stubs(:current_db) { "abc" }
|
||||||
|
expect(AiPersona.persona_cache[:value]).to eq(nil)
|
||||||
|
end
|
||||||
|
end
|
Loading…
Reference in New Issue