FEATURE: add initial support for personas (#172)
This splits out a bunch of code that used to live inside bots into a dedicated concept called a Persona. This allows us to start playing with multiple personas for the bot Ships with: artist - for making images sql helper - for helping with data explorer general - for everything and anything Also includes a few fixes that make the generic LLM function implementation more robust
This commit is contained in:
parent
6d69fb479e
commit
db19e37748
|
@ -3,10 +3,6 @@ import { inject as service } from "@ember/service";
|
||||||
import { computed } from "@ember/object";
|
import { computed } from "@ember/object";
|
||||||
|
|
||||||
export default class extends Component {
|
export default class extends Component {
|
||||||
static shouldRender() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@service currentUser;
|
@service currentUser;
|
||||||
@service siteSettings;
|
@service siteSettings;
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
<div class="gpt-persona">
|
||||||
|
<DropdownSelectBox
|
||||||
|
@value={{this.value}}
|
||||||
|
@content={{this.botOptions}}
|
||||||
|
@options={{hash icon="robot"}}
|
||||||
|
/>
|
||||||
|
</div>
|
|
@ -0,0 +1,53 @@
|
||||||
|
import Component from "@glimmer/component";
|
||||||
|
import { inject as service } from "@ember/service";
|
||||||
|
|
||||||
|
function isBotMessage(composer, currentUser) {
|
||||||
|
if (
|
||||||
|
composer &&
|
||||||
|
composer.targetRecipients &&
|
||||||
|
currentUser.ai_enabled_chat_bots
|
||||||
|
) {
|
||||||
|
const reciepients = composer.targetRecipients.split(",");
|
||||||
|
|
||||||
|
return currentUser.ai_enabled_chat_bots.any((bot) =>
|
||||||
|
reciepients.any((username) => username === bot.username)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default class BotSelector extends Component {
|
||||||
|
static shouldRender(args, container) {
|
||||||
|
return (
|
||||||
|
container?.currentUser?.ai_enabled_personas &&
|
||||||
|
isBotMessage(args.model, container.currentUser)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
@service currentUser;
|
||||||
|
|
||||||
|
get composer() {
|
||||||
|
return this.args?.outletArgs?.model;
|
||||||
|
}
|
||||||
|
|
||||||
|
get botOptions() {
|
||||||
|
if (this.currentUser.ai_enabled_personas) {
|
||||||
|
return this.currentUser.ai_enabled_personas.map((persona) => {
|
||||||
|
return {
|
||||||
|
id: persona.name,
|
||||||
|
name: persona.name,
|
||||||
|
description: persona.description,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
get value() {
|
||||||
|
return this._value || this.botOptions[0].id;
|
||||||
|
}
|
||||||
|
|
||||||
|
set value(val) {
|
||||||
|
this._value = val;
|
||||||
|
this.composer.metaData = { ai_persona: val };
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,6 +4,8 @@ import { ajax } from "discourse/lib/ajax";
|
||||||
import { popupAjaxError } from "discourse/lib/ajax-error";
|
import { popupAjaxError } from "discourse/lib/ajax-error";
|
||||||
import loadScript from "discourse/lib/load-script";
|
import loadScript from "discourse/lib/load-script";
|
||||||
import { composeAiBotMessage } from "discourse/plugins/discourse-ai/discourse/lib/ai-bot-helper";
|
import { composeAiBotMessage } from "discourse/plugins/discourse-ai/discourse/lib/ai-bot-helper";
|
||||||
|
import { registerWidgetShim } from "discourse/widgets/render-glimmer";
|
||||||
|
import { hbs } from "ember-cli-htmlbars";
|
||||||
|
|
||||||
function isGPTBot(user) {
|
function isGPTBot(user) {
|
||||||
return user && [-110, -111, -112].includes(user.id);
|
return user && [-110, -111, -112].includes(user.id);
|
||||||
|
@ -138,6 +140,32 @@ function initializeAIBotReplies(api) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function initializePersonaDecorator(api) {
|
||||||
|
let topicController = null;
|
||||||
|
api.decorateWidget(`poster-name:after`, (dec) => {
|
||||||
|
if (!isGPTBot(dec.attrs.user)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// this is hacky and will need to change
|
||||||
|
// trouble is we need to get the model for the topic
|
||||||
|
// and it is not available in the decorator
|
||||||
|
// long term this will not be a problem once we remove widgets and
|
||||||
|
// have a saner structure for our model
|
||||||
|
topicController =
|
||||||
|
topicController || api.container.lookup("controller:topic");
|
||||||
|
|
||||||
|
return dec.widget.attach("persona-flair", {
|
||||||
|
topicController,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
registerWidgetShim(
|
||||||
|
"persona-flair",
|
||||||
|
"span.persona-flair",
|
||||||
|
hbs`{{@data.topicController.model.ai_persona_name}}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export default {
|
export default {
|
||||||
name: "discourse-ai-bot-replies",
|
name: "discourse-ai-bot-replies",
|
||||||
|
|
||||||
|
@ -157,6 +185,7 @@ export default {
|
||||||
if (aiBotEnaled && canInteractWithAIBots) {
|
if (aiBotEnaled && canInteractWithAIBots) {
|
||||||
withPluginApi("1.6.0", attachHeaderIcon);
|
withPluginApi("1.6.0", attachHeaderIcon);
|
||||||
withPluginApi("1.6.0", initializeAIBotReplies);
|
withPluginApi("1.6.0", initializeAIBotReplies);
|
||||||
|
withPluginApi("1.6.0", initializePersonaDecorator);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
|
@ -2,10 +2,17 @@ nav.post-controls .actions button.cancel-streaming {
|
||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
.ai-bot-chat #reply-control {
|
.ai-bot-chat {
|
||||||
.title-and-category {
|
#reply-control {
|
||||||
|
.title-and-category,
|
||||||
|
#private-message-users {
|
||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
.gpt-persona {
|
||||||
|
margin-bottom: 5px;
|
||||||
|
margin-top: -10px;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
.ai-bot-chat-warning {
|
.ai-bot-chat-warning {
|
||||||
|
@ -39,3 +46,9 @@ article.streaming nav.post-controls .actions button.cancel-streaming {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.topic-body .persona-flair {
|
||||||
|
order: 2;
|
||||||
|
font-size: var(--font-down-1);
|
||||||
|
padding-top: 3px;
|
||||||
|
}
|
||||||
|
|
|
@ -93,6 +93,16 @@ en:
|
||||||
markdown_table: Generate Markdown table
|
markdown_table: Generate Markdown table
|
||||||
|
|
||||||
ai_bot:
|
ai_bot:
|
||||||
|
personas:
|
||||||
|
general:
|
||||||
|
name: Forum Helper
|
||||||
|
description: "General purpose AI Bot capable of performing various tasks"
|
||||||
|
artist:
|
||||||
|
name: Artist
|
||||||
|
description: "AI Bot specialized in generating images"
|
||||||
|
sql_helper:
|
||||||
|
name: SQL Helper
|
||||||
|
description: "AI Bot specialized in helping craft SQL queries on this Discourse instance"
|
||||||
default_pm_prefix: "[Untitled AI bot PM]"
|
default_pm_prefix: "[Untitled AI bot PM]"
|
||||||
topic_not_found: "Summary unavailable, topic not found!"
|
topic_not_found: "Summary unavailable, topic not found!"
|
||||||
command_summary:
|
command_summary:
|
||||||
|
@ -105,6 +115,7 @@ en:
|
||||||
google: "Search Google"
|
google: "Search Google"
|
||||||
read: "Read topic"
|
read: "Read topic"
|
||||||
setting_context: "Look up site setting context"
|
setting_context: "Look up site setting context"
|
||||||
|
schema: "Look up database schema"
|
||||||
command_description:
|
command_description:
|
||||||
read: "Reading: <a href='%{url}'>%{title}</a>"
|
read: "Reading: <a href='%{url}'>%{title}</a>"
|
||||||
time: "Time in %{timezone} is %{time}"
|
time: "Time in %{timezone} is %{time}"
|
||||||
|
@ -123,6 +134,7 @@ en:
|
||||||
one: "Found %{count} <a href='%{url}'>result</a> for '%{query}'"
|
one: "Found %{count} <a href='%{url}'>result</a> for '%{query}'"
|
||||||
other: "Found %{count} <a href='%{url}'>results</a> for '%{query}'"
|
other: "Found %{count} <a href='%{url}'>results</a> for '%{query}'"
|
||||||
setting_context: "Reading context for: %{setting_name}"
|
setting_context: "Reading context for: %{setting_name}"
|
||||||
|
schema: "%{tables}"
|
||||||
|
|
||||||
summarization:
|
summarization:
|
||||||
configuration_hint:
|
configuration_hint:
|
||||||
|
|
|
@ -58,6 +58,7 @@ module DiscourseAi
|
||||||
|
|
||||||
def initialize(bot_user)
|
def initialize(bot_user)
|
||||||
@bot_user = bot_user
|
@bot_user = bot_user
|
||||||
|
@persona = DiscourseAi::AiBot::Personas::General.new
|
||||||
end
|
end
|
||||||
|
|
||||||
def update_pm_title(post)
|
def update_pm_title(post)
|
||||||
|
@ -90,6 +91,13 @@ module DiscourseAi
|
||||||
)
|
)
|
||||||
return if total_completions > MAX_COMPLETIONS
|
return if total_completions > MAX_COMPLETIONS
|
||||||
|
|
||||||
|
@persona = DiscourseAi::AiBot::Personas::General.new
|
||||||
|
if persona_name = post.topic.custom_fields["ai_persona"]
|
||||||
|
persona_class =
|
||||||
|
DiscourseAi::AiBot::Personas.all.find { |current| current.name == persona_name }
|
||||||
|
@persona = persona_class.new if persona_class
|
||||||
|
end
|
||||||
|
|
||||||
prompt =
|
prompt =
|
||||||
if standalone && post.post_custom_prompt
|
if standalone && post.post_custom_prompt
|
||||||
username, standalone_prompt = post.post_custom_prompt.custom_prompt.last
|
username, standalone_prompt = post.post_custom_prompt.custom_prompt.last
|
||||||
|
@ -265,27 +273,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def available_commands
|
def available_commands
|
||||||
return @cmds if @cmds
|
@persona.available_commands
|
||||||
|
|
||||||
all_commands =
|
|
||||||
[
|
|
||||||
Commands::CategoriesCommand,
|
|
||||||
Commands::TimeCommand,
|
|
||||||
Commands::SearchCommand,
|
|
||||||
Commands::SummarizeCommand,
|
|
||||||
Commands::ReadCommand,
|
|
||||||
Commands::SettingContextCommand,
|
|
||||||
].tap do |cmds|
|
|
||||||
cmds << Commands::TagsCommand if SiteSetting.tagging_enabled
|
|
||||||
cmds << Commands::ImageCommand if SiteSetting.ai_stability_api_key.present?
|
|
||||||
if SiteSetting.ai_google_custom_search_api_key.present? &&
|
|
||||||
SiteSetting.ai_google_custom_search_cx.present?
|
|
||||||
cmds << Commands::GoogleCommand
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
allowed_commands = SiteSetting.ai_bot_enabled_chat_commands.split("|")
|
|
||||||
@cmds = all_commands.filter { |klass| allowed_commands.include?(klass.name) }
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def system_prompt_style!(style)
|
def system_prompt_style!(style)
|
||||||
|
@ -295,26 +283,10 @@ module DiscourseAi
|
||||||
def system_prompt(post)
|
def system_prompt(post)
|
||||||
return "You are a helpful Bot" if @style == :simple
|
return "You are a helpful Bot" if @style == :simple
|
||||||
|
|
||||||
prompt = +<<~TEXT
|
@persona.render_system_prompt(
|
||||||
You are a helpful Discourse assistant.
|
topic: post.topic,
|
||||||
You understand and generate Discourse Markdown.
|
render_function_instructions: include_function_instructions_in_system_prompt?,
|
||||||
You live in a Discourse Forum Message.
|
)
|
||||||
|
|
||||||
You live in the forum with the URL: #{Discourse.base_url}
|
|
||||||
The title of your site: #{SiteSetting.title}
|
|
||||||
The description is: #{SiteSetting.site_description}
|
|
||||||
The participants in this conversation are: #{post.topic.allowed_users.map(&:username).join(", ")}
|
|
||||||
The date now is: #{Time.zone.now}, much has changed since you were trained.
|
|
||||||
TEXT
|
|
||||||
|
|
||||||
if include_function_instructions_in_system_prompt?
|
|
||||||
prompt << "\n"
|
|
||||||
prompt << function_list.system_prompt
|
|
||||||
prompt << "\n"
|
|
||||||
end
|
|
||||||
|
|
||||||
prompt << available_commands.map(&:custom_system_message).compact.join("\n")
|
|
||||||
prompt
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def include_function_instructions_in_system_prompt?
|
def include_function_instructions_in_system_prompt?
|
||||||
|
@ -322,11 +294,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def function_list
|
def function_list
|
||||||
return @function_list if @function_list
|
@persona.function_list
|
||||||
|
|
||||||
@function_list = DiscourseAi::Inference::FunctionList.new
|
|
||||||
available_functions.each { |function| @function_list << function }
|
|
||||||
@function_list
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
|
@ -363,29 +331,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def available_functions
|
def available_functions
|
||||||
# note if defined? can be a problem in test
|
@persona.available_functions
|
||||||
# this can never be nil so it is safe
|
|
||||||
return @available_functions if @available_functions
|
|
||||||
|
|
||||||
functions = []
|
|
||||||
|
|
||||||
functions =
|
|
||||||
available_commands.map do |command|
|
|
||||||
function =
|
|
||||||
DiscourseAi::Inference::Function.new(name: command.name, description: command.desc)
|
|
||||||
command.parameters.each do |parameter|
|
|
||||||
function.add_parameter(
|
|
||||||
name: parameter.name,
|
|
||||||
type: parameter.type,
|
|
||||||
description: parameter.description,
|
|
||||||
required: parameter.required,
|
|
||||||
enum: parameter.enum,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
function
|
|
||||||
end
|
|
||||||
|
|
||||||
@available_functions = functions
|
|
||||||
end
|
end
|
||||||
|
|
||||||
protected
|
protected
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
#frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi::AiBot::Commands
|
||||||
|
class DbSchemaCommand < Command
|
||||||
|
class << self
|
||||||
|
def name
|
||||||
|
"schema"
|
||||||
|
end
|
||||||
|
|
||||||
|
def desc
|
||||||
|
"Will load schema information for specific tables in the database"
|
||||||
|
end
|
||||||
|
|
||||||
|
def parameters
|
||||||
|
[
|
||||||
|
Parameter.new(
|
||||||
|
name: "tables",
|
||||||
|
description:
|
||||||
|
"list of tables to load schema information for, comma seperated list eg: (users,posts))",
|
||||||
|
type: "string",
|
||||||
|
required: true,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def result_name
|
||||||
|
"results"
|
||||||
|
end
|
||||||
|
|
||||||
|
def description_args
|
||||||
|
{ tables: @tables.join(", ") }
|
||||||
|
end
|
||||||
|
|
||||||
|
def process(tables:)
|
||||||
|
@tables = tables.split(",").map(&:strip)
|
||||||
|
|
||||||
|
table_info = {}
|
||||||
|
DB
|
||||||
|
.query(<<~SQL, @tables)
|
||||||
|
select table_name, column_name, data_type from information_schema.columns
|
||||||
|
where table_schema = 'public'
|
||||||
|
and table_name in (?)
|
||||||
|
order by table_name
|
||||||
|
SQL
|
||||||
|
.each { |row| (table_info[row.table_name] ||= []) << "#{row.column_name} #{row.data_type}" }
|
||||||
|
|
||||||
|
schema_info =
|
||||||
|
table_info.map { |table_name, columns| "#{table_name}(#{columns.join(",")})" }.join("\n")
|
||||||
|
|
||||||
|
{ tables: @tables, schema_info: schema_info }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -39,6 +39,11 @@ module DiscourseAi
|
||||||
require_relative "commands/google_command"
|
require_relative "commands/google_command"
|
||||||
require_relative "commands/read_command"
|
require_relative "commands/read_command"
|
||||||
require_relative "commands/setting_context_command"
|
require_relative "commands/setting_context_command"
|
||||||
|
require_relative "commands/db_schema_command"
|
||||||
|
require_relative "personas/persona"
|
||||||
|
require_relative "personas/artist"
|
||||||
|
require_relative "personas/general"
|
||||||
|
require_relative "personas/sql_helper"
|
||||||
end
|
end
|
||||||
|
|
||||||
def inject_into(plugin)
|
def inject_into(plugin)
|
||||||
|
@ -46,6 +51,17 @@ module DiscourseAi
|
||||||
Rails.root.join("plugins", "discourse-ai", "db", "fixtures", "ai_bot"),
|
Rails.root.join("plugins", "discourse-ai", "db", "fixtures", "ai_bot"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
plugin.add_to_serializer(
|
||||||
|
:current_user,
|
||||||
|
:ai_enabled_personas,
|
||||||
|
include_condition: -> do
|
||||||
|
SiteSetting.ai_bot_enabled && scope.authenticated? &&
|
||||||
|
scope.user.in_any_groups?(SiteSetting.ai_bot_allowed_groups_map)
|
||||||
|
end,
|
||||||
|
) do
|
||||||
|
Personas.all.map { |persona| { name: persona.name, description: persona.description } }
|
||||||
|
end
|
||||||
|
|
||||||
plugin.add_to_serializer(
|
plugin.add_to_serializer(
|
||||||
:current_user,
|
:current_user,
|
||||||
:ai_enabled_chat_bots,
|
:ai_enabled_chat_bots,
|
||||||
|
@ -75,6 +91,12 @@ module DiscourseAi
|
||||||
|
|
||||||
plugin.register_svg_icon("robot")
|
plugin.register_svg_icon("robot")
|
||||||
|
|
||||||
|
plugin.add_to_serializer(
|
||||||
|
:topic_view,
|
||||||
|
:ai_persona_name,
|
||||||
|
include_condition: -> { SiteSetting.ai_bot_enabled && object.topic.private_message? },
|
||||||
|
) { topic.custom_fields["ai_persona"] }
|
||||||
|
|
||||||
plugin.on(:post_created) do |post|
|
plugin.on(:post_created) do |post|
|
||||||
bot_ids = BOTS.map(&:first)
|
bot_ids = BOTS.map(&:first)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
#frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module AiBot
|
||||||
|
module Personas
|
||||||
|
class Artist < Persona
|
||||||
|
def commands
|
||||||
|
[Commands::ImageCommand]
|
||||||
|
end
|
||||||
|
|
||||||
|
def system_prompt
|
||||||
|
<<~PROMPT
|
||||||
|
You are artistbot and you are here to help people generate images.
|
||||||
|
|
||||||
|
You generate images using stable diffusion.
|
||||||
|
|
||||||
|
- A good prompt needs to be detailed and specific.
|
||||||
|
- You can specify subject, medium (e.g. oil on canvas), artist (person who drew it or photographed it)
|
||||||
|
- You can specify details about lighting or time of day.
|
||||||
|
- You can specify a particular website you would like to emulate (artstation or deviantart)
|
||||||
|
- You can specify additional details such as "beutiful, dystopian, futuristic, etc."
|
||||||
|
- Prompts should generally be 10-20 words long
|
||||||
|
- Do not include any connector words such as "and" or "but" etc.
|
||||||
|
- You are extremely creative, when given short non descriptive prompts from a user you add your own details
|
||||||
|
|
||||||
|
{commands}
|
||||||
|
|
||||||
|
PROMPT
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,29 @@
|
||||||
|
#frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module AiBot
|
||||||
|
module Personas
|
||||||
|
class General < Persona
|
||||||
|
def commands
|
||||||
|
all_available_commands
|
||||||
|
end
|
||||||
|
|
||||||
|
def system_prompt
|
||||||
|
<<~PROMPT
|
||||||
|
You are a helpful Discourse assistant.
|
||||||
|
You understand and generate Discourse Markdown.
|
||||||
|
You live in a Discourse Forum Message.
|
||||||
|
|
||||||
|
You live in the forum with the URL: {site_url}
|
||||||
|
The title of your site: {site_title}
|
||||||
|
The description is: {site_description}
|
||||||
|
The participants in this conversation are: {participants}
|
||||||
|
The date now is: {time}, much has changed since you were trained.
|
||||||
|
|
||||||
|
{commands}
|
||||||
|
PROMPT
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,110 @@
|
||||||
|
#frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module AiBot
|
||||||
|
module Personas
|
||||||
|
def self.all
|
||||||
|
personas = [Personas::General, Personas::SqlHelper]
|
||||||
|
personas << Personas::Artist if SiteSetting.ai_stability_api_key.present?
|
||||||
|
personas
|
||||||
|
end
|
||||||
|
|
||||||
|
class Persona
|
||||||
|
def self.name
|
||||||
|
I18n.t("discourse_ai.ai_bot.personas.#{to_s.demodulize.underscore}.name")
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.description
|
||||||
|
I18n.t("discourse_ai.ai_bot.personas.#{to_s.demodulize.underscore}.description")
|
||||||
|
end
|
||||||
|
|
||||||
|
def commands
|
||||||
|
[]
|
||||||
|
end
|
||||||
|
|
||||||
|
def render_commands(render_function_instructions:)
|
||||||
|
result = +""
|
||||||
|
if render_function_instructions
|
||||||
|
result << "\n"
|
||||||
|
result << function_list.system_prompt
|
||||||
|
result << "\n"
|
||||||
|
end
|
||||||
|
result << available_commands.map(&:custom_system_message).compact.join("\n")
|
||||||
|
result
|
||||||
|
end
|
||||||
|
|
||||||
|
def render_system_prompt(topic: nil, render_function_instructions: true)
|
||||||
|
substitutions = {
|
||||||
|
site_url: Discourse.base_url,
|
||||||
|
site_title: SiteSetting.title,
|
||||||
|
site_description: SiteSetting.site_description,
|
||||||
|
time: Time.zone.now,
|
||||||
|
commands: render_commands(render_function_instructions: render_function_instructions),
|
||||||
|
}
|
||||||
|
|
||||||
|
substitutions[:participants] = topic.allowed_users.map(&:username).join(", ") if topic
|
||||||
|
|
||||||
|
system_prompt.gsub(/\{(\w+)\}/) do |match|
|
||||||
|
found = substitutions[match[1..-2].to_sym]
|
||||||
|
found.nil? ? match : found.to_s
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def available_commands
|
||||||
|
return @available_commands if @available_commands
|
||||||
|
|
||||||
|
@available_commands = all_available_commands.filter { |cmd| commands.include?(cmd) }
|
||||||
|
end
|
||||||
|
|
||||||
|
def available_functions
|
||||||
|
# note if defined? can be a problem in test
|
||||||
|
# this can never be nil so it is safe
|
||||||
|
return @available_functions if @available_functions
|
||||||
|
|
||||||
|
functions = []
|
||||||
|
|
||||||
|
functions =
|
||||||
|
available_commands.map do |command|
|
||||||
|
function =
|
||||||
|
DiscourseAi::Inference::Function.new(name: command.name, description: command.desc)
|
||||||
|
command.parameters.each { |parameter| function.add_parameter(parameter) }
|
||||||
|
function
|
||||||
|
end
|
||||||
|
|
||||||
|
@available_functions = functions
|
||||||
|
end
|
||||||
|
|
||||||
|
def function_list
|
||||||
|
return @function_list if @function_list
|
||||||
|
|
||||||
|
@function_list = DiscourseAi::Inference::FunctionList.new
|
||||||
|
available_functions.each { |function| @function_list << function }
|
||||||
|
@function_list
|
||||||
|
end
|
||||||
|
|
||||||
|
def all_available_commands
|
||||||
|
return @cmds if @cmds
|
||||||
|
|
||||||
|
all_commands = [
|
||||||
|
Commands::CategoriesCommand,
|
||||||
|
Commands::TimeCommand,
|
||||||
|
Commands::SearchCommand,
|
||||||
|
Commands::SummarizeCommand,
|
||||||
|
Commands::ReadCommand,
|
||||||
|
Commands::SettingContextCommand,
|
||||||
|
]
|
||||||
|
|
||||||
|
all_commands << Commands::TagsCommand if SiteSetting.tagging_enabled
|
||||||
|
all_commands << Commands::ImageCommand if SiteSetting.ai_stability_api_key.present?
|
||||||
|
if SiteSetting.ai_google_custom_search_api_key.present? &&
|
||||||
|
SiteSetting.ai_google_custom_search_cx.present?
|
||||||
|
all_commands << Commands::GoogleCommand
|
||||||
|
end
|
||||||
|
|
||||||
|
allowed_commands = SiteSetting.ai_bot_enabled_chat_commands.split("|")
|
||||||
|
@cmds = all_commands.filter { |klass| allowed_commands.include?(klass.name) }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,64 @@
|
||||||
|
#frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module AiBot
|
||||||
|
module Personas
|
||||||
|
class SqlHelper < Persona
|
||||||
|
def self.schema
|
||||||
|
return @schema if defined?(@schema)
|
||||||
|
|
||||||
|
tables = Hash.new
|
||||||
|
priority_tables = %w[posts topics notifications users user_actions]
|
||||||
|
|
||||||
|
DB.query(<<~SQL).each { |row| (tables[row.table_name] ||= []) << row.column_name }
|
||||||
|
select table_name, column_name from information_schema.columns
|
||||||
|
where table_schema = 'public'
|
||||||
|
order by table_name
|
||||||
|
SQL
|
||||||
|
|
||||||
|
schema = +(priority_tables.map { |name| "#{name}(#{tables[name].join(",")})" }.join("\n"))
|
||||||
|
|
||||||
|
schema << "\nOther tables (schema redacted, available on request): "
|
||||||
|
tables.each do |table_name, _|
|
||||||
|
next if priority_tables.include?(table_name)
|
||||||
|
schema << "#{table_name} "
|
||||||
|
end
|
||||||
|
|
||||||
|
@schema = schema
|
||||||
|
end
|
||||||
|
|
||||||
|
def commands
|
||||||
|
all_available_commands
|
||||||
|
end
|
||||||
|
|
||||||
|
def all_available_commands
|
||||||
|
[DiscourseAi::AiBot::Commands::DbSchemaCommand]
|
||||||
|
end
|
||||||
|
|
||||||
|
def system_prompt
|
||||||
|
<<~PROMPT
|
||||||
|
You are a PostgreSQL expert.
|
||||||
|
You understand and generate Discourse Markdown but specialize in creating queries.
|
||||||
|
You live in a Discourse Forum Message.
|
||||||
|
The schema in your training set MAY be out of date.
|
||||||
|
|
||||||
|
The user_actions tables stores likes (action_type 1).
|
||||||
|
the topics table stores private/personal messages it uses archetype private_message for them.
|
||||||
|
notification_level can be: {muted: 0, regular: 1, tracking: 2, watching: 3, watching_first_post: 4}.
|
||||||
|
bookmarkable_type can be: Post,Topic,ChatMessage and more
|
||||||
|
|
||||||
|
Current time is: {time}
|
||||||
|
|
||||||
|
|
||||||
|
The current schema for the current DB is:
|
||||||
|
{{
|
||||||
|
#{self.class.schema}
|
||||||
|
}}
|
||||||
|
|
||||||
|
{commands}
|
||||||
|
PROMPT
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -12,7 +12,21 @@ module ::DiscourseAi
|
||||||
@parameters = []
|
@parameters = []
|
||||||
end
|
end
|
||||||
|
|
||||||
def add_parameter(name:, type:, description:, enum: nil, required: false)
|
def add_parameter(parameter = nil, **kwargs)
|
||||||
|
if parameter
|
||||||
|
add_parameter_kwargs(
|
||||||
|
name: parameter.name,
|
||||||
|
type: parameter.type,
|
||||||
|
description: parameter.description,
|
||||||
|
required: parameter.required,
|
||||||
|
enum: parameter.enum,
|
||||||
|
)
|
||||||
|
else
|
||||||
|
add_parameter_kwargs(**kwargs)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def add_parameter_kwargs(name:, type:, description:, enum: nil, required: false)
|
||||||
@parameters << {
|
@parameters << {
|
||||||
name: name,
|
name: name,
|
||||||
type: type,
|
type: type,
|
||||||
|
|
|
@ -28,7 +28,22 @@ module ::DiscourseAi
|
||||||
next if function.blank?
|
next if function.blank?
|
||||||
|
|
||||||
arguments = arguments[0..-2] if arguments.end_with?(")")
|
arguments = arguments[0..-2] if arguments.end_with?(")")
|
||||||
arguments = arguments.split(",").map(&:strip)
|
|
||||||
|
temp_string = +""
|
||||||
|
in_string = nil
|
||||||
|
replace = SecureRandom.hex(10)
|
||||||
|
arguments.each_char do |char|
|
||||||
|
if %w[" '].include?(char) && !in_string
|
||||||
|
in_string = char
|
||||||
|
elsif char == in_string
|
||||||
|
in_string = nil
|
||||||
|
elsif char == "," && in_string
|
||||||
|
char = replace
|
||||||
|
end
|
||||||
|
temp_string << char
|
||||||
|
end
|
||||||
|
|
||||||
|
arguments = temp_string.split(",").map { |s| s.gsub(replace, ",").strip }
|
||||||
|
|
||||||
parsed_arguments = {}
|
parsed_arguments = {}
|
||||||
arguments.each do |argument|
|
arguments.each do |argument|
|
||||||
|
@ -76,8 +91,8 @@ module ::DiscourseAi
|
||||||
PROMPT
|
PROMPT
|
||||||
|
|
||||||
@functions.each do |function|
|
@functions.each do |function|
|
||||||
prompt << " // #{function.description}\n"
|
prompt << "// #{function.description}\n"
|
||||||
prompt << " #{function.name}"
|
prompt << "!#{function.name}"
|
||||||
if function.parameters.present?
|
if function.parameters.present?
|
||||||
prompt << "("
|
prompt << "("
|
||||||
function.parameters.each_with_index do |parameter, index|
|
function.parameters.each_with_index do |parameter, index|
|
||||||
|
@ -96,8 +111,9 @@ module ::DiscourseAi
|
||||||
|
|
||||||
prompt << " /* #{description} */" if description.present?
|
prompt << " /* #{description} */" if description.present?
|
||||||
end
|
end
|
||||||
prompt << ")\n"
|
prompt << ")"
|
||||||
end
|
end
|
||||||
|
prompt << "\n"
|
||||||
end
|
end
|
||||||
|
|
||||||
prompt << <<~PROMPT
|
prompt << <<~PROMPT
|
||||||
|
@ -110,7 +126,7 @@ module ::DiscourseAi
|
||||||
|
|
||||||
{
|
{
|
||||||
// echo a string
|
// echo a string
|
||||||
echo(message: string [required])
|
!echo(message: string [required])
|
||||||
}
|
}
|
||||||
|
|
||||||
Human: please echo out "hello"
|
Human: please echo out "hello"
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
#frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::AiBot::Commands::DbSchemaCommand do
|
||||||
|
let(:command) { DiscourseAi::AiBot::Commands::DbSchemaCommand.new(bot_user: nil, args: nil) }
|
||||||
|
describe "#process" do
|
||||||
|
it "returns rich schema for tables" do
|
||||||
|
result = command.process(tables: "posts,topics")
|
||||||
|
expect(result[:schema_info]).to include("raw text")
|
||||||
|
expect(result[:schema_info]).to include("views integer")
|
||||||
|
expect(result[:schema_info]).to include("posts")
|
||||||
|
expect(result[:schema_info]).to include("topics")
|
||||||
|
|
||||||
|
expect(result[:tables]).to eq(%w[posts topics])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,61 @@
|
||||||
|
#frozen_string_literal: true
|
||||||
|
|
||||||
|
class TestPersona < DiscourseAi::AiBot::Personas::Persona
|
||||||
|
def commands
|
||||||
|
[
|
||||||
|
DiscourseAi::AiBot::Commands::TagsCommand,
|
||||||
|
DiscourseAi::AiBot::Commands::SearchCommand,
|
||||||
|
DiscourseAi::AiBot::Commands::ImageCommand,
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
||||||
|
def system_prompt
|
||||||
|
<<~PROMPT
|
||||||
|
{site_url}
|
||||||
|
{site_title}
|
||||||
|
{site_description}
|
||||||
|
{participants}
|
||||||
|
{time}
|
||||||
|
|
||||||
|
{commands}
|
||||||
|
PROMPT
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||||
|
let :persona do
|
||||||
|
TestPersona.new
|
||||||
|
end
|
||||||
|
|
||||||
|
let :topic_with_users do
|
||||||
|
topic = Topic.new
|
||||||
|
topic.allowed_users = [User.new(username: "joe"), User.new(username: "jane")]
|
||||||
|
topic
|
||||||
|
end
|
||||||
|
|
||||||
|
it "renders the system prompt" do
|
||||||
|
freeze_time
|
||||||
|
|
||||||
|
SiteSetting.title = "test site title"
|
||||||
|
SiteSetting.site_description = "test site description"
|
||||||
|
|
||||||
|
rendered =
|
||||||
|
persona.render_system_prompt(topic: topic_with_users, render_function_instructions: true)
|
||||||
|
|
||||||
|
expect(rendered).to include(Discourse.base_url)
|
||||||
|
expect(rendered).to include("test site title")
|
||||||
|
expect(rendered).to include("test site description")
|
||||||
|
expect(rendered).to include("joe, jane")
|
||||||
|
expect(rendered).to include(Time.zone.now.to_s)
|
||||||
|
expect(rendered).to include("!search")
|
||||||
|
expect(rendered).to include("!tags")
|
||||||
|
# needs to be configured so it is not available
|
||||||
|
expect(rendered).not_to include("!image")
|
||||||
|
|
||||||
|
rendered =
|
||||||
|
persona.render_system_prompt(topic: topic_with_users, render_function_instructions: false)
|
||||||
|
|
||||||
|
expect(rendered).not_to include("!search")
|
||||||
|
expect(rendered).not_to include("!tags")
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,17 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::AiBot::Personas::SqlHelper do
|
||||||
|
let :sql_helper do
|
||||||
|
subject
|
||||||
|
end
|
||||||
|
|
||||||
|
it "renders schema" do
|
||||||
|
prompt = sql_helper.render_system_prompt
|
||||||
|
expect(prompt).to include("posts(")
|
||||||
|
expect(prompt).to include("topics(")
|
||||||
|
expect(prompt).not_to include("translation_key") # not a priority table
|
||||||
|
expect(prompt).to include("user_api_keys") # not a priority table
|
||||||
|
|
||||||
|
expect(sql_helper.available_commands).to eq([DiscourseAi::AiBot::Commands::DbSchemaCommand])
|
||||||
|
end
|
||||||
|
end
|
|
@ -29,16 +29,16 @@ module DiscourseAi::Inference
|
||||||
|
|
||||||
it "can handle complex parsing" do
|
it "can handle complex parsing" do
|
||||||
raw_prompt = <<~PROMPT
|
raw_prompt = <<~PROMPT
|
||||||
!get_weather(location: "sydney", unit: "f")
|
!get_weather(location: "sydney,melbourne", unit: "f")
|
||||||
!get_weather (location: sydney)
|
!get_weather (location: sydney)
|
||||||
!get_weather(location : 'sydney's', unit: "m", invalid: "invalid")
|
!get_weather(location : "sydney's", unit: "m", invalid: "invalid")
|
||||||
!get_weather(unit: "f", invalid: "invalid")
|
!get_weather(unit: "f", invalid: "invalid")
|
||||||
PROMPT
|
PROMPT
|
||||||
parsed = function_list.parse_prompt(raw_prompt)
|
parsed = function_list.parse_prompt(raw_prompt)
|
||||||
|
|
||||||
expect(parsed).to eq(
|
expect(parsed).to eq(
|
||||||
[
|
[
|
||||||
{ name: "get_weather", arguments: { location: "sydney", unit: "f" } },
|
{ name: "get_weather", arguments: { location: "sydney,melbourne", unit: "f" } },
|
||||||
{ name: "get_weather", arguments: { location: "sydney" } },
|
{ name: "get_weather", arguments: { location: "sydney" } },
|
||||||
{ name: "get_weather", arguments: { location: "sydney's" } },
|
{ name: "get_weather", arguments: { location: "sydney's" } },
|
||||||
],
|
],
|
||||||
|
@ -53,7 +53,7 @@ module DiscourseAi::Inference
|
||||||
expected = <<~PROMPT
|
expected = <<~PROMPT
|
||||||
{
|
{
|
||||||
// Get the weather in a city (default to c)
|
// Get the weather in a city (default to c)
|
||||||
get_weather(location: string [required] /* the city name */, unit: string [optional] /* the unit of measurement celcius c or fahrenheit f [valid values: c,f] */)
|
!get_weather(location: string [required] /* the city name */, unit: string [optional] /* the unit of measurement celcius c or fahrenheit f [valid values: c,f] */)
|
||||||
}
|
}
|
||||||
PROMPT
|
PROMPT
|
||||||
expect(prompt).to include(expected)
|
expect(prompt).to include(expected)
|
||||||
|
|
Loading…
Reference in New Issue