From cf34838a0967cb0aaad33518330434d8e626d3b1 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 8 May 2024 18:44:04 +1000 Subject: [PATCH] FIX: context repairs for @mentioned bot (#608) When the bot is @mentioned, we need to be a lot more careful about constructing context otherwise bot gets ultra confused. This changes multiple things: 1. We were omitting all thread first messages (fixed) 2. Include thread title (if available) in context 3. Construct context in a clearer way separating user request from data --- lib/ai_bot/playground.rb | 33 ++++++++-- lib/completions/prompt_messages_builder.rb | 50 ++++++++++++++- spec/lib/modules/ai_bot/playground_spec.rb | 71 ++++++++++++++++++++++ 3 files changed, 148 insertions(+), 6 deletions(-) diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 834a231a..d6418bd5 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -220,6 +220,13 @@ module DiscourseAi def chat_context(message, channel, persona_user) has_vision = bot.persona.class.vision_enabled + include_thread_titles = !channel.direct_message_channel? && !message.thread_id + + current_id = message.id + if !channel.direct_message_channel? + # we are interacting via mentions ... strip mention + instruction_message = message.message.gsub(/@#{bot.bot_user.username}/i, "").strip + end messages = nil @@ -233,7 +240,11 @@ module DiscourseAi elsif !channel.direct_message_channel? && !message.thread_id messages = Chat::Message - .where(chat_channel_id: channel.id, thread_id: nil) + .joins("left join chat_threads on chat_threads.id = chat_messages.thread_id") + .where(chat_channel_id: channel.id) + .where( + "chat_messages.thread_id IS NULL OR chat_threads.original_message_id = chat_messages.id", + ) .order(id: :desc) .limit(max_messages) .to_a @@ -250,21 +261,30 @@ module DiscourseAi builder = DiscourseAi::Completions::PromptMessagesBuilder.new messages.each do |m| + # restore stripped message + m.message = instruction_message if m.id == current_id && instruction_message + if available_bot_user_ids.include?(m.user_id) builder.push(type: :model, content: m.message) else upload_ids = nil upload_ids = m.uploads.map(&:id) if has_vision && m.uploads.present? + mapped_message = m.message + + thread_title = nil + thread_title = m.thread&.title if include_thread_titles && m.thread_id + mapped_message = "(#{thread_title})\n#{m.message}" if thread_title + builder.push( type: :user, - content: m.message, + content: mapped_message, name: m.user.username, upload_ids: upload_ids, ) end end - builder.to_a(limit: max_messages) + builder.to_a(limit: max_messages, style: channel.direct_message_channel? ? :default : :chat) end def reply_to_chat_message(message, channel) @@ -283,6 +303,9 @@ module DiscourseAi reply = nil guardian = Guardian.new(persona_user) + force_thread = message.thread_id.nil? && channel.direct_message_channel? + in_reply_to_id = channel.direct_message_channel? ? message.id : nil + new_prompts = bot.reply(context) do |partial, cancel, placeholder| if !reply @@ -294,8 +317,8 @@ module DiscourseAi thread_id: message.thread_id, channel_id: channel.id, guardian: guardian, - in_reply_to_id: message.id, - force_thread: message.thread_id.nil? && channel.direct_message_channel?, + in_reply_to_id: in_reply_to_id, + force_thread: force_thread, enforce_membership: !channel.direct_message_channel?, ) ChatSDK::Message.start_stream(message_id: reply.id, guardian: guardian) diff --git a/lib/completions/prompt_messages_builder.rb b/lib/completions/prompt_messages_builder.rb index 851ebcae..faed03d1 100644 --- a/lib/completions/prompt_messages_builder.rb +++ b/lib/completions/prompt_messages_builder.rb @@ -3,11 +3,14 @@ module DiscourseAi module Completions class PromptMessagesBuilder + MAX_CHAT_UPLOADS = 5 + def initialize @raw_messages = [] end - def to_a(limit: nil) + def to_a(limit: nil, style: nil) + return chat_array(limit: limit) if style == :chat result = [] # this will create a "valid" messages array @@ -68,6 +71,51 @@ module DiscourseAi @raw_messages << message end + + private + + def chat_array(limit:) + buffer = +"" + + if @raw_messages.length > 1 + buffer << (<<~TEXT).strip + You are replying inside a Discourse chat. Here is a summary of the conversation so far: + {{{ + TEXT + + upload_ids = [] + + @raw_messages[0..-2].each do |message| + buffer << "\n" + + upload_ids.concat(message[:upload_ids]) if message[:upload_ids].present? + + if message[:type] == :user + buffer << "#{message[:name] || "User"}: " + else + buffer << "Bot: " + end + + buffer << message[:content] + end + + buffer << "\n}}}" + buffer << "\n\n" + buffer << "Your instructions:" + buffer << "\n" + end + + last_message = @raw_messages[-1] + buffer << "#{last_message[:name] || "User"} said #{last_message[:content]} " + + message = { type: :user, content: buffer } + upload_ids.concat(last_message[:upload_ids]) if last_message[:upload_ids].present? + + message[:upload_ids] = upload_ids[-MAX_CHAT_UPLOADS..-1] || + upload_ids if upload_ids.present? + + [message] + end end end end diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index 677e8124..ea5d1a32 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -147,6 +147,77 @@ RSpec.describe DiscourseAi::AiBot::Playground do persona.update!(allow_chat: true, mentionable: true, default_llm: "anthropic:claude-3-opus") end + it "should behave in a sane way when threading is enabled" do + channel.update!(threading_enabled: true) + + message = + ChatSDK::Message.create( + channel_id: channel.id, + raw: "thread 1 message 1", + guardian: guardian, + ) + + message = + ChatSDK::Message.create( + channel_id: channel.id, + raw: "thread 1 message 2", + in_reply_to_id: message.id, + guardian: guardian, + ) + + thread = message.thread + thread.update!(title: "a magic thread") + + message = + ChatSDK::Message.create( + channel_id: channel.id, + raw: "thread 2 message 1", + guardian: guardian, + ) + + message = + ChatSDK::Message.create( + channel_id: channel.id, + raw: "thread 2 message 2", + in_reply_to_id: message.id, + guardian: guardian, + ) + + prompts = nil + DiscourseAi::Completions::Llm.with_prepared_responses(["world"]) do |_, _, _prompts| + message = + ChatSDK::Message.create( + channel_id: channel.id, + raw: "Hello @#{persona.user.username}", + guardian: guardian, + ) + + prompts = _prompts + end + + # don't start a thread cause it will get confusing + message.reload + expect(message.thread_id).to be_nil + + prompt = prompts[0] + + content = prompt.messages[1][:content] + # this is fragile by design, mainly so the example can be ultra clear + expected = (<<~TEXT).strip + You are replying inside a Discourse chat. Here is a summary of the conversation so far: + {{{ + #{user.username}: (a magic thread) + thread 1 message 1 + #{user.username}: thread 2 message 1 + }}} + + Your instructions: + #{user.username} said Hello + TEXT + + expect(content.strip).to eq(expected) + end + it "should reply to a mention if properly enabled" do prompts = nil