From be0b78cacdc67c1d6057e5d88ee854f45b4432e7 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 30 Oct 2024 10:28:20 +1100 Subject: [PATCH] FEATURE: new endpoint for directly accessing a persona (#876) The new `/admin/plugins/discourse-ai/ai-personas/stream-reply.json` was added. This endpoint streams data direct from a persona and can be used to access a persona from remote systems leaving a paper trail in PMs about the conversation that happened This endpoint is only accessible to admins. --------- Co-authored-by: Gabriel Grubba <70247653+Grubba27@users.noreply.github.com> Co-authored-by: Keegan George --- .../admin/ai_personas_controller.rb | 197 ++++++++++++++++++ app/models/completion_prompt.rb | 2 +- config/locales/client.en.yml | 3 +- config/locales/server.en.yml | 16 +- config/routes.rb | 2 + ...ue_ai_stream_conversation_user_id_index.rb | 9 + lib/ai_bot/bot.rb | 8 +- lib/ai_bot/entry_point.rb | 5 + lib/ai_bot/playground.rb | 11 +- lib/completions/endpoints/fake.rb | 13 ++ .../admin/ai_personas_controller_spec.rb | 174 ++++++++++++++++ 11 files changed, 428 insertions(+), 12 deletions(-) create mode 100644 db/migrate/20241028034232_add_unique_ai_stream_conversation_user_id_index.rb diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb index b4fd66b1..35e9a338 100644 --- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb @@ -74,8 +74,205 @@ module DiscourseAi end end + class << self + POOL_SIZE = 10 + def thread_pool + @thread_pool ||= + Concurrent::CachedThreadPool.new(min_threads: 0, max_threads: POOL_SIZE, idletime: 30) + end + + def schedule_block(&block) + # think about a better way to handle cross thread connections + if Rails.env.test? + block.call + return + end + + db = RailsMultisite::ConnectionManagement.current_db + thread_pool.post do + begin + RailsMultisite::ConnectionManagement.with_connection(db) { block.call } + rescue StandardError => e + Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply") + end + end + end + end + + CRLF = "\r\n" + + def stream_reply + persona = + AiPersona.find_by(name: params[:persona_name]) || + AiPersona.find_by(id: params[:persona_id]) + return render_json_error(I18n.t("discourse_ai.errors.persona_not_found")) if persona.nil? + + return render_json_error(I18n.t("discourse_ai.errors.persona_disabled")) if !persona.enabled + + if persona.default_llm.blank? + return render_json_error(I18n.t("discourse_ai.errors.no_default_llm")) + end + + if params[:query].blank? + return render_json_error(I18n.t("discourse_ai.errors.no_query_specified")) + end + + if !persona.user_id + return render_json_error(I18n.t("discourse_ai.errors.no_user_for_persona")) + end + + if !params[:username] && !params[:user_unique_id] + return render_json_error(I18n.t("discourse_ai.errors.no_user_specified")) + end + + user = nil + + if params[:username] + user = User.find_by_username(params[:username]) + return render_json_error(I18n.t("discourse_ai.errors.user_not_found")) if user.nil? + elsif params[:user_unique_id] + user = stage_user + end + + raise Discourse::NotFound if user.nil? + + topic_id = params[:topic_id].to_i + topic = nil + post = nil + + if topic_id > 0 + topic = Topic.find(topic_id) + + raise Discourse::NotFound if topic.nil? + + if topic.topic_allowed_users.where(user_id: user.id).empty? + return render_json_error(I18n.t("discourse_ai.errors.user_not_allowed")) + end + + post = + PostCreator.create!( + user, + topic_id: topic_id, + raw: params[:query], + skip_validations: true, + ) + else + post = + PostCreator.create!( + user, + title: I18n.t("discourse_ai.ai_bot.default_pm_prefix"), + raw: params[:query], + archetype: Archetype.private_message, + target_usernames: "#{user.username},#{persona.user.username}", + skip_validations: true, + ) + + topic = post.topic + end + + hijack = request.env["rack.hijack"] + io = hijack.call + + user = current_user + + self.class.queue_streamed_reply(io, persona, user, topic, post) + end + private + AI_STREAM_CONVERSATION_UNIQUE_ID = "ai-stream-conversation-unique-id" + + # keeping this in a static method so we don't capture ENV and other bits + # this allows us to release memory earlier + def self.queue_streamed_reply(io, persona, user, topic, post) + schedule_block do + begin + io.write "HTTP/1.1 200 OK" + io.write CRLF + io.write "Content-Type: text/plain; charset=utf-8" + io.write CRLF + io.write "Transfer-Encoding: chunked" + io.write CRLF + io.write "Cache-Control: no-cache, no-store, must-revalidate" + io.write CRLF + io.write "Connection: close" + io.write CRLF + io.write "X-Accel-Buffering: no" + io.write CRLF + io.write "X-Content-Type-Options: nosniff" + io.write CRLF + io.write CRLF + io.flush + + persona_class = + DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: user) + bot = DiscourseAi::AiBot::Bot.as(persona.user, persona: persona_class.new) + + data = + { topic_id: topic.id, bot_user_id: persona.user.id, persona_id: persona.id }.to_json + + "\n\n" + + io.write data.bytesize.to_s(16) + io.write CRLF + io.write data + io.write CRLF + + DiscourseAi::AiBot::Playground + .new(bot) + .reply_to(post) do |partial| + next if partial.length == 0 + + data = { partial: partial }.to_json + "\n\n" + + data.force_encoding("UTF-8") + + io.write data.bytesize.to_s(16) + io.write CRLF + io.write data + io.write CRLF + io.flush + end + + io.write "0" + io.write CRLF + io.write CRLF + + io.flush + io.done + rescue StandardError => e + # make it a tiny bit easier to debug in dev, this is tricky + # multi-threaded code that exhibits various limitations in rails + p e if Rails.env.development? + Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply") + ensure + io.close + end + end + end + + def stage_user + unique_id = params[:user_unique_id].to_s + field = UserCustomField.find_by(name: AI_STREAM_CONVERSATION_UNIQUE_ID, value: unique_id) + + if field + field.user + else + preferred_username = params[:preferred_username] + username = UserNameSuggester.suggest(preferred_username || unique_id) + + user = + User.new( + username: username, + email: "#{SecureRandom.hex}@invalid.com", + staged: true, + active: false, + ) + user.custom_fields[AI_STREAM_CONVERSATION_UNIQUE_ID] = unique_id + user.save! + user + end + end + def find_ai_persona @ai_persona = AiPersona.find(params[:id]) end diff --git a/app/models/completion_prompt.rb b/app/models/completion_prompt.rb index 0cd1f71d..4d2c74da 100644 --- a/app/models/completion_prompt.rb +++ b/app/models/completion_prompt.rb @@ -56,7 +56,7 @@ class CompletionPrompt < ActiveRecord::Base messages.each_with_index do |msg, idx| next if msg["content"].length <= 1000 - errors.add(:messages, I18n.t("errors.prompt_message_length", idx: idx + 1)) + errors.add(:messages, I18n.t("discourse_ai.errors.prompt_message_length", idx: idx + 1)) end end end diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 471baa0d..2d517946 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -5,7 +5,8 @@ en: scopes: descriptions: discourse_ai: - search: "Allows semantic search via the /discourse-ai/embeddings/semantic-search endpoint." + search: "Allows semantic search" + stream_completion: "Allows streaming ai persona completions" site_settings: categories: diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 1ede2ad0..d383a0a3 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -106,10 +106,6 @@ en: flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic. flagged_by_nsfw: The AI plugin flagged this after classifying at least one of the attached images as NSFW. - errors: - prompt_message_length: The message %{idx} is over the 1000 character limit. - invalid_prompt_role: The message %{idx} has an invalid role. - reports: overall_sentiment: title: "Overall sentiment" @@ -169,6 +165,7 @@ en: failed_to_share: "Failed to share the conversation" conversation_deleted: "Conversation share deleted successfully" ai_bot: + default_pm_prefix: "[Untitled AI bot PM]" personas: default_llm_required: "Default LLM model is required prior to enabling Chat" cannot_delete_system_persona: "System personas cannot be deleted, please disable it instead" @@ -347,3 +344,14 @@ en: llm_models: missing_provider_param: "%{param} can't be blank" bedrock_invalid_url: "Please complete all the fields to contact this model." + + errors: + no_query_specified: The query parameter is required, please specify it. + no_user_for_persona: The persona specified does not have a user associated with it. + persona_not_found: The persona specified does not exist. Check the persona_name or persona_id params. + no_user_specified: The username or the user_unique_id parameter is required, please specify it. + user_not_found: The user specified does not exist. Check the username param. + persona_disabled: The persona specified is disabled. Check the persona_name or persona_id params. + no_default_llm: The persona must have a default_llm defined. + user_not_allowed: The user is not allowed to participate in the topic. + prompt_message_length: The message %{idx} is over the 1000 character limit. diff --git a/config/routes.rb b/config/routes.rb index e806f741..a5c009ff 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -50,6 +50,8 @@ Discourse::Application.routes.draw do path: "ai-personas", controller: "discourse_ai/admin/ai_personas" + post "/ai-personas/stream-reply" => "discourse_ai/admin/ai_personas#stream_reply" + resources( :ai_tools, only: %i[index create show update destroy], diff --git a/db/migrate/20241028034232_add_unique_ai_stream_conversation_user_id_index.rb b/db/migrate/20241028034232_add_unique_ai_stream_conversation_user_id_index.rb new file mode 100644 index 00000000..cff38477 --- /dev/null +++ b/db/migrate/20241028034232_add_unique_ai_stream_conversation_user_id_index.rb @@ -0,0 +1,9 @@ +# frozen_string_literal: true +class AddUniqueAiStreamConversationUserIdIndex < ActiveRecord::Migration[7.1] + def change + add_index :user_custom_fields, + [:value], + unique: true, + where: "name = 'ai-stream-conversation-unique-id'" + end +end diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index e8e1eadd..834ae059 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -113,7 +113,7 @@ module DiscourseAi tool_found = true # a bit hacky, but extra newlines do no harm if needs_newlines - update_blk.call("\n\n", cancel, nil) + update_blk.call("\n\n", cancel) needs_newlines = false end @@ -123,7 +123,7 @@ module DiscourseAi end else needs_newlines = true - update_blk.call(partial, cancel, nil) + update_blk.call(partial, cancel) end end @@ -191,9 +191,9 @@ module DiscourseAi tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw) if context[:skip_tool_details] && tool.custom_raw.present? - update_blk.call(tool.custom_raw, cancel, nil) + update_blk.call(tool.custom_raw, cancel, nil, :custom_raw) elsif !context[:skip_tool_details] - update_blk.call(tool_details, cancel, nil) + update_blk.call(tool_details, cancel, nil, :tool_details) end result diff --git a/lib/ai_bot/entry_point.rb b/lib/ai_bot/entry_point.rb index 64fd9133..ff5bf69f 100644 --- a/lib/ai_bot/entry_point.rb +++ b/lib/ai_bot/entry_point.rb @@ -189,6 +189,11 @@ module DiscourseAi plugin.register_editable_topic_custom_field(:ai_persona_id) end + plugin.add_api_key_scope( + :discourse_ai, + { stream_completion: { actions: %w[discourse_ai/admin/ai_personas#stream_reply] } }, + ) + plugin.on(:site_setting_changed) do |name, old_value, new_value| if name == :ai_embeddings_model && SiteSetting.ai_embeddings_enabled? && new_value != old_value diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 13834c92..a4e696db 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -390,7 +390,12 @@ module DiscourseAi result end - def reply_to(post) + def reply_to(post, &blk) + # this is a multithreading issue + # post custom prompt is needed and it may not + # be properly loaded, ensure it is loaded + PostCustomPrompt.none + reply = +"" start = Time.now @@ -441,11 +446,13 @@ module DiscourseAi context[:skip_tool_details] ||= !bot.persona.class.tool_details new_custom_prompts = - bot.reply(context) do |partial, cancel, placeholder| + bot.reply(context) do |partial, cancel, placeholder, type| reply << partial raw = reply.dup raw << "\n\n" << placeholder if placeholder.present? && !context[:skip_tool_details] + blk.call(partial) if blk && type != :tool_details + if stream_reply && !Discourse.redis.get(redis_stream_key) cancel&.call reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) diff --git a/lib/completions/endpoints/fake.rb b/lib/completions/endpoints/fake.rb index 2beec61a..d057fb15 100644 --- a/lib/completions/endpoints/fake.rb +++ b/lib/completions/endpoints/fake.rb @@ -72,6 +72,10 @@ module DiscourseAi @fake_content = nil end + def self.fake_content=(content) + @fake_content = content + end + def self.fake_content @fake_content || STOCK_CONTENT end @@ -100,6 +104,13 @@ module DiscourseAi @last_call = params end + def self.reset! + @last_call = nil + @fake_content = nil + @delays = nil + @chunk_count = nil + end + def perform_completion!( dialect, user, @@ -111,6 +122,8 @@ module DiscourseAi content = self.class.fake_content + content = content.shift if content.is_a?(Array) + if block_given? split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort indexes = [0, *split_indices, content.length] diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb index 1dee9b73..bc7d10ef 100644 --- a/spec/requests/admin/ai_personas_controller_spec.rb +++ b/spec/requests/admin/ai_personas_controller_spec.rb @@ -407,4 +407,178 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do }.not_to change(AiPersona, :count) end end + + describe "#stream_reply" do + fab!(:llm) { Fabricate(:llm_model, name: "fake_llm", provider: "fake") } + let(:fake_endpoint) { DiscourseAi::Completions::Endpoints::Fake } + + before { fake_endpoint.delays = [] } + + after { fake_endpoint.reset! } + + it "ensures persona exists" do + post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json" + expect(response).to have_http_status(:unprocessable_entity) + # this ensures localization key is actually in the yaml + expect(response.body).to include("persona_name") + end + + it "ensures question exists" do + ai_persona.update!(default_llm: "custom:#{llm.id}") + + post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json", + params: { + persona_id: ai_persona.id, + user_unique_id: "site:test.com:user_id:1", + } + expect(response).to have_http_status(:unprocessable_entity) + expect(response.body).to include("query") + end + + it "ensure persona has a user specified" do + ai_persona.update!(default_llm: "custom:#{llm.id}") + + post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json", + params: { + persona_id: ai_persona.id, + query: "how are you today?", + user_unique_id: "site:test.com:user_id:1", + } + + expect(response).to have_http_status(:unprocessable_entity) + expect(response.body).to include("associated") + end + + def validate_streamed_response(raw_http, expected) + lines = raw_http.split("\r\n") + + header_lines, _, payload_lines = lines.chunk { |l| l == "" }.map(&:last) + + preamble = (<<~PREAMBLE).strip + HTTP/1.1 200 OK + Content-Type: text/plain; charset=utf-8 + Transfer-Encoding: chunked + Cache-Control: no-cache, no-store, must-revalidate + Connection: close + X-Accel-Buffering: no + X-Content-Type-Options: nosniff + PREAMBLE + + expect(header_lines.join("\n")).to eq(preamble) + + parsed = +"" + + context_info = nil + + payload_lines.each_slice(2) do |size, data| + size = size.to_i(16) + data = data.to_s + expect(data.bytesize).to eq(size) + + if size > 0 + json = JSON.parse(data) + parsed << json["partial"].to_s + + context_info = json if json["topic_id"] + end + end + + expect(parsed).to eq(expected) + + context_info + end + + it "is able to create a new conversation" do + Jobs.run_immediately! + + fake_endpoint.fake_content = ["This is a test! Testing!", "An amazing title"] + + ai_persona.create_user! + ai_persona.update!( + allowed_group_ids: [Group::AUTO_GROUPS[:staff]], + default_llm: "custom:#{llm.id}", + ) + + io_out, io_in = IO.pipe + + post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json", + params: { + persona_name: ai_persona.name, + query: "how are you today?", + user_unique_id: "site:test.com:user_id:1", + preferred_username: "test_user", + }, + env: { + "rack.hijack" => lambda { io_in }, + } + + # this is a fake response but it catches errors + expect(response).to have_http_status(:no_content) + + raw = io_out.read + context_info = validate_streamed_response(raw, "This is a test! Testing!") + + expect(context_info["topic_id"]).to be_present + topic = Topic.find(context_info["topic_id"]) + last_post = topic.posts.order(:created_at).last + expect(last_post.raw).to eq("This is a test! Testing!") + + user_post = topic.posts.find_by(post_number: 1) + expect(user_post.raw).to eq("how are you today?") + + # need ai persona and user + expect(topic.topic_allowed_users.count).to eq(2) + expect(topic.archetype).to eq(Archetype.private_message) + expect(topic.title).to eq("An amazing title") + + # now let's try to make a reply with a tool call + function_call = <<~XML + + + categories + + + XML + + fake_endpoint.fake_content = [function_call, "this is the response after the tool"] + # this simplifies function calls + fake_endpoint.chunk_count = 1 + + ai_persona.update!(tools: ["Categories"]) + + io_out, io_in = IO.pipe + + post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json", + params: { + persona_id: ai_persona.id, + query: "how are you now?", + user_unique_id: "site:test.com:user_id:1", + preferred_username: "test_user", + topic_id: context_info["topic_id"], + }, + env: { + "rack.hijack" => lambda { io_in }, + } + + # this is a fake response but it catches errors + expect(response).to have_http_status(:no_content) + + raw = io_out.read + context_info = validate_streamed_response(raw, "this is the response after the tool") + + topic = topic.reload + last_post = topic.posts.order(:created_at).last + + expect(last_post.raw).to end_with("this is the response after the tool") + # function call is visible in the post + expect(last_post.raw[0..8]).to eq("
") + + user_post = topic.posts.find_by(post_number: 3) + expect(user_post.raw).to eq("how are you now?") + expect(user_post.user.username).to eq("test_user") + expect(user_post.user.custom_fields).to eq( + { "ai-stream-conversation-unique-id" => "site:test.com:user_id:1" }, + ) + end + end end