FIX: context repairs for @mentioned bot (#608)

When the bot is @mentioned, we need to be a lot more careful
about constructing context otherwise bot gets ultra confused.

This changes multiple things:

1. We were omitting all thread first messages (fixed)
2. Include thread title (if available) in context
3. Construct context in a clearer way separating user request from data
This commit is contained in:
Sam 2024-05-08 18:44:04 +10:00 committed by GitHub
parent ab4544d897
commit cf34838a09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 148 additions and 6 deletions

View File

@ -220,6 +220,13 @@ module DiscourseAi
def chat_context(message, channel, persona_user) def chat_context(message, channel, persona_user)
has_vision = bot.persona.class.vision_enabled has_vision = bot.persona.class.vision_enabled
include_thread_titles = !channel.direct_message_channel? && !message.thread_id
current_id = message.id
if !channel.direct_message_channel?
# we are interacting via mentions ... strip mention
instruction_message = message.message.gsub(/@#{bot.bot_user.username}/i, "").strip
end
messages = nil messages = nil
@ -233,7 +240,11 @@ module DiscourseAi
elsif !channel.direct_message_channel? && !message.thread_id elsif !channel.direct_message_channel? && !message.thread_id
messages = messages =
Chat::Message Chat::Message
.where(chat_channel_id: channel.id, thread_id: nil) .joins("left join chat_threads on chat_threads.id = chat_messages.thread_id")
.where(chat_channel_id: channel.id)
.where(
"chat_messages.thread_id IS NULL OR chat_threads.original_message_id = chat_messages.id",
)
.order(id: :desc) .order(id: :desc)
.limit(max_messages) .limit(max_messages)
.to_a .to_a
@ -250,21 +261,30 @@ module DiscourseAi
builder = DiscourseAi::Completions::PromptMessagesBuilder.new builder = DiscourseAi::Completions::PromptMessagesBuilder.new
messages.each do |m| messages.each do |m|
# restore stripped message
m.message = instruction_message if m.id == current_id && instruction_message
if available_bot_user_ids.include?(m.user_id) if available_bot_user_ids.include?(m.user_id)
builder.push(type: :model, content: m.message) builder.push(type: :model, content: m.message)
else else
upload_ids = nil upload_ids = nil
upload_ids = m.uploads.map(&:id) if has_vision && m.uploads.present? upload_ids = m.uploads.map(&:id) if has_vision && m.uploads.present?
mapped_message = m.message
thread_title = nil
thread_title = m.thread&.title if include_thread_titles && m.thread_id
mapped_message = "(#{thread_title})\n#{m.message}" if thread_title
builder.push( builder.push(
type: :user, type: :user,
content: m.message, content: mapped_message,
name: m.user.username, name: m.user.username,
upload_ids: upload_ids, upload_ids: upload_ids,
) )
end end
end end
builder.to_a(limit: max_messages) builder.to_a(limit: max_messages, style: channel.direct_message_channel? ? :default : :chat)
end end
def reply_to_chat_message(message, channel) def reply_to_chat_message(message, channel)
@ -283,6 +303,9 @@ module DiscourseAi
reply = nil reply = nil
guardian = Guardian.new(persona_user) guardian = Guardian.new(persona_user)
force_thread = message.thread_id.nil? && channel.direct_message_channel?
in_reply_to_id = channel.direct_message_channel? ? message.id : nil
new_prompts = new_prompts =
bot.reply(context) do |partial, cancel, placeholder| bot.reply(context) do |partial, cancel, placeholder|
if !reply if !reply
@ -294,8 +317,8 @@ module DiscourseAi
thread_id: message.thread_id, thread_id: message.thread_id,
channel_id: channel.id, channel_id: channel.id,
guardian: guardian, guardian: guardian,
in_reply_to_id: message.id, in_reply_to_id: in_reply_to_id,
force_thread: message.thread_id.nil? && channel.direct_message_channel?, force_thread: force_thread,
enforce_membership: !channel.direct_message_channel?, enforce_membership: !channel.direct_message_channel?,
) )
ChatSDK::Message.start_stream(message_id: reply.id, guardian: guardian) ChatSDK::Message.start_stream(message_id: reply.id, guardian: guardian)

