From a4f419f54f6e466c945dd3de6322809e6d7a67e0 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 10 Nov 2023 11:39:49 +1100 Subject: [PATCH] FEATURE: basic infrastructure for custom personas (#288) - New AiPersona model which can store custom personas - Persona are restricted via group security - They can contain custom system messages - They can support a list of commands optionally To avoid expensive DB calls in the serializer a Multisite friendly Hash was introduced (which can be expired on transaction commit) --- app/models/ai_persona.rb | 121 ++++++++++++++++++ .../20231109011155_create_ai_personas.rb | 18 +++ lib/modules/ai_bot/bot.rb | 4 +- lib/modules/ai_bot/entry_point.rb | 4 +- lib/modules/ai_bot/personas/persona.rb | 17 ++- .../modules/ai_bot/personas/persona_spec.rb | 45 +++++++ spec/models/ai_persona_spec.rb | 19 +++ 7 files changed, 224 insertions(+), 4 deletions(-) create mode 100644 app/models/ai_persona.rb create mode 100644 db/migrate/20231109011155_create_ai_personas.rb create mode 100644 spec/models/ai_persona_spec.rb diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb new file mode 100644 index 00000000..5d19fee3 --- /dev/null +++ b/app/models/ai_persona.rb @@ -0,0 +1,121 @@ +# frozen_string_literal: true + +class AiPersona < ActiveRecord::Base + # places a hard limit, so per site we cache a maximum of 500 classes + MAX_PERSONAS_PER_SITE = 500 + + class MultisiteHash + def initialize(id) + @hash = Hash.new { |h, k| h[k] = {} } + @id = id + + MessageBus.subscribe(channel_name) { |message| @hash[message.data] = {} } + end + + def channel_name + "/multisite-hash-#{@id}" + end + + def current_db + RailsMultisite::ConnectionManagement.current_db + end + + def [](key) + @hash.dig(current_db, key) + end + + def []=(key, val) + @hash[current_db][key] = val + end + + def flush! + @hash[current_db] = {} + MessageBus.publish(channel_name, current_db) + end + end + + def self.persona_cache + @persona_cache ||= MultisiteHash.new("persona_cache") + end + + def self.all_personas + persona_cache[:value] ||= AiPersona + .order(:name) + .where(enabled: true) + .all + .limit(MAX_PERSONAS_PER_SITE) + .map do |ai_persona| + name = ai_persona.name + description = ai_persona.description + ai_persona_id = ai_persona.id + allowed_group_ids = ai_persona.allowed_group_ids + commands = + ai_persona.commands.filter_map do |inner_name| + begin + ("DiscourseAi::AiBot::Commands::#{inner_name}").constantize + rescue StandardError + nil + end + end + + Class.new(DiscourseAi::AiBot::Personas::Persona) do + define_singleton_method :name do + name + end + + define_singleton_method :description do + description + end + + define_singleton_method :allowed_group_ids do + allowed_group_ids + end + + define_singleton_method :to_s do + "#" + end + + define_singleton_method :inspect do + "#" + end + + define_method :initialize do |*args, **kwargs| + @ai_persona = AiPersona.find_by(id: ai_persona_id) + super(*args, **kwargs) + end + + define_method :commands do + commands + end + + define_method :system_prompt do + @ai_persona&.system_prompt || "You are a helpful bot." + end + end + end + end + + after_commit :bump_cache + + def bump_cache + self.class.persona_cache.flush! + end +end + +# == Schema Information +# +# Table name: ai_personas +# +# id :bigint not null, primary key +# name :string(100) not null +# description :string(2000) not null +# commands :string default([]), not null, is an Array +# system_prompt :string not null +# allowed_group_ids :integer default([]), not null, is an Array +# created_at :datetime not null +# updated_at :datetime not null +# +# Indexes +# +# index_ai_personas_on_name (name) UNIQUE +# diff --git a/db/migrate/20231109011155_create_ai_personas.rb b/db/migrate/20231109011155_create_ai_personas.rb new file mode 100644 index 00000000..da89d3ab --- /dev/null +++ b/db/migrate/20231109011155_create_ai_personas.rb @@ -0,0 +1,18 @@ +# frozen_string_literal: true +# +class CreateAiPersonas < ActiveRecord::Migration[7.0] + def change + create_table :ai_personas do |t| + t.string :name, null: false, unique: true, limit: 100 + t.string :description, null: false, limit: 2000 + t.string :commands, array: true, default: [], null: false + t.string :system_prompt, null: false, limit: 10_000_000 + t.integer :allowed_group_ids, array: true, default: [], null: false + t.integer :created_by_id + t.boolean :enabled, default: true, null: false + t.timestamps + end + + add_index :ai_personas, :name, unique: true + end +end diff --git a/lib/modules/ai_bot/bot.rb b/lib/modules/ai_bot/bot.rb index 52594b9b..86914a84 100644 --- a/lib/modules/ai_bot/bot.rb +++ b/lib/modules/ai_bot/bot.rb @@ -116,7 +116,9 @@ module DiscourseAi @persona = DiscourseAi::AiBot::Personas::General.new(allow_commands: allow_commands) if persona_name = post.topic.custom_fields["ai_persona"] persona_class = - DiscourseAi::AiBot::Personas.all.find { |current| current.name == persona_name } + DiscourseAi::AiBot::Personas + .all(user: post.user) + .find { |current| current.name == persona_name } @persona = persona_class.new(allow_commands: allow_commands) if persona_class end diff --git a/lib/modules/ai_bot/entry_point.rb b/lib/modules/ai_bot/entry_point.rb index 660d908a..b7bc227f 100644 --- a/lib/modules/ai_bot/entry_point.rb +++ b/lib/modules/ai_bot/entry_point.rb @@ -74,7 +74,9 @@ module DiscourseAi scope.user.in_any_groups?(SiteSetting.ai_bot_allowed_groups_map) end, ) do - Personas.all.map { |persona| { name: persona.name, description: persona.description } } + Personas + .all(user: scope.user) + .map { |persona| { name: persona.name, description: persona.description } } end plugin.add_to_serializer( diff --git a/lib/modules/ai_bot/personas/persona.rb b/lib/modules/ai_bot/personas/persona.rb index 3d40a5dd..e8687695 100644 --- a/lib/modules/ai_bot/personas/persona.rb +++ b/lib/modules/ai_bot/personas/persona.rb @@ -3,7 +3,7 @@ module DiscourseAi module AiBot module Personas - def self.all + def self.all(user: nil) personas = [Personas::General, Personas::SqlHelper] personas << Personas::Artist if SiteSetting.ai_stability_api_key.present? personas << Personas::SettingsExplorer @@ -11,7 +11,20 @@ module DiscourseAi personas << Personas::Creative personas_allowed = SiteSetting.ai_bot_enabled_personas.split("|") - personas.filter { |persona| personas_allowed.include?(persona.to_s.demodulize.underscore) } + personas = + personas.filter do |persona| + personas_allowed.include?(persona.to_s.demodulize.underscore) + end + + if user + personas.concat( + AiPersona.all_personas.filter do |persona| + user.in_any_groups?(persona.allowed_group_ids) + end, + ) + end + + personas end class Persona diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb index 18b84e4c..677a713d 100644 --- a/spec/lib/modules/ai_bot/personas/persona_spec.rb +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -34,6 +34,8 @@ module DiscourseAi::AiBot::Personas topic end + fab!(:user) { Fabricate(:user) } + it "can disable commands via constructor" do persona = TestPersona.new(allow_commands: false) @@ -72,6 +74,49 @@ module DiscourseAi::AiBot::Personas expect(rendered).not_to include("!tags") end + describe "custom personas" do + it "is able to find custom personas" do + Group.refresh_automatic_groups! + + # define an ai persona everyone can see + persona = + AiPersona.create!( + name: "pun_bot", + description: "you write puns", + system_prompt: "you are pun bot", + commands: ["ImageCommand"], + allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], + ) + + custom_persona = DiscourseAi::AiBot::Personas.all(user: user).last + expect(custom_persona.name).to eq("pun_bot") + expect(custom_persona.description).to eq("you write puns") + + instance = custom_persona.new + expect(instance.commands).to eq([DiscourseAi::AiBot::Commands::ImageCommand]) + expect(instance.render_system_prompt(render_function_instructions: true)).to eq( + "you are pun bot", + ) + + # should update + persona.update!(name: "pun_bot2") + custom_persona = DiscourseAi::AiBot::Personas.all(user: user).last + expect(custom_persona.name).to eq("pun_bot2") + + # can be disabled + persona.update!(enabled: false) + last_persona = DiscourseAi::AiBot::Personas.all(user: user).last + expect(last_persona.name).not_to eq("pun_bot2") + + persona.update!(enabled: true) + # no groups have access + persona.update!(allowed_group_ids: []) + + last_persona = DiscourseAi::AiBot::Personas.all(user: user).last + expect(last_persona.name).not_to eq("pun_bot2") + end + end + describe "available personas" do it "includes all personas by default" do # must be enabled to see it diff --git a/spec/models/ai_persona_spec.rb b/spec/models/ai_persona_spec.rb new file mode 100644 index 00000000..eb3d6f68 --- /dev/null +++ b/spec/models/ai_persona_spec.rb @@ -0,0 +1,19 @@ +# frozen_string_literal: true + +RSpec.describe AiPersona do + it "does not leak caches between sites" do + AiPersona.create!( + name: "pun_bot", + description: "you write puns", + system_prompt: "you are pun bot", + commands: ["ImageCommand"], + allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], + ) + + AiPersona.all_personas + + expect(AiPersona.persona_cache[:value].length).to eq(1) + RailsMultisite::ConnectionManagement.stubs(:current_db) { "abc" } + expect(AiPersona.persona_cache[:value]).to eq(nil) + end +end