From 34a59b623e55488c3f91c32776dc27f53cca35f5 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 30 Oct 2024 20:24:39 +1100 Subject: [PATCH] FIX: ensure replies are never double streamed (#879) The custom field "discourse_ai_bypass_ai_reply" was added so we can signal the post created hook to bypass replying even if it thinks it should. Otherwise there are cases where we double answer user questions leading to much confusion. This also slightly refactors code making the controller smaller --- .../admin/ai_personas_controller.rb | 109 +++--------------- lib/ai_bot/playground.rb | 3 + lib/ai_bot/response_http_streamer.rb | 105 +++++++++++++++++ .../admin/ai_personas_controller_spec.rb | 18 ++- 4 files changed, 138 insertions(+), 97 deletions(-) create mode 100644 lib/ai_bot/response_http_streamer.rb diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb index 35e9a338..8ef466f1 100644 --- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb @@ -74,33 +74,6 @@ module DiscourseAi end end - class << self - POOL_SIZE = 10 - def thread_pool - @thread_pool ||= - Concurrent::CachedThreadPool.new(min_threads: 0, max_threads: POOL_SIZE, idletime: 30) - end - - def schedule_block(&block) - # think about a better way to handle cross thread connections - if Rails.env.test? - block.call - return - end - - db = RailsMultisite::ConnectionManagement.current_db - thread_pool.post do - begin - RailsMultisite::ConnectionManagement.with_connection(db) { block.call } - rescue StandardError => e - Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply") - end - end - end - end - - CRLF = "\r\n" - def stream_reply persona = AiPersona.find_by(name: params[:persona_name]) || @@ -155,6 +128,9 @@ module DiscourseAi topic_id: topic_id, raw: params[:query], skip_validations: true, + custom_fields: { + DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true, + }, ) else post = @@ -165,6 +141,9 @@ module DiscourseAi archetype: Archetype.private_message, target_usernames: "#{user.username},#{persona.user.username}", skip_validations: true, + custom_fields: { + DiscourseAi::AiBot::Playground::BYPASS_AI_REPLY_CUSTOM_FIELD => true, + }, ) topic = post.topic @@ -175,81 +154,19 @@ module DiscourseAi user = current_user - self.class.queue_streamed_reply(io, persona, user, topic, post) + DiscourseAi::AiBot::ResponseHttpStreamer.queue_streamed_reply( + io, + persona, + user, + topic, + post, + ) end private AI_STREAM_CONVERSATION_UNIQUE_ID = "ai-stream-conversation-unique-id" - # keeping this in a static method so we don't capture ENV and other bits - # this allows us to release memory earlier - def self.queue_streamed_reply(io, persona, user, topic, post) - schedule_block do - begin - io.write "HTTP/1.1 200 OK" - io.write CRLF - io.write "Content-Type: text/plain; charset=utf-8" - io.write CRLF - io.write "Transfer-Encoding: chunked" - io.write CRLF - io.write "Cache-Control: no-cache, no-store, must-revalidate" - io.write CRLF - io.write "Connection: close" - io.write CRLF - io.write "X-Accel-Buffering: no" - io.write CRLF - io.write "X-Content-Type-Options: nosniff" - io.write CRLF - io.write CRLF - io.flush - - persona_class = - DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: user) - bot = DiscourseAi::AiBot::Bot.as(persona.user, persona: persona_class.new) - - data = - { topic_id: topic.id, bot_user_id: persona.user.id, persona_id: persona.id }.to_json + - "\n\n" - - io.write data.bytesize.to_s(16) - io.write CRLF - io.write data - io.write CRLF - - DiscourseAi::AiBot::Playground - .new(bot) - .reply_to(post) do |partial| - next if partial.length == 0 - - data = { partial: partial }.to_json + "\n\n" - - data.force_encoding("UTF-8") - - io.write data.bytesize.to_s(16) - io.write CRLF - io.write data - io.write CRLF - io.flush - end - - io.write "0" - io.write CRLF - io.write CRLF - - io.flush - io.done - rescue StandardError => e - # make it a tiny bit easier to debug in dev, this is tricky - # multi-threaded code that exhibits various limitations in rails - p e if Rails.env.development? - Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply") - ensure - io.close - end - end - end - def stage_user unique_id = params[:user_unique_id].to_s field = UserCustomField.find_by(name: AI_STREAM_CONVERSATION_UNIQUE_ID, value: unique_id) diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index a4e696db..4f0648cb 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -3,6 +3,8 @@ module DiscourseAi module AiBot class Playground + BYPASS_AI_REPLY_CUSTOM_FIELD = "discourse_ai_bypass_ai_reply" + attr_reader :bot # An abstraction to manage the bot and topic interactions. @@ -550,6 +552,7 @@ module DiscourseAi return false if bot.bot_user.nil? return false if post.topic.private_message? && post.post_type != Post.types[:regular] return false if (SiteSetting.ai_bot_allowed_groups_map & post.user.group_ids).blank? + return false if post.custom_fields[BYPASS_AI_REPLY_CUSTOM_FIELD].present? true end diff --git a/lib/ai_bot/response_http_streamer.rb b/lib/ai_bot/response_http_streamer.rb new file mode 100644 index 00000000..382c47f6 --- /dev/null +++ b/lib/ai_bot/response_http_streamer.rb @@ -0,0 +1,105 @@ +# frozen_string_literal: true + +module DiscourseAi + module AiBot + class ResponseHttpStreamer + CRLF = "\r\n" + POOL_SIZE = 10 + + class << self + def thread_pool + @thread_pool ||= + Concurrent::CachedThreadPool.new(min_threads: 0, max_threads: POOL_SIZE, idletime: 30) + end + + def schedule_block(&block) + # think about a better way to handle cross thread connections + if Rails.env.test? + block.call + return + end + + db = RailsMultisite::ConnectionManagement.current_db + thread_pool.post do + begin + RailsMultisite::ConnectionManagement.with_connection(db) { block.call } + rescue StandardError => e + Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply") + end + end + end + + # keeping this in a static method so we don't capture ENV and other bits + # this allows us to release memory earlier + def queue_streamed_reply(io, persona, user, topic, post) + schedule_block do + begin + io.write "HTTP/1.1 200 OK" + io.write CRLF + io.write "Content-Type: text/plain; charset=utf-8" + io.write CRLF + io.write "Transfer-Encoding: chunked" + io.write CRLF + io.write "Cache-Control: no-cache, no-store, must-revalidate" + io.write CRLF + io.write "Connection: close" + io.write CRLF + io.write "X-Accel-Buffering: no" + io.write CRLF + io.write "X-Content-Type-Options: nosniff" + io.write CRLF + io.write CRLF + io.flush + + persona_class = + DiscourseAi::AiBot::Personas::Persona.find_by(id: persona.id, user: user) + bot = DiscourseAi::AiBot::Bot.as(persona.user, persona: persona_class.new) + + data = + { + topic_id: topic.id, + bot_user_id: persona.user.id, + persona_id: persona.id, + }.to_json + "\n\n" + + io.write data.bytesize.to_s(16) + io.write CRLF + io.write data + io.write CRLF + + DiscourseAi::AiBot::Playground + .new(bot) + .reply_to(post) do |partial| + next if partial.length == 0 + + data = { partial: partial }.to_json + "\n\n" + + data.force_encoding("UTF-8") + + io.write data.bytesize.to_s(16) + io.write CRLF + io.write data + io.write CRLF + io.flush + end + + io.write "0" + io.write CRLF + io.write CRLF + + io.flush + io.done + rescue StandardError => e + # make it a tiny bit easier to debug in dev, this is tricky + # multi-threaded code that exhibits various limitations in rails + p e if Rails.env.development? + Discourse.warn_exception(e, message: "Discourse AI: Unable to stream reply") + ensure + io.close + end + end + end + end + end + end +end diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb index bc7d10ef..50f71009 100644 --- a/spec/requests/admin/ai_personas_controller_spec.rb +++ b/spec/requests/admin/ai_personas_controller_spec.rb @@ -490,13 +490,16 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do it "is able to create a new conversation" do Jobs.run_immediately! + # trust level 0 + SiteSetting.ai_bot_allowed_groups = "10" fake_endpoint.fake_content = ["This is a test! Testing!", "An amazing title"] ai_persona.create_user! ai_persona.update!( - allowed_group_ids: [Group::AUTO_GROUPS[:staff]], + allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], default_llm: "custom:#{llm.id}", + allow_personal_messages: true, ) io_out, io_in = IO.pipe @@ -530,6 +533,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do expect(topic.topic_allowed_users.count).to eq(2) expect(topic.archetype).to eq(Archetype.private_message) expect(topic.title).to eq("An amazing title") + expect(topic.posts.count).to eq(2) # now let's try to make a reply with a tool call function_call = <<~XML @@ -546,6 +550,16 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do ai_persona.update!(tools: ["Categories"]) + # lets also unstage the user and add the user to tl0 + # this will ensure there are no feedback loops + new_user = user_post.user + new_user.update!(staged: false) + Group.user_trust_level_change!(new_user.id, new_user.trust_level) + + # double check this happened and user is in group + personas = AiPersona.allowed_modalities(user: new_user.reload, allow_personal_messages: true) + expect(personas.count).to eq(1) + io_out, io_in = IO.pipe post "/admin/plugins/discourse-ai/ai-personas/stream-reply.json", @@ -579,6 +593,8 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do expect(user_post.user.custom_fields).to eq( { "ai-stream-conversation-unique-id" => "site:test.com:user_id:1" }, ) + + expect(topic.posts.count).to eq(4) end end end