mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-07-05 05:52:16 +00:00
FIX: cross talk when in ai helper (#1478)
Previous to this change we reused channels for proofreading progress and ai helper progress The new changeset ensures each POST to stream progress gets a dedicated message bus channel This fixes a class of issues where the wrong information could be displayed to end users on subsequent proofreading or helper calls * fix tests * fix implementation (got to subscribe at 0)
This commit is contained in:
parent
897f31e564
commit
40fa527633
@ -124,6 +124,9 @@ module DiscourseAi
|
|||||||
# otherwise we may end up streaming the data to the wrong client
|
# otherwise we may end up streaming the data to the wrong client
|
||||||
raise Discourse::InvalidParameters.new(:client_id) if params[:client_id].blank?
|
raise Discourse::InvalidParameters.new(:client_id) if params[:client_id].blank?
|
||||||
|
|
||||||
|
channel_id = next_channel_id
|
||||||
|
progress_channel = "discourse_ai_helper/stream_suggestions/#{channel_id}"
|
||||||
|
|
||||||
if location == "composer"
|
if location == "composer"
|
||||||
Jobs.enqueue(
|
Jobs.enqueue(
|
||||||
:stream_composer_helper,
|
:stream_composer_helper,
|
||||||
@ -133,6 +136,7 @@ module DiscourseAi
|
|||||||
custom_prompt: params[:custom_prompt],
|
custom_prompt: params[:custom_prompt],
|
||||||
force_default_locale: params[:force_default_locale] || false,
|
force_default_locale: params[:force_default_locale] || false,
|
||||||
client_id: params[:client_id],
|
client_id: params[:client_id],
|
||||||
|
progress_channel:,
|
||||||
)
|
)
|
||||||
else
|
else
|
||||||
post_id = get_post_param!
|
post_id = get_post_param!
|
||||||
@ -148,10 +152,11 @@ module DiscourseAi
|
|||||||
prompt: params[:mode],
|
prompt: params[:mode],
|
||||||
custom_prompt: params[:custom_prompt],
|
custom_prompt: params[:custom_prompt],
|
||||||
client_id: params[:client_id],
|
client_id: params[:client_id],
|
||||||
|
progress_channel:,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
render json: { success: true }, status: 200
|
render json: { success: true, progress_channel: }, status: 200
|
||||||
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
|
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
|
||||||
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
|
render_json_error I18n.t("discourse_ai.ai_helper.errors.completion_request_failed"),
|
||||||
status: 502
|
status: 502
|
||||||
@ -192,6 +197,18 @@ module DiscourseAi
|
|||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
|
CHANNEL_ID_KEY = "discourse_ai_helper_next_channel_id"
|
||||||
|
|
||||||
|
def next_channel_id
|
||||||
|
Discourse
|
||||||
|
.redis
|
||||||
|
.pipelined do |pipeline|
|
||||||
|
pipeline.incr(CHANNEL_ID_KEY)
|
||||||
|
pipeline.expire(CHANNEL_ID_KEY, 1.day)
|
||||||
|
end
|
||||||
|
.first
|
||||||
|
end
|
||||||
|
|
||||||
def get_text_param!
|
def get_text_param!
|
||||||
params[:text].tap { |t| raise Discourse::InvalidParameters.new(:text) if t.blank? }
|
params[:text].tap { |t| raise Discourse::InvalidParameters.new(:text) if t.blank? }
|
||||||
end
|
end
|
||||||
|
@ -9,6 +9,7 @@ module Jobs
|
|||||||
return unless user = User.find_by(id: args[:user_id])
|
return unless user = User.find_by(id: args[:user_id])
|
||||||
return unless args[:text]
|
return unless args[:text]
|
||||||
return unless args[:client_id]
|
return unless args[:client_id]
|
||||||
|
return unless args[:progress_channel]
|
||||||
|
|
||||||
helper_mode = args[:prompt]
|
helper_mode = args[:prompt]
|
||||||
|
|
||||||
@ -16,7 +17,7 @@ module Jobs
|
|||||||
helper_mode,
|
helper_mode,
|
||||||
args[:text],
|
args[:text],
|
||||||
user,
|
user,
|
||||||
"/discourse-ai/ai-helper/stream_composer_suggestion",
|
args[:progress_channel],
|
||||||
force_default_locale: args[:force_default_locale],
|
force_default_locale: args[:force_default_locale],
|
||||||
client_id: args[:client_id],
|
client_id: args[:client_id],
|
||||||
custom_prompt: args[:custom_prompt],
|
custom_prompt: args[:custom_prompt],
|
||||||
|
@ -8,6 +8,8 @@ module Jobs
|
|||||||
return unless post = Post.includes(:topic).find_by(id: args[:post_id])
|
return unless post = Post.includes(:topic).find_by(id: args[:post_id])
|
||||||
return unless user = User.find_by(id: args[:user_id])
|
return unless user = User.find_by(id: args[:user_id])
|
||||||
return unless args[:text]
|
return unless args[:text]
|
||||||
|
return unless args[:progress_channel]
|
||||||
|
return unless args[:client_id]
|
||||||
|
|
||||||
topic = post.topic
|
topic = post.topic
|
||||||
reply_to = post.reply_to_post
|
reply_to = post.reply_to_post
|
||||||
@ -31,8 +33,9 @@ module Jobs
|
|||||||
helper_mode,
|
helper_mode,
|
||||||
input,
|
input,
|
||||||
user,
|
user,
|
||||||
"/discourse-ai/ai-helper/stream_suggestion/#{post.id}",
|
args[:progress_channel],
|
||||||
custom_prompt: args[:custom_prompt],
|
custom_prompt: args[:custom_prompt],
|
||||||
|
client_id: args[:client_id],
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import Component from "@glimmer/component";
|
import Component from "@glimmer/component";
|
||||||
import { tracked } from "@glimmer/tracking";
|
import { tracked } from "@glimmer/tracking";
|
||||||
import { action } from "@ember/object";
|
import { action } from "@ember/object";
|
||||||
import didInsert from "@ember/render-modifiers/modifiers/did-insert";
|
|
||||||
import willDestroy from "@ember/render-modifiers/modifiers/will-destroy";
|
import willDestroy from "@ember/render-modifiers/modifiers/will-destroy";
|
||||||
import { service } from "@ember/service";
|
import { service } from "@ember/service";
|
||||||
import { modifier } from "ember-modifier";
|
import { modifier } from "ember-modifier";
|
||||||
@ -43,9 +42,6 @@ export default class AiPostHelperMenu extends Component {
|
|||||||
@tracked lastSelectedOption = null;
|
@tracked lastSelectedOption = null;
|
||||||
@tracked isSavingFootnote = false;
|
@tracked isSavingFootnote = false;
|
||||||
@tracked supportsAddFootnote = this.args.data.supportsFastEdit;
|
@tracked supportsAddFootnote = this.args.data.supportsFastEdit;
|
||||||
@tracked
|
|
||||||
channel =
|
|
||||||
`/discourse-ai/ai-helper/stream_suggestion/${this.args.data.quoteState.postId}`;
|
|
||||||
|
|
||||||
@tracked
|
@tracked
|
||||||
smoothStreamer = new SmoothStreamer(
|
smoothStreamer = new SmoothStreamer(
|
||||||
@ -150,19 +146,25 @@ export default class AiPostHelperMenu extends Component {
|
|||||||
return sanitize(text);
|
return sanitize(text);
|
||||||
}
|
}
|
||||||
|
|
||||||
@bind
|
set progressChannel(value) {
|
||||||
|
if (this._progressChannel) {
|
||||||
|
this.unsubscribe();
|
||||||
|
}
|
||||||
|
this._progressChannel = value;
|
||||||
|
this.subscribe();
|
||||||
|
}
|
||||||
|
|
||||||
subscribe() {
|
subscribe() {
|
||||||
this.messageBus.subscribe(
|
this.messageBus.subscribe(this._progressChannel, this._updateResult, 0);
|
||||||
this.channel,
|
|
||||||
(data) => this._updateResult(data),
|
|
||||||
this.args.data.post
|
|
||||||
.discourse_ai_helper_stream_suggestion_last_message_bus_id
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@bind
|
@bind
|
||||||
unsubscribe() {
|
unsubscribe() {
|
||||||
this.messageBus.unsubscribe(this.channel, this._updateResult);
|
if (!this._progressChannel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
this.messageBus.unsubscribe(this._progressChannel, this._updateResult);
|
||||||
|
this._progressChannel = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@bind
|
@bind
|
||||||
@ -182,8 +184,11 @@ export default class AiPostHelperMenu extends Component {
|
|||||||
this.lastSelectedOption = option;
|
this.lastSelectedOption = option;
|
||||||
const streamableOptions = ["explain", "translate", "custom_prompt"];
|
const streamableOptions = ["explain", "translate", "custom_prompt"];
|
||||||
|
|
||||||
|
try {
|
||||||
if (streamableOptions.includes(option.name)) {
|
if (streamableOptions.includes(option.name)) {
|
||||||
return this._handleStreamedResult(option);
|
const streamedResult = await this._handleStreamedResult(option);
|
||||||
|
this.progressChannel = streamedResult.progress_channel;
|
||||||
|
return;
|
||||||
} else {
|
} else {
|
||||||
this._activeAiRequest = ajax("/discourse-ai/ai-helper/suggest", {
|
this._activeAiRequest = ajax("/discourse-ai/ai-helper/suggest", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
@ -203,11 +208,13 @@ export default class AiPostHelperMenu extends Component {
|
|||||||
return this._handleProofreadOption();
|
return this._handleProofreadOption();
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.catch(popupAjaxError)
|
|
||||||
.finally(() => {
|
.finally(() => {
|
||||||
this.loading = false;
|
this.loading = false;
|
||||||
this.menuState = this.MENU_STATES.result;
|
this.menuState = this.MENU_STATES.result;
|
||||||
});
|
});
|
||||||
|
} catch (error) {
|
||||||
|
popupAjaxError(error);
|
||||||
|
}
|
||||||
|
|
||||||
return this._activeAiRequest;
|
return this._activeAiRequest;
|
||||||
}
|
}
|
||||||
@ -340,7 +347,6 @@ export default class AiPostHelperMenu extends Component {
|
|||||||
{{else if (eq this.menuState this.MENU_STATES.result)}}
|
{{else if (eq this.menuState this.MENU_STATES.result)}}
|
||||||
<div
|
<div
|
||||||
class="ai-post-helper__suggestion"
|
class="ai-post-helper__suggestion"
|
||||||
{{didInsert this.subscribe}}
|
|
||||||
{{willDestroy this.unsubscribe}}
|
{{willDestroy this.unsubscribe}}
|
||||||
>
|
>
|
||||||
{{#if this.suggestion}}
|
{{#if this.suggestion}}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import Component from "@glimmer/component";
|
import Component from "@glimmer/component";
|
||||||
import { tracked } from "@glimmer/tracking";
|
import { tracked } from "@glimmer/tracking";
|
||||||
import { action } from "@ember/object";
|
import { action } from "@ember/object";
|
||||||
import didInsert from "@ember/render-modifiers/modifiers/did-insert";
|
|
||||||
import willDestroy from "@ember/render-modifiers/modifiers/will-destroy";
|
import willDestroy from "@ember/render-modifiers/modifiers/will-destroy";
|
||||||
import { service } from "@ember/service";
|
import { service } from "@ember/service";
|
||||||
import { htmlSafe } from "@ember/template";
|
import { htmlSafe } from "@ember/template";
|
||||||
@ -19,8 +18,6 @@ import DiffStreamer from "../../lib/diff-streamer";
|
|||||||
import SmoothStreamer from "../../lib/smooth-streamer";
|
import SmoothStreamer from "../../lib/smooth-streamer";
|
||||||
import AiIndicatorWave from "../ai-indicator-wave";
|
import AiIndicatorWave from "../ai-indicator-wave";
|
||||||
|
|
||||||
const CHANNEL = "/discourse-ai/ai-helper/stream_composer_suggestion";
|
|
||||||
|
|
||||||
export default class ModalDiffModal extends Component {
|
export default class ModalDiffModal extends Component {
|
||||||
@service currentUser;
|
@service currentUser;
|
||||||
@service messageBus;
|
@service messageBus;
|
||||||
@ -83,21 +80,26 @@ export default class ModalDiffModal extends Component {
|
|||||||
return this.loading || this.isStreaming;
|
return this.loading || this.isStreaming;
|
||||||
}
|
}
|
||||||
|
|
||||||
@bind
|
set progressChannel(value) {
|
||||||
|
if (this._progressChannel) {
|
||||||
|
this.messageBus.unsubscribe(this._progressChannel, this.updateResult);
|
||||||
|
}
|
||||||
|
this._progressChannel = value;
|
||||||
|
this.subscribe();
|
||||||
|
}
|
||||||
|
|
||||||
subscribe() {
|
subscribe() {
|
||||||
this.messageBus.subscribe(
|
// we have 1 channel per operation so we can safely subscribe at head
|
||||||
CHANNEL,
|
this.messageBus.subscribe(this._progressChannel, this.updateResult, 0);
|
||||||
this.updateResult,
|
|
||||||
this.currentUser
|
|
||||||
?.discourse_ai_helper_stream_composer_suggestion_last_message_bus_id
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@bind
|
@bind
|
||||||
cleanup() {
|
cleanup() {
|
||||||
// stop all callbacks so it does not end up streaming pointlessly
|
// stop all callbacks so it does not end up streaming pointlessly
|
||||||
this.#resetState();
|
this.#resetState();
|
||||||
this.messageBus.unsubscribe(CHANNEL, this.updateResult);
|
if (this._progressChannel) {
|
||||||
|
this.messageBus.unsubscribe(this._progressChannel, this.updateResult);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@action
|
@action
|
||||||
@ -122,7 +124,7 @@ export default class ModalDiffModal extends Component {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
this.loading = true;
|
this.loading = true;
|
||||||
return await ajax("/discourse-ai/ai-helper/stream_suggestion", {
|
const result = await ajax("/discourse-ai/ai-helper/stream_suggestion", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
data: {
|
data: {
|
||||||
location: "composer",
|
location: "composer",
|
||||||
@ -133,6 +135,8 @@ export default class ModalDiffModal extends Component {
|
|||||||
client_id: this.messageBus.clientId,
|
client_id: this.messageBus.clientId,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
this.progressChannel = result.progress_channel;
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
popupAjaxError(e);
|
popupAjaxError(e);
|
||||||
}
|
}
|
||||||
@ -183,11 +187,7 @@ export default class ModalDiffModal extends Component {
|
|||||||
@closeModal={{this.cleanupAndClose}}
|
@closeModal={{this.cleanupAndClose}}
|
||||||
>
|
>
|
||||||
<:body>
|
<:body>
|
||||||
<div
|
<div {{willDestroy this.cleanup}} class="text-preview">
|
||||||
{{didInsert this.subscribe}}
|
|
||||||
{{willDestroy this.cleanup}}
|
|
||||||
class="text-preview"
|
|
||||||
>
|
|
||||||
<div
|
<div
|
||||||
class={{concatClass
|
class={{concatClass
|
||||||
"composer-ai-helper-modal__suggestion"
|
"composer-ai-helper-modal__suggestion"
|
||||||
|
@ -73,18 +73,6 @@ module DiscourseAi
|
|||||||
scope.user.in_any_groups?(SiteSetting.ai_auto_image_caption_allowed_groups_map)
|
scope.user.in_any_groups?(SiteSetting.ai_auto_image_caption_allowed_groups_map)
|
||||||
end,
|
end,
|
||||||
) { object.auto_image_caption }
|
) { object.auto_image_caption }
|
||||||
|
|
||||||
plugin.add_to_serializer(
|
|
||||||
:post,
|
|
||||||
:discourse_ai_helper_stream_suggestion_last_message_bus_id,
|
|
||||||
include_condition: -> { SiteSetting.ai_helper_enabled && scope.authenticated? },
|
|
||||||
) { MessageBus.last_id("/discourse-ai/ai-helper/stream_suggestion/#{object.id}") }
|
|
||||||
|
|
||||||
plugin.add_to_serializer(
|
|
||||||
:current_user,
|
|
||||||
:discourse_ai_helper_stream_composer_suggestion_last_message_bus_id,
|
|
||||||
include_condition: -> { SiteSetting.ai_helper_enabled && scope.authenticated? },
|
|
||||||
) { MessageBus.last_id("/discourse-ai/ai-helper/stream_composer_suggestion") }
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -18,23 +18,33 @@ RSpec.describe Jobs::StreamComposerHelper do
|
|||||||
let(:mode) { DiscourseAi::AiHelper::Assistant::PROOFREAD }
|
let(:mode) { DiscourseAi::AiHelper::Assistant::PROOFREAD }
|
||||||
|
|
||||||
it "does nothing if there is no user" do
|
it "does nothing if there is no user" do
|
||||||
|
channel = "/some/channel"
|
||||||
messages =
|
messages =
|
||||||
MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion") do
|
MessageBus.track_publish(channel) do
|
||||||
job.execute(user_id: nil, text: input, prompt: mode, force_default_locale: false)
|
job.execute(
|
||||||
|
user_id: nil,
|
||||||
|
text: input,
|
||||||
|
prompt: mode,
|
||||||
|
force_default_locale: false,
|
||||||
|
client_id: "123",
|
||||||
|
progress_channel: channel,
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
expect(messages).to be_empty
|
expect(messages).to be_empty
|
||||||
end
|
end
|
||||||
|
|
||||||
it "does nothing if there is no text" do
|
it "does nothing if there is no text" do
|
||||||
|
channel = "/some/channel"
|
||||||
messages =
|
messages =
|
||||||
MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion") do
|
MessageBus.track_publish(channel) do
|
||||||
job.execute(
|
job.execute(
|
||||||
user_id: user.id,
|
user_id: user.id,
|
||||||
text: nil,
|
text: nil,
|
||||||
prompt: mode,
|
prompt: mode,
|
||||||
force_default_locale: false,
|
force_default_locale: false,
|
||||||
client_id: "123",
|
client_id: "123",
|
||||||
|
progress_channel: channel,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -47,16 +57,18 @@ RSpec.describe Jobs::StreamComposerHelper do
|
|||||||
|
|
||||||
it "publishes updates with a partial result" do
|
it "publishes updates with a partial result" do
|
||||||
proofread_result = "I like to eat pie for breakfast because it is delicious."
|
proofread_result = "I like to eat pie for breakfast because it is delicious."
|
||||||
|
channel = "/channel/123"
|
||||||
|
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses([proofread_result]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses([proofread_result]) do
|
||||||
messages =
|
messages =
|
||||||
MessageBus.track_publish("/discourse-ai/ai-helper/stream_composer_suggestion") do
|
MessageBus.track_publish(channel) do
|
||||||
job.execute(
|
job.execute(
|
||||||
user_id: user.id,
|
user_id: user.id,
|
||||||
text: input,
|
text: input,
|
||||||
prompt: mode,
|
prompt: mode,
|
||||||
force_default_locale: true,
|
force_default_locale: true,
|
||||||
client_id: "123",
|
client_id: "123",
|
||||||
|
progress_channel: channel,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -68,16 +80,18 @@ RSpec.describe Jobs::StreamComposerHelper do
|
|||||||
|
|
||||||
it "publishes a final update to signal we're done" do
|
it "publishes a final update to signal we're done" do
|
||||||
proofread_result = "I like to eat pie for breakfast because it is delicious."
|
proofread_result = "I like to eat pie for breakfast because it is delicious."
|
||||||
|
channel = "/channel/123"
|
||||||
|
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses([proofread_result]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses([proofread_result]) do
|
||||||
messages =
|
messages =
|
||||||
MessageBus.track_publish("/discourse-ai/ai-helper/stream_composer_suggestion") do
|
MessageBus.track_publish(channel) do
|
||||||
job.execute(
|
job.execute(
|
||||||
user_id: user.id,
|
user_id: user.id,
|
||||||
text: input,
|
text: input,
|
||||||
prompt: mode,
|
prompt: mode,
|
||||||
force_default_locale: true,
|
force_default_locale: true,
|
||||||
client_id: "123",
|
client_id: "123",
|
||||||
|
progress_channel: channel,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -60,10 +60,18 @@ RSpec.describe Jobs::StreamPostHelper do
|
|||||||
explanation =
|
explanation =
|
||||||
"In this context, \"pie\" refers to a baked dessert typically consisting of a pastry crust and filling."
|
"In this context, \"pie\" refers to a baked dessert typically consisting of a pastry crust and filling."
|
||||||
|
|
||||||
|
channel = "/my/channel"
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses([explanation]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses([explanation]) do
|
||||||
messages =
|
messages =
|
||||||
MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion/#{post.id}") do
|
MessageBus.track_publish(channel) do
|
||||||
job.execute(post_id: post.id, user_id: user.id, text: "pie", prompt: mode)
|
job.execute(
|
||||||
|
post_id: post.id,
|
||||||
|
user_id: user.id,
|
||||||
|
text: "pie",
|
||||||
|
prompt: mode,
|
||||||
|
progress_channel: channel,
|
||||||
|
client_id: "test_client_id",
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
partial_result_update = messages.first.data
|
partial_result_update = messages.first.data
|
||||||
@ -76,10 +84,19 @@ RSpec.describe Jobs::StreamPostHelper do
|
|||||||
explanation =
|
explanation =
|
||||||
"In this context, \"pie\" refers to a baked dessert typically consisting of a pastry crust and filling."
|
"In this context, \"pie\" refers to a baked dessert typically consisting of a pastry crust and filling."
|
||||||
|
|
||||||
|
channel = "/my/channel"
|
||||||
|
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses([explanation]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses([explanation]) do
|
||||||
messages =
|
messages =
|
||||||
MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion/#{post.id}") do
|
MessageBus.track_publish(channel) do
|
||||||
job.execute(post_id: post.id, user_id: user.id, text: "pie", prompt: mode)
|
job.execute(
|
||||||
|
post_id: post.id,
|
||||||
|
user_id: user.id,
|
||||||
|
text: "pie",
|
||||||
|
prompt: mode,
|
||||||
|
client_id: "test_client_id",
|
||||||
|
progress_channel: channel,
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
final_update = messages.last.data
|
final_update = messages.last.data
|
||||||
@ -96,10 +113,18 @@ RSpec.describe Jobs::StreamPostHelper do
|
|||||||
sentence = "I like to eat pie."
|
sentence = "I like to eat pie."
|
||||||
translation = "Me gusta comer pastel."
|
translation = "Me gusta comer pastel."
|
||||||
|
|
||||||
|
channel = "/my/channel"
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses([translation]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses([translation]) do
|
||||||
messages =
|
messages =
|
||||||
MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion/#{post.id}") do
|
MessageBus.track_publish(channel) do
|
||||||
job.execute(post_id: post.id, user_id: user.id, text: sentence, prompt: mode)
|
job.execute(
|
||||||
|
post_id: post.id,
|
||||||
|
user_id: user.id,
|
||||||
|
text: sentence,
|
||||||
|
prompt: mode,
|
||||||
|
progress_channel: channel,
|
||||||
|
client_id: "test_client_id",
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
partial_result_update = messages.first.data
|
partial_result_update = messages.first.data
|
||||||
@ -111,11 +136,19 @@ RSpec.describe Jobs::StreamPostHelper do
|
|||||||
it "publishes a final update to signal we're done" do
|
it "publishes a final update to signal we're done" do
|
||||||
sentence = "I like to eat pie."
|
sentence = "I like to eat pie."
|
||||||
translation = "Me gusta comer pastel."
|
translation = "Me gusta comer pastel."
|
||||||
|
channel = "/my/channel"
|
||||||
|
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses([translation]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses([translation]) do
|
||||||
messages =
|
messages =
|
||||||
MessageBus.track_publish("/discourse-ai/ai-helper/stream_suggestion/#{post.id}") do
|
MessageBus.track_publish(channel) do
|
||||||
job.execute(post_id: post.id, user_id: user.id, text: sentence, prompt: mode)
|
job.execute(
|
||||||
|
post_id: post.id,
|
||||||
|
user_id: user.id,
|
||||||
|
text: sentence,
|
||||||
|
prompt: mode,
|
||||||
|
progress_channel: channel,
|
||||||
|
client_id: "test_client_id",
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
final_update = messages.last.data
|
final_update = messages.last.data
|
||||||
|
@ -11,10 +11,50 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
|||||||
SiteSetting.composer_ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_0]
|
SiteSetting.composer_ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_0]
|
||||||
end
|
end
|
||||||
|
|
||||||
it "is able to stream suggestions back on appropriate channel" do
|
it "is able to stream suggestions to helper" do
|
||||||
sign_in(user)
|
sign_in(user)
|
||||||
|
|
||||||
|
my_post = Fabricate(:post)
|
||||||
|
|
||||||
|
channel = nil
|
||||||
messages =
|
messages =
|
||||||
MessageBus.track_publish("/discourse-ai/ai-helper/stream_composer_suggestion") do
|
MessageBus.track_publish do
|
||||||
|
results = [["hello ", "world"]]
|
||||||
|
DiscourseAi::Completions::Llm.with_prepared_responses(results) do
|
||||||
|
post "/discourse-ai/ai-helper/stream_suggestion.json",
|
||||||
|
params: {
|
||||||
|
text: "hello wrld",
|
||||||
|
location: "helper",
|
||||||
|
client_id: "1234",
|
||||||
|
post_id: my_post.id,
|
||||||
|
custom_prompt: "Translate to Spanish",
|
||||||
|
mode: DiscourseAi::AiHelper::Assistant::CUSTOM_PROMPT,
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(response.status).to eq(200)
|
||||||
|
channel = response.parsed_body["progress_channel"]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# we only have the channel after we make the request
|
||||||
|
# so we can not filter till now
|
||||||
|
messages = messages.select { |m| m.channel == channel }
|
||||||
|
expect(messages).not_to be_empty
|
||||||
|
|
||||||
|
last_message = messages.last
|
||||||
|
expect(messages.all? { |m| m.client_ids == ["1234"] }).to eq(true)
|
||||||
|
expect(messages.all? { |m| m == last_message || !m.data[:done] }).to eq(true)
|
||||||
|
|
||||||
|
expect(last_message.channel).to eq(channel)
|
||||||
|
expect(last_message.data[:result]).to eq("hello world")
|
||||||
|
expect(last_message.data[:done]).to eq(true)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "is able to stream suggestions to composer" do
|
||||||
|
sign_in(user)
|
||||||
|
channel = nil
|
||||||
|
messages =
|
||||||
|
MessageBus.track_publish do
|
||||||
results = [["hello ", "world"]]
|
results = [["hello ", "world"]]
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(results) do
|
DiscourseAi::Completions::Llm.with_prepared_responses(results) do
|
||||||
post "/discourse-ai/ai-helper/stream_suggestion.json",
|
post "/discourse-ai/ai-helper/stream_suggestion.json",
|
||||||
@ -26,13 +66,19 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
|||||||
}
|
}
|
||||||
|
|
||||||
expect(response.status).to eq(200)
|
expect(response.status).to eq(200)
|
||||||
|
channel = response.parsed_body["progress_channel"]
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# we only have the channel after we make the request
|
||||||
|
# so we can not filter till now
|
||||||
|
messages = messages.select { |m| m.channel == channel }
|
||||||
|
|
||||||
last_message = messages.last
|
last_message = messages.last
|
||||||
expect(messages.all? { |m| m.client_ids == ["1234"] }).to eq(true)
|
expect(messages.all? { |m| m.client_ids == ["1234"] }).to eq(true)
|
||||||
expect(messages.all? { |m| m == last_message || !m.data[:done] }).to eq(true)
|
expect(messages.all? { |m| m == last_message || !m.data[:done] }).to eq(true)
|
||||||
|
|
||||||
|
expect(last_message.channel).to eq(channel)
|
||||||
expect(last_message.data[:result]).to eq("hello world")
|
expect(last_message.data[:result]).to eq("hello world")
|
||||||
expect(last_message.data[:done]).to eq(true)
|
expect(last_message.data[:done]).to eq(true)
|
||||||
end
|
end
|
||||||
|
@ -48,6 +48,7 @@ acceptance("AI Helper - Post Helper Menu", function (needs) {
|
|||||||
return helper.response({
|
return helper.response({
|
||||||
result: "This is a suggestio",
|
result: "This is a suggestio",
|
||||||
done: false,
|
done: false,
|
||||||
|
progress_channel: "/some/progress/channel",
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -61,13 +62,10 @@ acceptance("AI Helper - Post Helper Menu", function (needs) {
|
|||||||
await selectText(textNode, 9);
|
await selectText(textNode, 9);
|
||||||
await click(".ai-post-helper__trigger");
|
await click(".ai-post-helper__trigger");
|
||||||
await click(".ai-helper-options__button[data-name='explain']");
|
await click(".ai-helper-options__button[data-name='explain']");
|
||||||
await publishToMessageBus(
|
await publishToMessageBus(`/some/progress/channel`, {
|
||||||
`/discourse-ai/ai-helper/stream_suggestion/118591`,
|
|
||||||
{
|
|
||||||
done: true,
|
done: true,
|
||||||
result: suggestion,
|
result: suggestion,
|
||||||
}
|
});
|
||||||
);
|
|
||||||
assert.dom(".ai-post-helper__suggestion__text").hasText(suggestion);
|
assert.dom(".ai-post-helper__suggestion__text").hasText(suggestion);
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -91,13 +89,10 @@ acceptance("AI Helper - Post Helper Menu", function (needs) {
|
|||||||
await selectSpecificText(textNode, 72, 77);
|
await selectSpecificText(textNode, 72, 77);
|
||||||
await click(".ai-post-helper__trigger");
|
await click(".ai-post-helper__trigger");
|
||||||
await click(".ai-helper-options__button[data-name='explain']");
|
await click(".ai-helper-options__button[data-name='explain']");
|
||||||
await publishToMessageBus(
|
await publishToMessageBus(`/some/progress/channel`, {
|
||||||
`/discourse-ai/ai-helper/stream_suggestion/118591`,
|
|
||||||
{
|
|
||||||
done: true,
|
done: true,
|
||||||
result: suggestion,
|
result: suggestion,
|
||||||
}
|
});
|
||||||
);
|
|
||||||
|
|
||||||
assert.dom(".ai-post-helper__suggestion__insert-footnote").isDisabled();
|
assert.dom(".ai-post-helper__suggestion__insert-footnote").isDisabled();
|
||||||
});
|
});
|
||||||
@ -108,13 +103,10 @@ acceptance("AI Helper - Post Helper Menu", function (needs) {
|
|||||||
await selectText(query("#post_1 .cooked p"));
|
await selectText(query("#post_1 .cooked p"));
|
||||||
await click(".ai-post-helper__trigger");
|
await click(".ai-post-helper__trigger");
|
||||||
await click(".ai-helper-options__button[data-name='translate']");
|
await click(".ai-helper-options__button[data-name='translate']");
|
||||||
await publishToMessageBus(
|
await publishToMessageBus(`/some/progress/channel`, {
|
||||||
`/discourse-ai/ai-helper/stream_suggestion/118591`,
|
|
||||||
{
|
|
||||||
done: true,
|
done: true,
|
||||||
result: translated,
|
result: translated,
|
||||||
}
|
});
|
||||||
);
|
|
||||||
assert.dom(".ai-post-helper__suggestion__text").hasText(translated);
|
assert.dom(".ai-post-helper__suggestion__text").hasText(translated);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
Loading…
x
Reference in New Issue
Block a user