diff --git a/assets/javascripts/discourse/connectors/after-d-editor/composer-open.js b/assets/javascripts/discourse/connectors/after-d-editor/composer-open.js index 20a3044f..da5633db 100644 --- a/assets/javascripts/discourse/connectors/after-d-editor/composer-open.js +++ b/assets/javascripts/discourse/connectors/after-d-editor/composer-open.js @@ -3,10 +3,6 @@ import { inject as service } from "@ember/service"; import { computed } from "@ember/object"; export default class extends Component { - static shouldRender() { - return true; - } - @service currentUser; @service siteSettings; diff --git a/assets/javascripts/discourse/connectors/composer-fields/persona-selector.hbs b/assets/javascripts/discourse/connectors/composer-fields/persona-selector.hbs new file mode 100644 index 00000000..f79ee223 --- /dev/null +++ b/assets/javascripts/discourse/connectors/composer-fields/persona-selector.hbs @@ -0,0 +1,7 @@ +
+ +
\ No newline at end of file diff --git a/assets/javascripts/discourse/connectors/composer-fields/persona-selector.js b/assets/javascripts/discourse/connectors/composer-fields/persona-selector.js new file mode 100644 index 00000000..14fa85c7 --- /dev/null +++ b/assets/javascripts/discourse/connectors/composer-fields/persona-selector.js @@ -0,0 +1,53 @@ +import Component from "@glimmer/component"; +import { inject as service } from "@ember/service"; + +function isBotMessage(composer, currentUser) { + if ( + composer && + composer.targetRecipients && + currentUser.ai_enabled_chat_bots + ) { + const reciepients = composer.targetRecipients.split(","); + + return currentUser.ai_enabled_chat_bots.any((bot) => + reciepients.any((username) => username === bot.username) + ); + } + return false; +} + +export default class BotSelector extends Component { + static shouldRender(args, container) { + return ( + container?.currentUser?.ai_enabled_personas && + isBotMessage(args.model, container.currentUser) + ); + } + + @service currentUser; + + get composer() { + return this.args?.outletArgs?.model; + } + + get botOptions() { + if (this.currentUser.ai_enabled_personas) { + return this.currentUser.ai_enabled_personas.map((persona) => { + return { + id: persona.name, + name: persona.name, + description: persona.description, + }; + }); + } + } + + get value() { + return this._value || this.botOptions[0].id; + } + + set value(val) { + this._value = val; + this.composer.metaData = { ai_persona: val }; + } +} diff --git a/assets/javascripts/initializers/ai-bot-replies.js b/assets/javascripts/initializers/ai-bot-replies.js index 22d5af2e..053c9763 100644 --- a/assets/javascripts/initializers/ai-bot-replies.js +++ b/assets/javascripts/initializers/ai-bot-replies.js @@ -4,6 +4,8 @@ import { ajax } from "discourse/lib/ajax"; import { popupAjaxError } from "discourse/lib/ajax-error"; import loadScript from "discourse/lib/load-script"; import { composeAiBotMessage } from "discourse/plugins/discourse-ai/discourse/lib/ai-bot-helper"; +import { registerWidgetShim } from "discourse/widgets/render-glimmer"; +import { hbs } from "ember-cli-htmlbars"; function isGPTBot(user) { return user && [-110, -111, -112].includes(user.id); @@ -138,6 +140,32 @@ function initializeAIBotReplies(api) { }); } +function initializePersonaDecorator(api) { + let topicController = null; + api.decorateWidget(`poster-name:after`, (dec) => { + if (!isGPTBot(dec.attrs.user)) { + return; + } + // this is hacky and will need to change + // trouble is we need to get the model for the topic + // and it is not available in the decorator + // long term this will not be a problem once we remove widgets and + // have a saner structure for our model + topicController = + topicController || api.container.lookup("controller:topic"); + + return dec.widget.attach("persona-flair", { + topicController, + }); + }); + + registerWidgetShim( + "persona-flair", + "span.persona-flair", + hbs`{{@data.topicController.model.ai_persona_name}}` + ); +} + export default { name: "discourse-ai-bot-replies", @@ -157,6 +185,7 @@ export default { if (aiBotEnaled && canInteractWithAIBots) { withPluginApi("1.6.0", attachHeaderIcon); withPluginApi("1.6.0", initializeAIBotReplies); + withPluginApi("1.6.0", initializePersonaDecorator); } }, }; diff --git a/assets/stylesheets/modules/ai-bot/common/bot-replies.scss b/assets/stylesheets/modules/ai-bot/common/bot-replies.scss index 926ed647..0c487257 100644 --- a/assets/stylesheets/modules/ai-bot/common/bot-replies.scss +++ b/assets/stylesheets/modules/ai-bot/common/bot-replies.scss @@ -2,9 +2,16 @@ nav.post-controls .actions button.cancel-streaming { display: none; } -.ai-bot-chat #reply-control { - .title-and-category { - display: none; +.ai-bot-chat { + #reply-control { + .title-and-category, + #private-message-users { + display: none; + } + } + .gpt-persona { + margin-bottom: 5px; + margin-top: -10px; } } @@ -39,3 +46,9 @@ article.streaming nav.post-controls .actions button.cancel-streaming { } } } + +.topic-body .persona-flair { + order: 2; + font-size: var(--font-down-1); + padding-top: 3px; +} diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 75e5d13a..0a16980b 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -93,6 +93,16 @@ en: markdown_table: Generate Markdown table ai_bot: + personas: + general: + name: Forum Helper + description: "General purpose AI Bot capable of performing various tasks" + artist: + name: Artist + description: "AI Bot specialized in generating images" + sql_helper: + name: SQL Helper + description: "AI Bot specialized in helping craft SQL queries on this Discourse instance" default_pm_prefix: "[Untitled AI bot PM]" topic_not_found: "Summary unavailable, topic not found!" command_summary: @@ -105,6 +115,7 @@ en: google: "Search Google" read: "Read topic" setting_context: "Look up site setting context" + schema: "Look up database schema" command_description: read: "Reading: %{title}" time: "Time in %{timezone} is %{time}" @@ -123,6 +134,7 @@ en: one: "Found %{count} result for '%{query}'" other: "Found %{count} results for '%{query}'" setting_context: "Reading context for: %{setting_name}" + schema: "%{tables}" summarization: configuration_hint: diff --git a/lib/modules/ai_bot/bot.rb b/lib/modules/ai_bot/bot.rb index 97e3412e..cab1bb6e 100644 --- a/lib/modules/ai_bot/bot.rb +++ b/lib/modules/ai_bot/bot.rb @@ -58,6 +58,7 @@ module DiscourseAi def initialize(bot_user) @bot_user = bot_user + @persona = DiscourseAi::AiBot::Personas::General.new end def update_pm_title(post) @@ -90,6 +91,13 @@ module DiscourseAi ) return if total_completions > MAX_COMPLETIONS + @persona = DiscourseAi::AiBot::Personas::General.new + if persona_name = post.topic.custom_fields["ai_persona"] + persona_class = + DiscourseAi::AiBot::Personas.all.find { |current| current.name == persona_name } + @persona = persona_class.new if persona_class + end + prompt = if standalone && post.post_custom_prompt username, standalone_prompt = post.post_custom_prompt.custom_prompt.last @@ -265,27 +273,7 @@ module DiscourseAi end def available_commands - return @cmds if @cmds - - all_commands = - [ - Commands::CategoriesCommand, - Commands::TimeCommand, - Commands::SearchCommand, - Commands::SummarizeCommand, - Commands::ReadCommand, - Commands::SettingContextCommand, - ].tap do |cmds| - cmds << Commands::TagsCommand if SiteSetting.tagging_enabled - cmds << Commands::ImageCommand if SiteSetting.ai_stability_api_key.present? - if SiteSetting.ai_google_custom_search_api_key.present? && - SiteSetting.ai_google_custom_search_cx.present? - cmds << Commands::GoogleCommand - end - end - - allowed_commands = SiteSetting.ai_bot_enabled_chat_commands.split("|") - @cmds = all_commands.filter { |klass| allowed_commands.include?(klass.name) } + @persona.available_commands end def system_prompt_style!(style) @@ -295,26 +283,10 @@ module DiscourseAi def system_prompt(post) return "You are a helpful Bot" if @style == :simple - prompt = +<<~TEXT - You are a helpful Discourse assistant. - You understand and generate Discourse Markdown. - You live in a Discourse Forum Message. - - You live in the forum with the URL: #{Discourse.base_url} - The title of your site: #{SiteSetting.title} - The description is: #{SiteSetting.site_description} - The participants in this conversation are: #{post.topic.allowed_users.map(&:username).join(", ")} - The date now is: #{Time.zone.now}, much has changed since you were trained. - TEXT - - if include_function_instructions_in_system_prompt? - prompt << "\n" - prompt << function_list.system_prompt - prompt << "\n" - end - - prompt << available_commands.map(&:custom_system_message).compact.join("\n") - prompt + @persona.render_system_prompt( + topic: post.topic, + render_function_instructions: include_function_instructions_in_system_prompt?, + ) end def include_function_instructions_in_system_prompt? @@ -322,11 +294,7 @@ module DiscourseAi end def function_list - return @function_list if @function_list - - @function_list = DiscourseAi::Inference::FunctionList.new - available_functions.each { |function| @function_list << function } - @function_list + @persona.function_list end def tokenizer @@ -363,29 +331,7 @@ module DiscourseAi end def available_functions - # note if defined? can be a problem in test - # this can never be nil so it is safe - return @available_functions if @available_functions - - functions = [] - - functions = - available_commands.map do |command| - function = - DiscourseAi::Inference::Function.new(name: command.name, description: command.desc) - command.parameters.each do |parameter| - function.add_parameter( - name: parameter.name, - type: parameter.type, - description: parameter.description, - required: parameter.required, - enum: parameter.enum, - ) - end - function - end - - @available_functions = functions + @persona.available_functions end protected diff --git a/lib/modules/ai_bot/commands/db_schema_command.rb b/lib/modules/ai_bot/commands/db_schema_command.rb new file mode 100644 index 00000000..4f01c765 --- /dev/null +++ b/lib/modules/ai_bot/commands/db_schema_command.rb @@ -0,0 +1,54 @@ +#frozen_string_literal: true + +module DiscourseAi::AiBot::Commands + class DbSchemaCommand < Command + class << self + def name + "schema" + end + + def desc + "Will load schema information for specific tables in the database" + end + + def parameters + [ + Parameter.new( + name: "tables", + description: + "list of tables to load schema information for, comma seperated list eg: (users,posts))", + type: "string", + required: true, + ), + ] + end + end + + def result_name + "results" + end + + def description_args + { tables: @tables.join(", ") } + end + + def process(tables:) + @tables = tables.split(",").map(&:strip) + + table_info = {} + DB + .query(<<~SQL, @tables) + select table_name, column_name, data_type from information_schema.columns + where table_schema = 'public' + and table_name in (?) + order by table_name + SQL + .each { |row| (table_info[row.table_name] ||= []) << "#{row.column_name} #{row.data_type}" } + + schema_info = + table_info.map { |table_name, columns| "#{table_name}(#{columns.join(",")})" }.join("\n") + + { tables: @tables, schema_info: schema_info } + end + end +end diff --git a/lib/modules/ai_bot/entry_point.rb b/lib/modules/ai_bot/entry_point.rb index a6aa336e..e0ce6f0b 100644 --- a/lib/modules/ai_bot/entry_point.rb +++ b/lib/modules/ai_bot/entry_point.rb @@ -39,6 +39,11 @@ module DiscourseAi require_relative "commands/google_command" require_relative "commands/read_command" require_relative "commands/setting_context_command" + require_relative "commands/db_schema_command" + require_relative "personas/persona" + require_relative "personas/artist" + require_relative "personas/general" + require_relative "personas/sql_helper" end def inject_into(plugin) @@ -46,6 +51,17 @@ module DiscourseAi Rails.root.join("plugins", "discourse-ai", "db", "fixtures", "ai_bot"), ) + plugin.add_to_serializer( + :current_user, + :ai_enabled_personas, + include_condition: -> do + SiteSetting.ai_bot_enabled && scope.authenticated? && + scope.user.in_any_groups?(SiteSetting.ai_bot_allowed_groups_map) + end, + ) do + Personas.all.map { |persona| { name: persona.name, description: persona.description } } + end + plugin.add_to_serializer( :current_user, :ai_enabled_chat_bots, @@ -75,6 +91,12 @@ module DiscourseAi plugin.register_svg_icon("robot") + plugin.add_to_serializer( + :topic_view, + :ai_persona_name, + include_condition: -> { SiteSetting.ai_bot_enabled && object.topic.private_message? }, + ) { topic.custom_fields["ai_persona"] } + plugin.on(:post_created) do |post| bot_ids = BOTS.map(&:first) diff --git a/lib/modules/ai_bot/personas/artist.rb b/lib/modules/ai_bot/personas/artist.rb new file mode 100644 index 00000000..9f520c57 --- /dev/null +++ b/lib/modules/ai_bot/personas/artist.rb @@ -0,0 +1,33 @@ +#frozen_string_literal: true + +module DiscourseAi + module AiBot + module Personas + class Artist < Persona + def commands + [Commands::ImageCommand] + end + + def system_prompt + <<~PROMPT + You are artistbot and you are here to help people generate images. + + You generate images using stable diffusion. + + - A good prompt needs to be detailed and specific. + - You can specify subject, medium (e.g. oil on canvas), artist (person who drew it or photographed it) + - You can specify details about lighting or time of day. + - You can specify a particular website you would like to emulate (artstation or deviantart) + - You can specify additional details such as "beutiful, dystopian, futuristic, etc." + - Prompts should generally be 10-20 words long + - Do not include any connector words such as "and" or "but" etc. + - You are extremely creative, when given short non descriptive prompts from a user you add your own details + + {commands} + + PROMPT + end + end + end + end +end diff --git a/lib/modules/ai_bot/personas/general.rb b/lib/modules/ai_bot/personas/general.rb new file mode 100644 index 00000000..0448c631 --- /dev/null +++ b/lib/modules/ai_bot/personas/general.rb @@ -0,0 +1,29 @@ +#frozen_string_literal: true + +module DiscourseAi + module AiBot + module Personas + class General < Persona + def commands + all_available_commands + end + + def system_prompt + <<~PROMPT + You are a helpful Discourse assistant. + You understand and generate Discourse Markdown. + You live in a Discourse Forum Message. + + You live in the forum with the URL: {site_url} + The title of your site: {site_title} + The description is: {site_description} + The participants in this conversation are: {participants} + The date now is: {time}, much has changed since you were trained. + + {commands} + PROMPT + end + end + end + end +end diff --git a/lib/modules/ai_bot/personas/persona.rb b/lib/modules/ai_bot/personas/persona.rb new file mode 100644 index 00000000..693159d3 --- /dev/null +++ b/lib/modules/ai_bot/personas/persona.rb @@ -0,0 +1,110 @@ +#frozen_string_literal: true + +module DiscourseAi + module AiBot + module Personas + def self.all + personas = [Personas::General, Personas::SqlHelper] + personas << Personas::Artist if SiteSetting.ai_stability_api_key.present? + personas + end + + class Persona + def self.name + I18n.t("discourse_ai.ai_bot.personas.#{to_s.demodulize.underscore}.name") + end + + def self.description + I18n.t("discourse_ai.ai_bot.personas.#{to_s.demodulize.underscore}.description") + end + + def commands + [] + end + + def render_commands(render_function_instructions:) + result = +"" + if render_function_instructions + result << "\n" + result << function_list.system_prompt + result << "\n" + end + result << available_commands.map(&:custom_system_message).compact.join("\n") + result + end + + def render_system_prompt(topic: nil, render_function_instructions: true) + substitutions = { + site_url: Discourse.base_url, + site_title: SiteSetting.title, + site_description: SiteSetting.site_description, + time: Time.zone.now, + commands: render_commands(render_function_instructions: render_function_instructions), + } + + substitutions[:participants] = topic.allowed_users.map(&:username).join(", ") if topic + + system_prompt.gsub(/\{(\w+)\}/) do |match| + found = substitutions[match[1..-2].to_sym] + found.nil? ? match : found.to_s + end + end + + def available_commands + return @available_commands if @available_commands + + @available_commands = all_available_commands.filter { |cmd| commands.include?(cmd) } + end + + def available_functions + # note if defined? can be a problem in test + # this can never be nil so it is safe + return @available_functions if @available_functions + + functions = [] + + functions = + available_commands.map do |command| + function = + DiscourseAi::Inference::Function.new(name: command.name, description: command.desc) + command.parameters.each { |parameter| function.add_parameter(parameter) } + function + end + + @available_functions = functions + end + + def function_list + return @function_list if @function_list + + @function_list = DiscourseAi::Inference::FunctionList.new + available_functions.each { |function| @function_list << function } + @function_list + end + + def all_available_commands + return @cmds if @cmds + + all_commands = [ + Commands::CategoriesCommand, + Commands::TimeCommand, + Commands::SearchCommand, + Commands::SummarizeCommand, + Commands::ReadCommand, + Commands::SettingContextCommand, + ] + + all_commands << Commands::TagsCommand if SiteSetting.tagging_enabled + all_commands << Commands::ImageCommand if SiteSetting.ai_stability_api_key.present? + if SiteSetting.ai_google_custom_search_api_key.present? && + SiteSetting.ai_google_custom_search_cx.present? + all_commands << Commands::GoogleCommand + end + + allowed_commands = SiteSetting.ai_bot_enabled_chat_commands.split("|") + @cmds = all_commands.filter { |klass| allowed_commands.include?(klass.name) } + end + end + end + end +end diff --git a/lib/modules/ai_bot/personas/sql_helper.rb b/lib/modules/ai_bot/personas/sql_helper.rb new file mode 100644 index 00000000..321a4e2d --- /dev/null +++ b/lib/modules/ai_bot/personas/sql_helper.rb @@ -0,0 +1,64 @@ +#frozen_string_literal: true + +module DiscourseAi + module AiBot + module Personas + class SqlHelper < Persona + def self.schema + return @schema if defined?(@schema) + + tables = Hash.new + priority_tables = %w[posts topics notifications users user_actions] + + DB.query(<<~SQL).each { |row| (tables[row.table_name] ||= []) << row.column_name } + select table_name, column_name from information_schema.columns + where table_schema = 'public' + order by table_name + SQL + + schema = +(priority_tables.map { |name| "#{name}(#{tables[name].join(",")})" }.join("\n")) + + schema << "\nOther tables (schema redacted, available on request): " + tables.each do |table_name, _| + next if priority_tables.include?(table_name) + schema << "#{table_name} " + end + + @schema = schema + end + + def commands + all_available_commands + end + + def all_available_commands + [DiscourseAi::AiBot::Commands::DbSchemaCommand] + end + + def system_prompt + <<~PROMPT + You are a PostgreSQL expert. + You understand and generate Discourse Markdown but specialize in creating queries. + You live in a Discourse Forum Message. + The schema in your training set MAY be out of date. + + The user_actions tables stores likes (action_type 1). + the topics table stores private/personal messages it uses archetype private_message for them. + notification_level can be: {muted: 0, regular: 1, tracking: 2, watching: 3, watching_first_post: 4}. + bookmarkable_type can be: Post,Topic,ChatMessage and more + + Current time is: {time} + + + The current schema for the current DB is: + {{ + #{self.class.schema} + }} + + {commands} + PROMPT + end + end + end + end +end diff --git a/lib/shared/inference/function.rb b/lib/shared/inference/function.rb index 616229a3..d84acd42 100644 --- a/lib/shared/inference/function.rb +++ b/lib/shared/inference/function.rb @@ -12,7 +12,21 @@ module ::DiscourseAi @parameters = [] end - def add_parameter(name:, type:, description:, enum: nil, required: false) + def add_parameter(parameter = nil, **kwargs) + if parameter + add_parameter_kwargs( + name: parameter.name, + type: parameter.type, + description: parameter.description, + required: parameter.required, + enum: parameter.enum, + ) + else + add_parameter_kwargs(**kwargs) + end + end + + def add_parameter_kwargs(name:, type:, description:, enum: nil, required: false) @parameters << { name: name, type: type, diff --git a/lib/shared/inference/function_list.rb b/lib/shared/inference/function_list.rb index fab9b0e6..598e47de 100644 --- a/lib/shared/inference/function_list.rb +++ b/lib/shared/inference/function_list.rb @@ -28,7 +28,22 @@ module ::DiscourseAi next if function.blank? arguments = arguments[0..-2] if arguments.end_with?(")") - arguments = arguments.split(",").map(&:strip) + + temp_string = +"" + in_string = nil + replace = SecureRandom.hex(10) + arguments.each_char do |char| + if %w[" '].include?(char) && !in_string + in_string = char + elsif char == in_string + in_string = nil + elsif char == "," && in_string + char = replace + end + temp_string << char + end + + arguments = temp_string.split(",").map { |s| s.gsub(replace, ",").strip } parsed_arguments = {} arguments.each do |argument| @@ -76,8 +91,8 @@ module ::DiscourseAi PROMPT @functions.each do |function| - prompt << " // #{function.description}\n" - prompt << " #{function.name}" + prompt << "// #{function.description}\n" + prompt << "!#{function.name}" if function.parameters.present? prompt << "(" function.parameters.each_with_index do |parameter, index| @@ -96,8 +111,9 @@ module ::DiscourseAi prompt << " /* #{description} */" if description.present? end - prompt << ")\n" + prompt << ")" end + prompt << "\n" end prompt << <<~PROMPT @@ -109,8 +125,8 @@ module ::DiscourseAi For example for a function defined as: { - // echo a string - echo(message: string [required]) + // echo a string + !echo(message: string [required]) } Human: please echo out "hello" diff --git a/spec/lib/modules/ai_bot/commands/db_schema_command_spec.rb b/spec/lib/modules/ai_bot/commands/db_schema_command_spec.rb new file mode 100644 index 00000000..15da0041 --- /dev/null +++ b/spec/lib/modules/ai_bot/commands/db_schema_command_spec.rb @@ -0,0 +1,16 @@ +#frozen_string_literal: true + +RSpec.describe DiscourseAi::AiBot::Commands::DbSchemaCommand do + let(:command) { DiscourseAi::AiBot::Commands::DbSchemaCommand.new(bot_user: nil, args: nil) } + describe "#process" do + it "returns rich schema for tables" do + result = command.process(tables: "posts,topics") + expect(result[:schema_info]).to include("raw text") + expect(result[:schema_info]).to include("views integer") + expect(result[:schema_info]).to include("posts") + expect(result[:schema_info]).to include("topics") + + expect(result[:tables]).to eq(%w[posts topics]) + end + end +end diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb new file mode 100644 index 00000000..0bd8fa8e --- /dev/null +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -0,0 +1,61 @@ +#frozen_string_literal: true + +class TestPersona < DiscourseAi::AiBot::Personas::Persona + def commands + [ + DiscourseAi::AiBot::Commands::TagsCommand, + DiscourseAi::AiBot::Commands::SearchCommand, + DiscourseAi::AiBot::Commands::ImageCommand, + ] + end + + def system_prompt + <<~PROMPT + {site_url} + {site_title} + {site_description} + {participants} + {time} + + {commands} + PROMPT + end +end + +RSpec.describe DiscourseAi::AiBot::Personas::Persona do + let :persona do + TestPersona.new + end + + let :topic_with_users do + topic = Topic.new + topic.allowed_users = [User.new(username: "joe"), User.new(username: "jane")] + topic + end + + it "renders the system prompt" do + freeze_time + + SiteSetting.title = "test site title" + SiteSetting.site_description = "test site description" + + rendered = + persona.render_system_prompt(topic: topic_with_users, render_function_instructions: true) + + expect(rendered).to include(Discourse.base_url) + expect(rendered).to include("test site title") + expect(rendered).to include("test site description") + expect(rendered).to include("joe, jane") + expect(rendered).to include(Time.zone.now.to_s) + expect(rendered).to include("!search") + expect(rendered).to include("!tags") + # needs to be configured so it is not available + expect(rendered).not_to include("!image") + + rendered = + persona.render_system_prompt(topic: topic_with_users, render_function_instructions: false) + + expect(rendered).not_to include("!search") + expect(rendered).not_to include("!tags") + end +end diff --git a/spec/lib/modules/ai_bot/personas/sql_helper_spec.rb b/spec/lib/modules/ai_bot/personas/sql_helper_spec.rb new file mode 100644 index 00000000..33173a09 --- /dev/null +++ b/spec/lib/modules/ai_bot/personas/sql_helper_spec.rb @@ -0,0 +1,17 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::AiBot::Personas::SqlHelper do + let :sql_helper do + subject + end + + it "renders schema" do + prompt = sql_helper.render_system_prompt + expect(prompt).to include("posts(") + expect(prompt).to include("topics(") + expect(prompt).not_to include("translation_key") # not a priority table + expect(prompt).to include("user_api_keys") # not a priority table + + expect(sql_helper.available_commands).to eq([DiscourseAi::AiBot::Commands::DbSchemaCommand]) + end +end diff --git a/spec/shared/inference/function_list_spec.rb b/spec/shared/inference/function_list_spec.rb index 81d0d89a..db95823b 100644 --- a/spec/shared/inference/function_list_spec.rb +++ b/spec/shared/inference/function_list_spec.rb @@ -29,16 +29,16 @@ module DiscourseAi::Inference it "can handle complex parsing" do raw_prompt = <<~PROMPT - !get_weather(location: "sydney", unit: "f") + !get_weather(location: "sydney,melbourne", unit: "f") !get_weather (location: sydney) - !get_weather(location : 'sydney's', unit: "m", invalid: "invalid") + !get_weather(location : "sydney's", unit: "m", invalid: "invalid") !get_weather(unit: "f", invalid: "invalid") PROMPT parsed = function_list.parse_prompt(raw_prompt) expect(parsed).to eq( [ - { name: "get_weather", arguments: { location: "sydney", unit: "f" } }, + { name: "get_weather", arguments: { location: "sydney,melbourne", unit: "f" } }, { name: "get_weather", arguments: { location: "sydney" } }, { name: "get_weather", arguments: { location: "sydney's" } }, ], @@ -52,8 +52,8 @@ module DiscourseAi::Inference # expected = <<~PROMPT { - // Get the weather in a city (default to c) - get_weather(location: string [required] /* the city name */, unit: string [optional] /* the unit of measurement celcius c or fahrenheit f [valid values: c,f] */) + // Get the weather in a city (default to c) + !get_weather(location: string [required] /* the city name */, unit: string [optional] /* the unit of measurement celcius c or fahrenheit f [valid values: c,f] */) } PROMPT expect(prompt).to include(expected)