FEATURE: support custom instructions for persona streaming (#890)

This allows us to inject information into the system prompt
which can help shape replies without repeating over and over
in messages.
This commit is contained in:
Sam 2024-11-05 07:43:26 +11:00 committed by GitHub
parent fa7ca8bc31
commit 98022d7d96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 61 additions and 43 deletions

View File

@ -111,55 +111,26 @@ module DiscourseAi
topic_id = params[:topic_id].to_i topic_id = params[:topic_id].to_i
topic = nil topic = nil
post = nil
if topic_id > 0 if topic_id > 0
topic = Topic.find(topic_id) topic = Topic.find(topic_id)
raise Discourse::NotFound if topic.nil?
if topic.topic_allowed_users.where(user_id: user.id).empty? if topic.topic_allowed_users.where(user_id: user.id).empty?
return render_json_error(I18n.t("discourse_ai.errors.user_not_allowed")) return render_json_error(I18n.t("discourse_ai.errors.user_not_allowed"))
end end
post =
PostCreator.create!(
user,
topic_id: topic_id,
raw: params[:query],
skip_validations: true,
custom_fields: {
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
},
)
else
post =
PostCreator.create!(
user,
title: I18n.t("discourse_ai.ai_bot.default_pm_prefix"),
raw: params[:query],
archetype: Archetype.private_message,
target_usernames: "#{user.username},#{persona.user.username}",
skip_validations: true,
custom_fields: {
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
},
)
topic = post.topic
end end
hijack = request.env["rack.hijack"] hijack = request.env["rack.hijack"]
io = hijack.call io = hijack.call
user = current_user
DiscourseAi::AiBot::ResponseHttpStreamer.queue_streamed_reply( DiscourseAi::AiBot::ResponseHttpStreamer.queue_streamed_reply(
io, io: io,
persona, persona: persona,
user, user: user,
topic, topic: topic,
post, query: params[:query].to_s,
custom_instructions: params[:custom_instructions].to_s,
current_user: current_user,
) )
end end

View File

@ -171,6 +171,11 @@ module DiscourseAi
DiscourseAi::Completions::Llm.proxy(self.class.question_consolidator_llm) DiscourseAi::Completions::Llm.proxy(self.class.question_consolidator_llm)
end end
if context[:custom_instructions].present?
prompt_insts << "\n"
prompt_insts << context[:custom_instructions]
end
fragments_guidance = fragments_guidance =
rag_fragments_prompt( rag_fragments_prompt(
context[:conversation_context].to_a, context[:conversation_context].to_a,

View File

@ -392,7 +392,7 @@ module DiscourseAi
result result
end end
def reply_to(post, &blk) def reply_to(post, custom_instructions: nil, &blk)
# this is a multithreading issue # this is a multithreading issue
# post custom prompt is needed and it may not # post custom prompt is needed and it may not
# be properly loaded, ensure it is loaded # be properly loaded, ensure it is loaded
@ -413,6 +413,7 @@ module DiscourseAi
context[:post_id] = post.id context[:post_id] = post.id
context[:topic_id] = post.topic_id context[:topic_id] = post.topic_id
context[:private_message] = post.topic.private_message? context[:private_message] = post.topic.private_message?
context[:custom_instructions] = custom_instructions
reply_user = bot.bot_user reply_user = bot.bot_user
if bot.persona.class.respond_to?(:user_id) if bot.persona.class.respond_to?(:user_id)

View File

@ -31,9 +31,36 @@ module DiscourseAi
# keeping this in a static method so we don't capture ENV and other bits # keeping this in a static method so we don't capture ENV and other bits
# this allows us to release memory earlier # this allows us to release memory earlier
def queue_streamed_reply(io, persona, user, topic, post) def queue_streamed_reply(
io:,
persona:,
user:,
topic:,
query:,
custom_instructions:,
current_user:
)
schedule_block do schedule_block do
begin begin
post_params = {
raw: query,
skip_validations: true,
custom_fields: {
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
},
}
if topic
post_params[:topic_id] = topic.id
else
post_params[:title] = I18n.t("discourse_ai.ai_bot.default_pm_prefix")
post_params[:archetype] = Archetype.private_message
post_params[:target_usernames] = "#{user.username},#{persona.user.username}"
end
post = PostCreator.create!(user, post_params)
topic = post.topic
io.write "HTTP/1.1 200 OK" io.write "HTTP/1.1 200 OK"
io.write CRLF io.write CRLF
io.write "Content-Type: text/plain; charset=utf-8" io.write "Content-Type: text/plain; charset=utf-8"
@ -52,7 +79,7 @@ module DiscourseAi
io.flush io.flush
persona_class = persona_class =
DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: user) DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: current_user)
bot = DiscourseAi::AiBot::Bot.as(persona.user, persona: persona_class.new) bot = DiscourseAi::AiBot::Bot.as(persona.user, persona: persona_class.new)
data = data =
@ -69,7 +96,7 @@ module DiscourseAi
DiscourseAi::AiBot::Playground DiscourseAi::AiBot::Playground
.new(bot) .new(bot)
.reply_to(post) do |partial| .reply_to(post, custom_instructions: custom_instructions) do |partial|
next if partial.length == 0 next if partial.length == 0
data = { partial: partial }.to_json + "\n\n" data = { partial: partial }.to_json + "\n\n"
@ -88,11 +115,11 @@ module DiscourseAi
io.write CRLF io.write CRLF
io.flush io.flush
io.done io.done if io.respond_to?(:done)
rescue StandardError => e rescue StandardError => e
# make it a tiny bit easier to debug in dev, this is tricky # make it a tiny bit easier to debug in dev, this is tricky
# multi-threaded code that exhibits various limitations in rails # multi-threaded code that exhibits various limitations in rails
p e if Rails.env.development? p e if Rails.env.development? || Rails.env.test?
Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply") Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply")
ensure ensure
io.close io.close

