FEATURE: support custom instructions for persona streaming (#890)
This allows us to inject information into the system prompt which can help shape replies without repeating over and over in messages.
This commit is contained in:
parent
fa7ca8bc31
commit
98022d7d96
|
@ -111,55 +111,26 @@ module DiscourseAi
|
|||
|
||||
topic_id = params[:topic_id].to_i
|
||||
topic = nil
|
||||
post = nil
|
||||
|
||||
if topic_id > 0
|
||||
topic = Topic.find(topic_id)
|
||||
|
||||
raise Discourse::NotFound if topic.nil?
|
||||
|
||||
if topic.topic_allowed_users.where(user_id: user.id).empty?
|
||||
return render_json_error(I18n.t("discourse_ai.errors.user_not_allowed"))
|
||||
end
|
||||
|
||||
post =
|
||||
PostCreator.create!(
|
||||
user,
|
||||
topic_id: topic_id,
|
||||
raw: params[:query],
|
||||
skip_validations: true,
|
||||
custom_fields: {
|
||||
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
|
||||
},
|
||||
)
|
||||
else
|
||||
post =
|
||||
PostCreator.create!(
|
||||
user,
|
||||
title: I18n.t("discourse_ai.ai_bot.default_pm_prefix"),
|
||||
raw: params[:query],
|
||||
archetype: Archetype.private_message,
|
||||
target_usernames: "#{user.username},#{persona.user.username}",
|
||||
skip_validations: true,
|
||||
custom_fields: {
|
||||
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
|
||||
},
|
||||
)
|
||||
|
||||
topic = post.topic
|
||||
end
|
||||
|
||||
hijack = request.env["rack.hijack"]
|
||||
io = hijack.call
|
||||
|
||||
user = current_user
|
||||
|
||||
DiscourseAi::AiBot::ResponseHttpStreamer.queue_streamed_reply(
|
||||
io,
|
||||
persona,
|
||||
user,
|
||||
topic,
|
||||
post,
|
||||
io: io,
|
||||
persona: persona,
|
||||
user: user,
|
||||
topic: topic,
|
||||
query: params[:query].to_s,
|
||||
custom_instructions: params[:custom_instructions].to_s,
|
||||
current_user: current_user,
|
||||
)
|
||||
end
|
||||
|
||||
|
|
|
@ -171,6 +171,11 @@ module DiscourseAi
|
|||
DiscourseAi::Completions::Llm.proxy(self.class.question_consolidator_llm)
|
||||
end
|
||||
|
||||
if context[:custom_instructions].present?
|
||||
prompt_insts << "\n"
|
||||
prompt_insts << context[:custom_instructions]
|
||||
end
|
||||
|
||||
fragments_guidance =
|
||||
rag_fragments_prompt(
|
||||
context[:conversation_context].to_a,
|
||||
|
|
|
@ -392,7 +392,7 @@ module DiscourseAi
|
|||
result
|
||||
end
|
||||
|
||||
def reply_to(post, &blk)
|
||||
def reply_to(post, custom_instructions: nil, &blk)
|
||||
# this is a multithreading issue
|
||||
# post custom prompt is needed and it may not
|
||||
# be properly loaded, ensure it is loaded
|
||||
|
@ -413,6 +413,7 @@ module DiscourseAi
|
|||
context[:post_id] = post.id
|
||||
context[:topic_id] = post.topic_id
|
||||
context[:private_message] = post.topic.private_message?
|
||||
context[:custom_instructions] = custom_instructions
|
||||
|
||||
reply_user = bot.bot_user
|
||||
if bot.persona.class.respond_to?(:user_id)
|
||||
|
|
|
@ -31,9 +31,36 @@ module DiscourseAi
|
|||
|
||||
# keeping this in a static method so we don't capture ENV and other bits
|
||||
# this allows us to release memory earlier
|
||||
def queue_streamed_reply(io, persona, user, topic, post)
|
||||
def queue_streamed_reply(
|
||||
io:,
|
||||
persona:,
|
||||
user:,
|
||||
topic:,
|
||||
query:,
|
||||
custom_instructions:,
|
||||
current_user:
|
||||
)
|
||||
schedule_block do
|
||||
begin
|
||||
post_params = {
|
||||
raw: query,
|
||||
skip_validations: true,
|
||||
custom_fields: {
|
||||
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
|
||||
},
|
||||
}
|
||||
|
||||
if topic
|
||||
post_params[:topic_id] = topic.id
|
||||
else
|
||||
post_params[:title] = I18n.t("discourse_ai.ai_bot.default_pm_prefix")
|
||||
post_params[:archetype] = Archetype.private_message
|
||||
post_params[:target_usernames] = "#{user.username},#{persona.user.username}"
|
||||
end
|
||||
|
||||
post = PostCreator.create!(user, post_params)
|
||||
topic = post.topic
|
||||
|
||||
io.write "HTTP/1.1 200 OK"
|
||||
io.write CRLF
|
||||
io.write "Content-Type: text/plain; charset=utf-8"
|
||||
|
@ -52,7 +79,7 @@ module DiscourseAi
|
|||
io.flush
|
||||
|
||||
persona_class =
|
||||
DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: user)
|
||||
DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: current_user)
|
||||
bot = DiscourseAi::AiBot::Bot.as(persona.user, persona: persona_class.new)
|
||||
|
||||
data =
|
||||
|
@ -69,7 +96,7 @@ module DiscourseAi
|
|||
|
||||
DiscourseAi::AiBot::Playground
|
||||
.new(bot)
|
||||
.reply_to(post) do |partial|
|
||||
.reply_to(post, custom_instructions: custom_instructions) do |partial|
|
||||
next if partial.length == 0
|
||||
|
||||
data = { partial: partial }.to_json + "\n\n"
|
||||
|
@ -88,11 +115,11 @@ module DiscourseAi
|
|||
io.write CRLF
|
||||
|
||||
io.flush
|
||||
io.done
|
||||
io.done if io.respond_to?(:done)
|
||||
rescue StandardError => e
|
||||
# make it a tiny bit easier to debug in dev, this is tricky
|
||||
# multi-threaded code that exhibits various limitations in rails
|
||||
p e if Rails.env.development?
|
||||
p e if Rails.env.development? || Rails.env.test?
|
||||
Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply")
|
||||
ensure
|
||||
io.close
|
||||
|
|
|
@ -104,6 +104,10 @@ module DiscourseAi
|
|||
@last_call = params
|
||||
end
|
||||
|
||||
def self.previous_calls
|
||||
@previous_calls ||= []
|
||||
end
|
||||
|
||||
def self.reset!
|
||||
@last_call = nil
|
||||
@fake_content = nil
|
||||
|
@ -118,7 +122,11 @@ module DiscourseAi
|
|||
feature_name: nil,
|
||||
feature_context: nil
|
||||
)
|
||||
self.class.last_call = { dialect: dialect, user: user, model_params: model_params }
|
||||
last_call = { dialect: dialect, user: user, model_params: model_params }
|
||||
self.class.last_call = last_call
|
||||
self.class.previous_calls << last_call
|
||||
# guard memory in test
|
||||
self.class.previous_calls.shift if self.class.previous_calls.length > 10
|
||||
|
||||
content = self.class.fake_content
|
||||
|
||||
|
|
|
@ -500,6 +500,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
|
||||
default_llm: "custom:#{llm.id}",
|
||||
allow_personal_messages: true,
|
||||
system_prompt: "you are a helpful bot",
|
||||
)
|
||||
|
||||
io_out, io_in = IO.pipe
|
||||
|
@ -510,6 +511,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
query: "how are you today?",
|
||||
user_unique_id: "site:test.com:user_id:1",
|
||||
preferred_username: "test_user",
|
||||
custom_instructions: "To be appended to system prompt",
|
||||
},
|
||||
env: {
|
||||
"rack.hijack" => lambda { io_in },
|
||||
|
@ -521,6 +523,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
raw = io_out.read
|
||||
context_info = validate_streamed_response(raw, "This is a test! Testing!")
|
||||
|
||||
system_prompt = fake_endpoint.previous_calls[-2][:dialect].prompt.messages.first[:content]
|
||||
|
||||
expect(system_prompt).to eq("you are a helpful bot\nTo be appended to system prompt")
|
||||
|
||||
expect(context_info["topic_id"]).to be_present
|
||||
topic = Topic.find(context_info["topic_id"])
|
||||
last_post = topic.posts.order(:created_at).last
|
||||
|
|
Loading…
Reference in New Issue