diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 28d06c28..260331f3 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -11,28 +11,40 @@ module DiscourseAi REQUIRE_TITLE_UPDATE = "discourse-ai-title-update" - def self.schedule_chat_reply(message, channel, user, context) + def self.find_chat_persona(message, channel, user) if channel.direct_message_channel? - allowed_user_ids = channel.allowed_user_ids - - return if AiPersona.allowed_chat.any? { |m| m[:user_id] == user.id } - - persona = - AiPersona.allowed_chat.find do |p| - p[:user_id].in?(allowed_user_ids) && (user.group_ids & p[:allowed_group_ids]) + AiPersona.allowed_chat.find do |p| + p[:user_id].in?(channel.allowed_user_ids) && (user.group_ids & p[:allowed_group_ids]) + end + else + # let's defer on the parse if there is no @ in the message + if message.message.include?("@") + mentions = message.parsed_mentions.parsed_direct_mentions + if mentions.present? + AiPersona.allowed_chat.find do |p| + p[:username].in?(mentions) && (user.group_ids & p[:allowed_group_ids]) + end end - - if persona - ::Jobs.enqueue( - :create_ai_chat_reply, - channel_id: channel.id, - message_id: message.id, - persona_id: persona[:id], - ) end end end + def self.schedule_chat_reply(message, channel, user, context) + return if !SiteSetting.ai_bot_enabled + return if AiPersona.allowed_chat.blank? + return if AiPersona.allowed_chat.any? { |m| m[:user_id] == user.id } + + persona = find_chat_persona(message, channel, user) + return if !persona + + ::Jobs.enqueue( + :create_ai_chat_reply, + channel_id: channel.id, + message_id: message.id, + persona_id: persona[:id], + ) + end + def self.is_bot_user_id?(user_id) # this will catch everything and avoid any feedback loops # we could get feedback loops between say discobot and ai-bot or third party plugins @@ -209,20 +221,26 @@ module DiscourseAi def chat_context(message, channel, persona_user) has_vision = bot.persona.class.vision_enabled - if !message.thread_id - hash = { type: :user, content: message.message } - hash[:upload_ids] = message.uploads.map(&:id) if has_vision && message.uploads.present? - return [hash] - end + messages = nil max_messages = 40 if bot.persona.class.respond_to?(:max_context_posts) max_messages = bot.persona.class.max_context_posts || 40 end - # I would like to use a guardian however membership for - # persona_user is far in future - thread_messages = + if !message.thread_id && channel.direct_message_channel? + messages = [message] + elsif !channel.direct_message_channel? && !message.thread_id + messages = + Chat::Message + .where(chat_channel_id: channel.id, thread_id: nil) + .order(id: :desc) + .limit(max_messages) + .to_a + .reverse + end + + messages ||= ChatSDK::Thread.last_messages( thread_id: message.thread_id, guardian: Discourse.system_user.guardian, @@ -231,7 +249,7 @@ module DiscourseAi builder = DiscourseAi::Completions::PromptMessagesBuilder.new - thread_messages.each do |m| + messages.each do |m| if available_bot_user_ids.include?(m.user_id) builder.push(type: :model, content: m.message) else @@ -277,7 +295,8 @@ module DiscourseAi channel_id: channel.id, guardian: guardian, in_reply_to_id: message.id, - force_thread: message.thread_id.nil?, + force_thread: message.thread_id.nil? && channel.direct_message_channel?, + enforce_membership: !channel.direct_message_channel?, ) ChatSDK::Message.start_stream(message_id: reply.id, guardian: guardian) else diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index ce88c397..677e8124 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -131,7 +131,54 @@ RSpec.describe DiscourseAi::AiBot::Playground do persona end - context "with chat" do + context "with chat channels" do + fab!(:channel) { Fabricate(:chat_channel) } + + fab!(:membership) do + Fabricate(:user_chat_channel_membership, user: user, chat_channel: channel) + end + + let(:guardian) { Guardian.new(user) } + + before do + SiteSetting.ai_bot_enabled = true + SiteSetting.chat_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}" + Group.refresh_automatic_groups! + persona.update!(allow_chat: true, mentionable: true, default_llm: "anthropic:claude-3-opus") + end + + it "should reply to a mention if properly enabled" do + prompts = nil + + ChatSDK::Message.create( + channel_id: channel.id, + raw: "This is a story about stuff", + guardian: guardian, + ) + + DiscourseAi::Completions::Llm.with_prepared_responses(["world"]) do |_, _, _prompts| + ChatSDK::Message.create( + channel_id: channel.id, + raw: "Hello @#{persona.user.username}", + guardian: guardian, + ) + + prompts = _prompts + end + + expect(prompts.length).to eq(1) + prompt = prompts[0] + + expect(prompt.messages.length).to eq(2) + expect(prompt.messages[1][:content]).to include("story about stuff") + expect(prompt.messages[1][:content]).to include("Hello") + + last_message = Chat::Message.where(chat_channel_id: channel.id).order("id desc").first + expect(last_message.message).to eq("world") + end + end + + context "with chat dms" do fab!(:dm_channel) { Fabricate(:direct_message_channel, users: [user, persona.user]) } before do @@ -142,6 +189,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do mentionable: false, default_llm: "anthropic:claude-3-opus", ) + SiteSetting.ai_bot_enabled = true end let(:guardian) { Guardian.new(user) }