View File

@ -3,11 +3,14 @@
module DiscourseAi module DiscourseAi
module Completions module Completions
class PromptMessagesBuilder class PromptMessagesBuilder
MAX_CHAT_UPLOADS = 5
def initialize def initialize
@raw_messages = [] @raw_messages = []
end end
def to_a(limit: nil) def to_a(limit: nil, style: nil)
return chat_array(limit: limit) if style == :chat
result = [] result = []
# this will create a "valid" messages array # this will create a "valid" messages array
@ -68,6 +71,51 @@ module DiscourseAi
@raw_messages << message @raw_messages << message
end end
private
def chat_array(limit:)
buffer = +""
if @raw_messages.length > 1
buffer << (<<~TEXT).strip
You are replying inside a Discourse chat. Here is a summary of the conversation so far:
{{{
TEXT
upload_ids = []
@raw_messages[0..-2].each do |message|
buffer << "\n"
upload_ids.concat(message[:upload_ids]) if message[:upload_ids].present?
if message[:type] == :user
buffer << "#{message[:name] || "User"}: "
else
buffer << "Bot: "
end
buffer << message[:content]
end
buffer << "\n}}}"
buffer << "\n\n"
buffer << "Your instructions:"
buffer << "\n"
end
last_message = @raw_messages[-1]
buffer << "#{last_message[:name] || "User"} said #{last_message[:content]} "
message = { type: :user, content: buffer }
upload_ids.concat(last_message[:upload_ids]) if last_message[:upload_ids].present?
message[:upload_ids] = upload_ids[-MAX_CHAT_UPLOADS..-1] ||
upload_ids if upload_ids.present?
[message]
end
end end
end end
end end

View File

@ -147,6 +147,77 @@ RSpec.describe DiscourseAi::AiBot::Playground do
persona.update!(allow_chat: true, mentionable: true, default_llm: "anthropic:claude-3-opus") persona.update!(allow_chat: true, mentionable: true, default_llm: "anthropic:claude-3-opus")
end end
it "should behave in a sane way when threading is enabled" do
channel.update!(threading_enabled: true)
message =
ChatSDK::Message.create(
channel_id: channel.id,
raw: "thread 1 message 1",
guardian: guardian,
)
message =
ChatSDK::Message.create(
channel_id: channel.id,
raw: "thread 1 message 2",
in_reply_to_id: message.id,
guardian: guardian,
)
thread = message.thread
thread.update!(title: "a magic thread")
message =
ChatSDK::Message.create(
channel_id: channel.id,
raw: "thread 2 message 1",
guardian: guardian,
)
message =
ChatSDK::Message.create(
channel_id: channel.id,
raw: "thread 2 message 2",
in_reply_to_id: message.id,
guardian: guardian,
)
prompts = nil
DiscourseAi::Completions::Llm.with_prepared_responses(["world"]) do |_, _, _prompts|
message =
ChatSDK::Message.create(
channel_id: channel.id,
raw: "Hello @#{persona.user.username}",
guardian: guardian,
)
prompts = _prompts
end
# don't start a thread cause it will get confusing
message.reload
expect(message.thread_id).to be_nil
prompt = prompts[0]
content = prompt.messages[1][:content]
# this is fragile by design, mainly so the example can be ultra clear
expected = (<<~TEXT).strip
You are replying inside a Discourse chat. Here is a summary of the conversation so far:
{{{
#{user.username}: (a magic thread)
thread 1 message 1
#{user.username}: thread 2 message 1
}}}
Your instructions:
#{user.username} said Hello
TEXT
expect(content.strip).to eq(expected)
end
it "should reply to a mention if properly enabled" do it "should reply to a mention if properly enabled" do
prompts = nil prompts = nil