FEATURE: new endpoint for directly accessing a persona (#876)
The new `/admin/plugins/discourse-ai/ai-personas/stream-reply.json` was added. This endpoint streams data direct from a persona and can be used to access a persona from remote systems leaving a paper trail in PMs about the conversation that happened This endpoint is only accessible to admins. --------- Co-authored-by: Gabriel Grubba <70247653+Grubba27@users.noreply.github.com> Co-authored-by: Keegan George <kgeorge13@gmail.com>
This commit is contained in:
parent
05790a6a40
commit
be0b78cacd
|
@ -74,8 +74,205 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
class << self
|
||||
POOL_SIZE = 10
|
||||
def thread_pool
|
||||
@thread_pool ||=
|
||||
Concurrent::CachedThreadPool.new(min_threads: 0, max_threads: POOL_SIZE, idletime: 30)
|
||||
end
|
||||
|
||||
def schedule_block(&block)
|
||||
# think about a better way to handle cross thread connections
|
||||
if Rails.env.test?
|
||||
block.call
|
||||
return
|
||||
end
|
||||
|
||||
db = RailsMultisite::ConnectionManagement.current_db
|
||||
thread_pool.post do
|
||||
begin
|
||||
RailsMultisite::ConnectionManagement.with_connection(db) { block.call }
|
||||
rescue StandardError => e
|
||||
Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply")
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
CRLF = "\r\n"
|
||||
|
||||
def stream_reply
|
||||
persona =
|
||||
AiPersona.find_by(name: params[:persona_name]) ||
|
||||
AiPersona.find_by(id: params[:persona_id])
|
||||
return render_json_error(I18n.t("discourse_ai.errors.persona_not_found")) if persona.nil?
|
||||
|
||||
return render_json_error(I18n.t("discourse_ai.errors.persona_disabled")) if !persona.enabled
|
||||
|
||||
if persona.default_llm.blank?
|
||||
return render_json_error(I18n.t("discourse_ai.errors.no_default_llm"))
|
||||
end
|
||||
|
||||
if params[:query].blank?
|
||||
return render_json_error(I18n.t("discourse_ai.errors.no_query_specified"))
|
||||
end
|
||||
|
||||
if !persona.user_id
|
||||
return render_json_error(I18n.t("discourse_ai.errors.no_user_for_persona"))
|
||||
end
|
||||
|
||||
if !params[:username] && !params[:user_unique_id]
|
||||
return render_json_error(I18n.t("discourse_ai.errors.no_user_specified"))
|
||||
end
|
||||
|
||||
user = nil
|
||||
|
||||
if params[:username]
|
||||
user = User.find_by_username(params[:username])
|
||||
return render_json_error(I18n.t("discourse_ai.errors.user_not_found")) if user.nil?
|
||||
elsif params[:user_unique_id]
|
||||
user = stage_user
|
||||
end
|
||||
|
||||
raise Discourse::NotFound if user.nil?
|
||||
|
||||
topic_id = params[:topic_id].to_i
|
||||
topic = nil
|
||||
post = nil
|
||||
|
||||
if topic_id > 0
|
||||
topic = Topic.find(topic_id)
|
||||
|
||||
raise Discourse::NotFound if topic.nil?
|
||||
|
||||
if topic.topic_allowed_users.where(user_id: user.id).empty?
|
||||
return render_json_error(I18n.t("discourse_ai.errors.user_not_allowed"))
|
||||
end
|
||||
|
||||
post =
|
||||
PostCreator.create!(
|
||||
user,
|
||||
topic_id: topic_id,
|
||||
raw: params[:query],
|
||||
skip_validations: true,
|
||||
)
|
||||
else
|
||||
post =
|
||||
PostCreator.create!(
|
||||
user,
|
||||
title: I18n.t("discourse_ai.ai_bot.default_pm_prefix"),
|
||||
raw: params[:query],
|
||||
archetype: Archetype.private_message,
|
||||
target_usernames: "#{user.username},#{persona.user.username}",
|
||||
skip_validations: true,
|
||||
)
|
||||
|
||||
topic = post.topic
|
||||
end
|
||||
|
||||
hijack = request.env["rack.hijack"]
|
||||
io = hijack.call
|
||||
|
||||
user = current_user
|
||||
|
||||
self.class.queue_streamed_reply(io, persona, user, topic, post)
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
AI_STREAM_CONVERSATION_UNIQUE_ID = "ai-stream-conversation-unique-id"
|
||||
|
||||
# keeping this in a static method so we don't capture ENV and other bits
|
||||
# this allows us to release memory earlier
|
||||
def self.queue_streamed_reply(io, persona, user, topic, post)
|
||||
schedule_block do
|
||||
begin
|
||||
io.write "HTTP/1.1 200 OK"
|
||||
io.write CRLF
|
||||
io.write "Content-Type: text/plain; charset=utf-8"
|
||||
io.write CRLF
|
||||
io.write "Transfer-Encoding: chunked"
|
||||
io.write CRLF
|
||||
io.write "Cache-Control: no-cache, no-store, must-revalidate"
|
||||
io.write CRLF
|
||||
io.write "Connection: close"
|
||||
io.write CRLF
|
||||
io.write "X-Accel-Buffering: no"
|
||||
io.write CRLF
|
||||
io.write "X-Content-Type-Options: nosniff"
|
||||
io.write CRLF
|
||||
io.write CRLF
|
||||
io.flush
|
||||
|
||||
persona_class =
|
||||
DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: user)
|
||||
bot = DiscourseAi::AiBot::Bot.as(persona.user, persona: persona_class.new)
|
||||
|
||||
data =
|
||||
{ topic_id: topic.id, bot_user_id: persona.user.id, persona_id: persona.id }.to_json +
|
||||
"\n\n"
|
||||
|
||||
io.write data.bytesize.to_s(16)
|
||||
io.write CRLF
|
||||
io.write data
|
||||
io.write CRLF
|
||||
|
||||
DiscourseAi::AiBot::Playground
|
||||
.new(bot)
|
||||
.reply_to(post) do |partial|
|
||||
next if partial.length == 0
|
||||
|
||||
data = { partial: partial }.to_json + "\n\n"
|
||||
|
||||
data.force_encoding("UTF-8")
|
||||
|
||||
io.write data.bytesize.to_s(16)
|
||||
io.write CRLF
|
||||
io.write data
|
||||
io.write CRLF
|
||||
io.flush
|
||||
end
|
||||
|
||||
io.write "0"
|
||||
io.write CRLF
|
||||
io.write CRLF
|
||||
|
||||
io.flush
|
||||
io.done
|
||||
rescue StandardError => e
|
||||
# make it a tiny bit easier to debug in dev, this is tricky
|
||||
# multi-threaded code that exhibits various limitations in rails
|
||||
p e if Rails.env.development?
|
||||
Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply")
|
||||
ensure
|
||||
io.close
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def stage_user
|
||||
unique_id = params[:user_unique_id].to_s
|
||||
field = UserCustomField.find_by(name: AI_STREAM_CONVERSATION_UNIQUE_ID, value: unique_id)
|
||||
|
||||
if field
|
||||
field.user
|
||||
else
|
||||
preferred_username = params[:preferred_username]
|
||||
username = UserNameSuggester.suggest(preferred_username || unique_id)
|
||||
|
||||
user =
|
||||
User.new(
|
||||
username: username,
|
||||
email: "#{SecureRandom.hex}@invalid.com",
|
||||
staged: true,
|
||||
active: false,
|
||||
)
|
||||
user.custom_fields[AI_STREAM_CONVERSATION_UNIQUE_ID] = unique_id
|
||||
user.save!
|
||||
user
|
||||
end
|
||||
end
|
||||
|
||||
def find_ai_persona
|
||||
@ai_persona = AiPersona.find(params[:id])
|
||||
end
|
||||
|
|
|
@ -56,7 +56,7 @@ class CompletionPrompt < ActiveRecord::Base
|
|||
messages.each_with_index do |msg, idx|
|
||||
next if msg["content"].length <= 1000
|
||||
|
||||
errors.add(:messages, I18n.t("errors.prompt_message_length", idx: idx + 1))
|
||||
errors.add(:messages, I18n.t("discourse_ai.errors.prompt_message_length", idx: idx + 1))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -5,7 +5,8 @@ en:
|
|||
scopes:
|
||||
descriptions:
|
||||
discourse_ai:
|
||||
search: "Allows semantic search via the /discourse-ai/embeddings/semantic-search endpoint."
|
||||
search: "Allows semantic search"
|
||||
stream_completion: "Allows streaming ai persona completions"
|
||||
|
||||
site_settings:
|
||||
categories:
|
||||
|
|
|
@ -106,10 +106,6 @@ en:
|
|||
flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic.
|
||||
flagged_by_nsfw: The AI plugin flagged this after classifying at least one of the attached images as NSFW.
|
||||
|
||||
errors:
|
||||
prompt_message_length: The message %{idx} is over the 1000 character limit.
|
||||
invalid_prompt_role: The message %{idx} has an invalid role.
|
||||
|
||||
reports:
|
||||
overall_sentiment:
|
||||
title: "Overall sentiment"
|
||||
|
@ -169,6 +165,7 @@ en:
|
|||
failed_to_share: "Failed to share the conversation"
|
||||
conversation_deleted: "Conversation share deleted successfully"
|
||||
ai_bot:
|
||||
default_pm_prefix: "[Untitled AI bot PM]"
|
||||
personas:
|
||||
default_llm_required: "Default LLM model is required prior to enabling Chat"
|
||||
cannot_delete_system_persona: "System personas cannot be deleted, please disable it instead"
|
||||
|
@ -347,3 +344,14 @@ en:
|
|||
llm_models:
|
||||
missing_provider_param: "%{param} can't be blank"
|
||||
bedrock_invalid_url: "Please complete all the fields to contact this model."
|
||||
|
||||
errors:
|
||||
no_query_specified: The query parameter is required, please specify it.
|
||||
no_user_for_persona: The persona specified does not have a user associated with it.
|
||||
persona_not_found: The persona specified does not exist. Check the persona_name or persona_id params.
|
||||
no_user_specified: The username or the user_unique_id parameter is required, please specify it.
|
||||
user_not_found: The user specified does not exist. Check the username param.
|
||||
persona_disabled: The persona specified is disabled. Check the persona_name or persona_id params.
|
||||
no_default_llm: The persona must have a default_llm defined.
|
||||
user_not_allowed: The user is not allowed to participate in the topic.
|
||||
prompt_message_length: The message %{idx} is over the 1000 character limit.
|
||||
|
|
|
@ -50,6 +50,8 @@ Discourse::Application.routes.draw do
|
|||
path: "ai-personas",
|
||||
controller: "discourse_ai/admin/ai_personas"
|
||||
|
||||
post "/ai-personas/stream-reply" => "discourse_ai/admin/ai_personas#stream_reply"
|
||||
|
||||
resources(
|
||||
:ai_tools,
|
||||
only: %i[index create show update destroy],
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
# frozen_string_literal: true
|
||||
class AddUniqueAiStreamConversationUserIdIndex < ActiveRecord::Migration[7.1]
|
||||
def change
|
||||
add_index :user_custom_fields,
|
||||
[:value],
|
||||
unique: true,
|
||||
where: "name = 'ai-stream-conversation-unique-id'"
|
||||
end
|
||||
end
|
|
@ -113,7 +113,7 @@ module DiscourseAi
|
|||
tool_found = true
|
||||
# a bit hacky, but extra newlines do no harm
|
||||
if needs_newlines
|
||||
update_blk.call("\n\n", cancel, nil)
|
||||
update_blk.call("\n\n", cancel)
|
||||
needs_newlines = false
|
||||
end
|
||||
|
||||
|
@ -123,7 +123,7 @@ module DiscourseAi
|
|||
end
|
||||
else
|
||||
needs_newlines = true
|
||||
update_blk.call(partial, cancel, nil)
|
||||
update_blk.call(partial, cancel)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -191,9 +191,9 @@ module DiscourseAi
|
|||
tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw)
|
||||
|
||||
if context[:skip_tool_details] && tool.custom_raw.present?
|
||||
update_blk.call(tool.custom_raw, cancel, nil)
|
||||
update_blk.call(tool.custom_raw, cancel, nil, :custom_raw)
|
||||
elsif !context[:skip_tool_details]
|
||||
update_blk.call(tool_details, cancel, nil)
|
||||
update_blk.call(tool_details, cancel, nil, :tool_details)
|
||||
end
|
||||
|
||||
result
|
||||
|
|
|
@ -189,6 +189,11 @@ module DiscourseAi
|
|||
plugin.register_editable_topic_custom_field(:ai_persona_id)
|
||||
end
|
||||
|
||||
plugin.add_api_key_scope(
|
||||
:discourse_ai,
|
||||
{ stream_completion: { actions: %w[discourse_ai/admin/ai_personas#stream_reply] } },
|
||||
)
|
||||
|
||||
plugin.on(:site_setting_changed) do |name, old_value, new_value|
|
||||
if name == :ai_embeddings_model && SiteSetting.ai_embeddings_enabled? &&
|
||||
new_value != old_value
|
||||
|
|
|
@ -390,7 +390,12 @@ module DiscourseAi
|
|||
result
|
||||
end
|
||||
|
||||
def reply_to(post)
|
||||
def reply_to(post, &blk)
|
||||
# this is a multithreading issue
|
||||
# post custom prompt is needed and it may not
|
||||
# be properly loaded, ensure it is loaded
|
||||
PostCustomPrompt.none
|
||||
|
||||
reply = +""
|
||||
start = Time.now
|
||||
|
||||
|
@ -441,11 +446,13 @@ module DiscourseAi
|
|||
context[:skip_tool_details] ||= !bot.persona.class.tool_details
|
||||
|
||||
new_custom_prompts =
|
||||
bot.reply(context) do |partial, cancel, placeholder|
|
||||
bot.reply(context) do |partial, cancel, placeholder, type|
|
||||
reply << partial
|
||||
raw = reply.dup
|
||||
raw << "\n\n" << placeholder if placeholder.present? && !context[:skip_tool_details]
|
||||
|
||||
blk.call(partial) if blk && type != :tool_details
|
||||
|
||||
if stream_reply && !Discourse.redis.get(redis_stream_key)
|
||||
cancel&.call
|
||||
reply_post.update!(raw: reply, cooked: PrettyText.cook(reply))
|
||||
|
|
|
@ -72,6 +72,10 @@ module DiscourseAi
|
|||
@fake_content = nil
|
||||
end
|
||||
|
||||
def self.fake_content=(content)
|
||||
@fake_content = content
|
||||
end
|
||||
|
||||
def self.fake_content
|
||||
@fake_content || STOCK_CONTENT
|
||||
end
|
||||
|
@ -100,6 +104,13 @@ module DiscourseAi
|
|||
@last_call = params
|
||||
end
|
||||
|
||||
def self.reset!
|
||||
@last_call = nil
|
||||
@fake_content = nil
|
||||
@delays = nil
|
||||
@chunk_count = nil
|
||||
end
|
||||
|
||||
def perform_completion!(
|
||||
dialect,
|
||||
user,
|
||||
|
@ -111,6 +122,8 @@ module DiscourseAi
|
|||
|
||||
content = self.class.fake_content
|
||||
|
||||
content = content.shift if content.is_a?(Array)
|
||||
|
||||
if block_given?
|
||||
split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort
|
||||
indexes = [0, *split_indices, content.length]
|
||||
|
|
|
@ -407,4 +407,178 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
}.not_to change(AiPersona, :count)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#stream_reply" do
|
||||
fab!(:llm) { Fabricate(:llm_model, name: "fake_llm", provider: "fake") }
|
||||
let(:fake_endpoint) { DiscourseAi::Completions::Endpoints::Fake }
|
||||
|
||||
before { fake_endpoint.delays = [] }
|
||||
|
||||
after { fake_endpoint.reset! }
|
||||
|
||||
it "ensures persona exists" do
|
||||
post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json"
|
||||
expect(response).to have_http_status(:unprocessable_entity)
|
||||
# this ensures localization key is actually in the yaml
|
||||
expect(response.body).to include("persona_name")
|
||||
end
|
||||
|
||||
it "ensures question exists" do
|
||||
ai_persona.update!(default_llm: "custom:#{llm.id}")
|
||||
|
||||
post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json",
|
||||
params: {
|
||||
persona_id: ai_persona.id,
|
||||
user_unique_id: "site:test.com:user_id:1",
|
||||
}
|
||||
expect(response).to have_http_status(:unprocessable_entity)
|
||||
expect(response.body).to include("query")
|
||||
end
|
||||
|
||||
it "ensure persona has a user specified" do
|
||||
ai_persona.update!(default_llm: "custom:#{llm.id}")
|
||||
|
||||
post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json",
|
||||
params: {
|
||||
persona_id: ai_persona.id,
|
||||
query: "how are you today?",
|
||||
user_unique_id: "site:test.com:user_id:1",
|
||||
}
|
||||
|
||||
expect(response).to have_http_status(:unprocessable_entity)
|
||||
expect(response.body).to include("associated")
|
||||
end
|
||||
|
||||
def validate_streamed_response(raw_http, expected)
|
||||
lines = raw_http.split("\r\n")
|
||||
|
||||
header_lines, _, payload_lines = lines.chunk { |l| l == "" }.map(&:last)
|
||||
|
||||
preamble = (<<~PREAMBLE).strip
|
||||
HTTP/1.1 200 OK
|
||||
Content-Type: text/plain; charset=utf-8
|
||||
Transfer-Encoding: chunked
|
||||
Cache-Control: no-cache, no-store, must-revalidate
|
||||
Connection: close
|
||||
X-Accel-Buffering: no
|
||||
X-Content-Type-Options: nosniff
|
||||
PREAMBLE
|
||||
|
||||
expect(header_lines.join("\n")).to eq(preamble)
|
||||
|
||||
parsed = +""
|
||||
|
||||
context_info = nil
|
||||
|
||||
payload_lines.each_slice(2) do |size, data|
|
||||
size = size.to_i(16)
|
||||
data = data.to_s
|
||||
expect(data.bytesize).to eq(size)
|
||||
|
||||
if size > 0
|
||||
json = JSON.parse(data)
|
||||
parsed << json["partial"].to_s
|
||||
|
||||
context_info = json if json["topic_id"]
|
||||
end
|
||||
end
|
||||
|
||||
expect(parsed).to eq(expected)
|
||||
|
||||
context_info
|
||||
end
|
||||
|
||||
it "is able to create a new conversation" do
|
||||
Jobs.run_immediately!
|
||||
|
||||
fake_endpoint.fake_content = ["This is a test! Testing!", "An amazing title"]
|
||||
|
||||
ai_persona.create_user!
|
||||
ai_persona.update!(
|
||||
allowed_group_ids: [Group::AUTO_GROUPS[:staff]],
|
||||
default_llm: "custom:#{llm.id}",
|
||||
)
|
||||
|
||||
io_out, io_in = IO.pipe
|
||||
|
||||
post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json",
|
||||
params: {
|
||||
persona_name: ai_persona.name,
|
||||
query: "how are you today?",
|
||||
user_unique_id: "site:test.com:user_id:1",
|
||||
preferred_username: "test_user",
|
||||
},
|
||||
env: {
|
||||
"rack.hijack" => lambda { io_in },
|
||||
}
|
||||
|
||||
# this is a fake response but it catches errors
|
||||
expect(response).to have_http_status(:no_content)
|
||||
|
||||
raw = io_out.read
|
||||
context_info = validate_streamed_response(raw, "This is a test! Testing!")
|
||||
|
||||
expect(context_info["topic_id"]).to be_present
|
||||
topic = Topic.find(context_info["topic_id"])
|
||||
last_post = topic.posts.order(:created_at).last
|
||||
expect(last_post.raw).to eq("This is a test! Testing!")
|
||||
|
||||
user_post = topic.posts.find_by(post_number: 1)
|
||||
expect(user_post.raw).to eq("how are you today?")
|
||||
|
||||
# need ai persona and user
|
||||
expect(topic.topic_allowed_users.count).to eq(2)
|
||||
expect(topic.archetype).to eq(Archetype.private_message)
|
||||
expect(topic.title).to eq("An amazing title")
|
||||
|
||||
# now let's try to make a reply with a tool call
|
||||
function_call = <<~XML
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>categories</tool_name>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
|
||||
fake_endpoint.fake_content = [function_call, "this is the response after the tool"]
|
||||
# this simplifies function calls
|
||||
fake_endpoint.chunk_count = 1
|
||||
|
||||
ai_persona.update!(tools: ["Categories"])
|
||||
|
||||
io_out, io_in = IO.pipe
|
||||
|
||||
post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json",
|
||||
params: {
|
||||
persona_id: ai_persona.id,
|
||||
query: "how are you now?",
|
||||
user_unique_id: "site:test.com:user_id:1",
|
||||
preferred_username: "test_user",
|
||||
topic_id: context_info["topic_id"],
|
||||
},
|
||||
env: {
|
||||
"rack.hijack" => lambda { io_in },
|
||||
}
|
||||
|
||||
# this is a fake response but it catches errors
|
||||
expect(response).to have_http_status(:no_content)
|
||||
|
||||
raw = io_out.read
|
||||
context_info = validate_streamed_response(raw, "this is the response after the tool")
|
||||
|
||||
topic = topic.reload
|
||||
last_post = topic.posts.order(:created_at).last
|
||||
|
||||
expect(last_post.raw).to end_with("this is the response after the tool")
|
||||
# function call is visible in the post
|
||||
expect(last_post.raw[0..8]).to eq("<details>")
|
||||
|
||||
user_post = topic.posts.find_by(post_number: 3)
|
||||
expect(user_post.raw).to eq("how are you now?")
|
||||
expect(user_post.user.username).to eq("test_user")
|
||||
expect(user_post.user.custom_fields).to eq(
|
||||
{ "ai-stream-conversation-unique-id" => "site:test.com:user_id:1" },
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue