diff --git a/app/jobs/regular/update_ai_bot_pm_title.rb b/app/jobs/regular/update_ai_bot_pm_title.rb index 5f181e0d..5da82291 100644 --- a/app/jobs/regular/update_ai_bot_pm_title.rb +++ b/app/jobs/regular/update_ai_bot_pm_title.rb @@ -6,7 +6,7 @@ module ::Jobs def execute(args) return unless bot_user = User.find_by(id: args[:bot_user_id]) - return unless bot = DiscourseAi::AiBot::Bot.as(bot_user) + return unless bot = DiscourseAi::AiBot::Bot.as(bot_user, model: args[:model]) return unless post = Post.includes(:topic).find_by(id: args[:post_id]) return unless post.topic.custom_fields[DiscourseAi::AiBot::EntryPoint::REQUIRE_TITLE_UPDATE] diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb index 587d2872..931e2183 100644 --- a/app/models/ai_persona.rb +++ b/app/models/ai_persona.rb @@ -60,21 +60,28 @@ class AiPersona < ActiveRecord::Base .map(&:class_instance) end - def self.mentionables - persona_cache[:mentionable_usernames] ||= AiPersona - .where(mentionable: true) - .where(enabled: true) - .joins(:user) - .pluck("ai_personas.id, users.id, users.username_lower, allowed_group_ids, default_llm") - .map do |id, user_id, username, allowed_group_ids, default_llm| - { - id: id, - user_id: user_id, - username: username, - allowed_group_ids: allowed_group_ids, - default_llm: default_llm, - } - end + def self.mentionables(user: nil) + all_mentionables = + persona_cache[:mentionable_usernames] ||= AiPersona + .where(mentionable: true) + .where(enabled: true) + .joins(:user) + .pluck("ai_personas.id, users.id, users.username_lower, allowed_group_ids, default_llm") + .map do |id, user_id, username, allowed_group_ids, default_llm| + { + id: id, + user_id: user_id, + username: username, + allowed_group_ids: allowed_group_ids, + default_llm: default_llm, + } + end + + if user + all_mentionables.select { |mentionable| user.in_any_groups?(mentionable[:allowed_group_ids]) } + else + all_mentionables + end end after_commit :bump_cache diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index c7a38496..909a3811 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -3,16 +3,19 @@ module DiscourseAi module AiBot class Bot + attr_reader :model + BOT_NOT_FOUND = Class.new(StandardError) MAX_COMPLETIONS = 5 - def self.as(bot_user, persona: DiscourseAi::AiBot::Personas::General.new) - new(bot_user, persona) + def self.as(bot_user, persona: DiscourseAi::AiBot::Personas::General.new, model: nil) + new(bot_user, persona, model) end - def initialize(bot_user, persona) + def initialize(bot_user, persona, model = nil) @bot_user = bot_user @persona = persona + @model = model || self.class.guess_model(bot_user) || @persona.class.default_llm end attr_reader :bot_user @@ -46,7 +49,6 @@ module DiscourseAi total_completions = 0 ongoing_chain = true - low_cost = false raw_context = [] user = context[:user] @@ -56,7 +58,7 @@ module DiscourseAi llm_kwargs[:top_p] = persona.top_p if persona.top_p while total_completions <= MAX_COMPLETIONS && ongoing_chain - current_model = model(prefer_low_cost: low_cost) + current_model = model llm = DiscourseAi::Completions::Llm.proxy(current_model) tool_found = false @@ -65,7 +67,6 @@ module DiscourseAi if (tool = persona.find_tool(partial)) tool_found = true ongoing_chain = tool.chain_next_response? - low_cost = tool.low_cost? tool_call_id = tool.tool_call_id invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json @@ -129,43 +130,36 @@ module DiscourseAi result end - def model(prefer_low_cost: false) + def self.guess_model(bot_user) # HACK(roman): We'll do this until we define how we represent different providers in the bot settings - default_model = - case bot_user.id - when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID - if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") - "aws_bedrock:claude-2" - else - "anthropic:claude-2" - end - when DiscourseAi::AiBot::EntryPoint::GPT4_ID - "open_ai:gpt-4" - when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID - "open_ai:gpt-4-turbo" - when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID - "open_ai:gpt-3.5-turbo-16k" - when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID - if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?( - "mistralai/Mixtral-8x7B-Instruct-v0.1", - ) - "vllm:mistralai/Mixtral-8x7B-Instruct-v0.1" - else - "hugging_face:mistralai/Mixtral-8x7B-Instruct-v0.1" - end - when DiscourseAi::AiBot::EntryPoint::GEMINI_ID - "google:gemini-pro" - when DiscourseAi::AiBot::EntryPoint::FAKE_ID - "fake:fake" + case bot_user.id + when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") + "aws_bedrock:claude-2" else - nil + "anthropic:claude-2" end - - if %w[open_ai:gpt-4 open_ai:gpt-4-turbo].include?(default_model) && prefer_low_cost - return "open_ai:gpt-3.5-turbo-16k" + when DiscourseAi::AiBot::EntryPoint::GPT4_ID + "open_ai:gpt-4" + when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID + "open_ai:gpt-4-turbo" + when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID + "open_ai:gpt-3.5-turbo-16k" + when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID + if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?( + "mistralai/Mixtral-8x7B-Instruct-v0.1", + ) + "vllm:mistralai/Mixtral-8x7B-Instruct-v0.1" + else + "hugging_face:mistralai/Mixtral-8x7B-Instruct-v0.1" + end + when DiscourseAi::AiBot::EntryPoint::GEMINI_ID + "google:gemini-pro" + when DiscourseAi::AiBot::EntryPoint::FAKE_ID + "fake:fake" + else + nil end - - default_model end def tool_invocation?(partial) diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 65ea6788..27386b8b 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -11,46 +11,46 @@ module DiscourseAi REQUIRE_TITLE_UPDATE = "discourse-ai-title-update" - def self.schedule_reply(post) + def self.is_bot_user_id?(user_id) bot_ids = DiscourseAi::AiBot::EntryPoint::BOT_USER_IDS + bot_ids.include?(user_id) || + begin + mentionable_ids = AiPersona.mentionables.map { |mentionable| mentionable[:user_id] } + mentionable_ids.include?(user_id) + end + end - return if bot_ids.include?(post.user_id) - if AiPersona.mentionables.any? { |mentionable| mentionable[:user_id] == post.user_id } - return - end + def self.schedule_reply(post) + return if is_bot_user_id?(post.user_id) + + bot_ids = DiscourseAi::AiBot::EntryPoint::BOT_USER_IDS + mentionables = AiPersona.mentionables(user: post.user) bot_user = nil mentioned = nil if post.topic.private_message? bot_user = post.topic.topic_allowed_users.where(user_id: bot_ids).first&.user + bot_user ||= + post + .topic + .topic_allowed_users + .where(user_id: mentionables.map { |m| m[:user_id] }) + .first + &.user end - if AiPersona.mentionables.length > 0 + if mentionables.present? mentions = post.mentions.map(&:downcase) - mentioned = - AiPersona.mentionables.find do |mentionable| - mentions.include?(mentionable[:username]) && - (post.user.group_ids & mentionable[:allowed_group_ids]).present? - end + mentioned = mentionables.find { |mentionable| mentions.include?(mentionable[:username]) } - # PM always takes precedence - if mentioned && !bot_user - model_without_provider = mentioned[:default_llm].split(":").last - user_id = - DiscourseAi::AiBot::EntryPoint.map_bot_model_to_user_id(model_without_provider) - - if !user_id - Rails.logger.warn( - "Model #{mentioned[:default_llm]} not found for persona #{mentioned[:username]}", - ) - if Rails.env.development? || Rails.env.test? - raise "Model #{mentioned[:default_llm]} not found for persona #{mentioned[:username]}" - end - else - bot_user = User.find_by(id: user_id) - end + # direct PM to mentionable + if !mentioned && bot_user + mentioned = mentionables.find { |mentionable| bot_user.id == mentionable[:user_id] } end + + # public topic so we need to use the persona user + bot_user ||= User.find_by(id: mentioned[:user_id]) if mentioned end if bot_user @@ -309,6 +309,7 @@ module DiscourseAi :update_ai_bot_pm_title, post_id: post.id, bot_user_id: bot.bot_user.id, + model: bot.model, ) end end diff --git a/lib/ai_bot/tools/summarize.rb b/lib/ai_bot/tools/summarize.rb index d8d8a472..113bd215 100644 --- a/lib/ai_bot/tools/summarize.rb +++ b/lib/ai_bot/tools/summarize.rb @@ -44,10 +44,6 @@ module DiscourseAi true end - def low_cost? - true - end - def custom_raw @last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found") end diff --git a/lib/ai_bot/tools/tool.rb b/lib/ai_bot/tools/tool.rb index 03a32417..e223c9a2 100644 --- a/lib/ai_bot/tools/tool.rb +++ b/lib/ai_bot/tools/tool.rb @@ -75,10 +75,6 @@ module DiscourseAi false end - def low_cost? - false - end - protected def accepted_options diff --git a/lib/ai_helper/assistant.rb b/lib/ai_helper/assistant.rb index 8ab46361..98807428 100644 --- a/lib/ai_helper/assistant.rb +++ b/lib/ai_helper/assistant.rb @@ -147,8 +147,7 @@ module DiscourseAi { type: "text", text: - "Describe this image in a single sentence" + - custom_locale_instructions(user), + "Describe this image in a single sentence#{custom_locale_instructions(user)}", }, { type: "image_url", image_url: image_url }, ], diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index 918a3973..5b77cf1c 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -69,7 +69,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do ) persona.create_user! - persona.update!(default_llm: "claude-2", mentionable: true) + persona.update!(default_llm: "anthropic:claude-2", mentionable: true) persona end @@ -92,6 +92,9 @@ RSpec.describe DiscourseAi::AiBot::Playground do end it "allows mentioning a persona" do + # we still should be able to mention with no bots + SiteSetting.ai_bot_enabled_chat_bots = "" + post = nil DiscourseAi::Completions::Llm.with_prepared_responses(["Yes I can"]) do post = @@ -107,6 +110,33 @@ RSpec.describe DiscourseAi::AiBot::Playground do expect(last_post.user_id).to eq(persona.user_id) end + it "allows PMing a persona even when no particular bots are enabled" do + SiteSetting.ai_bot_enabled = true + SiteSetting.ai_bot_enabled_chat_bots = "" + post = nil + + DiscourseAi::Completions::Llm.with_prepared_responses( + ["Magic title", "Yes I can"], + llm: "anthropic:claude-2", + ) do + post = + create_post( + title: "I just made a PM", + raw: "Hey there #{persona.user.username}, can you help me?", + target_usernames: "#{user.username},#{persona.user.username}", + archetype: Archetype.private_message, + user: admin, + ) + end + + last_post = post.topic.posts.order(:post_number).last + expect(last_post.raw).to eq("Yes I can") + expect(last_post.user_id).to eq(persona.user_id) + + last_post.topic.reload + expect(last_post.topic.allowed_users.pluck(:user_id)).to include(persona.user_id) + end + it "picks the correct llm for persona in PMs" do # If you start a PM with GPT 3.5 bot, replies should come from it, not from Claude SiteSetting.ai_bot_enabled = true