FEATURE: support Chat with AI Persona via a DM (#488)

Add support for chat with AI personas

- Allow enabling chat for AI personas that have an associated user
- Add new setting `allow_chat` to AI persona to enable/disable chat
- When a message is created in a DM channel with an allowed AI persona user, schedule a reply job
- AI replies to chat messages using the persona's `max_context_posts` setting to determine context
- Store tool calls and custom prompts used to generate a chat reply on the `ChatMessageCustomPrompt` table
- Add tests for AI chat replies with tools and context

At the moment unlike posts we do not carry tool calls in the context.

No @mention support yet for ai personas in channels, this is future work
This commit is contained in:
Sam 2024-05-06 09:49:02 +10:00 committed by GitHub
parent 8875830f6a
commit e4b326c711
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 575 additions and 86 deletions

View File

@ -125,6 +125,7 @@ module DiscourseAi
:rag_chunk_overlap_tokens, :rag_chunk_overlap_tokens,
:rag_conversation_chunks, :rag_conversation_chunks,
:question_consolidator_llm, :question_consolidator_llm,
:allow_chat,
allowed_group_ids: [], allowed_group_ids: [],
rag_uploads: [:id], rag_uploads: [:id],
) )

View File

@ -0,0 +1,24 @@
# frozen_string_literal: true
module ::Jobs
class CreateAiChatReply < ::Jobs::Base
sidekiq_options retry: false
def execute(args)
channel = ::Chat::Channel.find_by(id: args[:channel_id])
return if channel.blank?
message = ::Chat::Message.find_by(id: args[:message_id])
return if message.blank?
personaClass =
DiscourseAi::AiBot::Personas::Persona.find_by(id: args[:persona_id], user: message.user)
return if personaClass.blank?
user = User.find_by(id: personaClass.user_id)
bot = DiscourseAi::AiBot::Bot.as(user, persona: personaClass.new)
DiscourseAi::AiBot::Playground.new(bot).reply_to_chat_message(message, channel)
end
end
end

View File

