FIX: context repairs for @mentioned bot (#608)
When the bot is @mentioned, we need to be a lot more careful about constructing context otherwise bot gets ultra confused. This changes multiple things: 1. We were omitting all thread first messages (fixed) 2. Include thread title (if available) in context 3. Construct context in a clearer way separating user request from data
This commit is contained in:
parent
ab4544d897
commit
cf34838a09
|
@ -220,6 +220,13 @@ module DiscourseAi
|
||||||
|
|
||||||
def chat_context(message, channel, persona_user)
|
def chat_context(message, channel, persona_user)
|
||||||
has_vision = bot.persona.class.vision_enabled
|
has_vision = bot.persona.class.vision_enabled
|
||||||
|
include_thread_titles = !channel.direct_message_channel? && !message.thread_id
|
||||||
|
|
||||||
|
current_id = message.id
|
||||||
|
if !channel.direct_message_channel?
|
||||||
|
# we are interacting via mentions ... strip mention
|
||||||
|
instruction_message = message.message.gsub(/@#{bot.bot_user.username}/i, "").strip
|
||||||
|
end
|
||||||
|
|
||||||
messages = nil
|
messages = nil
|
||||||
|
|
||||||
|
@ -233,7 +240,11 @@ module DiscourseAi
|
||||||
elsif !channel.direct_message_channel? && !message.thread_id
|
elsif !channel.direct_message_channel? && !message.thread_id
|
||||||
messages =
|
messages =
|
||||||
Chat::Message
|
Chat::Message
|
||||||
.where(chat_channel_id: channel.id, thread_id: nil)
|
.joins("left join chat_threads on chat_threads.id = chat_messages.thread_id")
|
||||||
|
.where(chat_channel_id: channel.id)
|
||||||
|
.where(
|
||||||
|
"chat_messages.thread_id IS NULL OR chat_threads.original_message_id = chat_messages.id",
|
||||||
|
)
|
||||||
.order(id: :desc)
|
.order(id: :desc)
|
||||||
.limit(max_messages)
|
.limit(max_messages)
|
||||||
.to_a
|
.to_a
|
||||||
|
@ -250,21 +261,30 @@ module DiscourseAi
|
||||||
builder = DiscourseAi::Completions::PromptMessagesBuilder.new
|
builder = DiscourseAi::Completions::PromptMessagesBuilder.new
|
||||||
|
|
||||||
messages.each do |m|
|
messages.each do |m|
|
||||||
|
# restore stripped message
|
||||||
|
m.message = instruction_message if m.id == current_id && instruction_message
|
||||||
|
|
||||||
if available_bot_user_ids.include?(m.user_id)
|
if available_bot_user_ids.include?(m.user_id)
|
||||||
builder.push(type: :model, content: m.message)
|
builder.push(type: :model, content: m.message)
|
||||||
else
|
else
|
||||||
upload_ids = nil
|
upload_ids = nil
|
||||||
upload_ids = m.uploads.map(&:id) if has_vision && m.uploads.present?
|
upload_ids = m.uploads.map(&:id) if has_vision && m.uploads.present?
|
||||||
|
mapped_message = m.message
|
||||||
|
|
||||||
|
thread_title = nil
|
||||||
|
thread_title = m.thread&.title if include_thread_titles && m.thread_id
|
||||||
|
mapped_message = "(#{thread_title})\n#{m.message}" if thread_title
|
||||||
|
|
||||||
builder.push(
|
builder.push(
|
||||||
type: :user,
|
type: :user,
|
||||||
content: m.message,
|
content: mapped_message,
|
||||||
name: m.user.username,
|
name: m.user.username,
|
||||||
upload_ids: upload_ids,
|
upload_ids: upload_ids,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
builder.to_a(limit: max_messages)
|
builder.to_a(limit: max_messages, style: channel.direct_message_channel? ? :default : :chat)
|
||||||
end
|
end
|
||||||
|
|
||||||
def reply_to_chat_message(message, channel)
|
def reply_to_chat_message(message, channel)
|
||||||
|
@ -283,6 +303,9 @@ module DiscourseAi
|
||||||
reply = nil
|
reply = nil
|
||||||
guardian = Guardian.new(persona_user)
|
guardian = Guardian.new(persona_user)
|
||||||
|
|
||||||
|
force_thread = message.thread_id.nil? && channel.direct_message_channel?
|
||||||
|
in_reply_to_id = channel.direct_message_channel? ? message.id : nil
|
||||||
|
|
||||||
new_prompts =
|
new_prompts =
|
||||||
bot.reply(context) do |partial, cancel, placeholder|
|
bot.reply(context) do |partial, cancel, placeholder|
|
||||||
if !reply
|
if !reply
|
||||||
|
@ -294,8 +317,8 @@ module DiscourseAi
|
||||||
thread_id: message.thread_id,
|
thread_id: message.thread_id,
|
||||||
channel_id: channel.id,
|
channel_id: channel.id,
|
||||||
guardian: guardian,
|
guardian: guardian,
|
||||||
in_reply_to_id: message.id,
|
in_reply_to_id: in_reply_to_id,
|
||||||
force_thread: message.thread_id.nil? && channel.direct_message_channel?,
|
force_thread: force_thread,
|
||||||
enforce_membership: !channel.direct_message_channel?,
|
enforce_membership: !channel.direct_message_channel?,
|
||||||
)
|
)
|
||||||
ChatSDK::Message.start_stream(message_id: reply.id, guardian: guardian)
|
ChatSDK::Message.start_stream(message_id: reply.id, guardian: guardian)
|
||||||
|
|
|
@ -3,11 +3,14 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
class PromptMessagesBuilder
|
class PromptMessagesBuilder
|
||||||
|
MAX_CHAT_UPLOADS = 5
|
||||||
|
|
||||||
def initialize
|
def initialize
|
||||||
@raw_messages = []
|
@raw_messages = []
|
||||||
end
|
end
|
||||||
|
|
||||||
def to_a(limit: nil)
|
def to_a(limit: nil, style: nil)
|
||||||
|
return chat_array(limit: limit) if style == :chat
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
# this will create a "valid" messages array
|
# this will create a "valid" messages array
|
||||||
|
@ -68,6 +71,51 @@ module DiscourseAi
|
||||||
|
|
||||||
@raw_messages << message
|
@raw_messages << message
|
||||||
end
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def chat_array(limit:)
|
||||||
|
buffer = +""
|
||||||
|
|
||||||
|
if @raw_messages.length > 1
|
||||||
|
buffer << (<<~TEXT).strip
|
||||||
|
You are replying inside a Discourse chat. Here is a summary of the conversation so far:
|
||||||
|
{{{
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
upload_ids = []
|
||||||
|
|
||||||
|
@raw_messages[0..-2].each do |message|
|
||||||
|
buffer << "\n"
|
||||||
|
|
||||||
|
upload_ids.concat(message[:upload_ids]) if message[:upload_ids].present?
|
||||||
|
|
||||||
|
if message[:type] == :user
|
||||||
|
buffer << "#{message[:name] || "User"}: "
|
||||||
|
else
|
||||||
|
buffer << "Bot: "
|
||||||
|
end
|
||||||
|
|
||||||
|
buffer << message[:content]
|
||||||
|
end
|
||||||
|
|
||||||
|
buffer << "\n}}}"
|
||||||
|
buffer << "\n\n"
|
||||||
|
buffer << "Your instructions:"
|
||||||
|
buffer << "\n"
|
||||||
|
end
|
||||||
|
|
||||||
|
last_message = @raw_messages[-1]
|
||||||
|
buffer << "#{last_message[:name] || "User"} said #{last_message[:content]} "
|
||||||
|
|
||||||
|
message = { type: :user, content: buffer }
|
||||||
|
upload_ids.concat(last_message[:upload_ids]) if last_message[:upload_ids].present?
|
||||||
|
|
||||||
|
message[:upload_ids] = upload_ids[-MAX_CHAT_UPLOADS..-1] ||
|
||||||
|
upload_ids if upload_ids.present?
|
||||||
|
|
||||||
|
[message]
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -147,6 +147,77 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
persona.update!(allow_chat: true, mentionable: true, default_llm: "anthropic:claude-3-opus")
|
persona.update!(allow_chat: true, mentionable: true, default_llm: "anthropic:claude-3-opus")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "should behave in a sane way when threading is enabled" do
|
||||||
|
channel.update!(threading_enabled: true)
|
||||||
|
|
||||||
|
message =
|
||||||
|
ChatSDK::Message.create(
|
||||||
|
channel_id: channel.id,
|
||||||
|
raw: "thread 1 message 1",
|
||||||
|
guardian: guardian,
|
||||||
|
)
|
||||||
|
|
||||||
|
message =
|
||||||
|
ChatSDK::Message.create(
|
||||||
|
channel_id: channel.id,
|
||||||
|
raw: "thread 1 message 2",
|
||||||
|
in_reply_to_id: message.id,
|
||||||
|
guardian: guardian,
|
||||||
|
)
|
||||||
|
|
||||||
|
thread = message.thread
|
||||||
|
thread.update!(title: "a magic thread")
|
||||||
|
|
||||||
|
message =
|
||||||
|
ChatSDK::Message.create(
|
||||||
|
channel_id: channel.id,
|
||||||
|
raw: "thread 2 message 1",
|
||||||
|
guardian: guardian,
|
||||||
|
)
|
||||||
|
|
||||||
|
message =
|
||||||
|
ChatSDK::Message.create(
|
||||||
|
channel_id: channel.id,
|
||||||
|
raw: "thread 2 message 2",
|
||||||
|
in_reply_to_id: message.id,
|
||||||
|
guardian: guardian,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompts = nil
|
||||||
|
DiscourseAi::Completions::Llm.with_prepared_responses(["world"]) do |_, _, _prompts|
|
||||||
|
message =
|
||||||
|
ChatSDK::Message.create(
|
||||||
|
channel_id: channel.id,
|
||||||
|
raw: "Hello @#{persona.user.username}",
|
||||||
|
guardian: guardian,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompts = _prompts
|
||||||
|
end
|
||||||
|
|
||||||
|
# don't start a thread cause it will get confusing
|
||||||
|
message.reload
|
||||||
|
expect(message.thread_id).to be_nil
|
||||||
|
|
||||||
|
prompt = prompts[0]
|
||||||
|
|
||||||
|
content = prompt.messages[1][:content]
|
||||||
|
# this is fragile by design, mainly so the example can be ultra clear
|
||||||
|
expected = (<<~TEXT).strip
|
||||||
|
You are replying inside a Discourse chat. Here is a summary of the conversation so far:
|
||||||
|
{{{
|
||||||
|
#{user.username}: (a magic thread)
|
||||||
|
thread 1 message 1
|
||||||
|
#{user.username}: thread 2 message 1
|
||||||
|
}}}
|
||||||
|
|
||||||
|
Your instructions:
|
||||||
|
#{user.username} said Hello
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
expect(content.strip).to eq(expected)
|
||||||
|
end
|
||||||
|
|
||||||
it "should reply to a mention if properly enabled" do
|
it "should reply to a mention if properly enabled" do
|
||||||
prompts = nil
|
prompts = nil
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue