FEATURE: new endpoint for directly accessing a persona (#876)

The new `/admin/plugins/discourse-ai/ai-personas/stream-reply.json` was added.

This endpoint streams data direct from a persona and can be used
to access a persona from remote systems leaving a paper trail in
PMs about the conversation that happened

This endpoint is only accessible to admins.

---------

Co-authored-by: Gabriel Grubba <70247653+Grubba27@users.noreply.github.com>
Co-authored-by: Keegan George <kgeorge13@gmail.com>
This commit is contained in:
Sam 2024-10-30 10:28:20 +11:00 committed by GitHub
parent 05790a6a40
commit be0b78cacd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 428 additions and 12 deletions

View File

@ -74,8 +74,205 @@ module DiscourseAi
end
end
class << self
POOL_SIZE = 10
def thread_pool
@thread_pool ||=
Concurrent::CachedThreadPool.new(min_threads: 0, max_threads: POOL_SIZE, idletime: 30)
end
def schedule_block(&block)
# think about a better way to handle cross thread connections
if Rails.env.test?
block.call
return
end
db = RailsMultisite::ConnectionManagement.current_db
thread_pool.post do
begin
RailsMultisite::ConnectionManagement.with_connection(db) { block.call }
rescue StandardError => e
Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply")
end
end
end
end
CRLF = "\r\n"
def stream_reply
persona =
AiPersona.find_by(name: params[:persona_name]) ||
AiPersona.find_by(id: params[:persona_id])
return render_json_error(I18n.t("discourse_ai.errors.persona_not_found")) if persona.nil?
return render_json_error(I18n.t("discourse_ai.errors.persona_disabled")) if !persona.enabled
if persona.default_llm.blank?
return render_json_error(I18n.t("discourse_ai.errors.no_default_llm"))
end
if params[:query].blank?
return render_json_error(I18n.t("discourse_ai.errors.no_query_specified"))
end
if !persona.user_id
return render_json_error(I18n.t("discourse_ai.errors.no_user_for_persona"))
end
if !params[:username] && !params[:user_unique_id]
return render_json_error(I18n.t("discourse_ai.errors.no_user_specified"))
end
user = nil
if params[:username]
user = User.find_by_username(params[:username])
return render_json_error(I18n.t("discourse_ai.errors.user_not_found")) if user.nil?
elsif params[:user_unique_id]
user = stage_user
end
raise Discourse::NotFound if user.nil?
topic_id = params[:topic_id].to_i
topic = nil
post = nil
if topic_id > 0
topic = Topic.find(topic_id)
raise Discourse::NotFound if topic.nil?
if topic.topic_allowed_users.where(user_id: user.id).empty?
return render_json_error(I18n.t("discourse_ai.errors.user_not_allowed"))
end
post =
PostCreator.create!(
user,
topic_id: topic_id,
raw: params[:query],
skip_validations: true,
)
else
post =
PostCreator.create!(
user,
title: I18n.t("discourse_ai.ai_bot.default_pm_prefix"),
raw: params[:query],
archetype: Archetype.private_message,
target_usernames: "#{user.username},#{persona.user.username}",
skip_validations: true,
)
topic = post.topic
end
hijack = request.env["rack.hijack"]
io = hijack.call
user = current_user
self.class.queue_streamed_reply(io, persona, user, topic, post)
end
private
AI_STREAM_CONVERSATION_UNIQUE_ID = "ai-stream-conversation-unique-id"
# keeping this in a static method so we don't capture ENV and other bits
# this allows us to release memory earlier
def self.queue_streamed_reply(io, persona, user, topic, post)
schedule_block do
begin
io.write "HTTP/1.1 200 OK"
io.write CRLF
io.write "Content-Type: text/plain; charset=utf-8"
io.write CRLF
io.write "Transfer-Encoding: chunked"
io.write CRLF
io.write "Cache-Control: no-cache, no-store, must-revalidate"
io.write CRLF
io.write "Connection: close"
io.write CRLF
io.write "X-Accel-Buffering: no"
io.write CRLF
io.write "X-Content-Type-Options: nosniff"
io.write CRLF
io.write CRLF
io.flush
persona_class =
DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: user)
bot = DiscourseAi::AiBot::Bot.as(persona.user, persona: persona_class.new)
data =
{ topic_id: topic.id, bot_user_id: persona.user.id, persona_id: persona.id }.to_json +
"\n\n"
io.write data.bytesize.to_s(16)
io.write CRLF
io.write data
io.write CRLF
DiscourseAi::AiBot::Playground
.new(bot)
.reply_to(post) do |partial|
next if partial.length == 0
data = { partial: partial }.to_json + "\n\n"
data.force_encoding("UTF-8")
io.write data.bytesize.to_s(16)
io.write CRLF
io.write data
io.write CRLF
io.flush
end
io.write "0"
io.write CRLF
io.write CRLF
io.flush
io.done
rescue StandardError => e
# make it a tiny bit easier to debug in dev, this is tricky
# multi-threaded code that exhibits various limitations in rails
p e if Rails.env.development?
Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply")
ensure
io.close
end
end
end
def stage_user
unique_id = params[:user_unique_id].to_s
field = UserCustomField.find_by(name: AI_STREAM_CONVERSATION_UNIQUE_ID, value: unique_id)
if field
field.user
else
preferred_username = params[:preferred_username]
username = UserNameSuggester.suggest(preferred_username || unique_id)
user =
User.new(
username: username,
email: "#{SecureRandom.hex}@invalid.com",
staged: true,
active: false,
)
user.custom_fields[AI_STREAM_CONVERSATION_UNIQUE_ID] = unique_id
user.save!
user
end
end
def find_ai_persona
@ai_persona = AiPersona.find(params[:id])
end

View File

@ -56,7 +56,7 @@ class CompletionPrompt < ActiveRecord::Base
messages.each_with_index do |msg, idx|
next if msg["content"].length <= 1000
errors.add(:messages, I18n.t("errors.prompt_message_length", idx: idx + 1))
errors.add(:messages, I18n.t("discourse_ai.errors.prompt_message_length", idx: idx + 1))
end
end
end

View File

@ -5,7 +5,8 @@ en:
scopes:
descriptions:
discourse_ai:
search: "Allows semantic search via the /discourse-ai/embeddings/semantic-search endpoint."
search: "Allows semantic search"
stream_completion: "Allows streaming ai persona completions"
site_settings:
categories:

View File

@ -106,10 +106,6 @@ en:
flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic.
flagged_by_nsfw: The AI plugin flagged this after classifying at least one of the attached images as NSFW.
errors:
prompt_message_length: The message %{idx} is over the 1000 character limit.
invalid_prompt_role: The message %{idx} has an invalid role.
reports:
overall_sentiment:
title: "Overall sentiment"
@ -169,6 +165,7 @@ en:
failed_to_share: "Failed to share the conversation"
conversation_deleted: "Conversation share deleted successfully"
ai_bot:
default_pm_prefix: "[Untitled AI bot PM]"
personas:
default_llm_required: "Default LLM model is required prior to enabling Chat"
cannot_delete_system_persona: "System personas cannot be deleted, please disable it instead"
@ -347,3 +344,14 @@ en:
llm_models:
missing_provider_param: "%{param} can't be blank"
bedrock_invalid_url: "Please complete all the fields to contact this model."
errors:
no_query_specified: The query parameter is required, please specify it.
no_user_for_persona: The persona specified does not have a user associated with it.
persona_not_found: The persona specified does not exist. Check the persona_name or persona_id params.
no_user_specified: The username or the user_unique_id parameter is required, please specify it.
user_not_found: The user specified does not exist. Check the username param.
persona_disabled: The persona specified is disabled. Check the persona_name or persona_id params.
no_default_llm: The persona must have a default_llm defined.
user_not_allowed: The user is not allowed to participate in the topic.
prompt_message_length: The message %{idx} is over the 1000 character limit.

View File