@ -8,6 +8,7 @@ class AiPersona < ActiveRecord::Base
validates :description, presence: true, length: { maximum: 2000 } validates :description, presence: true, length: { maximum: 2000 }
validates :system_prompt, presence: true, length: { maximum: 10_000_000 } validates :system_prompt, presence: true, length: { maximum: 10_000_000 }
validate :system_persona_unchangeable, on: :update, if: :system validate :system_persona_unchangeable, on: :update, if: :system
validate :chat_preconditions
validates :max_context_posts, numericality: { greater_than: 0 }, allow_nil: true validates :max_context_posts, numericality: { greater_than: 0 }, allow_nil: true
# leaves some room for growth but sets a maximum to avoid memory issues # leaves some room for growth but sets a maximum to avoid memory issues
# we may want to revisit this in the future # we may want to revisit this in the future
@ -52,9 +53,9 @@ class AiPersona < ActiveRecord::Base
.where(enabled: true) .where(enabled: true)
.joins(:user) .joins(:user)
.pluck( .pluck(
"ai_personas.id, users.id, users.username_lower, allowed_group_ids, default_llm, mentionable", "ai_personas.id, users.id, users.username_lower, allowed_group_ids, default_llm, mentionable, allow_chat",
) )
.map do |id, user_id, username, allowed_group_ids, default_llm, mentionable| .map do |id, user_id, username, allowed_group_ids, default_llm, mentionable, allow_chat|
{ {
id: id, id: id,
user_id: user_id, user_id: user_id,
@ -62,6 +63,7 @@ class AiPersona < ActiveRecord::Base
allowed_group_ids: allowed_group_ids, allowed_group_ids: allowed_group_ids,
default_llm: default_llm, default_llm: default_llm,
mentionable: mentionable, mentionable: mentionable,
allow_chat: allow_chat,
} }
end end
@ -72,23 +74,20 @@ class AiPersona < ActiveRecord::Base
end end
end end
def self.allowed_chat(user: nil)
personas = persona_cache[:allowed_chat] ||= persona_users.select { |u| u[:allow_chat] }
if user
personas.select { |u| user.in_any_groups?(u[:allowed_group_ids]) }
else
personas
end
end
def self.mentionables(user: nil) def self.mentionables(user: nil)
all_mentionables = all_mentionables =
persona_cache[:mentionable_usernames] ||= AiPersona persona_cache[:mentionables] ||= persona_users.select do |mentionable|
.where(mentionable: true) mentionable[:mentionable]
.where(enabled: true) end
.joins(:user)
.pluck("ai_personas.id, users.id, users.username_lower, allowed_group_ids, default_llm")
.map do |id, user_id, username, allowed_group_ids, default_llm|
{
id: id,
user_id: user_id,
username: username,
allowed_group_ids: allowed_group_ids,
default_llm: default_llm,
}
end
if user if user
all_mentionables.select { |mentionable| user.in_any_groups?(mentionable[:allowed_group_ids]) } all_mentionables.select { |mentionable| user.in_any_groups?(mentionable[:allowed_group_ids]) }
else else
@ -114,6 +113,7 @@ class AiPersona < ActiveRecord::Base
vision_max_pixels = self.vision_max_pixels vision_max_pixels = self.vision_max_pixels
rag_conversation_chunks = self.rag_conversation_chunks rag_conversation_chunks = self.rag_conversation_chunks
question_consolidator_llm = self.question_consolidator_llm question_consolidator_llm = self.question_consolidator_llm
allow_chat = self.allow_chat
persona_class = DiscourseAi::AiBot::Personas::Persona.system_personas_by_id[self.id] persona_class = DiscourseAi::AiBot::Personas::Persona.system_personas_by_id[self.id]
if persona_class if persona_class
@ -133,6 +133,10 @@ class AiPersona < ActiveRecord::Base
user_id user_id
end end
persona_class.define_singleton_method :allow_chat do
allow_chat
end
persona_class.define_singleton_method :mentionable do persona_class.define_singleton_method :mentionable do
mentionable mentionable
end end
@ -252,6 +256,10 @@ class AiPersona < ActiveRecord::Base
question_consolidator_llm question_consolidator_llm
end end
define_singleton_method :allow_chat do
allow_chat
end
define_singleton_method :to_s do define_singleton_method :to_s do
"#<DiscourseAi::AiBot::Personas::Persona::Custom @name=#{self.name} @allowed_group_ids=#{self.allowed_group_ids.join(",")}>" "#<DiscourseAi::AiBot::Personas::Persona::Custom @name=#{self.name} @allowed_group_ids=#{self.allowed_group_ids.join(",")}>"
end end
@ -342,6 +350,12 @@ class AiPersona < ActiveRecord::Base
private private
def chat_preconditions
if allow_chat && !default_llm
errors.add(:default_llm, I18n.t("discourse_ai.ai_bot.personas.default_llm_required"))
end
end
def system_persona_unchangeable def system_persona_unchangeable
if top_p_changed? || temperature_changed? || system_prompt_changed? || commands_changed? || if top_p_changed? || temperature_changed? || system_prompt_changed? || commands_changed? ||
name_changed? || description_changed? name_changed? || description_changed?
@ -386,7 +400,14 @@ end
# rag_chunk_tokens :integer default(374), not null # rag_chunk_tokens :integer default(374), not null
# rag_chunk_overlap_tokens :integer default(10), not null # rag_chunk_overlap_tokens :integer default(10), not null
# rag_conversation_chunks :integer default(10), not null # rag_conversation_chunks :integer default(10), not null
# role :enum default("bot"), not null
# role_category_ids :integer default([]), not null, is an Array
# role_tags :string default([]), not null, is an Array
# role_group_ids :integer default([]), not null, is an Array
# role_whispers :boolean default(FALSE), not null
# role_max_responses_per_hour :integer default(50), not null
# question_consolidator_llm :text # question_consolidator_llm :text
# allow_chat :boolean default(FALSE), not null
# #
# Indexes # Indexes
# #

View File

@ -0,0 +1,20 @@
# frozen_string_literal: true
class ChatMessageCustomPrompt < ActiveRecord::Base
# belongs_to chat message but going to avoid the cross dependency for now
end
# == Schema Information
#
# Table name: message_custom_prompts
#
# id :bigint not null, primary key
# message_id :bigint not null
# custom_prompt :json not null
# created_at :datetime not null
# updated_at :datetime not null
#
# Indexes
#
# index_message_custom_prompts_on_message_id (message_id) UNIQUE
#

View File

@ -23,7 +23,8 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer
:rag_chunk_tokens, :rag_chunk_tokens,
:rag_chunk_overlap_tokens, :rag_chunk_overlap_tokens,
:rag_conversation_chunks, :rag_conversation_chunks,
:question_consolidator_llm :question_consolidator_llm,
:allow_chat
has_one :user, serializer: BasicUserSerializer, embed: :object has_one :user, serializer: BasicUserSerializer, embed: :object
has_many :rag_uploads, serializer: UploadSerializer, embed: :object has_many :rag_uploads, serializer: UploadSerializer, embed: :object

View File

