diff --git a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb index 4558c5d6..8a65a187 100644 --- a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb +++ b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb @@ -124,6 +124,9 @@ module DiscourseAi # otherwise we may end up streaming the data to the wrong client raise Discourse::InvalidParameters.new(:client_id) if params[:client_id].blank? + channel_id = next_channel_id + progress_channel = "discourse_ai_helper/stream_suggestions/#{channel_id}" + if location == "composer" Jobs.enqueue( :stream_composer_helper, @@ -133,6 +136,7 @@ module DiscourseAi custom_prompt: params[:custom_prompt], force_default_locale: params[:force_default_locale] || false, client_id: params[:client_id], + progress_channel:, ) else post_id = get_post_param! @@ -148,10 +152,11 @@ module DiscourseAi prompt: params[:mode], custom_prompt: params[:custom_prompt], client_id: params[:client_id], + progress_channel:, ) end - render json: { success: true }, status: 200 + render json: { success: true, progress_channel: }, status: 200 rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"), status: 502 @@ -192,6 +197,18 @@ module DiscourseAi private + CHANNEL_ID_KEY = "discourse_ai_helper_next_channel_id" + + def next_channel_id + Discourse + .redis + .pipelined do |pipeline| + pipeline.incr(CHANNEL_ID_KEY) + pipeline.expire(CHANNEL_ID_KEY, 1.day) + end + .first + end + def get_text_param! params[:text].tap { |t| raise Discourse::InvalidParameters.new(:text) if t.blank? } end diff --git a/app/jobs/regular/stream_composer_helper.rb b/app/jobs/regular/stream_composer_helper.rb index ae3c5017..b7e9b33f 100644 --- a/app/jobs/regular/stream_composer_helper.rb +++ b/app/jobs/regular/stream_composer_helper.rb @@ -9,6 +9,7 @@ module Jobs return unless user = User.find_by(id: args[:user_id]) return unless args[:text] return unless args[:client_id] + return unless args[:progress_channel] helper_mode = args[:prompt] @@ -16,7 +17,7 @@ module Jobs helper_mode, args[:text], user, - "/discourse-ai/ai-helper/stream_composer_suggestion", + args[:progress_channel], force_default_locale: args[:force_default_locale], client_id: args[:client_id], custom_prompt: args[:custom_prompt], diff --git a/app/jobs/regular/stream_post_helper.rb b/app/jobs/regular/stream_post_helper.rb index 56a3149f..a1cca62c 100644 --- a/app/jobs/regular/stream_post_helper.rb +++ b/app/jobs/regular/stream_post_helper.rb @@ -8,6 +8,8 @@ module Jobs return unless post = Post.includes(:topic).find_by(id: args[:post_id]) return unless user = User.find_by(id: args[:user_id]) return unless args[:text] + return unless args[:progress_channel] + return unless args[:client_id] topic = post.topic reply_to = post.reply_to_post @@ -31,8 +33,9 @@ module Jobs helper_mode, input, user, - "/discourse-ai/ai-helper/stream_suggestion/#{post.id}", + args[:progress_channel], custom_prompt: args[:custom_prompt], + client_id: args[:client_id], ) end end diff --git a/assets/javascripts/discourse/components/ai-post-helper-menu.gjs b/assets/javascripts/discourse/components/ai-post-helper-menu.gjs index 8edacb48..e22ee061 100644 --- a/assets/javascripts/discourse/components/ai-post-helper-menu.gjs +++ b/assets/javascripts/discourse/components/ai-post-helper-menu.gjs @@ -1,7 +1,6 @@ import Component from "@glimmer/component"; import { tracked } from "@glimmer/tracking"; import { action } from "@ember/object"; -import didInsert from "@ember/render-modifiers/modifiers/did-insert"; import willDestroy from "@ember/render-modifiers/modifiers/will-destroy"; import { service } from "@ember/service"; import { modifier } from "ember-modifier"; @@ -43,9 +42,6 @@ export default class AiPostHelperMenu extends Component { @tracked lastSelectedOption = null; @tracked isSavingFootnote = false; @tracked supportsAddFootnote = this.args.data.supportsFastEdit; - @tracked - channel = - `/discourse-ai/ai-helper/stream_suggestion/${this.args.data.quoteState.postId}`; @tracked smoothStreamer = new SmoothStreamer( @@ -150,19 +146,25 @@ export default class AiPostHelperMenu extends Component { return sanitize(text); } - @bind + set progressChannel(value) { + if (this._progressChannel) { + this.unsubscribe(); + } + this._progressChannel = value; + this.subscribe(); + } + subscribe() { - this.messageBus.subscribe( - this.channel, - (data) => this._updateResult(data), - this.args.data.post - .discourse_ai_helper_stream_suggestion_last_message_bus_id - ); + this.messageBus.subscribe(this._progressChannel, this._updateResult, 0); } @bind unsubscribe() { - this.messageBus.unsubscribe(this.channel, this._updateResult); + if (!this._progressChannel) { + return; + } + this.messageBus.unsubscribe(this._progressChannel, this._updateResult); + this._progressChannel = null; } @bind @@ -182,33 +184,38 @@ export default class AiPostHelperMenu extends Component { this.lastSelectedOption = option; const streamableOptions = ["explain", "translate", "custom_prompt"]; - if (streamableOptions.includes(option.name)) { - return this._handleStreamedResult(option); - } else { - this._activeAiRequest = ajax("/discourse-ai/ai-helper/suggest", { - method: "POST", - data: { - mode: option.name, - text: this.args.data.quoteState.buffer, - custom_prompt: this.customPromptValue, - }, - }); + try { + if (streamableOptions.includes(option.name)) { + const streamedResult = await this._handleStreamedResult(option); + this.progressChannel = streamedResult.progress_channel; + return; + } else { + this._activeAiRequest = ajax("/discourse-ai/ai-helper/suggest", { + method: "POST", + data: { + mode: option.name, + text: this.args.data.quoteState.buffer, + custom_prompt: this.customPromptValue, + }, + }); + } + + this._activeAiRequest + .then(({ suggestions }) => { + this.suggestion = suggestions[0].trim(); + + if (option.name === "proofread") { + return this._handleProofreadOption(); + } + }) + .finally(() => { + this.loading = false; + this.menuState = this.MENU_STATES.result; + }); + } catch (error) { + popupAjaxError(error); } - this._activeAiRequest - .then(({ suggestions }) => { - this.suggestion = suggestions[0].trim(); - - if (option.name === "proofread") { - return this._handleProofreadOption(); - } - }) - .catch(popupAjaxError) - .finally(() => { - this.loading = false; - this.menuState = this.MENU_STATES.result; - }); - return this._activeAiRequest; } @@ -340,7 +347,6 @@ export default class AiPostHelperMenu extends Component { {{else if (eq this.menuState this.MENU_STATES.result)}}