From 40fa527633ad8394377154ea997435e75e4f9960 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 1 Jul 2025 18:02:16 +1000 Subject: [PATCH] FIX: cross talk when in ai helper (#1478) Previous to this change we reused channels for proofreading progress and ai helper progress The new changeset ensures each POST to stream progress gets a dedicated message bus channel This fixes a class of issues where the wrong information could be displayed to end users on subsequent proofreading or helper calls * fix tests * fix implementation (got to subscribe at 0) --- .../ai_helper/assistant_controller.rb | 19 ++++- app/jobs/regular/stream_composer_helper.rb | 3 +- app/jobs/regular/stream_post_helper.rb | 5 +- .../components/ai-post-helper-menu.gjs | 82 ++++++++++--------- .../discourse/components/modal/diff-modal.gjs | 34 ++++---- lib/ai_helper/entry_point.rb | 12 --- .../regular/stream_composer_helper_spec.rb | 24 ++++-- spec/jobs/regular/stream_post_helper_spec.rb | 49 +++++++++-- .../ai_helper/assistant_controller_spec.rb | 50 ++++++++++- .../acceptance/post-helper-menu-test.js | 34 +++----- 10 files changed, 206 insertions(+), 106 deletions(-) diff --git a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb index 4558c5d6..8a65a187 100644 --- a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb +++ b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb @@ -124,6 +124,9 @@ module DiscourseAi # otherwise we may end up streaming the data to the wrong client raise Discourse::InvalidParameters.new(:client_id) if params[:client_id].blank? + channel_id = next_channel_id + progress_channel = "discourse_ai_helper/stream_suggestions/#{channel_id}" + if location == "composer" Jobs.enqueue( :stream_composer_helper, @@ -133,6 +136,7 @@ module DiscourseAi custom_prompt: params[:custom_prompt], force_default_locale: params[:force_default_locale] || false, client_id: params[:client_id], + progress_channel:, ) else post_id = get_post_param! @@ -148,10 +152,11 @@ module DiscourseAi prompt: params[:mode], custom_prompt: params[:custom_prompt], client_id: params[:client_id], + progress_channel:, ) end - render json: { success: true }, status: 200 + render json: { success: true, progress_channel: }, status: 200 rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"), status: 502 @@ -192,6 +197,18 @@ module DiscourseAi private + CHANNEL_ID_KEY = "discourse_ai_helper_next_channel_id" + + def next_channel_id + Discourse + .redis + .pipelined do |pipeline| + pipeline.incr(CHANNEL_ID_KEY) + pipeline.expire(CHANNEL_ID_KEY, 1.day) + end + .first + end + def get_text_param! params[:text].tap { |t| raise Discourse::InvalidParameters.new(:text) if t.blank? } end diff --git a/app/jobs/regular/stream_composer_helper.rb b/app/jobs/regular/stream_composer_helper.rb index ae3c5017..b7e9b33f 100644 --- a/app/jobs/regular/stream_composer_helper.rb +++ b/app/jobs/regular/stream_composer_helper.rb @@ -9,6 +9,7 @@ module Jobs return unless user = User.find_by(id: args[:user_id]) return unless args[:text] return unless args[:client_id] + return unless args[:progress_channel] helper_mode = args[:prompt] @@ -16,7 +17,7 @@ module Jobs helper_mode, args[:text], user, - "/discourse-ai/ai-helper/stream_composer_suggestion", + args[:progress_channel], force_default_locale: args[:force_default_locale], client_id: args[:client_id], custom_prompt: args[:custom_prompt], diff --git a/app/jobs/regular/stream_post_helper.rb b/app/jobs/regular/stream_post_helper.rb index 56a3149f..a1cca62c 100644 --- a/app/jobs/regular/stream_post_helper.rb +++ b/app/jobs/regular/stream_post_helper.rb @@ -8,6 +8,8 @@ module Jobs return unless post = Post.includes(:topic).find_by(id: args[:post_id]) return unless user = User.find_by(id: args[:user_id]) return unless args[:text] + return unless args[:progress_channel] + return unless args[:client_id] topic = post.topic reply_to = post.reply_to_post @@ -31,8 +33,9 @@ module Jobs helper_mode, input, user, - "/discourse-ai/ai-helper/stream_suggestion/#{post.id}", + args[:progress_channel], custom_prompt: args[:custom_prompt], + client_id: args[:client_id], ) end end diff --git a/assets/javascripts/discourse/components/ai-post-helper-menu.gjs b/assets/javascripts/discourse/components/ai-post-helper-menu.gjs index 8edacb48..e22ee061 100644 --- a/assets/javascripts/discourse/components/ai-post-helper-menu.gjs +++ b/assets/javascripts/discourse/components/ai-post-helper-menu.gjs @@ -1,7 +1,6 @@ import Component from "@glimmer/component"; import { tracked } from "@glimmer/tracking"; import { action } from "@ember/object"; -import didInsert from "@ember/render-modifiers/modifiers/did-insert"; import willDestroy from "@ember/render-modifiers/modifiers/will-destroy"; import { service } from "@ember/service"; import { modifier } from "ember-modifier"; @@ -43,9 +42,6 @@ export default class AiPostHelperMenu extends Component { @tracked lastSelectedOption = null; @tracked isSavingFootnote = false; @tracked supportsAddFootnote = this.args.data.supportsFastEdit; - @tracked - channel = - `/discourse-ai/ai-helper/stream_suggestion/${this.args.data.quoteState.postId}`; @tracked smoothStreamer = new SmoothStreamer( @@ -150,19 +146,25 @@ export default class AiPostHelperMenu extends Component { return sanitize(text); } - @bind + set progressChannel(value) { + if (this._progressChannel) { + this.unsubscribe(); + } + this._progressChannel = value; + this.subscribe(); + } + subscribe() { - this.messageBus.subscribe( - this.channel, - (data) => this._updateResult(data), - this.args.data.post - .discourse_ai_helper_stream_suggestion_last_message_bus_id - ); + this.messageBus.subscribe(this._progressChannel, this._updateResult, 0); } @bind unsubscribe() { - this.messageBus.unsubscribe(this.channel, this._updateResult); + if (!this._progressChannel) { + return; + } + this.messageBus.unsubscribe(this._progressChannel, this._updateResult); + this._progressChannel = null; } @bind @@ -182,33 +184,38 @@ export default class AiPostHelperMenu extends Component { this.lastSelectedOption = option; const streamableOptions = ["explain", "translate", "custom_prompt"]; - if (streamableOptions.includes(option.name)) { - return this._handleStreamedResult(option); - } else { - this._activeAiRequest = ajax("/discourse-ai/ai-helper/suggest", { - method: "POST", - data: { - mode: option.name, - text: this.args.data.quoteState.buffer, - custom_prompt: this.customPromptValue, - }, - }); + try { + if (streamableOptions.includes(option.name)) { + const streamedResult = await this._handleStreamedResult(option); + this.progressChannel = streamedResult.progress_channel; + return; + } else { + this._activeAiRequest = ajax("/discourse-ai/ai-helper/suggest", { + method: "POST", + data: { + mode: option.name, + text: this.args.data.quoteState.buffer, + custom_prompt: this.customPromptValue, + }, + }); + } + + this._activeAiRequest + .then(({ suggestions }) => { + this.suggestion = suggestions[0].trim(); + + if (option.name === "proofread") { + return this._handleProofreadOption(); + } + }) + .finally(() => { + this.loading = false; + this.menuState = this.MENU_STATES.result; + }); + } catch (error) { + popupAjaxError(error); } - this._activeAiRequest - .then(({ suggestions }) => { - this.suggestion = suggestions[0].trim(); - - if (option.name === "proofread") { - return this._handleProofreadOption(); - } - }) - .catch(popupAjaxError) - .finally(() => { - this.loading = false; - this.menuState = this.MENU_STATES.result; - }); - return this._activeAiRequest; } @@ -340,7 +347,6 @@ export default class AiPostHelperMenu extends Component { {{else if (eq this.menuState this.MENU_STATES.result)}}
{{#if this.suggestion}} diff --git a/assets/javascripts/discourse/components/modal/diff-modal.gjs b/assets/javascripts/discourse/components/modal/diff-modal.gjs index 6d504aba..eaceee5e 100644 --- a/assets/javascripts/discourse/components/modal/diff-modal.gjs +++ b/assets/javascripts/discourse/components/modal/diff-modal.gjs @@ -1,7 +1,6 @@ import Component from "@glimmer/component"; import { tracked } from "@glimmer/tracking"; import { action } from "@ember/object"; -import didInsert from "@ember/render-modifiers/modifiers/did-insert"; import willDestroy from "@ember/render-modifiers/modifiers/will-destroy"; import { service } from "@ember/service"; import { htmlSafe } from "@ember/template"; @@ -19,8 +18,6 @@ import DiffStreamer from "../../lib/diff-streamer"; import SmoothStreamer from "../../lib/smooth-streamer"; import AiIndicatorWave from "../ai-indicator-wave"; -const CHANNEL = "/discourse-ai/ai-helper/stream_composer_suggestion"; - export default class ModalDiffModal extends Component { @service currentUser; @service messageBus; @@ -83,21 +80,26 @@ export default class ModalDiffModal extends Component { return this.loading || this.isStreaming; } - @bind + set progressChannel(value) { + if (this._progressChannel) { + this.messageBus.unsubscribe(this._progressChannel, this.updateResult); + } + this._progressChannel = value; + this.subscribe(); + } + subscribe() { - this.messageBus.subscribe( - CHANNEL, - this.updateResult, - this.currentUser - ?.discourse_ai_helper_stream_composer_suggestion_last_message_bus_id - ); + // we have 1 channel per operation so we can safely subscribe at head + this.messageBus.subscribe(this._progressChannel, this.updateResult, 0); } @bind cleanup() { // stop all callbacks so it does not end up streaming pointlessly this.#resetState(); - this.messageBus.unsubscribe(CHANNEL, this.updateResult); + if (this._progressChannel) { + this.messageBus.unsubscribe(this._progressChannel, this.updateResult); + } } @action @@ -122,7 +124,7 @@ export default class ModalDiffModal extends Component { try { this.loading = true; - return await ajax("/discourse-ai/ai-helper/stream_suggestion", { + const result = await ajax("/discourse-ai/ai-helper/stream_suggestion", { method: "POST", data: { location: "composer", @@ -133,6 +135,8 @@ export default class ModalDiffModal extends Component { client_id: this.messageBus.clientId, }, }); + + this.progressChannel = result.progress_channel; } catch (e) { popupAjaxError(e); } @@ -183,11 +187,7 @@ export default class ModalDiffModal extends Component { @closeModal={{this.cleanupAndClose}} > <:body> -
+
{ SiteSetting.ai_helper_enabled && scope.authenticated? }, - ) { MessageBus.last_id("/discourse-ai/ai-helper/stream_suggestion/#{object.id}") } - - plugin.add_to_serializer( - :current_user, - :discourse_ai_helper_stream_composer_suggestion_last_message_bus_id, - include_condition: -> { SiteSetting.ai_helper_enabled && scope.authenticated? }, - ) { MessageBus.last_id("/discourse-ai/ai-helper/stream_composer_suggestion") } end end end diff --git a/spec/jobs/regular/stream_composer_helper_spec.rb b/spec/jobs/regular/stream_composer_helper_spec.rb index 2d1cb623..350f758f 100644 --- a/spec/jobs/regular/stream_composer_helper_spec.rb +++ b/spec/jobs/regular/stream_composer_helper_spec.rb @@ -18,23 +18,33 @@ RSpec.describe Jobs::StreamComposerHelper do let(:mode) { DiscourseAi::AiHelper::Assistant::PROOFREAD } it "does nothing if there is no user" do + channel = "/some/channel" messages = - MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion") do - job.execute(user_id: nil, text: input, prompt: mode, force_default_locale: false) + MessageBus.track_publish(channel) do + job.execute( + user_id: nil, + text: input, + prompt: mode, + force_default_locale: false, + client_id: "123", + progress_channel: channel, + ) end expect(messages).to be_empty end it "does nothing if there is no text" do + channel = "/some/channel" messages = - MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion") do + MessageBus.track_publish(channel) do job.execute( user_id: user.id, text: nil, prompt: mode, force_default_locale: false, client_id: "123", + progress_channel: channel, ) end @@ -47,16 +57,18 @@ RSpec.describe Jobs::StreamComposerHelper do it "publishes updates with a partial result" do proofread_result = "I like to eat pie for breakfast because it is delicious." + channel = "/channel/123" DiscourseAi::Completions::Llm.with_prepared_responses([proofread_result]) do messages = - MessageBus.track_publish("/discourse-ai/ai-helper/stream_composer_suggestion") do + MessageBus.track_publish(channel) do job.execute( user_id: user.id, text: input, prompt: mode, force_default_locale: true, client_id: "123", + progress_channel: channel, ) end @@ -68,16 +80,18 @@ RSpec.describe Jobs::StreamComposerHelper do it "publishes a final update to signal we're done" do proofread_result = "I like to eat pie for breakfast because it is delicious." + channel = "/channel/123" DiscourseAi::Completions::Llm.with_prepared_responses([proofread_result]) do messages = - MessageBus.track_publish("/discourse-ai/ai-helper/stream_composer_suggestion") do + MessageBus.track_publish(channel) do job.execute( user_id: user.id, text: input, prompt: mode, force_default_locale: true, client_id: "123", + progress_channel: channel, ) end diff --git a/spec/jobs/regular/stream_post_helper_spec.rb b/spec/jobs/regular/stream_post_helper_spec.rb index 3c492929..06cb69d5 100644 --- a/spec/jobs/regular/stream_post_helper_spec.rb +++ b/spec/jobs/regular/stream_post_helper_spec.rb @@ -60,10 +60,18 @@ RSpec.describe Jobs::StreamPostHelper do explanation = "In this context, \"pie\" refers to a baked dessert typically consisting of a pastry crust and filling." + channel = "/my/channel" DiscourseAi::Completions::Llm.with_prepared_responses([explanation]) do messages = - MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion/#{post.id}") do - job.execute(post_id: post.id, user_id: user.id, text: "pie", prompt: mode) + MessageBus.track_publish(channel) do + job.execute( + post_id: post.id, + user_id: user.id, + text: "pie", + prompt: mode, + progress_channel: channel, + client_id: "test_client_id", + ) end partial_result_update = messages.first.data @@ -76,10 +84,19 @@ RSpec.describe Jobs::StreamPostHelper do explanation = "In this context, \"pie\" refers to a baked dessert typically consisting of a pastry crust and filling." + channel = "/my/channel" + DiscourseAi::Completions::Llm.with_prepared_responses([explanation]) do messages = - MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion/#{post.id}") do - job.execute(post_id: post.id, user_id: user.id, text: "pie", prompt: mode) + MessageBus.track_publish(channel) do + job.execute( + post_id: post.id, + user_id: user.id, + text: "pie", + prompt: mode, + client_id: "test_client_id", + progress_channel: channel, + ) end final_update = messages.last.data @@ -96,10 +113,18 @@ RSpec.describe Jobs::StreamPostHelper do sentence = "I like to eat pie." translation = "Me gusta comer pastel." + channel = "/my/channel" DiscourseAi::Completions::Llm.with_prepared_responses([translation]) do messages = - MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion/#{post.id}") do - job.execute(post_id: post.id, user_id: user.id, text: sentence, prompt: mode) + MessageBus.track_publish(channel) do + job.execute( + post_id: post.id, + user_id: user.id, + text: sentence, + prompt: mode, + progress_channel: channel, + client_id: "test_client_id", + ) end partial_result_update = messages.first.data @@ -111,11 +136,19 @@ RSpec.describe Jobs::StreamPostHelper do it "publishes a final update to signal we're done" do sentence = "I like to eat pie." translation = "Me gusta comer pastel." + channel = "/my/channel" DiscourseAi::Completions::Llm.with_prepared_responses([translation]) do messages = - MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion/#{post.id}") do - job.execute(post_id: post.id, user_id: user.id, text: sentence, prompt: mode) + MessageBus.track_publish(channel) do + job.execute( + post_id: post.id, + user_id: user.id, + text: sentence, + prompt: mode, + progress_channel: channel, + client_id: "test_client_id", + ) end final_update = messages.last.data diff --git a/spec/requests/ai_helper/assistant_controller_spec.rb b/spec/requests/ai_helper/assistant_controller_spec.rb index 7c9ea8b2..47d5d7d7 100644 --- a/spec/requests/ai_helper/assistant_controller_spec.rb +++ b/spec/requests/ai_helper/assistant_controller_spec.rb @@ -11,10 +11,50 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do SiteSetting.composer_ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_0] end - it "is able to stream suggestions back on appropriate channel" do + it "is able to stream suggestions to helper" do sign_in(user) + + my_post = Fabricate(:post) + + channel = nil messages = - MessageBus.track_publish("/discourse-ai/ai-helper/stream_composer_suggestion") do + MessageBus.track_publish do + results = [["hello ", "world"]] + DiscourseAi::Completions::Llm.with_prepared_responses(results) do + post "/discourse-ai/ai-helper/stream_suggestion.json", + params: { + text: "hello wrld", + location: "helper", + client_id: "1234", + post_id: my_post.id, + custom_prompt: "Translate to Spanish", + mode: DiscourseAi::AiHelper::Assistant::CUSTOM_PROMPT, + } + + expect(response.status).to eq(200) + channel = response.parsed_body["progress_channel"] + end + end + + # we only have the channel after we make the request + # so we can not filter till now + messages = messages.select { |m| m.channel == channel } + expect(messages).not_to be_empty + + last_message = messages.last + expect(messages.all? { |m| m.client_ids == ["1234"] }).to eq(true) + expect(messages.all? { |m| m == last_message || !m.data[:done] }).to eq(true) + + expect(last_message.channel).to eq(channel) + expect(last_message.data[:result]).to eq("hello world") + expect(last_message.data[:done]).to eq(true) + end + + it "is able to stream suggestions to composer" do + sign_in(user) + channel = nil + messages = + MessageBus.track_publish do results = [["hello ", "world"]] DiscourseAi::Completions::Llm.with_prepared_responses(results) do post "/discourse-ai/ai-helper/stream_suggestion.json", @@ -26,13 +66,19 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do } expect(response.status).to eq(200) + channel = response.parsed_body["progress_channel"] end end + # we only have the channel after we make the request + # so we can not filter till now + messages = messages.select { |m| m.channel == channel } + last_message = messages.last expect(messages.all? { |m| m.client_ids == ["1234"] }).to eq(true) expect(messages.all? { |m| m == last_message || !m.data[:done] }).to eq(true) + expect(last_message.channel).to eq(channel) expect(last_message.data[:result]).to eq("hello world") expect(last_message.data[:done]).to eq(true) end diff --git a/test/javascripts/acceptance/post-helper-menu-test.js b/test/javascripts/acceptance/post-helper-menu-test.js index 6b1002f5..1f17d8f7 100644 --- a/test/javascripts/acceptance/post-helper-menu-test.js +++ b/test/javascripts/acceptance/post-helper-menu-test.js @@ -48,6 +48,7 @@ acceptance("AI Helper - Post Helper Menu", function (needs) { return helper.response({ result: "This is a suggestio", done: false, + progress_channel: "/some/progress/channel", }); }); @@ -61,13 +62,10 @@ acceptance("AI Helper - Post Helper Menu", function (needs) { await selectText(textNode, 9); await click(".ai-post-helper__trigger"); await click(".ai-helper-options__button[data-name='explain']"); - await publishToMessageBus( - `/discourse-ai/ai-helper/stream_suggestion/118591`, - { - done: true, - result: suggestion, - } - ); + await publishToMessageBus(`/some/progress/channel`, { + done: true, + result: suggestion, + }); assert.dom(".ai-post-helper__suggestion__text").hasText(suggestion); }); @@ -91,13 +89,10 @@ acceptance("AI Helper - Post Helper Menu", function (needs) { await selectSpecificText(textNode, 72, 77); await click(".ai-post-helper__trigger"); await click(".ai-helper-options__button[data-name='explain']"); - await publishToMessageBus( - `/discourse-ai/ai-helper/stream_suggestion/118591`, - { - done: true, - result: suggestion, - } - ); + await publishToMessageBus(`/some/progress/channel`, { + done: true, + result: suggestion, + }); assert.dom(".ai-post-helper__suggestion__insert-footnote").isDisabled(); }); @@ -108,13 +103,10 @@ acceptance("AI Helper - Post Helper Menu", function (needs) { await selectText(query("#post_1 .cooked p")); await click(".ai-post-helper__trigger"); await click(".ai-helper-options__button[data-name='translate']"); - await publishToMessageBus( - `/discourse-ai/ai-helper/stream_suggestion/118591`, - { - done: true, - result: translated, - } - ); + await publishToMessageBus(`/some/progress/channel`, { + done: true, + result: translated, + }); assert.dom(".ai-post-helper__suggestion__text").hasText(translated); }); });