FIX: cross talk when in ai helper (#1478)

Previous to this change we reused channels for proofreading progress and
ai helper progress

The new changeset ensures each POST to stream progress gets a dedicated
message bus channel

This fixes a class of issues where the wrong information could be displayed
to end users on subsequent proofreading or helper calls

* fix tests

* fix implementation (got to subscribe at 0)
This commit is contained in:
Sam 2025-07-01 18:02:16 +10:00 committed by GitHub
parent 897f31e564
commit 40fa527633
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 206 additions and 106 deletions

View File

@ -124,6 +124,9 @@ module DiscourseAi
# otherwise we may end up streaming the data to the wrong client
raise Discourse::InvalidParameters.new(:client_id) if params[:client_id].blank?
channel_id = next_channel_id
progress_channel = "discourse_ai_helper/stream_suggestions/#{channel_id}"
if location == "composer"
Jobs.enqueue(
:stream_composer_helper,
@ -133,6 +136,7 @@ module DiscourseAi
custom_prompt: params[:custom_prompt],
force_default_locale: params[:force_default_locale] || false,
client_id: params[:client_id],
progress_channel:,
)
else
post_id = get_post_param!
@ -148,10 +152,11 @@ module DiscourseAi
prompt: params[:mode],
custom_prompt: params[:custom_prompt],
client_id: params[:client_id],
progress_channel:,
)
end
render json: { success: true }, status: 200
render json: { success: true, progress_channel: }, status: 200
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
status: 502
@ -192,6 +197,18 @@ module DiscourseAi
private
CHANNEL_ID_KEY = "discourse_ai_helper_next_channel_id"
def next_channel_id
Discourse
.redis
.pipelined do |pipeline|
pipeline.incr(CHANNEL_ID_KEY)
pipeline.expire(CHANNEL_ID_KEY, 1.day)
end
.first
end
def get_text_param!
params[:text].tap { |t| raise Discourse::InvalidParameters.new(:text) if t.blank? }
end

View File

@ -9,6 +9,7 @@ module Jobs
return unless user = User.find_by(id: args[:user_id])
return unless args[:text]
return unless args[:client_id]
return unless args[:progress_channel]
helper_mode = args[:prompt]
@ -16,7 +17,7 @@ module Jobs
helper_mode,
args[:text],
user,
"/discourse-ai/ai-helper/stream_composer_suggestion",
args[:progress_channel],
force_default_locale: args[:force_default_locale],
client_id: args[:client_id],
custom_prompt: args[:custom_prompt],

View File

@ -8,6 +8,8 @@ module Jobs
return unless post = Post.includes(:topic).find_by(id: args[:post_id])
return unless user = User.find_by(id: args[:user_id])
return unless args[:text]
return unless args[:progress_channel]
return unless args[:client_id]
topic = post.topic
reply_to = post.reply_to_post
@ -31,8 +33,9 @@ module Jobs
helper_mode,
input,
user,
"/discourse-ai/ai-helper/stream_suggestion/#{post.id}",
args[:progress_channel],
custom_prompt: args[:custom_prompt],
client_id: args[:client_id],
)
end
end

View File

