diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index 481d0ddc..a28411fe 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -122,16 +122,22 @@ module DiscourseAi tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name } return false if tool_klass.nil? - arguments = - tool_klass.signature[:parameters] - .to_a - .reduce({}) do |memo, p| - argument = parsed_function.at(p[:name])&.text - next(memo) unless argument + arguments = {} + tool_klass.signature[:parameters].to_a.each do |param| + name = param[:name] + value = parsed_function.at(name)&.text - memo[p[:name].to_sym] = argument - memo - end + if param[:type] == "array" && value + value = + begin + JSON.parse(value) + rescue JSON::ParserError + nil + end + end + + arguments[name.to_sym] = value if value + end tool_klass.new( arguments, diff --git a/lib/ai_bot/site_settings_extension.rb b/lib/ai_bot/site_settings_extension.rb index 77f17641..d98873d4 100644 --- a/lib/ai_bot/site_settings_extension.rb +++ b/lib/ai_bot/site_settings_extension.rb @@ -24,14 +24,14 @@ module DiscourseAi::AiBot::SiteSettingsExtension ) user.save!(validate: false) else - user.update!(active: true) + user.update_columns(active: true) end elsif !active && user # will include deleted has_posts = DB.query_single("SELECT 1 FROM posts WHERE user_id = #{id} LIMIT 1").present? if has_posts - user.update!(active: false) if user.active + user.update_columns(active: false) if user.active else user.destroy end diff --git a/lib/ai_bot/tools/dall_e.rb b/lib/ai_bot/tools/dall_e.rb index 1e9532c3..85621056 100644 --- a/lib/ai_bot/tools/dall_e.rb +++ b/lib/ai_bot/tools/dall_e.rb @@ -99,17 +99,17 @@ module DiscourseAi end self.custom_raw = <<~RAW - - [grid] - #{ + + [grid] + #{ uploads .map do |item| "![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})" end .join(" ") } - [/grid] - RAW + [/grid] + RAW { prompts: uploads.map { |item| item[:prompt] } } end diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb index c09728f6..97db50cf 100644 --- a/spec/lib/modules/ai_bot/personas/persona_spec.rb +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -20,131 +20,157 @@ class TestPersona < DiscourseAi::AiBot::Personas::Persona end end -module DiscourseAi::AiBot::Personas - RSpec.describe Persona do - let :persona do - TestPersona.new - end +RSpec.describe DiscourseAi::AiBot::Personas::Persona do + let :persona do + TestPersona.new + end - let :topic_with_users do - topic = Topic.new - topic.allowed_users = [User.new(username: "joe"), User.new(username: "jane")] - topic - end + let :topic_with_users do + topic = Topic.new + topic.allowed_users = [User.new(username: "joe"), User.new(username: "jane")] + topic + end - after do - # we are rolling back transactions so we can create poison cache - AiPersona.persona_cache.flush! - end + after do + # we are rolling back transactions so we can create poison cache + AiPersona.persona_cache.flush! + end - let(:context) do - { - site_url: Discourse.base_url, - site_title: "test site title", - site_description: "test site description", - time: Time.zone.now, - participants: topic_with_users.allowed_users.map(&:username).join(", "), - } - end + let(:context) do + { + site_url: Discourse.base_url, + site_title: "test site title", + site_description: "test site description", + time: Time.zone.now, + participants: topic_with_users.allowed_users.map(&:username).join(", "), + } + end - fab!(:user) + fab!(:user) - it "renders the system prompt" do - freeze_time + it "renders the system prompt" do + freeze_time - rendered = persona.craft_prompt(context) + rendered = persona.craft_prompt(context) - expect(rendered[:insts]).to include(Discourse.base_url) - expect(rendered[:insts]).to include("test site title") - expect(rendered[:insts]).to include("test site description") - expect(rendered[:insts]).to include("joe, jane") - expect(rendered[:insts]).to include(Time.zone.now.to_s) + expect(rendered[:insts]).to include(Discourse.base_url) + expect(rendered[:insts]).to include("test site title") + expect(rendered[:insts]).to include("test site description") + expect(rendered[:insts]).to include("joe, jane") + expect(rendered[:insts]).to include(Time.zone.now.to_s) - tools = rendered[:tools] + tools = rendered[:tools] - expect(tools.find { |t| t[:name] == "search" }).to be_present - expect(tools.find { |t| t[:name] == "tags" }).to be_present + expect(tools.find { |t| t[:name] == "search" }).to be_present + expect(tools.find { |t| t[:name] == "tags" }).to be_present - # needs to be configured so it is not available - expect(tools.find { |t| t[:name] == "image" }).to be_nil - end + # needs to be configured so it is not available + expect(tools.find { |t| t[:name] == "image" }).to be_nil + end - describe "custom personas" do - it "is able to find custom personas" do - Group.refresh_automatic_groups! + it "can correctly parse arrays in tools" do + SiteSetting.ai_openai_api_key = "123" - # define an ai persona everyone can see - persona = - AiPersona.create!( - name: "zzzpun_bot", - description: "you write puns", - system_prompt: "you are pun bot", - commands: ["ImageCommand"], - allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], - ) + # Dall E tool uses an array for params + xml = <<~XML + + + dall_e + call_JtYQMful5QKqw97XFsHzPweB + + ["cat oil painting", "big car"] + + + + XML + dall_e = DiscourseAi::AiBot::Personas::DallE3.new.find_tool(xml) + expect(dall_e.parameters[:prompts]).to eq(["cat oil painting", "big car"]) + end - custom_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last - expect(custom_persona.name).to eq("zzzpun_bot") - expect(custom_persona.description).to eq("you write puns") + describe "custom personas" do + it "is able to find custom personas" do + Group.refresh_automatic_groups! - instance = custom_persona.new - expect(instance.tools).to eq([DiscourseAi::AiBot::Tools::Image]) - expect(instance.craft_prompt(context).dig(:insts)).to eq("you are pun bot\n\n") - - # should update - persona.update!(name: "zzzpun_bot2") - custom_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last - expect(custom_persona.name).to eq("zzzpun_bot2") - - # can be disabled - persona.update!(enabled: false) - last_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last - expect(last_persona.name).not_to eq("zzzpun_bot2") - - persona.update!(enabled: true) - # no groups have access - persona.update!(allowed_group_ids: []) - - last_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last - expect(last_persona.name).not_to eq("zzzpun_bot2") - end - end - - describe "available personas" do - it "includes all personas by default" do - Group.refresh_automatic_groups! - - # must be enabled to see it - SiteSetting.ai_stability_api_key = "abc" - SiteSetting.ai_google_custom_search_api_key = "abc" - SiteSetting.ai_google_custom_search_cx = "abc123" - - # should be ordered by priority and then alpha - expect(DiscourseAi::AiBot::Personas::Persona.all(user: user)).to eq( - [General, Artist, Creative, Researcher, SettingsExplorer, SqlHelper], + # define an ai persona everyone can see + persona = + AiPersona.create!( + name: "zzzpun_bot", + description: "you write puns", + system_prompt: "you are pun bot", + commands: ["ImageCommand"], + allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], ) - # omits personas if key is missing - SiteSetting.ai_stability_api_key = "" - SiteSetting.ai_google_custom_search_api_key = "" + custom_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last + expect(custom_persona.name).to eq("zzzpun_bot") + expect(custom_persona.description).to eq("you write puns") - expect(DiscourseAi::AiBot::Personas::Persona.all(user: user)).to contain_exactly( - General, - SqlHelper, - SettingsExplorer, - Creative, - ) + instance = custom_persona.new + expect(instance.tools).to eq([DiscourseAi::AiBot::Tools::Image]) + expect(instance.craft_prompt(context).dig(:insts)).to eq("you are pun bot\n\n") - AiPersona.find(DiscourseAi::AiBot::Personas::Persona.system_personas[General]).update!( - enabled: false, - ) + # should update + persona.update!(name: "zzzpun_bot2") + custom_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last + expect(custom_persona.name).to eq("zzzpun_bot2") - expect(DiscourseAi::AiBot::Personas::Persona.all(user: user)).to contain_exactly( - SqlHelper, - SettingsExplorer, - Creative, - ) - end + # can be disabled + persona.update!(enabled: false) + last_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last + expect(last_persona.name).not_to eq("zzzpun_bot2") + + persona.update!(enabled: true) + # no groups have access + persona.update!(allowed_group_ids: []) + + last_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last + expect(last_persona.name).not_to eq("zzzpun_bot2") + end + end + + describe "available personas" do + it "includes all personas by default" do + Group.refresh_automatic_groups! + + # must be enabled to see it + SiteSetting.ai_stability_api_key = "abc" + SiteSetting.ai_google_custom_search_api_key = "abc" + SiteSetting.ai_google_custom_search_cx = "abc123" + + # should be ordered by priority and then alpha + expect(DiscourseAi::AiBot::Personas::Persona.all(user: user)).to eq( + [ + DiscourseAi::AiBot::Personas::General, + DiscourseAi::AiBot::Personas::Artist, + DiscourseAi::AiBot::Personas::Creative, + DiscourseAi::AiBot::Personas::Researcher, + DiscourseAi::AiBot::Personas::SettingsExplorer, + DiscourseAi::AiBot::Personas::SqlHelper, + ], + ) + + # omits personas if key is missing + SiteSetting.ai_stability_api_key = "" + SiteSetting.ai_google_custom_search_api_key = "" + + expect(DiscourseAi::AiBot::Personas::Persona.all(user: user)).to contain_exactly( + DiscourseAi::AiBot::Personas::General, + DiscourseAi::AiBot::Personas::SqlHelper, + DiscourseAi::AiBot::Personas::SettingsExplorer, + DiscourseAi::AiBot::Personas::Creative, + ) + + AiPersona.find( + DiscourseAi::AiBot::Personas::Persona.system_personas[ + DiscourseAi::AiBot::Personas::General + ], + ).update!(enabled: false) + + expect(DiscourseAi::AiBot::Personas::Persona.all(user: user)).to contain_exactly( + DiscourseAi::AiBot::Personas::SqlHelper, + DiscourseAi::AiBot::Personas::SettingsExplorer, + DiscourseAi::AiBot::Personas::Creative, + ) end end end