@ -50,6 +50,8 @@ Discourse::Application.routes.draw do
path: "ai-personas",
controller: "discourse_ai/admin/ai_personas"
post "/ai-personas/stream-reply" => "discourse_ai/admin/ai_personas#stream_reply"
resources(
:ai_tools,
only: %i[index create show update destroy],

View File

@ -0,0 +1,9 @@
# frozen_string_literal: true
class AddUniqueAiStreamConversationUserIdIndex < ActiveRecord::Migration[7.1]
def change
add_index :user_custom_fields,
[:value],
unique: true,
where: "name = 'ai-stream-conversation-unique-id'"
end
end

View File

@ -113,7 +113,7 @@ module DiscourseAi
tool_found = true
# a bit hacky, but extra newlines do no harm
if needs_newlines
update_blk.call("\n\n", cancel, nil)
update_blk.call("\n\n", cancel)
needs_newlines = false
end
@ -123,7 +123,7 @@ module DiscourseAi
end
else
needs_newlines = true
update_blk.call(partial, cancel, nil)
update_blk.call(partial, cancel)
end
end
@ -191,9 +191,9 @@ module DiscourseAi
tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw)
if context[:skip_tool_details] && tool.custom_raw.present?
update_blk.call(tool.custom_raw, cancel, nil)
update_blk.call(tool.custom_raw, cancel, nil, :custom_raw)
elsif !context[:skip_tool_details]
update_blk.call(tool_details, cancel, nil)
update_blk.call(tool_details, cancel, nil, :tool_details)
end
result

View File