View File

@ -104,6 +104,10 @@ module DiscourseAi
@last_call = params @last_call = params
end end
def self.previous_calls
@previous_calls ||= []
end
def self.reset! def self.reset!
@last_call = nil @last_call = nil
@fake_content = nil @fake_content = nil
@ -118,7 +122,11 @@ module DiscourseAi
feature_name: nil, feature_name: nil,
feature_context: nil feature_context: nil
) )
self.class.last_call = { dialect: dialect, user: user, model_params: model_params } last_call = { dialect: dialect, user: user, model_params: model_params }
self.class.last_call = last_call
self.class.previous_calls << last_call
# guard memory in test
self.class.previous_calls.shift if self.class.previous_calls.length > 10
content = self.class.fake_content content = self.class.fake_content

View File

@ -500,6 +500,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
default_llm: "custom:#{llm.id}", default_llm: "custom:#{llm.id}",
allow_personal_messages: true, allow_personal_messages: true,
system_prompt: "you are a helpful bot",
) )
io_out, io_in = IO.pipe io_out, io_in = IO.pipe
@ -510,6 +511,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
query: "how are you today?", query: "how are you today?",
user_unique_id: "site:test.com:user_id:1", user_unique_id: "site:test.com:user_id:1",
preferred_username: "test_user", preferred_username: "test_user",
custom_instructions: "To be appended to system prompt",
}, },
env: { env: {
"rack.hijack" => lambda { io_in }, "rack.hijack" => lambda { io_in },
@ -521,6 +523,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
raw = io_out.read raw = io_out.read
context_info = validate_streamed_response(raw, "This is a test! Testing!") context_info = validate_streamed_response(raw, "This is a test! Testing!")
system_prompt = fake_endpoint.previous_calls[-2][:dialect].prompt.messages.first[:content]
expect(system_prompt).to eq("you are a helpful bot\nTo be appended to system prompt")
expect(context_info["topic_id"]).to be_present expect(context_info["topic_id"]).to be_present
topic = Topic.find(context_info["topic_id"]) topic = Topic.find(context_info["topic_id"])
last_post = topic.posts.order(:created_at).last last_post = topic.posts.order(:created_at).last