From 98022d7d9625f4f059d33ea3e0cd845f21128cb1 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 5 Nov 2024 07:43:26 +1100 Subject: [PATCH] FEATURE: support custom instructions for persona streaming (#890) This allows us to inject information into the system prompt which can help shape replies without repeating over and over in messages. --- .../admin/ai_personas_controller.rb | 43 +++---------------- lib/ai_bot/personas/persona.rb | 5 +++ lib/ai_bot/playground.rb | 3 +- lib/ai_bot/response_http_streamer.rb | 37 +++++++++++++--- lib/completions/endpoints/fake.rb | 10 ++++- .../admin/ai_personas_controller_spec.rb | 6 +++ 6 files changed, 61 insertions(+), 43 deletions(-) diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb index 8ef466f1..cac72815 100644 --- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb @@ -111,55 +111,26 @@ module DiscourseAi topic_id = params[:topic_id].to_i topic = nil - post = nil if topic_id > 0 topic = Topic.find(topic_id) - raise Discourse::NotFound if topic.nil? - if topic.topic_allowed_users.where(user_id: user.id).empty? return render_json_error(I18n.t("discourse_ai.errors.user_not_allowed")) end - - post = - PostCreator.create!( - user, - topic_id: topic_id, - raw: params[:query], - skip_validations: true, - custom_fields: { - DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true, - }, - ) - else - post = - PostCreator.create!( - user, - title: I18n.t("discourse_ai.ai_bot.default_pm_prefix"), - raw: params[:query], - archetype: Archetype.private_message, - target_usernames: "#{user.username},#{persona.user.username}", - skip_validations: true, - custom_fields: { - DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true, - }, - ) - - topic = post.topic end hijack = request.env["rack.hijack"] io = hijack.call - user = current_user - DiscourseAi::AiBot::ResponseHttpStreamer.queue_streamed_reply( - io, - persona, - user, - topic, - post, + io: io, + persona: persona, + user: user, + topic: topic, + query: params[:query].to_s, + custom_instructions: params[:custom_instructions].to_s, + current_user: current_user, ) end diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index ca47df0a..73224808 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -171,6 +171,11 @@ module DiscourseAi DiscourseAi::Completions::Llm.proxy(self.class.question_consolidator_llm) end + if context[:custom_instructions].present? + prompt_insts << "\n" + prompt_insts << context[:custom_instructions] + end + fragments_guidance = rag_fragments_prompt( context[:conversation_context].to_a, diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 4f0648cb..af18f989 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -392,7 +392,7 @@ module DiscourseAi result end - def reply_to(post, &blk) + def reply_to(post, custom_instructions: nil, &blk) # this is a multithreading issue # post custom prompt is needed and it may not # be properly loaded, ensure it is loaded @@ -413,6 +413,7 @@ module DiscourseAi context[:post_id] = post.id context[:topic_id] = post.topic_id context[:private_message] = post.topic.private_message? + context[:custom_instructions] = custom_instructions reply_user = bot.bot_user if bot.persona.class.respond_to?(:user_id) diff --git a/lib/ai_bot/response_http_streamer.rb b/lib/ai_bot/response_http_streamer.rb index 382c47f6..ac1758ff 100644 --- a/lib/ai_bot/response_http_streamer.rb +++ b/lib/ai_bot/response_http_streamer.rb @@ -31,9 +31,36 @@ module DiscourseAi # keeping this in a static method so we don't capture ENV and other bits # this allows us to release memory earlier - def queue_streamed_reply(io, persona, user, topic, post) + def queue_streamed_reply( + io:, + persona:, + user:, + topic:, + query:, + custom_instructions:, + current_user: + ) schedule_block do begin + post_params = { + raw: query, + skip_validations: true, + custom_fields: { + DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true, + }, + } + + if topic + post_params[:topic_id] = topic.id + else + post_params[:title] = I18n.t("discourse_ai.ai_bot.default_pm_prefix") + post_params[:archetype] = Archetype.private_message + post_params[:target_usernames] = "#{user.username},#{persona.user.username}" + end + + post = PostCreator.create!(user, post_params) + topic = post.topic + io.write "HTTP/1.1 200 OK" io.write CRLF io.write "Content-Type: text/plain; charset=utf-8" @@ -52,7 +79,7 @@ module DiscourseAi io.flush persona_class = - DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: user) + DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: current_user) bot = DiscourseAi::AiBot::Bot.as(persona.user, persona: persona_class.new) data = @@ -69,7 +96,7 @@ module DiscourseAi DiscourseAi::AiBot::Playground .new(bot) - .reply_to(post) do |partial| + .reply_to(post, custom_instructions: custom_instructions) do |partial| next if partial.length == 0 data = { partial: partial }.to_json + "\n\n" @@ -88,11 +115,11 @@ module DiscourseAi io.write CRLF io.flush - io.done + io.done if io.respond_to?(:done) rescue StandardError => e # make it a tiny bit easier to debug in dev, this is tricky # multi-threaded code that exhibits various limitations in rails - p e if Rails.env.development? + p e if Rails.env.development? || Rails.env.test? Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply") ensure io.close diff --git a/lib/completions/endpoints/fake.rb b/lib/completions/endpoints/fake.rb index d057fb15..a51ff3ac 100644 --- a/lib/completions/endpoints/fake.rb +++ b/lib/completions/endpoints/fake.rb @@ -104,6 +104,10 @@ module DiscourseAi @last_call = params end + def self.previous_calls + @previous_calls ||= [] + end + def self.reset! @last_call = nil @fake_content = nil @@ -118,7 +122,11 @@ module DiscourseAi feature_name: nil, feature_context: nil ) - self.class.last_call = { dialect: dialect, user: user, model_params: model_params } + last_call = { dialect: dialect, user: user, model_params: model_params } + self.class.last_call = last_call + self.class.previous_calls << last_call + # guard memory in test + self.class.previous_calls.shift if self.class.previous_calls.length > 10 content = self.class.fake_content diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb index 50f71009..fb42506e 100644 --- a/spec/requests/admin/ai_personas_controller_spec.rb +++ b/spec/requests/admin/ai_personas_controller_spec.rb @@ -500,6 +500,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], default_llm: "custom:#{llm.id}", allow_personal_messages: true, + system_prompt: "you are a helpful bot", ) io_out, io_in = IO.pipe @@ -510,6 +511,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do query: "how are you today?", user_unique_id: "site:test.com:user_id:1", preferred_username: "test_user", + custom_instructions: "To be appended to system prompt", }, env: { "rack.hijack" => lambda { io_in }, @@ -521,6 +523,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do raw = io_out.read context_info = validate_streamed_response(raw, "This is a test! Testing!") + system_prompt = fake_endpoint.previous_calls[-2][:dialect].prompt.messages.first[:content] + + expect(system_prompt).to eq("you are a helpful bot\nTo be appended to system prompt") + expect(context_info["topic_id"]).to be_present topic = Topic.find(context_info["topic_id"]) last_post = topic.posts.order(:created_at).last