@ -189,6 +189,11 @@ module DiscourseAi
plugin.register_editable_topic_custom_field(:ai_persona_id)
end
plugin.add_api_key_scope(
:discourse_ai,
{ stream_completion: { actions: %w[discourse_ai/admin/ai_personas#stream_reply] } },
)
plugin.on(:site_setting_changed) do |name, old_value, new_value|
if name == :ai_embeddings_model && SiteSetting.ai_embeddings_enabled? &&
new_value != old_value

View File

@ -390,7 +390,12 @@ module DiscourseAi
result
end
def reply_to(post)
def reply_to(post, &blk)
# this is a multithreading issue
# post custom prompt is needed and it may not
# be properly loaded, ensure it is loaded
PostCustomPrompt.none
reply = +""
start = Time.now
@ -441,11 +446,13 @@ module DiscourseAi
context[:skip_tool_details] ||= !bot.persona.class.tool_details
new_custom_prompts =
bot.reply(context) do |partial, cancel, placeholder|
bot.reply(context) do |partial, cancel, placeholder, type|
reply << partial
raw = reply.dup
raw << "\n\n" << placeholder if placeholder.present? && !context[:skip_tool_details]
blk.call(partial) if blk && type != :tool_details
if stream_reply && !Discourse.redis.get(redis_stream_key)
cancel&.call
reply_post.update!(raw: reply, cooked: PrettyText.cook(reply))

View File

@ -72,6 +72,10 @@ module DiscourseAi
@fake_content = nil
end
def self.fake_content=(content)
@fake_content = content
end
def self.fake_content
@fake_content || STOCK_CONTENT
end
@ -100,6 +104,13 @@ module DiscourseAi
@last_call = params
end
def self.reset!
@last_call = nil
@fake_content = nil
@delays = nil
@chunk_count = nil
end
def perform_completion!(
dialect,
user,
@ -111,6 +122,8 @@ module DiscourseAi
content = self.class.fake_content
content = content.shift if content.is_a?(Array)
if block_given?
split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort
indexes = [0, *split_indices, content.length]

View File

@ -407,4 +407,178 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
}.not_to change(AiPersona, :count)
end
end
describe "#stream_reply" do
fab!(:llm) { Fabricate(:llm_model, name: "fake_llm", provider: "fake") }
let(:fake_endpoint) { DiscourseAi::Completions::Endpoints::Fake }
before { fake_endpoint.delays = [] }
after { fake_endpoint.reset! }
it "ensures persona exists" do
post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json"
expect(response).to have_http_status(:unprocessable_entity)
# this ensures localization key is actually in the yaml
expect(response.body).to include("persona_name")
end
it "ensures question exists" do
ai_persona.update!(default_llm: "custom:#{llm.id}")
post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json",
params: {
persona_id: ai_persona.id,
user_unique_id: "site:test.com:user_id:1",
}
expect(response).to have_http_status(:unprocessable_entity)
expect(response.body).to include("query")
end
it "ensure persona has a user specified" do
ai_persona.update!(default_llm: "custom:#{llm.id}")
post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json",
params: {
persona_id: ai_persona.id,
query: "how are you today?",
user_unique_id: "site:test.com:user_id:1",
}
expect(response).to have_http_status(:unprocessable_entity)
expect(response.body).to include("associated")
end
def validate_streamed_response(raw_http, expected)
lines = raw_http.split("\r\n")
header_lines, _, payload_lines = lines.chunk { |l| l == "" }.map(&:last)
preamble = (<<~PREAMBLE).strip
HTTP/1.1 200 OK
Content-Type: text/plain; charset=utf-8
Transfer-Encoding: chunked
Cache-Control: no-cache, no-store, must-revalidate
Connection: close
X-Accel-Buffering: no
X-Content-Type-Options: nosniff
PREAMBLE
expect(header_lines.join("\n")).to eq(preamble)
parsed = +""
context_info = nil
payload_lines.each_slice(2) do |size, data|
size = size.to_i(16)
data = data.to_s
expect(data.bytesize).to eq(size)
if size > 0
json = JSON.parse(data)
parsed << json["partial"].to_s
context_info = json if json["topic_id"]
end
end
expect(parsed).to eq(expected)
context_info
end
it "is able to create a new conversation" do
Jobs.run_immediately!
fake_endpoint.fake_content = ["This is a test! Testing!", "An amazing title"]
ai_persona.create_user!
ai_persona.update!(
allowed_group_ids: [Group::AUTO_GROUPS[:staff]],
default_llm: "custom:#{llm.id}",
)
io_out, io_in = IO.pipe
post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json",
params: {
persona_name: ai_persona.name,
query: "how are you today?",
user_unique_id: "site:test.com:user_id:1",
preferred_username: "test_user",
},
env: {
"rack.hijack" => lambda { io_in },
}
# this is a fake response but it catches errors
expect(response).to have_http_status(:no_content)
raw = io_out.read
context_info = validate_streamed_response(raw, "This is a test! Testing!")
expect(context_info["topic_id"]).to be_present
topic = Topic.find(context_info["topic_id"])
last_post = topic.posts.order(:created_at).last
expect(last_post.raw).to eq("This is a test! Testing!")
user_post = topic.posts.find_by(post_number: 1)
expect(user_post.raw).to eq("how are you today?")
# need ai persona and user
expect(topic.topic_allowed_users.count).to eq(2)
expect(topic.archetype).to eq(Archetype.private_message)
expect(topic.title).to eq("An amazing title")
# now let's try to make a reply with a tool call
function_call = <<~XML
<function_calls>
<invoke>
<tool_name>categories</tool_name>
</invoke>
</function_calls>
XML
fake_endpoint.fake_content = [function_call, "this is the response after the tool"]
# this simplifies function calls
fake_endpoint.chunk_count = 1
ai_persona.update!(tools: ["Categories"])
io_out, io_in = IO.pipe
post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json",
params: {
persona_id: ai_persona.id,
query: "how are you now?",
user_unique_id: "site:test.com:user_id:1",
preferred_username: "test_user",
topic_id: context_info["topic_id"],
},
env: {
"rack.hijack" => lambda { io_in },
}
# this is a fake response but it catches errors
expect(response).to have_http_status(:no_content)
raw = io_out.read
context_info = validate_streamed_response(raw, "this is the response after the tool")
topic = topic.reload
last_post = topic.posts.order(:created_at).last
expect(last_post.raw).to end_with("this is the response after the tool")
# function call is visible in the post
expect(last_post.raw[0..8]).to eq("<details>")
user_post = topic.posts.find_by(post_number: 3)
expect(user_post.raw).to eq("how are you now?")
expect(user_post.user.username).to eq("test_user")
expect(user_post.user.custom_fields).to eq(
{ "ai-stream-conversation-unique-id" => "site:test.com:user_id:1" },
)
end
end
end