From e4b326c7114939467e04a325addfc363087316fd Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 6 May 2024 09:49:02 +1000 Subject: [PATCH] FEATURE: support Chat with AI Persona via a DM (#488) Add support for chat with AI personas - Allow enabling chat for AI personas that have an associated user - Add new setting `allow_chat` to AI persona to enable/disable chat - When a message is created in a DM channel with an allowed AI persona user, schedule a reply job - AI replies to chat messages using the persona's `max_context_posts` setting to determine context - Store tool calls and custom prompts used to generate a chat reply on the `ChatMessageCustomPrompt` table - Add tests for AI chat replies with tools and context At the moment unlike posts we do not carry tool calls in the context. No @mention support yet for ai personas in channels, this is future work --- .../admin/ai_personas_controller.rb | 1 + app/jobs/regular/create_ai_chat_reply.rb | 24 +++ app/models/ai_persona.rb | 55 ++++-- app/models/chat_message_custom_prompt.rb | 20 ++ .../localized_ai_persona_serializer.rb | 3 +- .../discourse/admin/models/ai-persona.js | 2 + .../components/ai-persona-editor.gjs | 23 +++ .../modules/ai-bot/common/ai-persona.scss | 5 + config/locales/client.en.yml | 6 +- config/locales/server.en.yml | 1 + ...0503034946_add_allow_chat_to_ai_persona.rb | 7 + ...03042558_add_chat_message_custom_prompt.rb | 12 ++ lib/ai_bot/bot.rb | 10 +- lib/ai_bot/entry_point.rb | 13 ++ lib/ai_bot/personas/persona.rb | 4 + lib/ai_bot/playground.rb | 169 +++++++++++++++-- lib/completions/prompt_messages_builder.rb | 73 ++++++++ .../prompt_messages_builder_spec.rb | 43 +++++ spec/lib/modules/ai_bot/playground_spec.rb | 171 +++++++++++++----- spec/models/ai_persona_spec.rb | 17 ++ .../unit/models/ai-persona-test.js | 2 + 21 files changed, 575 insertions(+), 86 deletions(-) create mode 100644 app/jobs/regular/create_ai_chat_reply.rb create mode 100644 app/models/chat_message_custom_prompt.rb create mode 100644 db/migrate/20240503034946_add_allow_chat_to_ai_persona.rb create mode 100644 db/migrate/20240503042558_add_chat_message_custom_prompt.rb create mode 100644 lib/completions/prompt_messages_builder.rb create mode 100644 spec/lib/completions/prompt_messages_builder_spec.rb diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb index ca92ed50..ca52ea56 100644 --- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb @@ -125,6 +125,7 @@ module DiscourseAi :rag_chunk_overlap_tokens, :rag_conversation_chunks, :question_consolidator_llm, + :allow_chat, allowed_group_ids: [], rag_uploads: [:id], ) diff --git a/app/jobs/regular/create_ai_chat_reply.rb b/app/jobs/regular/create_ai_chat_reply.rb new file mode 100644 index 00000000..48258c72 --- /dev/null +++ b/app/jobs/regular/create_ai_chat_reply.rb @@ -0,0 +1,24 @@ +# frozen_string_literal: true + +module ::Jobs + class CreateAiChatReply < ::Jobs::Base + sidekiq_options retry: false + + def execute(args) + channel = ::Chat::Channel.find_by(id: args[:channel_id]) + return if channel.blank? + + message = ::Chat::Message.find_by(id: args[:message_id]) + return if message.blank? + + personaClass = + DiscourseAi::AiBot::Personas::Persona.find_by(id: args[:persona_id], user: message.user) + return if personaClass.blank? + + user = User.find_by(id: personaClass.user_id) + bot = DiscourseAi::AiBot::Bot.as(user, persona: personaClass.new) + + DiscourseAi::AiBot::Playground.new(bot).reply_to_chat_message(message, channel) + end + end +end diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb index a1ed5f88..7db31800 100644 --- a/app/models/ai_persona.rb +++ b/app/models/ai_persona.rb @@ -8,6 +8,7 @@ class AiPersona < ActiveRecord::Base validates :description, presence: true, length: { maximum: 2000 } validates :system_prompt, presence: true, length: { maximum: 10_000_000 } validate :system_persona_unchangeable, on: :update, if: :system + validate :chat_preconditions validates :max_context_posts, numericality: { greater_than: 0 }, allow_nil: true # leaves some room for growth but sets a maximum to avoid memory issues # we may want to revisit this in the future @@ -52,9 +53,9 @@ class AiPersona < ActiveRecord::Base .where(enabled: true) .joins(:user) .pluck( - "ai_personas.id, users.id, users.username_lower, allowed_group_ids, default_llm, mentionable", + "ai_personas.id, users.id, users.username_lower, allowed_group_ids, default_llm, mentionable, allow_chat", ) - .map do |id, user_id, username, allowed_group_ids, default_llm, mentionable| + .map do |id, user_id, username, allowed_group_ids, default_llm, mentionable, allow_chat| { id: id, user_id: user_id, @@ -62,6 +63,7 @@ class AiPersona < ActiveRecord::Base allowed_group_ids: allowed_group_ids, default_llm: default_llm, mentionable: mentionable, + allow_chat: allow_chat, } end @@ -72,23 +74,20 @@ class AiPersona < ActiveRecord::Base end end + def self.allowed_chat(user: nil) + personas = persona_cache[:allowed_chat] ||= persona_users.select { |u| u[:allow_chat] } + if user + personas.select { |u| user.in_any_groups?(u[:allowed_group_ids]) } + else + personas + end + end + def self.mentionables(user: nil) all_mentionables = - persona_cache[:mentionable_usernames] ||= AiPersona - .where(mentionable: true) - .where(enabled: true) - .joins(:user) - .pluck("ai_personas.id, users.id, users.username_lower, allowed_group_ids, default_llm") - .map do |id, user_id, username, allowed_group_ids, default_llm| - { - id: id, - user_id: user_id, - username: username, - allowed_group_ids: allowed_group_ids, - default_llm: default_llm, - } - end - + persona_cache[:mentionables] ||= persona_users.select do |mentionable| + mentionable[:mentionable] + end if user all_mentionables.select { |mentionable| user.in_any_groups?(mentionable[:allowed_group_ids]) } else @@ -114,6 +113,7 @@ class AiPersona < ActiveRecord::Base vision_max_pixels = self.vision_max_pixels rag_conversation_chunks = self.rag_conversation_chunks question_consolidator_llm = self.question_consolidator_llm + allow_chat = self.allow_chat persona_class = DiscourseAi::AiBot::Personas::Persona.system_personas_by_id[self.id] if persona_class @@ -133,6 +133,10 @@ class AiPersona < ActiveRecord::Base user_id end + persona_class.define_singleton_method :allow_chat do + allow_chat + end + persona_class.define_singleton_method :mentionable do mentionable end @@ -252,6 +256,10 @@ class AiPersona < ActiveRecord::Base question_consolidator_llm end + define_singleton_method :allow_chat do + allow_chat + end + define_singleton_method :to_s do "#" end @@ -342,6 +350,12 @@ class AiPersona < ActiveRecord::Base private + def chat_preconditions + if allow_chat && !default_llm + errors.add(:default_llm, I18n.t("discourse_ai.ai_bot.personas.default_llm_required")) + end + end + def system_persona_unchangeable if top_p_changed? || temperature_changed? || system_prompt_changed? || commands_changed? || name_changed? || description_changed? @@ -386,7 +400,14 @@ end # rag_chunk_tokens :integer default(374), not null # rag_chunk_overlap_tokens :integer default(10), not null # rag_conversation_chunks :integer default(10), not null +# role :enum default("bot"), not null +# role_category_ids :integer default([]), not null, is an Array +# role_tags :string default([]), not null, is an Array +# role_group_ids :integer default([]), not null, is an Array +# role_whispers :boolean default(FALSE), not null +# role_max_responses_per_hour :integer default(50), not null # question_consolidator_llm :text +# allow_chat :boolean default(FALSE), not null # # Indexes # diff --git a/app/models/chat_message_custom_prompt.rb b/app/models/chat_message_custom_prompt.rb new file mode 100644 index 00000000..30b28533 --- /dev/null +++ b/app/models/chat_message_custom_prompt.rb @@ -0,0 +1,20 @@ +# frozen_string_literal: true + +class ChatMessageCustomPrompt < ActiveRecord::Base + # belongs_to chat message but going to avoid the cross dependency for now +end + +# == Schema Information +# +# Table name: message_custom_prompts +# +# id :bigint not null, primary key +# message_id :bigint not null +# custom_prompt :json not null +# created_at :datetime not null +# updated_at :datetime not null +# +# Indexes +# +# index_message_custom_prompts_on_message_id (message_id) UNIQUE +# diff --git a/app/serializers/localized_ai_persona_serializer.rb b/app/serializers/localized_ai_persona_serializer.rb index d37a2276..a3e41b26 100644 --- a/app/serializers/localized_ai_persona_serializer.rb +++ b/app/serializers/localized_ai_persona_serializer.rb @@ -23,7 +23,8 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer :rag_chunk_tokens, :rag_chunk_overlap_tokens, :rag_conversation_chunks, - :question_consolidator_llm + :question_consolidator_llm, + :allow_chat has_one :user, serializer: BasicUserSerializer, embed: :object has_many :rag_uploads, serializer: UploadSerializer, embed: :object diff --git a/assets/javascripts/discourse/admin/models/ai-persona.js b/assets/javascripts/discourse/admin/models/ai-persona.js index 9b0046d8..7b06baf2 100644 --- a/assets/javascripts/discourse/admin/models/ai-persona.js +++ b/assets/javascripts/discourse/admin/models/ai-persona.js @@ -26,6 +26,7 @@ const CREATE_ATTRIBUTES = [ "rag_chunk_overlap_tokens", "rag_conversation_chunks", "question_consolidator_llm", + "allow_chat", ]; const SYSTEM_ATTRIBUTES = [ @@ -46,6 +47,7 @@ const SYSTEM_ATTRIBUTES = [ "rag_chunk_overlap_tokens", "rag_conversation_chunks", "question_consolidator_llm", + "allow_chat", ]; class CommandOption { diff --git a/assets/javascripts/discourse/components/ai-persona-editor.gjs b/assets/javascripts/discourse/components/ai-persona-editor.gjs index 1ed1474b..cdd2eabf 100644 --- a/assets/javascripts/discourse/components/ai-persona-editor.gjs +++ b/assets/javascripts/discourse/components/ai-persona-editor.gjs @@ -40,6 +40,10 @@ export default class PersonaEditor extends Component { @tracked ragIndexingStatuses = null; @tracked showIndexingOptions = false; + get chatPluginEnabled() { + return this.siteSettings.chat_enabled; + } + @action updateModel() { this.editingModel = this.args.model.workingCopy(); @@ -202,6 +206,11 @@ export default class PersonaEditor extends Component { await this.toggleField("mentionable"); } + @action + async toggleAllowChat() { + await this.toggleField("allow_chat"); + } + @action async toggleVisionEnabled() { await this.toggleField("vision_enabled"); @@ -295,6 +304,20 @@ export default class PersonaEditor extends Component { /> {{#if this.editingModel.user}} + {{#if this.chatPluginEnabled}} +
+ + +
+ {{/if}}
-1, diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index d84bd82f..28d06c28 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -11,6 +11,28 @@ module DiscourseAi REQUIRE_TITLE_UPDATE = "discourse-ai-title-update" + def self.schedule_chat_reply(message, channel, user, context) + if channel.direct_message_channel? + allowed_user_ids = channel.allowed_user_ids + + return if AiPersona.allowed_chat.any? { |m| m[:user_id] == user.id } + + persona = + AiPersona.allowed_chat.find do |p| + p[:user_id].in?(allowed_user_ids) && (user.group_ids & p[:allowed_group_ids]) + end + + if persona + ::Jobs.enqueue( + :create_ai_chat_reply, + channel_id: channel.id, + message_id: message.id, + persona_id: persona[:id], + ) + end + end + end + def self.is_bot_user_id?(user_id) # this will catch everything and avoid any feedback loops # we could get feedback loops between say discobot and ai-bot or third party plugins @@ -128,7 +150,7 @@ module DiscourseAi ) as upload_ids", ) - result = [] + builder = DiscourseAi::Completions::PromptMessagesBuilder.new context.reverse_each do |raw, username, custom_prompt, upload_ids| custom_prompt_translation = @@ -144,7 +166,7 @@ module DiscourseAi custom_context[:id] = message[1] if custom_context[:type] != :model custom_context[:name] = message[3] if message[3] - result << custom_context + builder.push(**custom_context) end end @@ -162,11 +184,11 @@ module DiscourseAi context[:upload_ids] = upload_ids.compact end - result << context + builder.push(**context) end end - result + builder.to_a end def title_playground(post) @@ -184,6 +206,116 @@ module DiscourseAi end end + def chat_context(message, channel, persona_user) + has_vision = bot.persona.class.vision_enabled + + if !message.thread_id + hash = { type: :user, content: message.message } + hash[:upload_ids] = message.uploads.map(&:id) if has_vision && message.uploads.present? + return [hash] + end + + max_messages = 40 + if bot.persona.class.respond_to?(:max_context_posts) + max_messages = bot.persona.class.max_context_posts || 40 + end + + # I would like to use a guardian however membership for + # persona_user is far in future + thread_messages = + ChatSDK::Thread.last_messages( + thread_id: message.thread_id, + guardian: Discourse.system_user.guardian, + page_size: max_messages, + ) + + builder = DiscourseAi::Completions::PromptMessagesBuilder.new + + thread_messages.each do |m| + if available_bot_user_ids.include?(m.user_id) + builder.push(type: :model, content: m.message) + else + upload_ids = nil + upload_ids = m.uploads.map(&:id) if has_vision && m.uploads.present? + builder.push( + type: :user, + content: m.message, + name: m.user.username, + upload_ids: upload_ids, + ) + end + end + + builder.to_a(limit: max_messages) + end + + def reply_to_chat_message(message, channel) + persona_user = User.find(bot.persona.class.user_id) + + participants = channel.user_chat_channel_memberships.map { |m| m.user.username } + + context = + get_context( + participants: participants.join(", "), + conversation_context: chat_context(message, channel, persona_user), + user: message.user, + skip_tool_details: true, + ) + + reply = nil + guardian = Guardian.new(persona_user) + + new_prompts = + bot.reply(context) do |partial, cancel, placeholder| + if !reply + # just eat all leading spaces we can not create the message + next if partial.blank? + reply = + ChatSDK::Message.create( + raw: partial, + thread_id: message.thread_id, + channel_id: channel.id, + guardian: guardian, + in_reply_to_id: message.id, + force_thread: message.thread_id.nil?, + ) + ChatSDK::Message.start_stream(message_id: reply.id, guardian: guardian) + else + streaming = + ChatSDK::Message.stream(message_id: reply.id, raw: partial, guardian: guardian) + + if !streaming + cancel&.call + break + end + end + end + + if new_prompts.length > 1 && reply.id + ChatMessageCustomPrompt.create!(message_id: reply.id, custom_prompt: new_prompts) + end + + ChatSDK::Message.stop_stream(message_id: reply.id, guardian: guardian) if reply + + reply + end + + def get_context(participants:, conversation_context:, user:, skip_tool_details: nil) + result = { + site_url: Discourse.base_url, + site_title: SiteSetting.title, + site_description: SiteSetting.site_description, + time: Time.zone.now, + participants: participants, + conversation_context: conversation_context, + user: user, + } + + result[:skip_tool_details] = true if skip_tool_details + + result + end + def reply_to(post) reply = +"" start = Time.now @@ -191,17 +323,14 @@ module DiscourseAi post_type = post.post_type == Post.types[:whisper] ? Post.types[:whisper] : Post.types[:regular] - context = { - site_url: Discourse.base_url, - site_title: SiteSetting.title, - site_description: SiteSetting.site_description, - time: Time.zone.now, - participants: post.topic.allowed_users.map(&:username).join(", "), - conversation_context: conversation_context(post), - user: post.user, - post_id: post.id, - topic_id: post.topic_id, - } + context = + get_context( + participants: post.topic.allowed_users.map(&:username).join(", "), + conversation_context: conversation_context(post), + user: post.user, + ) + context[:post_id] = post.id + context[:topic_id] = post.topic_id reply_user = bot.bot_user if bot.persona.class.respond_to?(:user_id) @@ -282,7 +411,7 @@ module DiscourseAi ) end - # not need to add a custom prompt for a single reply + # we do not need to add a custom prompt for a single reply if new_custom_prompts.length > 1 reply_post.post_custom_prompt ||= reply_post.build_post_custom_prompt(custom_prompt: []) prompt = reply_post.post_custom_prompt.custom_prompt || [] @@ -309,6 +438,14 @@ module DiscourseAi .concat(DiscourseAi::AiBot::EntryPoint::BOTS.map(&:second)) end + def available_bot_user_ids + @bot_ids ||= + AiPersona + .joins(:user) + .pluck("users.id") + .concat(DiscourseAi::AiBot::EntryPoint::BOTS.map(&:first)) + end + private def publish_final_update(reply_post) diff --git a/lib/completions/prompt_messages_builder.rb b/lib/completions/prompt_messages_builder.rb new file mode 100644 index 00000000..851ebcae --- /dev/null +++ b/lib/completions/prompt_messages_builder.rb @@ -0,0 +1,73 @@ +# frozen_string_literal: true +# +module DiscourseAi + module Completions + class PromptMessagesBuilder + def initialize + @raw_messages = [] + end + + def to_a(limit: nil) + result = [] + + # this will create a "valid" messages array + # 1. ensures we always start with a user message + # 2. ensures we always end with a user message + # 3. ensures we always interleave user and model messages + last_type = nil + @raw_messages.each do |message| + next if !last_type && message[:type] != :user + + if last_type == :tool_call && message[:type] != :tool + result.pop + last_type = result.length > 0 ? result[-1][:type] : nil + end + + next if message[:type] == :tool && last_type != :tool_call + + if message[:type] == last_type + # merge the message for :user message + # replace the message for other messages + last_message = result[-1] + + if message[:type] == :user + old_name = last_message.delete(:name) + last_message[:content] = "#{old_name}: #{last_message[:content]}" if old_name + + new_content = message[:content] + new_content = "#{message[:name]}: #{new_content}" if message[:name] + + last_message[:content] += "\n#{new_content}" + else + last_message[:content] = message[:content] + end + else + result << message + end + + last_type = message[:type] + end + + if limit + result[0..limit] + else + result + end + end + + def push(type:, content:, name: nil, upload_ids: nil, id: nil) + if !%i[user model tool tool_call system].include?(type) + raise ArgumentError, "type must be either :user, :model, :tool, :tool_call or :system" + end + raise ArgumentError, "upload_ids must be an array" if upload_ids && !upload_ids.is_a?(Array) + + message = { type: type, content: content } + message[:name] = name.to_s if name + message[:upload_ids] = upload_ids if upload_ids + message[:id] = id.to_s if id + + @raw_messages << message + end + end + end +end diff --git a/spec/lib/completions/prompt_messages_builder_spec.rb b/spec/lib/completions/prompt_messages_builder_spec.rb new file mode 100644 index 00000000..7e758e7a --- /dev/null +++ b/spec/lib/completions/prompt_messages_builder_spec.rb @@ -0,0 +1,43 @@ +# frozen_string_literal: true + +describe DiscourseAi::Completions::PromptMessagesBuilder do + let(:builder) { DiscourseAi::Completions::PromptMessagesBuilder.new } + + it "should allow merging user messages" do + builder.push(type: :user, content: "Hello", name: "Alice") + builder.push(type: :user, content: "World", name: "Bob") + + expect(builder.to_a).to eq([{ type: :user, content: "Alice: Hello\nBob: World" }]) + end + + it "should allow adding uploads" do + builder.push(type: :user, content: "Hello", name: "Alice", upload_ids: [1, 2]) + + expect(builder.to_a).to eq( + [{ type: :user, name: "Alice", content: "Hello", upload_ids: [1, 2] }], + ) + end + + it "should support function calls" do + builder.push(type: :user, content: "Echo 123 please", name: "Alice") + builder.push(type: :tool_call, content: "echo(123)", name: "echo", id: 1) + builder.push(type: :tool, content: "123", name: "echo", id: 1) + builder.push(type: :user, content: "Hello", name: "Alice") + expected = [ + { type: :user, content: "Echo 123 please", name: "Alice" }, + { type: :tool_call, content: "echo(123)", name: "echo", id: "1" }, + { type: :tool, content: "123", name: "echo", id: "1" }, + { type: :user, content: "Hello", name: "Alice" }, + ] + expect(builder.to_a).to eq(expected) + end + + it "should drop a tool call if it is not followed by tool" do + builder.push(type: :user, content: "Echo 123 please", name: "Alice") + builder.push(type: :tool_call, content: "echo(123)", name: "echo", id: 1) + builder.push(type: :user, content: "OK", name: "James") + + expected = [{ type: :user, content: "Alice: Echo 123 please\nJames: OK" }] + expect(builder.to_a).to eq(expected) + end +end diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index 29ffc968..ce88c397 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -131,6 +131,121 @@ RSpec.describe DiscourseAi::AiBot::Playground do persona end + context "with chat" do + fab!(:dm_channel) { Fabricate(:direct_message_channel, users: [user, persona.user]) } + + before do + SiteSetting.chat_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}" + Group.refresh_automatic_groups! + persona.update!( + allow_chat: true, + mentionable: false, + default_llm: "anthropic:claude-3-opus", + ) + end + + let(:guardian) { Guardian.new(user) } + + it "can run tools" do + persona.update!(commands: ["TimeCommand"]) + + responses = [ + "timetimeBuenos Aires", + "The time is 2023-12-14 17:24:00 -0300", + ] + + message = + DiscourseAi::Completions::Llm.with_prepared_responses(responses) do + ChatSDK::Message.create(channel_id: dm_channel.id, raw: "Hello", guardian: guardian) + end + + message.reload + expect(message.thread_id).to be_present + reply = ChatSDK::Thread.messages(thread_id: message.thread_id, guardian: guardian).last + + expect(reply.message).to eq("The time is 2023-12-14 17:24:00 -0300") + + # it also needs to have tool details now set on message + prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id) + expect(prompt.custom_prompt.length).to eq(3) + + # TODO in chat I am mixed on including this in the context, but I guess maybe? + # thinking about this + end + + it "can reply to a chat message" do + message = + DiscourseAi::Completions::Llm.with_prepared_responses(["World"]) do + ChatSDK::Message.create(channel_id: dm_channel.id, raw: "Hello", guardian: guardian) + end + + message.reload + expect(message.thread_id).to be_present + + thread_messages = ChatSDK::Thread.messages(thread_id: message.thread_id, guardian: guardian) + expect(thread_messages.length).to eq(2) + expect(thread_messages.last.message).to eq("World") + + # it also needs to include history per config - first feed some history + persona.update!(enabled: false) + + persona_guardian = Guardian.new(persona.user) + + 4.times do |i| + ChatSDK::Message.create( + channel_id: dm_channel.id, + thread_id: message.thread_id, + raw: "request #{i}", + guardian: guardian, + ) + + ChatSDK::Message.create( + channel_id: dm_channel.id, + thread_id: message.thread_id, + raw: "response #{i}", + guardian: persona_guardian, + ) + end + + persona.update!(max_context_posts: 4, enabled: true) + + prompts = nil + DiscourseAi::Completions::Llm.with_prepared_responses( + ["World 2"], + ) do |_response, _llm, _prompts| + ChatSDK::Message.create( + channel_id: dm_channel.id, + thread_id: message.thread_id, + raw: "Hello", + guardian: guardian, + ) + prompts = _prompts + end + + expect(prompts.length).to eq(1) + + mapped = + prompts[0] + .messages + .map { |m| "#{m[:type]}: #{m[:content]}" if m[:type] != :system } + .compact + .join("\n") + .strip + + # why? + # 1. we set context to 4 + # 2. however PromptMessagesBuilder will enforce rules of starting with :user and ending with it + # so one of the model messages is dropped + expected = (<<~TEXT).strip + user: request 3 + model: response 3 + user: Hello + TEXT + + expect(mapped).to eq(expected) + end + end + it "replies to whispers with a whisper" do post = nil DiscourseAi::Completions::Llm.with_prepared_responses(["Yes I can"]) do @@ -458,28 +573,26 @@ RSpec.describe DiscourseAi::AiBot::Playground do context = playground.conversation_context(third_post) + # skips leading model reply which makes no sense cause first post was whisper expect(context).to contain_exactly( - *[ - { type: :user, id: user.username, content: third_post.raw }, - { type: :model, content: second_post.raw }, - ], + *[{ type: :user, id: user.username, content: third_post.raw }], ) end context "with custom prompts" do it "When post custom prompt is present, we use that instead of the post content" do custom_prompt = [ - [ - { args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json, - "time", - "tool", - ], [ { name: "time", arguments: { name: "time", timezone: "Buenos Aires" } }.to_json, "time", "tool_call", ], - ["I replied this thanks to the time command", bot_user.username], + [ + { args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json, + "time", + "tool", + ], + ["I replied to the time command", bot_user.username], ] PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt) @@ -488,43 +601,11 @@ RSpec.describe DiscourseAi::AiBot::Playground do expect(context).to contain_exactly( *[ - { type: :user, id: user.username, content: third_post.raw }, - { type: :model, content: custom_prompt.third.first }, - { type: :tool_call, content: custom_prompt.second.first, id: "time" }, - { type: :tool, id: "time", content: custom_prompt.first.first }, { type: :user, id: user.username, content: first_post.raw }, - ], - ) - end - - it "include replies generated from tools" do - custom_prompt = [ - [ - { args: { timezone: "Buenos Aires" }, time: "2023-12-14 17:24:00 -0300" }.to_json, - "time", - "tool", - ], - [ - { name: "time", arguments: { name: "time", timezone: "Buenos Aires" } }.to_json, - "time", - "tool_call", - ], - ["I replied", bot_user.username], - ] - PostCustomPrompt.create!(post: second_post, custom_prompt: custom_prompt) - PostCustomPrompt.create!(post: first_post, custom_prompt: custom_prompt) - - context = playground.conversation_context(third_post) - - expect(context).to contain_exactly( - *[ - { type: :user, id: user.username, content: third_post.raw }, + { type: :tool_call, content: custom_prompt.first.first, id: "time" }, + { type: :tool, id: "time", content: custom_prompt.second.first }, { type: :model, content: custom_prompt.third.first }, - { type: :tool_call, content: custom_prompt.second.first, id: "time" }, - { type: :tool, id: "time", content: custom_prompt.first.first }, - { type: :tool_call, content: custom_prompt.second.first, id: "time" }, - { type: :tool, id: "time", content: custom_prompt.first.first }, - { type: :model, content: "I replied" }, + { type: :user, id: user.username, content: third_post.raw }, ], ) end diff --git a/spec/models/ai_persona_spec.rb b/spec/models/ai_persona_spec.rb index 00f86f37..bc9cd9ae 100644 --- a/spec/models/ai_persona_spec.rb +++ b/spec/models/ai_persona_spec.rb @@ -113,6 +113,23 @@ RSpec.describe AiPersona do expect(klass.max_context_posts).to eq(3) end + it "does not allow setting allow_chat without a default_llm" do + persona = + AiPersona.create( + name: "test", + description: "test", + system_prompt: "test", + allowed_group_ids: [], + default_llm: nil, + allow_chat: true, + ) + + expect(persona.valid?).to eq(false) + expect(persona.errors[:default_llm].first).to eq( + I18n.t("discourse_ai.ai_bot.personas.default_llm_required"), + ) + end + it "does not leak caches between sites" do AiPersona.create!( name: "pun_bot", diff --git a/test/javascripts/unit/models/ai-persona-test.js b/test/javascripts/unit/models/ai-persona-test.js index 737d5a7b..ac1c5530 100644 --- a/test/javascripts/unit/models/ai-persona-test.js +++ b/test/javascripts/unit/models/ai-persona-test.js @@ -53,6 +53,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () { rag_chunk_overlap_tokens: 10, rag_conversation_chunks: 10, question_consolidator_llm: "Question Consolidator LLM", + allow_chat: false, }; const aiPersona = AiPersona.create({ ...properties }); @@ -92,6 +93,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () { rag_chunk_overlap_tokens: 10, rag_conversation_chunks: 10, question_consolidator_llm: "Question Consolidator LLM", + allow_chat: false, }; const aiPersona = AiPersona.create({ ...properties });