@ -26,6 +26,7 @@ const CREATE_ATTRIBUTES = [
"rag_chunk_overlap_tokens", "rag_chunk_overlap_tokens",
"rag_conversation_chunks", "rag_conversation_chunks",
"question_consolidator_llm", "question_consolidator_llm",
"allow_chat",
]; ];
const SYSTEM_ATTRIBUTES = [ const SYSTEM_ATTRIBUTES = [
@ -46,6 +47,7 @@ const SYSTEM_ATTRIBUTES = [
"rag_chunk_overlap_tokens", "rag_chunk_overlap_tokens",
"rag_conversation_chunks", "rag_conversation_chunks",
"question_consolidator_llm", "question_consolidator_llm",
"allow_chat",
]; ];
class CommandOption { class CommandOption {

View File

@ -40,6 +40,10 @@ export default class PersonaEditor extends Component {
@tracked ragIndexingStatuses = null; @tracked ragIndexingStatuses = null;
@tracked showIndexingOptions = false; @tracked showIndexingOptions = false;
get chatPluginEnabled() {
return this.siteSettings.chat_enabled;
}
@action @action
updateModel() { updateModel() {
this.editingModel = this.args.model.workingCopy(); this.editingModel = this.args.model.workingCopy();
@ -202,6 +206,11 @@ export default class PersonaEditor extends Component {
await this.toggleField("mentionable"); await this.toggleField("mentionable");
} }
@action
async toggleAllowChat() {
await this.toggleField("allow_chat");
}
@action @action
async toggleVisionEnabled() { async toggleVisionEnabled() {
await this.toggleField("vision_enabled"); await this.toggleField("vision_enabled");
@ -295,6 +304,20 @@ export default class PersonaEditor extends Component {
/> />
</div> </div>
{{#if this.editingModel.user}} {{#if this.editingModel.user}}
{{#if this.chatPluginEnabled}}
<div class="control-group ai-persona-editor__allow_chat">
<DToggleSwitch
class="ai-persona-editor__allow_chat_toggle"
@state={{@model.allow_chat}}
@label="discourse_ai.ai_persona.allow_chat"
{{on "click" this.toggleAllowChat}}
/>
<DTooltip
@icon="question-circle"
@content={{I18n.t "discourse_ai.ai_persona.allow_chat_help"}}
/>
</div>
{{/if}}
<div class="control-group ai-persona-editor__mentionable"> <div class="control-group ai-persona-editor__mentionable">
<DToggleSwitch <DToggleSwitch
class="ai-persona-editor__mentionable_toggle" class="ai-persona-editor__mentionable_toggle"

View File

@ -72,6 +72,11 @@
align-items: center; align-items: center;
} }
&__allow_chat {
display: flex;
align-items: center;
}
&__vision_enabled { &__vision_enabled {
display: flex; display: flex;
align-items: center; align-items: center;

View File

@ -134,8 +134,8 @@ en:
low: Low Quality - cheapest (256x256) low: Low Quality - cheapest (256x256)
medium: Medium Quality (512x512) medium: Medium Quality (512x512)
high: High Quality - slowest (1024x1024) high: High Quality - slowest (1024x1024)
mentionable: Mentionable mentionable: Allow Mentions
mentionable_help: If enabled, users in allowed groups can mention this user in posts and messages, the AI will respond as this persona. mentionable_help: If enabled, users in allowed groups can mention this user in posts, the AI will respond as this persona.
user: User user: User
create_user: Create User create_user: Create User
create_user_help: You can optionally attach a user to this persona. If you do, the AI will use this user to respond to requests. create_user_help: You can optionally attach a user to this persona. If you do, the AI will use this user to respond to requests.
@ -146,6 +146,8 @@ en:
system_prompt: System Prompt system_prompt: System Prompt
show_indexing_options: "Show Upload Options" show_indexing_options: "Show Upload Options"
hide_indexing_options: "Hide Upload Options" hide_indexing_options: "Hide Upload Options"
allow_chat: "Allow Chat"
allow_chat_help: "If enabled, users in allowed groups can DM this persona"
save: Save save: Save
saved: AI Persona Saved saved: AI Persona Saved
enabled: "Enabled?" enabled: "Enabled?"

View File

@ -170,6 +170,7 @@ en:
conversation_deleted: "Conversation share deleted successfully" conversation_deleted: "Conversation share deleted successfully"
ai_bot: ai_bot:
personas: personas:
default_llm_required: "Default LLM model is required prior to enabling Chat"
cannot_delete_system_persona: "System personas cannot be deleted, please disable it instead" cannot_delete_system_persona: "System personas cannot be deleted, please disable it instead"
cannot_edit_system_persona: "System personas can only be renamed, you may not edit commands or system prompt, instead disable and make a copy" cannot_edit_system_persona: "System personas can only be renamed, you may not edit commands or system prompt, instead disable and make a copy"
github_helper: github_helper:

View File

@ -0,0 +1,7 @@
# frozen_string_literal: true
class AddAllowChatToAiPersona < ActiveRecord::Migration[7.0]
def change
add_column :ai_personas, :allow_chat, :boolean, default: false, null: false
end
end

View File

@ -0,0 +1,12 @@
# frozen_string_literal: true
class AddChatMessageCustomPrompt < ActiveRecord::Migration[7.0]
def change
create_table :chat_message_custom_prompts do |t|
t.bigint :message_id, null: false
t.json :custom_prompt, null: false
t.timestamps
end
add_index :chat_message_custom_prompts, :message_id, unique: true
end
end

View File

@ -74,7 +74,7 @@ module DiscourseAi
tool_found = true tool_found = true
tools[0..MAX_TOOLS].each do |tool| tools[0..MAX_TOOLS].each do |tool|
ongoing_chain &&= tool.chain_next_response? ongoing_chain &&= tool.chain_next_response?
process_tool(tool, raw_context, llm, cancel, update_blk, prompt) process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
end end
else else
update_blk.call(partial, cancel, nil) update_blk.call(partial, cancel, nil)
@ -96,9 +96,9 @@ module DiscourseAi
private private
def process_tool(tool, raw_context, llm, cancel, update_blk, prompt) def process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
tool_call_id = tool.tool_call_id tool_call_id = tool.tool_call_id
invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json invocation_result_json = invoke_tool(tool, llm, cancel, context, &update_blk).to_json
tool_call_message = { tool_call_message = {
type: :tool_call, type: :tool_call,
@ -133,7 +133,7 @@ module DiscourseAi
raw_context << [invocation_result_json, tool_call_id, "tool", tool.name] raw_context << [invocation_result_json, tool_call_id, "tool", tool.name]
end end
def invoke_tool(tool, llm, cancel, &update_blk) def invoke_tool(tool, llm, cancel, context, &update_blk)
update_blk.call("", cancel, build_placeholder(tool.summary, "")) update_blk.call("", cancel, build_placeholder(tool.summary, ""))
result = result =
@ -143,7 +143,7 @@ module DiscourseAi
end end
tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw) tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw)
update_blk.call(tool_details, cancel, nil) update_blk.call(tool_details, cancel, nil) if !context[:skip_tool_details]
result result
end end

View File

@ -99,6 +99,15 @@ module DiscourseAi
end end
def inject_into(plugin) def inject_into(plugin)
plugin.register_modifier(:chat_allowed_bot_user_ids) do |user_ids, guardian|
if guardian.user
mentionables = AiPersona.mentionables(user: guardian.user)
allowed_bot_ids = mentionables.map { |mentionable| mentionable[:user_id] }
user_ids.concat(allowed_bot_ids)
end
user_ids
end
plugin.on(:site_setting_changed) do |name, _old_value, _new_value| plugin.on(:site_setting_changed) do |name, _old_value, _new_value|
if name == :ai_bot_enabled_chat_bots || name == :ai_bot_enabled || if name == :ai_bot_enabled_chat_bots || name == :ai_bot_enabled ||
name == :discourse_ai_enabled name == :discourse_ai_enabled
@ -219,6 +228,10 @@ module DiscourseAi
plugin.on(:post_created) { |post| DiscourseAi::AiBot::Playground.schedule_reply(post) } plugin.on(:post_created) { |post| DiscourseAi::AiBot::Playground.schedule_reply(post) }
plugin.on(:chat_message_created) do |chat_message, channel, user, context|
DiscourseAi::AiBot::Playground.schedule_chat_reply(chat_message, channel, user, context)
end
if plugin.respond_to?(:register_editable_topic_custom_field) if plugin.respond_to?(:register_editable_topic_custom_field)
plugin.register_editable_topic_custom_field(:ai_persona_id) plugin.register_editable_topic_custom_field(:ai_persona_id)
end end

View File

@ -21,6 +21,10 @@ module DiscourseAi
nil nil
end end
def allow_chat
false
end
def system_personas def system_personas
@system_personas ||= { @system_personas ||= {
Personas::General => -1, Personas::General => -1,

View File

@ -11,6 +11,28 @@ module DiscourseAi
REQUIRE_TITLE_UPDATE = "discourse-ai-title-update" REQUIRE_TITLE_UPDATE = "discourse-ai-title-update"
def self.schedule_chat_reply(message, channel, user, context)
if channel.direct_message_channel?
allowed_user_ids = channel.allowed_user_ids
return if AiPersona.allowed_chat.any? { |m| m[:user_id] == user.id }
persona =
AiPersona.allowed_chat.find do |p|
p[:user_id].in?(allowed_user_ids) && (user.group_ids & p[:allowed_group_ids])
end
if persona
::Jobs.enqueue(
:create_ai_chat_reply,
channel_id: channel.id,
message_id: message.id,
persona_id: persona[:id],
)
end
end
end
def self.is_bot_user_id?(user_id) def self.is_bot_user_id?(user_id)
# this will catch everything and avoid any feedback loops # this will catch everything and avoid any feedback loops
# we could get feedback loops between say discobot and ai-bot or third party plugins # we could get feedback loops between say discobot and ai-bot or third party plugins
@ -128,7 +150,7 @@ module DiscourseAi
) as upload_ids", ) as upload_ids",
) )
result = [] builder = DiscourseAi::Completions::PromptMessagesBuilder.new
context.reverse_each do |raw, username, custom_prompt, upload_ids| context.reverse_each do |raw, username, custom_prompt, upload_ids|
custom_prompt_translation = custom_prompt_translation =
@ -144,7 +166,7 @@ module DiscourseAi
custom_context[:id] = message[1] if custom_context[:type] != :model custom_context[:id] = message[1] if custom_context[:type] != :model
custom_context[:name] = message[3] if message[3] custom_context[:name] = message[3] if message[3]
result << custom_context builder.push(**custom_context)
end end
end end
@ -162,11 +184,11 @@ module DiscourseAi
context[:upload_ids] = upload_ids.compact context[:upload_ids] = upload_ids.compact
end end
result << context builder.push(**context)
end end
end end
result builder.to_a
end end
def title_playground(post) def title_playground(post)
@ -184,6 +206,116 @@ module DiscourseAi
end end
end end
def chat_context(message, channel, persona_user)
has_vision = bot.persona.class.vision_enabled
if !message.thread_id
hash = { type: :user, content: message.message }
hash[:upload_ids] = message.uploads.map(&:id) if has_vision && message.uploads.present?
return [hash]
end
max_messages = 40
if bot.persona.class.respond_to?(:max_context_posts)
max_messages = bot.persona.class.max_context_posts || 40
end
# I would like to use a guardian however membership for
# persona_user is far in future
thread_messages =
ChatSDK::Thread.last_messages(
thread_id: message.thread_id,
guardian: Discourse.system_user.guardian,
page_size: max_messages,
)
builder = DiscourseAi::Completions::PromptMessagesBuilder.new
thread_messages.each do |m|
if available_bot_user_ids.include?(m.user_id)
builder.push(type: :model, content: m.message)
else
upload_ids = nil
upload_ids = m.uploads.map(&:id) if has_vision && m.uploads.present?
builder.push(
type: :user,
content: m.message,
name: m.user.username,
upload_ids: upload_ids,
)
end
end
builder.to_a(limit: max_messages)
end
def reply_to_chat_message(message, channel)
persona_user = User.find(bot.persona.class.user_id)
participants = channel.user_chat_channel_memberships.map { |m| m.user.username }
context =
get_context(
participants: participants.join(", "),
conversation_context: chat_context(message, channel, persona_user),
user: message.user,
skip_tool_details: true,
)
reply = nil
guardian = Guardian.new(persona_user)
new_prompts =
bot.reply(context) do |partial, cancel, placeholder|
if !reply
# just eat all leading spaces we can not create the message
next if partial.blank?
reply =
ChatSDK::Message.create(
raw: partial,
thread_id: message.thread_id,
channel_id: channel.id,
guardian: guardian,
in_reply_to_id: message.id,
force_thread: message.thread_id.nil?,
)
ChatSDK::Message.start_stream(message_id: reply.id, guardian: guardian)
else
streaming =
ChatSDK::Message.stream(message_id: reply.id, raw: partial, guardian: guardian)
if !streaming
cancel&.call
break
end
end
end
if new_prompts.length > 1 && reply.id
ChatMessageCustomPrompt.create!(message_id: reply.id, custom_prompt: new_prompts)
end
ChatSDK::Message.stop_stream(message_id: reply.id, guardian: guardian) if reply
reply
end
def get_context(participants:, conversation_context:, user:, skip_tool_details: nil)
result = {
site_url: Discourse.base_url,
site_title: SiteSetting.title,
site_description: SiteSetting.site_description,
time: Time.zone.now,
participants: participants,
conversation_context: conversation_context,
user: user,
}
result[:skip_tool_details] = true if skip_tool_details
result
end
def reply_to(post) def reply_to(post)
reply = +"" reply = +""
start = Time.now start = Time.now
@ -191,17 +323,14 @@ module DiscourseAi
post_type = post_type =
post.post_type == Post.types[:whisper] ? Post.types[:whisper] : Post.types[:regular] post.post_type == Post.types[:whisper] ? Post.types[:whisper] : Post.types[:regular]
context = { context =
site_url: Discourse.base_url, get_context(
site_title: SiteSetting.title, participants: post.topic.allowed_users.map(&:username).join(", "),
site_description: SiteSetting.site_description, conversation_context: conversation_context(post),
time: Time.zone.now, user: post.user,
participants: post.topic.allowed_users.map(&:username).join(", "), )
conversation_context: conversation_context(post), context[:post_id] = post.id
user: post.user, context[:topic_id] = post.topic_id
post_id: post.id,
topic_id: post.topic_id,
}
reply_user = bot.bot_user reply_user = bot.bot_user
if bot.persona.class.respond_to?(:user_id) if bot.persona.class.respond_to?(:user_id)
@ -282,7 +411,7 @@ module DiscourseAi
) )
end end
# not need to add a custom prompt for a single reply # we do not need to add a custom prompt for a single reply
if new_custom_prompts.length > 1 if new_custom_prompts.length > 1
reply_post.post_custom_prompt ||= reply_post.build_post_custom_prompt(custom_prompt: []) reply_post.post_custom_prompt ||= reply_post.build_post_custom_prompt(custom_prompt: [])
prompt = reply_post.post_custom_prompt.custom_prompt || [] prompt = reply_post.post_custom_prompt.custom_prompt || []
@ -309,6 +438,14 @@ module DiscourseAi
.concat(DiscourseAi::AiBot::EntryPoint::BOTS.map(&:second)) .concat(DiscourseAi::AiBot::EntryPoint::BOTS.map(&:second))
end end
def available_bot_user_ids
@bot_ids ||=
AiPersona
.joins(:user)
.pluck("users.id")
.concat(DiscourseAi::AiBot::EntryPoint::BOTS.map(&:first))
end
private private
def publish_final_update(reply_post) def publish_final_update(reply_post)

View File

@ -0,0 +1,73 @@
# frozen_string_literal: true
#
module DiscourseAi
module Completions
class PromptMessagesBuilder
def initialize
@raw_messages = []
end
def to_a(limit: nil)
result = []
# this will create a "valid" messages array
# 1. ensures we always start with a user message
# 2. ensures we always end with a user message
# 3. ensures we always interleave user and model messages
last_type = nil
@raw_messages.each do |message|
next if !last_type && message[:type] != :user
if last_type == :tool_call && message[:type] != :tool
result.pop
last_type = result.length > 0 ? result[-1][:type] : nil
end
next if message[:type] == :tool && last_type != :tool_call
if message[:type] == last_type
# merge the message for :user message
# replace the message for other messages
last_message = result[-1]
if message[:type] == :user
old_name = last_message.delete(:name)
last_message[:content] = "#{old_name}: #{last_message[:content]}" if old_name
new_content = message[:content]
new_content = "#{message[:name]}: #{new_content}" if message[:name]
last_message[:content] += "\n#{new_content}"
else
last_message[:content] = message[:content]
end
else
result << message
end
last_type = message[:type]
end
if limit
result[0..limit]
else
result
end
end
def push(type:, content:, name: nil, upload_ids: nil, id: nil)
if !%i[user model tool tool_call system].include?(type)
raise ArgumentError, "type must be either :user, :model, :tool, :tool_call or :system"
end
raise ArgumentError, "upload_ids must be an array" if upload_ids && !upload_ids.is_a?(Array)
message = { type: type, content: content }
message[:name] = name.to_s if name
message[:upload_ids] = upload_ids if upload_ids
message[:id] = id.to_s if id
@raw_messages << message
end
end
end
end

View File

@ -0,0 +1,43 @@
# frozen_string_literal: true
describe DiscourseAi::Completions::PromptMessagesBuilder do
let(:builder) { DiscourseAi::Completions::PromptMessagesBuilder.new }
it "should allow merging user messages" do
builder.push(type: :user, content: "Hello", name: "Alice")
builder.push(type: :user, content: "World", name: "Bob")
expect(builder.to_a).to eq([{ type: :user, content: "Alice: Hello\nBob: World" }])
end
it "should allow adding uploads" do
builder.push(type: :user, content: "Hello", name: "Alice", upload_ids: [1, 2])
expect(builder.to_a).to eq(
[{ type: :user, name: "Alice", content: "Hello", upload_ids: [1, 2] }],
)
end
it "should support function calls" do
builder.push(type: :user, content: "Echo 123 please", name: "Alice")
builder.push(type: :tool_call, content: "echo(123)", name: "echo", id: 1)
builder.push(type: :tool, content: "123", name: "echo", id: 1)
builder.push(type: :user, content: "Hello", name: "Alice")
expected = [
{ type: :user, content: "Echo 123 please", name: "Alice" },
{ type: :tool_call, content: "echo(123)", name: "echo", id: "1" },
{ type: :tool, content: "123", name: "echo", id: "1" },
{ type: :user, content: "Hello", name: "Alice" },
]
expect(builder.to_a).to eq(expected)
end
it "should drop a tool call if it is not followed by tool" do
builder.push(type: :user, content: "Echo 123 please", name: "Alice")
builder.push(type: :tool_call, content: "echo(123)", name: "echo", id: 1)
builder.push(type: :user, content: "OK", name: "James")
expected = [{ type: :user, content: "Alice: Echo 123 please\nJames: OK" }]
expect(builder.to_a).to eq(expected)
end
end

View File

@ -131,6 +131,121 @@ RSpec.describe DiscourseAi::AiBot::Playground do
persona persona
end end
context "with chat" do
fab!(:dm_channel) { Fabricate(:direct_message_channel, users: [user, persona.user]) }
before do
SiteSetting.chat_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}"
Group.refresh_automatic_groups!
persona.update!(
allow_chat: true,
mentionable: false,
default_llm: "anthropic:claude-3-opus",
)
end
let(:guardian) { Guardian.new(user) }
it "can run tools" do
persona.update!(commands: ["TimeCommand"])
responses = [
"<function_calls><invoke><tool_name>time</tool_name><tool_id>time</tool_id><parameters><timezone>Buenos Aires</timezone></parameters></invoke></function_calls>",
"The time is 2023-12-14 17:24:00 -0300",
]
message =
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do
ChatSDK::Message.create(channel_id: dm_channel.id, raw: "Hello", guardian: guardian)
end
message.reload
expect(message.thread_id).to be_present
reply = ChatSDK::Thread.messages(thread_id: message.thread_id, guardian: guardian).last
expect(reply.message).to eq("The time is 2023-12-14 17:24:00 -0300")
# it also needs to have tool details now set on message
prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id)
expect(prompt.custom_prompt.length).to eq(3)
# TODO in chat I am mixed on including this in the context, but I guess maybe?
# thinking about this
end
it "can reply to a chat message" do
message =
DiscourseAi::Completions::Llm.with_prepared_responses(["World"]) do
ChatSDK::Message.create(channel_id: dm_channel.id, raw: "Hello", guardian: guardian)
end
message.reload
expect(message.thread_id).to be_present
thread_messages = ChatSDK::Thread.messages(thread_id: message.thread_id, guardian: guardian)
expect(thread_messages.length).to eq(2)
expect(thread_messages.last.message).to eq("World")
# it also needs to include history per config - first feed some history
persona.update!(enabled: false)
persona_guardian = Guardian.new(persona.user)
4.times do |i|
ChatSDK::Message.create(
channel_id: dm_channel.id,
thread_id: message.thread_id,
raw: "request #{i}",
guardian: guardian,
)
ChatSDK::Message.create(
channel_id: dm_channel.id,
thread_id: message.thread_id,
raw: "response #{i}",
guardian: persona_guardian,
)
end
persona.update!(max_context_posts: 4, enabled: true)
prompts = nil
DiscourseAi::Completions::Llm.with_prepared_responses(
["World 2"],
) do |_response, _llm, _prompts|
ChatSDK::Message.create(
channel_id: dm_channel.id,
thread_id: message.thread_id,
raw: "Hello",
guardian: guardian,
)
prompts = _prompts
end
expect(prompts.length).to eq(1)
mapped =
prompts[0]
.messages
.map { |m| "#{m[:type]}: #{m[:content]}" if m[:type] != :system }
.compact
.join("\n")
.strip
# why?
# 1. we set context to 4
# 2. however PromptMessagesBuilder will enforce rules of starting with :user and ending with it
# so one of the model messages is dropped
expected = (<<~TEXT).strip
user: request 3
model: response 3
user: Hello
TEXT
expect(mapped).to eq(expected)
end
end
it "replies to whispers with a whisper" do it "replies to whispers with a whisper" do
post = nil post = nil
DiscourseAi::Completions::Llm.with_prepared_responses(["Yes I can"]) do DiscourseAi::Completions::Llm.with_prepared_responses(["Yes I can"]) do
@ -458,28 +573,26 @@ RSpec.describe DiscourseAi::AiBot::Playground do
context = playground.conversation_context(third_post) context = playground.conversation_context(third_post)
# skips leading model reply which makes no sense cause first post was whisper
expect(context).to contain_exactly( expect(context).to contain_exactly(
*[ *[{ type: :user, id: user.username, content: third_post.raw }],
{ type: :user, id: user.username, content: third_post.raw },
{ type: :model, content: second_post.raw },
],
) )
end end
context "with custom prompts" do context "with custom prompts" do
it "When post custom prompt is present, we use that instead of the post content" do it "When post custom prompt is present, we use that instead of the post content" do
custom_prompt = [ custom_prompt = [
[
{ args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json,
"time",
"tool",
],
[ [
{ name: "time", arguments: { name: "time", timezone: "Buenos Aires" } }.to_json, { name: "time", arguments: { name: "time", timezone: "Buenos Aires" } }.to_json,
"time", "time",
"tool_call", "tool_call",
], ],
["I replied this thanks to the time command", bot_user.username], [
{ args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json,
"time",
"tool",
],
["I replied to the time command", bot_user.username],
] ]
PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt) PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt)
@ -488,43 +601,11 @@ RSpec.describe DiscourseAi::AiBot::Playground do
expect(context).to contain_exactly( expect(context).to contain_exactly(
*[ *[
{ type: :user, id: user.username, content: third_post.raw },
{ type: :model, content: custom_prompt.third.first },
{ type: :tool_call, content: custom_prompt.second.first, id: "time" },
{ type: :tool, id: "time", content: custom_prompt.first.first },
{ type: :user, id: user.username, content: first_post.raw }, { type: :user, id: user.username, content: first_post.raw },
], { type: :tool_call, content: custom_prompt.first.first, id: "time" },
) { type: :tool, id: "time", content: custom_prompt.second.first },
end
it "include replies generated from tools" do
custom_prompt = [
[
{ args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json,
"time",
"tool",
],
[
{ name: "time", arguments: { name: "time", timezone: "Buenos Aires" } }.to_json,
"time",
"tool_call",
],
["I replied", bot_user.username],
]
PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt)
PostCustomPrompt.create!(post: first_post, custom_prompt: custom_prompt)
context = playground.conversation_context(third_post)
expect(context).to contain_exactly(
*[
{ type: :user, id: user.username, content: third_post.raw },
{ type: :model, content: custom_prompt.third.first }, { type: :model, content: custom_prompt.third.first },
{ type: :tool_call, content: custom_prompt.second.first, id: "time" }, { type: :user, id: user.username, content: third_post.raw },
{ type: :tool, id: "time", content: custom_prompt.first.first },
{ type: :tool_call, content: custom_prompt.second.first, id: "time" },
{ type: :tool, id: "time", content: custom_prompt.first.first },
{ type: :model, content: "I replied" },
], ],
) )
end end

View File

@ -113,6 +113,23 @@ RSpec.describe AiPersona do
expect(klass.max_context_posts).to eq(3) expect(klass.max_context_posts).to eq(3)
end end
it "does not allow setting allow_chat without a default_llm" do
persona =
AiPersona.create(
name: "test",
description: "test",
system_prompt: "test",
allowed_group_ids: [],
default_llm: nil,
allow_chat: true,
)
expect(persona.valid?).to eq(false)
expect(persona.errors[:default_llm].first).to eq(
I18n.t("discourse_ai.ai_bot.personas.default_llm_required"),
)
end
it "does not leak caches between sites" do it "does not leak caches between sites" do
AiPersona.create!( AiPersona.create!(
name: "pun_bot", name: "pun_bot",

View File

@ -53,6 +53,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
rag_chunk_overlap_tokens: 10, rag_chunk_overlap_tokens: 10,
rag_conversation_chunks: 10, rag_conversation_chunks: 10,
question_consolidator_llm: "Question Consolidator LLM", question_consolidator_llm: "Question Consolidator LLM",
allow_chat: false,
}; };
const aiPersona = AiPersona.create({ ...properties }); const aiPersona = AiPersona.create({ ...properties });
@ -92,6 +93,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
rag_chunk_overlap_tokens: 10, rag_chunk_overlap_tokens: 10,
rag_conversation_chunks: 10, rag_conversation_chunks: 10,
question_consolidator_llm: "Question Consolidator LLM", question_consolidator_llm: "Question Consolidator LLM",
allow_chat: false,
}; };
const aiPersona = AiPersona.create({ ...properties }); const aiPersona = AiPersona.create({ ...properties });