FEATURE: support custom instructions for persona streaming (#890)
This allows us to inject information into the system prompt which can help shape replies without repeating over and over in messages.
This commit is contained in:
parent
fa7ca8bc31
commit
98022d7d96
|
@ -111,55 +111,26 @@ module DiscourseAi
|
||||||
|
|
||||||
topic_id = params[:topic_id].to_i
|
topic_id = params[:topic_id].to_i
|
||||||
topic = nil
|
topic = nil
|
||||||
post = nil
|
|
||||||
|
|
||||||
if topic_id > 0
|
if topic_id > 0
|
||||||
topic = Topic.find(topic_id)
|
topic = Topic.find(topic_id)
|
||||||
|
|
||||||
raise Discourse::NotFound if topic.nil?
|
|
||||||
|
|
||||||
if topic.topic_allowed_users.where(user_id: user.id).empty?
|
if topic.topic_allowed_users.where(user_id: user.id).empty?
|
||||||
return render_json_error(I18n.t("discourse_ai.errors.user_not_allowed"))
|
return render_json_error(I18n.t("discourse_ai.errors.user_not_allowed"))
|
||||||
end
|
end
|
||||||
|
|
||||||
post =
|
|
||||||
PostCreator.create!(
|
|
||||||
user,
|
|
||||||
topic_id: topic_id,
|
|
||||||
raw: params[:query],
|
|
||||||
skip_validations: true,
|
|
||||||
custom_fields: {
|
|
||||||
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else
|
|
||||||
post =
|
|
||||||
PostCreator.create!(
|
|
||||||
user,
|
|
||||||
title: I18n.t("discourse_ai.ai_bot.default_pm_prefix"),
|
|
||||||
raw: params[:query],
|
|
||||||
archetype: Archetype.private_message,
|
|
||||||
target_usernames: "#{user.username},#{persona.user.username}",
|
|
||||||
skip_validations: true,
|
|
||||||
custom_fields: {
|
|
||||||
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
topic = post.topic
|
|
||||||
end
|
end
|
||||||
|
|
||||||
hijack = request.env["rack.hijack"]
|
hijack = request.env["rack.hijack"]
|
||||||
io = hijack.call
|
io = hijack.call
|
||||||
|
|
||||||
user = current_user
|
|
||||||
|
|
||||||
DiscourseAi::AiBot::ResponseHttpStreamer.queue_streamed_reply(
|
DiscourseAi::AiBot::ResponseHttpStreamer.queue_streamed_reply(
|
||||||
io,
|
io: io,
|
||||||
persona,
|
persona: persona,
|
||||||
user,
|
user: user,
|
||||||
topic,
|
topic: topic,
|
||||||
post,
|
query: params[:query].to_s,
|
||||||
|
custom_instructions: params[:custom_instructions].to_s,
|
||||||
|
current_user: current_user,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -171,6 +171,11 @@ module DiscourseAi
|
||||||
DiscourseAi::Completions::Llm.proxy(self.class.question_consolidator_llm)
|
DiscourseAi::Completions::Llm.proxy(self.class.question_consolidator_llm)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
if context[:custom_instructions].present?
|
||||||
|
prompt_insts << "\n"
|
||||||
|
prompt_insts << context[:custom_instructions]
|
||||||
|
end
|
||||||
|
|
||||||
fragments_guidance =
|
fragments_guidance =
|
||||||
rag_fragments_prompt(
|
rag_fragments_prompt(
|
||||||
context[:conversation_context].to_a,
|
context[:conversation_context].to_a,
|
||||||
|
|
|
@ -392,7 +392,7 @@ module DiscourseAi
|
||||||
result
|
result
|
||||||
end
|
end
|
||||||
|
|
||||||
def reply_to(post, &blk)
|
def reply_to(post, custom_instructions: nil, &blk)
|
||||||
# this is a multithreading issue
|
# this is a multithreading issue
|
||||||
# post custom prompt is needed and it may not
|
# post custom prompt is needed and it may not
|
||||||
# be properly loaded, ensure it is loaded
|
# be properly loaded, ensure it is loaded
|
||||||
|
@ -413,6 +413,7 @@ module DiscourseAi
|
||||||
context[:post_id] = post.id
|
context[:post_id] = post.id
|
||||||
context[:topic_id] = post.topic_id
|
context[:topic_id] = post.topic_id
|
||||||
context[:private_message] = post.topic.private_message?
|
context[:private_message] = post.topic.private_message?
|
||||||
|
context[:custom_instructions] = custom_instructions
|
||||||
|
|
||||||
reply_user = bot.bot_user
|
reply_user = bot.bot_user
|
||||||
if bot.persona.class.respond_to?(:user_id)
|
if bot.persona.class.respond_to?(:user_id)
|
||||||
|
|
|
@ -31,9 +31,36 @@ module DiscourseAi
|
||||||
|
|
||||||
# keeping this in a static method so we don't capture ENV and other bits
|
# keeping this in a static method so we don't capture ENV and other bits
|
||||||
# this allows us to release memory earlier
|
# this allows us to release memory earlier
|
||||||
def queue_streamed_reply(io, persona, user, topic, post)
|
def queue_streamed_reply(
|
||||||
|
io:,
|
||||||
|
persona:,
|
||||||
|
user:,
|
||||||
|
topic:,
|
||||||
|
query:,
|
||||||
|
custom_instructions:,
|
||||||
|
current_user:
|
||||||
|
)
|
||||||
schedule_block do
|
schedule_block do
|
||||||
begin
|
begin
|
||||||
|
post_params = {
|
||||||
|
raw: query,
|
||||||
|
skip_validations: true,
|
||||||
|
custom_fields: {
|
||||||
|
DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if topic
|
||||||
|
post_params[:topic_id] = topic.id
|
||||||
|
else
|
||||||
|
post_params[:title] = I18n.t("discourse_ai.ai_bot.default_pm_prefix")
|
||||||
|
post_params[:archetype] = Archetype.private_message
|
||||||
|
post_params[:target_usernames] = "#{user.username},#{persona.user.username}"
|
||||||
|
end
|
||||||
|
|
||||||
|
post = PostCreator.create!(user, post_params)
|
||||||
|
topic = post.topic
|
||||||
|
|
||||||
io.write "HTTP/1.1 200 OK"
|
io.write "HTTP/1.1 200 OK"
|
||||||
io.write CRLF
|
io.write CRLF
|
||||||
io.write "Content-Type: text/plain; charset=utf-8"
|
io.write "Content-Type: text/plain; charset=utf-8"
|
||||||
|
@ -52,7 +79,7 @@ module DiscourseAi
|
||||||
io.flush
|
io.flush
|
||||||
|
|
||||||
persona_class =
|
persona_class =
|
||||||
DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: user)
|
DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: current_user)
|
||||||
bot = DiscourseAi::AiBot::Bot.as(persona.user, persona: persona_class.new)
|
bot = DiscourseAi::AiBot::Bot.as(persona.user, persona: persona_class.new)
|
||||||
|
|
||||||
data =
|
data =
|
||||||
|
@ -69,7 +96,7 @@ module DiscourseAi
|
||||||
|
|
||||||
DiscourseAi::AiBot::Playground
|
DiscourseAi::AiBot::Playground
|
||||||
.new(bot)
|
.new(bot)
|
||||||
.reply_to(post) do |partial|
|
.reply_to(post, custom_instructions: custom_instructions) do |partial|
|
||||||
next if partial.length == 0
|
next if partial.length == 0
|
||||||
|
|
||||||
data = { partial: partial }.to_json + "\n\n"
|
data = { partial: partial }.to_json + "\n\n"
|
||||||
|
@ -88,11 +115,11 @@ module DiscourseAi
|
||||||
io.write CRLF
|
io.write CRLF
|
||||||
|
|
||||||
io.flush
|
io.flush
|
||||||
io.done
|
io.done if io.respond_to?(:done)
|
||||||
rescue StandardError => e
|
rescue StandardError => e
|
||||||
# make it a tiny bit easier to debug in dev, this is tricky
|
# make it a tiny bit easier to debug in dev, this is tricky
|
||||||
# multi-threaded code that exhibits various limitations in rails
|
# multi-threaded code that exhibits various limitations in rails
|
||||||
p e if Rails.env.development?
|
p e if Rails.env.development? || Rails.env.test?
|
||||||
Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply")
|
Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply")
|
||||||
ensure
|
ensure
|
||||||
io.close
|
io.close
|
||||||
|
|
|
@ -104,6 +104,10 @@ module DiscourseAi
|
||||||
@last_call = params
|
@last_call = params
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def self.previous_calls
|
||||||
|
@previous_calls ||= []
|
||||||
|
end
|
||||||
|
|
||||||
def self.reset!
|
def self.reset!
|
||||||
@last_call = nil
|
@last_call = nil
|
||||||
@fake_content = nil
|
@fake_content = nil
|
||||||
|
@ -118,7 +122,11 @@ module DiscourseAi
|
||||||
feature_name: nil,
|
feature_name: nil,
|
||||||
feature_context: nil
|
feature_context: nil
|
||||||
)
|
)
|
||||||
self.class.last_call = { dialect: dialect, user: user, model_params: model_params }
|
last_call = { dialect: dialect, user: user, model_params: model_params }
|
||||||
|
self.class.last_call = last_call
|
||||||
|
self.class.previous_calls << last_call
|
||||||
|
# guard memory in test
|
||||||
|
self.class.previous_calls.shift if self.class.previous_calls.length > 10
|
||||||
|
|
||||||
content = self.class.fake_content
|
content = self.class.fake_content
|
||||||
|
|
||||||
|
|
|
@ -500,6 +500,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
||||||
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
|
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
|
||||||
default_llm: "custom:#{llm.id}",
|
default_llm: "custom:#{llm.id}",
|
||||||
allow_personal_messages: true,
|
allow_personal_messages: true,
|
||||||
|
system_prompt: "you are a helpful bot",
|
||||||
)
|
)
|
||||||
|
|
||||||
io_out, io_in = IO.pipe
|
io_out, io_in = IO.pipe
|
||||||
|
@ -510,6 +511,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
||||||
query: "how are you today?",
|
query: "how are you today?",
|
||||||
user_unique_id: "site:test.com:user_id:1",
|
user_unique_id: "site:test.com:user_id:1",
|
||||||
preferred_username: "test_user",
|
preferred_username: "test_user",
|
||||||
|
custom_instructions: "To be appended to system prompt",
|
||||||
},
|
},
|
||||||
env: {
|
env: {
|
||||||
"rack.hijack" => lambda { io_in },
|
"rack.hijack" => lambda { io_in },
|
||||||
|
@ -521,6 +523,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
||||||
raw = io_out.read
|
raw = io_out.read
|
||||||
context_info = validate_streamed_response(raw, "This is a test! Testing!")
|
context_info = validate_streamed_response(raw, "This is a test! Testing!")
|
||||||
|
|
||||||
|
system_prompt = fake_endpoint.previous_calls[-2][:dialect].prompt.messages.first[:content]
|
||||||
|
|
||||||
|
expect(system_prompt).to eq("you are a helpful bot\nTo be appended to system prompt")
|
||||||
|
|
||||||
expect(context_info["topic_id"]).to be_present
|
expect(context_info["topic_id"]).to be_present
|
||||||
topic = Topic.find(context_info["topic_id"])
|
topic = Topic.find(context_info["topic_id"])
|
||||||
last_post = topic.posts.order(:created_at).last
|
last_post = topic.posts.order(:created_at).last
|
||||||
|
|
Loading…
Reference in New Issue