diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb
index b4fd66b1..35e9a338 100644
--- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb
+++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb
@@ -74,8 +74,205 @@ module DiscourseAi
end
end
+ class << self
+ POOL_SIZE = 10
+ def thread_pool
+ @thread_pool ||=
+ Concurrent::CachedThreadPool.new(min_threads: 0, max_threads: POOL_SIZE, idletime: 30)
+ end
+
+ def schedule_block(&block)
+ # think about a better way to handle cross thread connections
+ if Rails.env.test?
+ block.call
+ return
+ end
+
+ db = RailsMultisite::ConnectionManagement.current_db
+ thread_pool.post do
+ begin
+ RailsMultisite::ConnectionManagement.with_connection(db) { block.call }
+ rescue StandardError => e
+ Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply")
+ end
+ end
+ end
+ end
+
+ CRLF = "\r\n"
+
+ def stream_reply
+ persona =
+ AiPersona.find_by(name: params[:persona_name]) ||
+ AiPersona.find_by(id: params[:persona_id])
+ return render_json_error(I18n.t("discourse_ai.errors.persona_not_found")) if persona.nil?
+
+ return render_json_error(I18n.t("discourse_ai.errors.persona_disabled")) if !persona.enabled
+
+ if persona.default_llm.blank?
+ return render_json_error(I18n.t("discourse_ai.errors.no_default_llm"))
+ end
+
+ if params[:query].blank?
+ return render_json_error(I18n.t("discourse_ai.errors.no_query_specified"))
+ end
+
+ if !persona.user_id
+ return render_json_error(I18n.t("discourse_ai.errors.no_user_for_persona"))
+ end
+
+ if !params[:username] && !params[:user_unique_id]
+ return render_json_error(I18n.t("discourse_ai.errors.no_user_specified"))
+ end
+
+ user = nil
+
+ if params[:username]
+ user = User.find_by_username(params[:username])
+ return render_json_error(I18n.t("discourse_ai.errors.user_not_found")) if user.nil?
+ elsif params[:user_unique_id]
+ user = stage_user
+ end
+
+ raise Discourse::NotFound if user.nil?
+
+ topic_id = params[:topic_id].to_i
+ topic = nil
+ post = nil
+
+ if topic_id > 0
+ topic = Topic.find(topic_id)
+
+ raise Discourse::NotFound if topic.nil?
+
+ if topic.topic_allowed_users.where(user_id: user.id).empty?
+ return render_json_error(I18n.t("discourse_ai.errors.user_not_allowed"))
+ end
+
+ post =
+ PostCreator.create!(
+ user,
+ topic_id: topic_id,
+ raw: params[:query],
+ skip_validations: true,
+ )
+ else
+ post =
+ PostCreator.create!(
+ user,
+ title: I18n.t("discourse_ai.ai_bot.default_pm_prefix"),
+ raw: params[:query],
+ archetype: Archetype.private_message,
+ target_usernames: "#{user.username},#{persona.user.username}",
+ skip_validations: true,
+ )
+
+ topic = post.topic
+ end
+
+ hijack = request.env["rack.hijack"]
+ io = hijack.call
+
+ user = current_user
+
+ self.class.queue_streamed_reply(io, persona, user, topic, post)
+ end
+
private
+ AI_STREAM_CONVERSATION_UNIQUE_ID = "ai-stream-conversation-unique-id"
+
+ # keeping this in a static method so we don't capture ENV and other bits
+ # this allows us to release memory earlier
+ def self.queue_streamed_reply(io, persona, user, topic, post)
+ schedule_block do
+ begin
+ io.write "HTTP/1.1 200 OK"
+ io.write CRLF
+ io.write "Content-Type: text/plain; charset=utf-8"
+ io.write CRLF
+ io.write "Transfer-Encoding: chunked"
+ io.write CRLF
+ io.write "Cache-Control: no-cache, no-store, must-revalidate"
+ io.write CRLF
+ io.write "Connection: close"
+ io.write CRLF
+ io.write "X-Accel-Buffering: no"
+ io.write CRLF
+ io.write "X-Content-Type-Options: nosniff"
+ io.write CRLF
+ io.write CRLF
+ io.flush
+
+ persona_class =
+ DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: user)
+ bot = DiscourseAi::AiBot::Bot.as(persona.user, persona: persona_class.new)
+
+ data =
+ { topic_id: topic.id, bot_user_id: persona.user.id, persona_id: persona.id }.to_json +
+ "\n\n"
+
+ io.write data.bytesize.to_s(16)
+ io.write CRLF
+ io.write data
+ io.write CRLF
+
+ DiscourseAi::AiBot::Playground
+ .new(bot)
+ .reply_to(post) do |partial|
+ next if partial.length == 0
+
+ data = { partial: partial }.to_json + "\n\n"
+
+ data.force_encoding("UTF-8")
+
+ io.write data.bytesize.to_s(16)
+ io.write CRLF
+ io.write data
+ io.write CRLF
+ io.flush
+ end
+
+ io.write "0"
+ io.write CRLF
+ io.write CRLF
+
+ io.flush
+ io.done
+ rescue StandardError => e
+ # make it a tiny bit easier to debug in dev, this is tricky
+ # multi-threaded code that exhibits various limitations in rails
+ p e if Rails.env.development?
+ Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply")
+ ensure
+ io.close
+ end
+ end
+ end
+
+ def stage_user
+ unique_id = params[:user_unique_id].to_s
+ field = UserCustomField.find_by(name: AI_STREAM_CONVERSATION_UNIQUE_ID, value: unique_id)
+
+ if field
+ field.user
+ else
+ preferred_username = params[:preferred_username]
+ username = UserNameSuggester.suggest(preferred_username || unique_id)
+
+ user =
+ User.new(
+ username: username,
+ email: "#{SecureRandom.hex}@invalid.com",
+ staged: true,
+ active: false,
+ )
+ user.custom_fields[AI_STREAM_CONVERSATION_UNIQUE_ID] = unique_id
+ user.save!
+ user
+ end
+ end
+
def find_ai_persona
@ai_persona = AiPersona.find(params[:id])
end
diff --git a/app/models/completion_prompt.rb b/app/models/completion_prompt.rb
index 0cd1f71d..4d2c74da 100644
--- a/app/models/completion_prompt.rb
+++ b/app/models/completion_prompt.rb
@@ -56,7 +56,7 @@ class CompletionPrompt < ActiveRecord::Base
messages.each_with_index do |msg, idx|
next if msg["content"].length <= 1000
- errors.add(:messages, I18n.t("errors.prompt_message_length", idx: idx + 1))
+ errors.add(:messages, I18n.t("discourse_ai.errors.prompt_message_length", idx: idx + 1))
end
end
end
diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml
index 471baa0d..2d517946 100644
--- a/config/locales/client.en.yml
+++ b/config/locales/client.en.yml
@@ -5,7 +5,8 @@ en:
scopes:
descriptions:
discourse_ai:
- search: "Allows semantic search via the /discourse-ai/embeddings/semantic-search endpoint."
+ search: "Allows semantic search"
+ stream_completion: "Allows streaming ai persona completions"
site_settings:
categories:
diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml
index 1ede2ad0..d383a0a3 100644
--- a/config/locales/server.en.yml
+++ b/config/locales/server.en.yml
@@ -106,10 +106,6 @@ en:
flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic.
flagged_by_nsfw: The AI plugin flagged this after classifying at least one of the attached images as NSFW.
- errors:
- prompt_message_length: The message %{idx} is over the 1000 character limit.
- invalid_prompt_role: The message %{idx} has an invalid role.
-
reports:
overall_sentiment:
title: "Overall sentiment"
@@ -169,6 +165,7 @@ en:
failed_to_share: "Failed to share the conversation"
conversation_deleted: "Conversation share deleted successfully"
ai_bot:
+ default_pm_prefix: "[Untitled AI bot PM]"
personas:
default_llm_required: "Default LLM model is required prior to enabling Chat"
cannot_delete_system_persona: "System personas cannot be deleted, please disable it instead"
@@ -347,3 +344,14 @@ en:
llm_models:
missing_provider_param: "%{param} can't be blank"
bedrock_invalid_url: "Please complete all the fields to contact this model."
+
+ errors:
+ no_query_specified: The query parameter is required, please specify it.
+ no_user_for_persona: The persona specified does not have a user associated with it.
+ persona_not_found: The persona specified does not exist. Check the persona_name or persona_id params.
+ no_user_specified: The username or the user_unique_id parameter is required, please specify it.
+ user_not_found: The user specified does not exist. Check the username param.
+ persona_disabled: The persona specified is disabled. Check the persona_name or persona_id params.
+ no_default_llm: The persona must have a default_llm defined.
+ user_not_allowed: The user is not allowed to participate in the topic.
+ prompt_message_length: The message %{idx} is over the 1000 character limit.
diff --git a/config/routes.rb b/config/routes.rb
index e806f741..a5c009ff 100644
--- a/config/routes.rb
+++ b/config/routes.rb
@@ -50,6 +50,8 @@ Discourse::Application.routes.draw do
path: "ai-personas",
controller: "discourse_ai/admin/ai_personas"
+ post "/ai-personas/stream-reply" => "discourse_ai/admin/ai_personas#stream_reply"
+
resources(
:ai_tools,
only: %i[index create show update destroy],
diff --git a/db/migrate/20241028034232_add_unique_ai_stream_conversation_user_id_index.rb b/db/migrate/20241028034232_add_unique_ai_stream_conversation_user_id_index.rb
new file mode 100644
index 00000000..cff38477
--- /dev/null
+++ b/db/migrate/20241028034232_add_unique_ai_stream_conversation_user_id_index.rb
@@ -0,0 +1,9 @@
+# frozen_string_literal: true
+class AddUniqueAiStreamConversationUserIdIndex < ActiveRecord::Migration[7.1]
+ def change
+ add_index :user_custom_fields,
+ [:value],
+ unique: true,
+ where: "name = 'ai-stream-conversation-unique-id'"
+ end
+end
diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb
index e8e1eadd..834ae059 100644
--- a/lib/ai_bot/bot.rb
+++ b/lib/ai_bot/bot.rb
@@ -113,7 +113,7 @@ module DiscourseAi
tool_found = true
# a bit hacky, but extra newlines do no harm
if needs_newlines
- update_blk.call("\n\n", cancel, nil)
+ update_blk.call("\n\n", cancel)
needs_newlines = false
end
@@ -123,7 +123,7 @@ module DiscourseAi
end
else
needs_newlines = true
- update_blk.call(partial, cancel, nil)
+ update_blk.call(partial, cancel)
end
end
@@ -191,9 +191,9 @@ module DiscourseAi
tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw)
if context[:skip_tool_details] && tool.custom_raw.present?
- update_blk.call(tool.custom_raw, cancel, nil)
+ update_blk.call(tool.custom_raw, cancel, nil, :custom_raw)
elsif !context[:skip_tool_details]
- update_blk.call(tool_details, cancel, nil)
+ update_blk.call(tool_details, cancel, nil, :tool_details)
end
result
diff --git a/lib/ai_bot/entry_point.rb b/lib/ai_bot/entry_point.rb
index 64fd9133..ff5bf69f 100644
--- a/lib/ai_bot/entry_point.rb
+++ b/lib/ai_bot/entry_point.rb
@@ -189,6 +189,11 @@ module DiscourseAi
plugin.register_editable_topic_custom_field(:ai_persona_id)
end
+ plugin.add_api_key_scope(
+ :discourse_ai,
+ { stream_completion: { actions: %w[discourse_ai/admin/ai_personas#stream_reply] } },
+ )
+
plugin.on(:site_setting_changed) do |name, old_value, new_value|
if name == :ai_embeddings_model && SiteSetting.ai_embeddings_enabled? &&
new_value != old_value
diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb
index 13834c92..a4e696db 100644
--- a/lib/ai_bot/playground.rb
+++ b/lib/ai_bot/playground.rb
@@ -390,7 +390,12 @@ module DiscourseAi
result
end
- def reply_to(post)
+ def reply_to(post, &blk)
+ # this is a multithreading issue
+ # post custom prompt is needed and it may not
+ # be properly loaded, ensure it is loaded
+ PostCustomPrompt.none
+
reply = +""
start = Time.now
@@ -441,11 +446,13 @@ module DiscourseAi
context[:skip_tool_details] ||= !bot.persona.class.tool_details
new_custom_prompts =
- bot.reply(context) do |partial, cancel, placeholder|
+ bot.reply(context) do |partial, cancel, placeholder, type|
reply << partial
raw = reply.dup
raw << "\n\n" << placeholder if placeholder.present? && !context[:skip_tool_details]
+ blk.call(partial) if blk && type != :tool_details
+
if stream_reply && !Discourse.redis.get(redis_stream_key)
cancel&.call
reply_post.update!(raw: reply, cooked: PrettyText.cook(reply))
diff --git a/lib/completions/endpoints/fake.rb b/lib/completions/endpoints/fake.rb
index 2beec61a..d057fb15 100644
--- a/lib/completions/endpoints/fake.rb
+++ b/lib/completions/endpoints/fake.rb
@@ -72,6 +72,10 @@ module DiscourseAi
@fake_content = nil
end
+ def self.fake_content=(content)
+ @fake_content = content
+ end
+
def self.fake_content
@fake_content || STOCK_CONTENT
end
@@ -100,6 +104,13 @@ module DiscourseAi
@last_call = params
end
+ def self.reset!
+ @last_call = nil
+ @fake_content = nil
+ @delays = nil
+ @chunk_count = nil
+ end
+
def perform_completion!(
dialect,
user,
@@ -111,6 +122,8 @@ module DiscourseAi
content = self.class.fake_content
+ content = content.shift if content.is_a?(Array)
+
if block_given?
split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort
indexes = [0, *split_indices, content.length]
diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb
index 1dee9b73..bc7d10ef 100644
--- a/spec/requests/admin/ai_personas_controller_spec.rb
+++ b/spec/requests/admin/ai_personas_controller_spec.rb
@@ -407,4 +407,178 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
}.not_to change(AiPersona, :count)
end
end
+
+ describe "#stream_reply" do
+ fab!(:llm) { Fabricate(:llm_model, name: "fake_llm", provider: "fake") }
+ let(:fake_endpoint) { DiscourseAi::Completions::Endpoints::Fake }
+
+ before { fake_endpoint.delays = [] }
+
+ after { fake_endpoint.reset! }
+
+ it "ensures persona exists" do
+ post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json"
+ expect(response).to have_http_status(:unprocessable_entity)
+ # this ensures localization key is actually in the yaml
+ expect(response.body).to include("persona_name")
+ end
+
+ it "ensures question exists" do
+ ai_persona.update!(default_llm: "custom:#{llm.id}")
+
+ post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json",
+ params: {
+ persona_id: ai_persona.id,
+ user_unique_id: "site:test.com:user_id:1",
+ }
+ expect(response).to have_http_status(:unprocessable_entity)
+ expect(response.body).to include("query")
+ end
+
+ it "ensure persona has a user specified" do
+ ai_persona.update!(default_llm: "custom:#{llm.id}")
+
+ post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json",
+ params: {
+ persona_id: ai_persona.id,
+ query: "how are you today?",
+ user_unique_id: "site:test.com:user_id:1",
+ }
+
+ expect(response).to have_http_status(:unprocessable_entity)
+ expect(response.body).to include("associated")
+ end
+
+ def validate_streamed_response(raw_http, expected)
+ lines = raw_http.split("\r\n")
+
+ header_lines, _, payload_lines = lines.chunk { |l| l == "" }.map(&:last)
+
+ preamble = (<<~PREAMBLE).strip
+ HTTP/1.1 200 OK
+ Content-Type: text/plain; charset=utf-8
+ Transfer-Encoding: chunked
+ Cache-Control: no-cache, no-store, must-revalidate
+ Connection: close
+ X-Accel-Buffering: no
+ X-Content-Type-Options: nosniff
+ PREAMBLE
+
+ expect(header_lines.join("\n")).to eq(preamble)
+
+ parsed = +""
+
+ context_info = nil
+
+ payload_lines.each_slice(2) do |size, data|
+ size = size.to_i(16)
+ data = data.to_s
+ expect(data.bytesize).to eq(size)
+
+ if size > 0
+ json = JSON.parse(data)
+ parsed << json["partial"].to_s
+
+ context_info = json if json["topic_id"]
+ end
+ end
+
+ expect(parsed).to eq(expected)
+
+ context_info
+ end
+
+ it "is able to create a new conversation" do
+ Jobs.run_immediately!
+
+ fake_endpoint.fake_content = ["This is a test! Testing!", "An amazing title"]
+
+ ai_persona.create_user!
+ ai_persona.update!(
+ allowed_group_ids: [Group::AUTO_GROUPS[:staff]],
+ default_llm: "custom:#{llm.id}",
+ )
+
+ io_out, io_in = IO.pipe
+
+ post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json",
+ params: {
+ persona_name: ai_persona.name,
+ query: "how are you today?",
+ user_unique_id: "site:test.com:user_id:1",
+ preferred_username: "test_user",
+ },
+ env: {
+ "rack.hijack" => lambda { io_in },
+ }
+
+ # this is a fake response but it catches errors
+ expect(response).to have_http_status(:no_content)
+
+ raw = io_out.read
+ context_info = validate_streamed_response(raw, "This is a test! Testing!")
+
+ expect(context_info["topic_id"]).to be_present
+ topic = Topic.find(context_info["topic_id"])
+ last_post = topic.posts.order(:created_at).last
+ expect(last_post.raw).to eq("This is a test! Testing!")
+
+ user_post = topic.posts.find_by(post_number: 1)
+ expect(user_post.raw).to eq("how are you today?")
+
+ # need ai persona and user
+ expect(topic.topic_allowed_users.count).to eq(2)
+ expect(topic.archetype).to eq(Archetype.private_message)
+ expect(topic.title).to eq("An amazing title")
+
+ # now let's try to make a reply with a tool call
+ function_call = <<~XML
+
+
+ categories
+
+
+ XML
+
+ fake_endpoint.fake_content = [function_call, "this is the response after the tool"]
+ # this simplifies function calls
+ fake_endpoint.chunk_count = 1
+
+ ai_persona.update!(tools: ["Categories"])
+
+ io_out, io_in = IO.pipe
+
+ post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json",
+ params: {
+ persona_id: ai_persona.id,
+ query: "how are you now?",
+ user_unique_id: "site:test.com:user_id:1",
+ preferred_username: "test_user",
+ topic_id: context_info["topic_id"],
+ },
+ env: {
+ "rack.hijack" => lambda { io_in },
+ }
+
+ # this is a fake response but it catches errors
+ expect(response).to have_http_status(:no_content)
+
+ raw = io_out.read
+ context_info = validate_streamed_response(raw, "this is the response after the tool")
+
+ topic = topic.reload
+ last_post = topic.posts.order(:created_at).last
+
+ expect(last_post.raw).to end_with("this is the response after the tool")
+ # function call is visible in the post
+ expect(last_post.raw[0..8]).to eq("")
+
+ user_post = topic.posts.find_by(post_number: 3)
+ expect(user_post.raw).to eq("how are you now?")
+ expect(user_post.user.username).to eq("test_user")
+ expect(user_post.user.custom_fields).to eq(
+ { "ai-stream-conversation-unique-id" => "site:test.com:user_id:1" },
+ )
+ end
+ end
end