@ -1,7 +1,6 @@
import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking";
import { action } from "@ember/object";
import didInsert from "@ember/render-modifiers/modifiers/did-insert";
import willDestroy from "@ember/render-modifiers/modifiers/will-destroy";
import { service } from "@ember/service";
import { modifier } from "ember-modifier";
@ -43,9 +42,6 @@ export default class AiPostHelperMenu extends Component {
@tracked lastSelectedOption = null;
@tracked isSavingFootnote = false;
@tracked supportsAddFootnote = this.args.data.supportsFastEdit;
@tracked
channel =
`/discourse-ai/ai-helper/stream_suggestion/${this.args.data.quoteState.postId}`;
@tracked
smoothStreamer = new SmoothStreamer(
@ -150,19 +146,25 @@ export default class AiPostHelperMenu extends Component {
return sanitize(text);
}
@bind
set progressChannel(value) {
if (this._progressChannel) {
this.unsubscribe();
}
this._progressChannel = value;
this.subscribe();
}
subscribe() {
this.messageBus.subscribe(
this.channel,
(data) => this._updateResult(data),
this.args.data.post
.discourse_ai_helper_stream_suggestion_last_message_bus_id
);
this.messageBus.subscribe(this._progressChannel, this._updateResult, 0);
}
@bind
unsubscribe() {
this.messageBus.unsubscribe(this.channel, this._updateResult);
if (!this._progressChannel) {
return;
}
this.messageBus.unsubscribe(this._progressChannel, this._updateResult);
this._progressChannel = null;
}
@bind
@ -182,33 +184,38 @@ export default class AiPostHelperMenu extends Component {
this.lastSelectedOption = option;
const streamableOptions = ["explain", "translate", "custom_prompt"];
if (streamableOptions.includes(option.name)) {
return this._handleStreamedResult(option);
} else {
this._activeAiRequest = ajax("/discourse-ai/ai-helper/suggest", {
method: "POST",
data: {
mode: option.name,
text: this.args.data.quoteState.buffer,
custom_prompt: this.customPromptValue,
},
});
try {
if (streamableOptions.includes(option.name)) {
const streamedResult = await this._handleStreamedResult(option);
this.progressChannel = streamedResult.progress_channel;
return;
} else {
this._activeAiRequest = ajax("/discourse-ai/ai-helper/suggest", {
method: "POST",
data: {
mode: option.name,
text: this.args.data.quoteState.buffer,
custom_prompt: this.customPromptValue,
},
});
}
this._activeAiRequest
.then(({ suggestions }) => {
this.suggestion = suggestions[0].trim();
if (option.name === "proofread") {
return this._handleProofreadOption();
}
})
.finally(() => {
this.loading = false;
this.menuState = this.MENU_STATES.result;
});
} catch (error) {
popupAjaxError(error);
}
this._activeAiRequest
.then(({ suggestions }) => {
this.suggestion = suggestions[0].trim();
if (option.name === "proofread") {
return this._handleProofreadOption();
}
})
.catch(popupAjaxError)
.finally(() => {
this.loading = false;
this.menuState = this.MENU_STATES.result;
});
return this._activeAiRequest;
}
@ -340,7 +347,6 @@ export default class AiPostHelperMenu extends Component {
{{else if (eq this.menuState this.MENU_STATES.result)}}
<div
class="ai-post-helper__suggestion"
{{didInsert this.subscribe}}
{{willDestroy this.unsubscribe}}
>
{{#if this.suggestion}}

View File

@ -1,7 +1,6 @@
import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking";
import { action } from "@ember/object";
import didInsert from "@ember/render-modifiers/modifiers/did-insert";
import willDestroy from "@ember/render-modifiers/modifiers/will-destroy";
import { service } from "@ember/service";
import { htmlSafe } from "@ember/template";
@ -19,8 +18,6 @@ import DiffStreamer from "../../lib/diff-streamer";
import SmoothStreamer from "../../lib/smooth-streamer";
import AiIndicatorWave from "../ai-indicator-wave";
const CHANNEL = "/discourse-ai/ai-helper/stream_composer_suggestion";
export default class ModalDiffModal extends Component {
@service currentUser;
@service messageBus;
@ -83,21 +80,26 @@ export default class ModalDiffModal extends Component {
return this.loading || this.isStreaming;
}
@bind
set progressChannel(value) {
if (this._progressChannel) {
this.messageBus.unsubscribe(this._progressChannel, this.updateResult);
}
this._progressChannel = value;
this.subscribe();
}
subscribe() {
this.messageBus.subscribe(
CHANNEL,
this.updateResult,
this.currentUser
?.discourse_ai_helper_stream_composer_suggestion_last_message_bus_id
);
// we have 1 channel per operation so we can safely subscribe at head
this.messageBus.subscribe(this._progressChannel, this.updateResult, 0);
}
@bind
cleanup() {
// stop all callbacks so it does not end up streaming pointlessly
this.#resetState();
this.messageBus.unsubscribe(CHANNEL, this.updateResult);
if (this._progressChannel) {
this.messageBus.unsubscribe(this._progressChannel, this.updateResult);
}
}
@action
@ -122,7 +124,7 @@ export default class ModalDiffModal extends Component {
try {
this.loading = true;
return await ajax("/discourse-ai/ai-helper/stream_suggestion", {
const result = await ajax("/discourse-ai/ai-helper/stream_suggestion", {
method: "POST",
data: {
location: "composer",
@ -133,6 +135,8 @@ export default class ModalDiffModal extends Component {
client_id: this.messageBus.clientId,
},
});
this.progressChannel = result.progress_channel;
} catch (e) {
popupAjaxError(e);
}
@ -183,11 +187,7 @@ export default class ModalDiffModal extends Component {
@closeModal={{this.cleanupAndClose}}
>
<:body>
<div
{{didInsert this.subscribe}}
{{willDestroy this.cleanup}}
class="text-preview"
>
<div {{willDestroy this.cleanup}} class="text-preview">
<div
class={{concatClass
"composer-ai-helper-modal__suggestion"

View File

@ -73,18 +73,6 @@ module DiscourseAi
scope.user.in_any_groups?(SiteSetting.ai_auto_image_caption_allowed_groups_map)
end,
) { object.auto_image_caption }
plugin.add_to_serializer(
:post,
:discourse_ai_helper_stream_suggestion_last_message_bus_id,
include_condition: -> { SiteSetting.ai_helper_enabled && scope.authenticated? },
) { MessageBus.last_id("/discourse-ai/ai-helper/stream_suggestion/#{object.id}") }
plugin.add_to_serializer(
:current_user,
:discourse_ai_helper_stream_composer_suggestion_last_message_bus_id,
include_condition: -> { SiteSetting.ai_helper_enabled && scope.authenticated? },
) { MessageBus.last_id("/discourse-ai/ai-helper/stream_composer_suggestion") }
end
end
end

View File

@ -18,23 +18,33 @@ RSpec.describe Jobs::StreamComposerHelper do
let(:mode) { DiscourseAi::AiHelper::Assistant::PROOFREAD }
it "does nothing if there is no user" do
channel = "/some/channel"
messages =
MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion") do
job.execute(user_id: nil, text: input, prompt: mode, force_default_locale: false)
MessageBus.track_publish(channel) do
job.execute(
user_id: nil,
text: input,
prompt: mode,
force_default_locale: false,
client_id: "123",
progress_channel: channel,
)
end
expect(messages).to be_empty
end
it "does nothing if there is no text" do
channel = "/some/channel"
messages =
MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion") do
MessageBus.track_publish(channel) do
job.execute(
user_id: user.id,
text: nil,
prompt: mode,
force_default_locale: false,
client_id: "123",
progress_channel: channel,
)
end
@ -47,16 +57,18 @@ RSpec.describe Jobs::StreamComposerHelper do
it "publishes updates with a partial result" do
proofread_result = "I like to eat pie for breakfast because it is delicious."
channel = "/channel/123"
DiscourseAi::Completions::Llm.with_prepared_responses([proofread_result]) do
messages =
MessageBus.track_publish("/discourse-ai/ai-helper/stream_composer_suggestion") do
MessageBus.track_publish(channel) do
job.execute(
user_id: user.id,
text: input,
prompt: mode,
force_default_locale: true,
client_id: "123",
progress_channel: channel,
)
end
@ -68,16 +80,18 @@ RSpec.describe Jobs::StreamComposerHelper do
it "publishes a final update to signal we're done" do
proofread_result = "I like to eat pie for breakfast because it is delicious."
channel = "/channel/123"
DiscourseAi::Completions::Llm.with_prepared_responses([proofread_result]) do
messages =
MessageBus.track_publish("/discourse-ai/ai-helper/stream_composer_suggestion") do
MessageBus.track_publish(channel) do
job.execute(
user_id: user.id,
text: input,
prompt: mode,
force_default_locale: true,
client_id: "123",
progress_channel: channel,
)
end

View File

@ -60,10 +60,18 @@ RSpec.describe Jobs::StreamPostHelper do
explanation =
"In this context, \"pie\" refers to a baked dessert typically consisting of a pastry crust and filling."
channel = "/my/channel"
DiscourseAi::Completions::Llm.with_prepared_responses([explanation]) do
messages =
MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion/#{post.id}") do
job.execute(post_id: post.id, user_id: user.id, text: "pie", prompt: mode)
MessageBus.track_publish(channel) do
job.execute(
post_id: post.id,
user_id: user.id,
text: "pie",
prompt: mode,
progress_channel: channel,
client_id: "test_client_id",
)
end
partial_result_update = messages.first.data
@ -76,10 +84,19 @@ RSpec.describe Jobs::StreamPostHelper do
explanation =
"In this context, \"pie\" refers to a baked dessert typically consisting of a pastry crust and filling."
channel = "/my/channel"
DiscourseAi::Completions::Llm.with_prepared_responses([explanation]) do
messages =
MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion/#{post.id}") do
job.execute(post_id: post.id, user_id: user.id, text: "pie", prompt: mode)
MessageBus.track_publish(channel) do
job.execute(
post_id: post.id,
user_id: user.id,
text: "pie",
prompt: mode,
client_id: "test_client_id",
progress_channel: channel,
)
end
final_update = messages.last.data
@ -96,10 +113,18 @@ RSpec.describe Jobs::StreamPostHelper do
sentence = "I like to eat pie."
translation = "Me gusta comer pastel."
channel = "/my/channel"
DiscourseAi::Completions::Llm.with_prepared_responses([translation]) do
messages =
MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion/#{post.id}") do
job.execute(post_id: post.id, user_id: user.id, text: sentence, prompt: mode)
MessageBus.track_publish(channel) do
job.execute(
post_id: post.id,
user_id: user.id,
text: sentence,
prompt: mode,
progress_channel: channel,
client_id: "test_client_id",
)
end
partial_result_update = messages.first.data
@ -111,11 +136,19 @@ RSpec.describe Jobs::StreamPostHelper do
it "publishes a final update to signal we're done" do
sentence = "I like to eat pie."
translation = "Me gusta comer pastel."
channel = "/my/channel"
DiscourseAi::Completions::Llm.with_prepared_responses([translation]) do
messages =
MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion/#{post.id}") do
job.execute(post_id: post.id, user_id: user.id, text: sentence, prompt: mode)
MessageBus.track_publish(channel) do
job.execute(
post_id: post.id,
user_id: user.id,
text: sentence,
prompt: mode,
progress_channel: channel,
client_id: "test_client_id",
)
end
final_update = messages.last.data

View File

@ -11,10 +11,50 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
SiteSetting.composer_ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_0]
end
it "is able to stream suggestions back on appropriate channel" do
it "is able to stream suggestions to helper" do
sign_in(user)
my_post = Fabricate(:post)
channel = nil
messages =
MessageBus.track_publish("/discourse-ai/ai-helper/stream_composer_suggestion") do
MessageBus.track_publish do
results = [["hello ", "world"]]
DiscourseAi::Completions::Llm.with_prepared_responses(results) do
post "/discourse-ai/ai-helper/stream_suggestion.json",
params: {
text: "hello wrld",
location: "helper",
client_id: "1234",
post_id: my_post.id,
custom_prompt: "Translate to Spanish",
mode: DiscourseAi::AiHelper::Assistant::CUSTOM_PROMPT,
}
expect(response.status).to eq(200)
channel = response.parsed_body["progress_channel"]
end
end
# we only have the channel after we make the request
# so we can not filter till now
messages = messages.select { |m| m.channel == channel }
expect(messages).not_to be_empty
last_message = messages.last
expect(messages.all? { |m| m.client_ids == ["1234"] }).to eq(true)
expect(messages.all? { |m| m == last_message || !m.data[:done] }).to eq(true)
expect(last_message.channel).to eq(channel)
expect(last_message.data[:result]).to eq("hello world")
expect(last_message.data[:done]).to eq(true)
end
it "is able to stream suggestions to composer" do
sign_in(user)
channel = nil
messages =
MessageBus.track_publish do
results = [["hello ", "world"]]
DiscourseAi::Completions::Llm.with_prepared_responses(results) do
post "/discourse-ai/ai-helper/stream_suggestion.json",
@ -26,13 +66,19 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
}
expect(response.status).to eq(200)
channel = response.parsed_body["progress_channel"]
end
end
# we only have the channel after we make the request
# so we can not filter till now
messages = messages.select { |m| m.channel == channel }
last_message = messages.last
expect(messages.all? { |m| m.client_ids == ["1234"] }).to eq(true)
expect(messages.all? { |m| m == last_message || !m.data[:done] }).to eq(true)
expect(last_message.channel).to eq(channel)
expect(last_message.data[:result]).to eq("hello world")
expect(last_message.data[:done]).to eq(true)
end

View File

@ -48,6 +48,7 @@ acceptance("AI Helper - Post Helper Menu", function (needs) {
return helper.response({
result: "This is a suggestio",
done: false,
progress_channel: "/some/progress/channel",
});
});
@ -61,13 +62,10 @@ acceptance("AI Helper - Post Helper Menu", function (needs) {
await selectText(textNode, 9);
await click(".ai-post-helper__trigger");
await click(".ai-helper-options__button[data-name='explain']");
await publishToMessageBus(
`/discourse-ai/ai-helper/stream_suggestion/118591`,
{
done: true,
result: suggestion,
}
);
await publishToMessageBus(`/some/progress/channel`, {
done: true,
result: suggestion,
});
assert.dom(".ai-post-helper__suggestion__text").hasText(suggestion);
});
@ -91,13 +89,10 @@ acceptance("AI Helper - Post Helper Menu", function (needs) {
await selectSpecificText(textNode, 72, 77);
await click(".ai-post-helper__trigger");
await click(".ai-helper-options__button[data-name='explain']");
await publishToMessageBus(
`/discourse-ai/ai-helper/stream_suggestion/118591`,
{
done: true,
result: suggestion,
}
);
await publishToMessageBus(`/some/progress/channel`, {
done: true,
result: suggestion,
});
assert.dom(".ai-post-helper__suggestion__insert-footnote").isDisabled();
});
@ -108,13 +103,10 @@ acceptance("AI Helper - Post Helper Menu", function (needs) {
await selectText(query("#post_1 .cooked p"));
await click(".ai-post-helper__trigger");
await click(".ai-helper-options__button[data-name='translate']");
await publishToMessageBus(
`/discourse-ai/ai-helper/stream_suggestion/118591`,
{
done: true,
result: translated,
}
);
await publishToMessageBus(`/some/progress/channel`, {
done: true,
result: translated,
});
assert.dom(".ai-post-helper__suggestion__text").hasText(translated);
});
});