DEV: Improve diff streaming accuracy with safety checker (#1338)

This update adds a safety checker which scans the streamed updates. It ensures that incomplete segments of text are not sent yet over message bus as this will cause breakage with the diff streamer. It also updates the diff streamer to handle a thinking state for when we are waiting for message bus updates.
This commit is contained in:
Keegan George 2025-05-15 11:38:46 -07:00 committed by GitHub
parent ff2e18f9ca
commit dfea784fc4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 240 additions and 14 deletions

View File

@ -22,6 +22,7 @@ export default class ModalDiffModal extends Component {
@service messageBus;
@tracked loading = false;
@tracked finalResult = "";
@tracked diffStreamer = new DiffStreamer(this.args.model.selectedText);
@tracked suggestion = "";
@tracked
@ -65,6 +66,10 @@ export default class ModalDiffModal extends Component {
async updateResult(result) {
this.loading = false;
if (result.done) {
this.finalResult = result.result;
}
if (this.args.model.showResultAsDiff) {
this.diffStreamer.updateResult(result, "result");
} else {
@ -105,10 +110,14 @@ export default class ModalDiffModal extends Component {
);
}
if (this.args.model.showResultAsDiff && this.diffStreamer.suggestion) {
const finalResult =
this.finalResult?.length > 0
? this.finalResult
: this.diffStreamer.suggestion;
if (this.args.model.showResultAsDiff && finalResult) {
this.args.model.toolbarEvent.replaceText(
this.args.model.selectedText,
this.diffStreamer.suggestion
finalResult
);
}
}
@ -131,6 +140,7 @@ export default class ModalDiffModal extends Component {
"composer-ai-helper-modal__suggestion"
"streamable-content"
(if this.isStreaming "streaming")
(if this.diffStreamer.isThinking "thinking")
(if @model.showResultAsDiff "inline-diff")
}}
>

View File

@ -12,6 +12,8 @@ export default class DiffStreamer {
@tracked lastResultText = "";
@tracked diff = "";
@tracked suggestion = "";
@tracked isDone = false;
@tracked isThinking = false;
typingTimer = null;
currentWordIndex = 0;
@ -35,6 +37,7 @@ export default class DiffStreamer {
const newText = result[newTextKey];
const diffText = newText.slice(this.lastResultText.length).trim();
const newWords = diffText.split(/\s+/).filter(Boolean);
this.isDone = result?.done;
if (newWords.length > 0) {
this.isStreaming = true;
@ -64,7 +67,12 @@ export default class DiffStreamer {
* Highlights the current word if streaming is ongoing.
*/
#streamNextWord() {
if (this.currentWordIndex === this.words.length) {
if (this.currentWordIndex === this.words.length && !this.isDone) {
this.isThinking = true;
}
if (this.currentWordIndex === this.words.length && this.isDone) {
this.isThinking = false;
this.diff = this.#compareText(this.selectedText, this.suggestion, {
markLastWord: false,
});
@ -72,6 +80,7 @@ export default class DiffStreamer {
}
if (this.currentWordIndex < this.words.length) {
this.isThinking = false;
this.suggestion += this.words[this.currentWordIndex] + " ";
this.diff = this.#compareText(this.selectedText, this.suggestion, {
markLastWord: true,
@ -99,22 +108,36 @@ export default class DiffStreamer {
const oldWords = oldText.trim().split(/\s+/);
const newWords = newText.trim().split(/\s+/);
// Track where the line breaks are in the original oldText
const lineBreakMap = (() => {
const lines = oldText.trim().split("\n");
const map = new Set();
let wordIndex = 0;
for (const line of lines) {
const wordsInLine = line.trim().split(/\s+/);
wordIndex += wordsInLine.length;
map.add(wordIndex - 1); // Mark the last word in each line
}
return map;
})();
const diff = [];
let i = 0;
while (i < oldWords.length) {
while (i < oldWords.length || i < newWords.length) {
const oldWord = oldWords[i];
const newWord = newWords[i];
let wordHTML = "";
let originalWordHTML = `<span class="ghost">${oldWord}</span>`;
if (newWord === undefined) {
wordHTML = originalWordHTML;
wordHTML = `<span class="ghost">${oldWord}</span>`;
} else if (oldWord === newWord) {
wordHTML = `<span class="same-word">${newWord}</span>`;
} else if (oldWord !== newWord) {
wordHTML = `<del>${oldWord}</del> <ins>${newWord}</ins>`;
wordHTML = `<del>${oldWord ?? ""}</del> <ins>${newWord ?? ""}</ins>`;
}
if (i === newWords.length - 1 && opts.markLastWord) {
@ -122,6 +145,12 @@ export default class DiffStreamer {
}
diff.push(wordHTML);
// Add a line break after this word if it ended a line in the original text
if (lineBreakMap.has(i)) {
diff.push("<br>");
}
i++;
}

View File

@ -79,3 +79,18 @@ article.streaming .cooked {
}
}
}
@keyframes mark-blink {
0%,
100% {
border-color: var(--highlight-high);
}
50% {
border-color: transparent;
}
}
.composer-ai-helper-modal__suggestion.thinking mark.highlight {
animation: mark-blink 1s step-start 0s infinite;
}

View File

@ -181,14 +181,15 @@ module DiscourseAi
streamed_diff = parse_diff(input, partial_response) if completion_prompt.diff?
# Throttle the updates and
# checking length prevents partial tags
# that aren't sanitized correctly yet (i.e. '<output')
# from being sent in the stream
# Throttle updates and check for safe stream points
if (streamed_result.length > 10 && (Time.now - start > 0.3)) || Rails.env.test?
payload = { result: sanitize_result(streamed_result), diff: streamed_diff, done: false }
publish_update(channel, payload, user)
start = Time.now
sanitized = sanitize_result(streamed_result)
if DiscourseAi::Utils::DiffUtils::SafetyChecker.safe_to_stream?(sanitized)
payload = { result: sanitized, diff: streamed_diff, done: false }
publish_update(channel, payload, user)
start = Time.now
end
end
end

View File

@ -0,0 +1,91 @@
# frozen_string_literal: true
require "cgi"
module DiscourseAi
module Utils
module DiffUtils
class SafetyChecker
def self.safe_to_stream?(html_text)
new(html_text).safe?
end
def initialize(html_text)
@original_html = html_text
@text = sanitize(html_text)
end
def safe?
return false if unclosed_markdown_links?
return false if unclosed_raw_html_tag?
return false if trailing_incomplete_url?
return false if unclosed_backticks?
return false if unbalanced_bold_or_italic?
return false if incomplete_image_markdown?
return false if unbalanced_quote_blocks?
return false if unclosed_triple_backticks?
return false if partial_emoji?
true
end
private
def sanitize(html)
text = html.gsub(%r{</?[^>]+>}, "") # remove tags like <span>, <del>, etc.
CGI.unescapeHTML(text)
end
def unclosed_markdown_links?
open_brackets = @text.count("[")
close_brackets = @text.count("]")
open_parens = @text.count("(")
close_parens = @text.count(")")
open_brackets != close_brackets || open_parens != close_parens
end
def unclosed_raw_html_tag?
last_lt = @text.rindex("<")
last_gt = @text.rindex(">")
last_lt && (!last_gt || last_gt < last_lt)
end
def trailing_incomplete_url?
last_word = @text.split(/\s/).last
last_word =~ %r{\Ahttps?://[^\s]*\z} && last_word !~ /[)\].,!?:;'"]\z/
end
def unclosed_backticks?
@text.count("`").odd?
end
def unbalanced_bold_or_italic?
@text.scan(/\*\*/).count.odd? || @text.scan(/\*(?!\*)/).count.odd? ||
@text.scan(/_/).count.odd?
end
def incomplete_image_markdown?
last_image = @text[/!\[.*?\]\(.*?$/, 0]
last_image && last_image[-1] != ")"
end
def unbalanced_quote_blocks?
opens = @text.scan(/\[quote(=.*?)?\]/i).count
closes = @text.scan(%r{\[/quote\]}i).count
opens > closes
end
def unclosed_triple_backticks?
@text.scan(/```/).count.odd?
end
def partial_emoji?
text = @text.gsub(/!\[.*?\]\(.*?\)/, "").gsub(%r{https?://[^\s]+}, "")
tokens = text.scan(/:[a-z0-9_+\-\.]+:?/i)
tokens.any? { |token| token.start_with?(":") && !token.end_with?(":") }
end
end
end
end
end

View File

@ -0,0 +1,80 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Utils::DiffUtils::SafetyChecker do
describe "#safe?" do
subject { described_class.new(text).safe? }
context "with safe text" do
let(:text) { "This is a simple safe text without issues." }
it { is_expected.to eq(true) }
context "with normal HTML tags" do
let(:text) { "Here is <strong>bold</strong> and <em>italic</em> text." }
it { is_expected.to eq(true) }
end
context "with balanced markdown and no partial emoji" do
let(:text) { "This is **bold**, *italic*, and a smiley :smile:!" }
it { is_expected.to eq(true) }
end
context "with balanced quote blocks" do
let(:text) { "[quote]Quoted text[/quote]" }
it { is_expected.to eq(true) }
end
context "with complete image markdown" do
let(:text) { "![alt text](https://example.com/image.png)" }
it { is_expected.to eq(true) }
end
end
context "with unsafe text" do
context "with unclosed markdown link" do
let(:text) { "This is a [link(https://example.com)" }
it { is_expected.to eq(false) }
end
context "with unclosed raw HTML tag" do
let(:text) { "Text with <div unclosed tag" }
it { is_expected.to eq(false) }
end
context "with trailing incomplete URL" do
let(:text) { "Check this out https://example.com/something" } # no closing punctuation
it { is_expected.to eq(false) }
end
context "with unclosed backticks" do
let(:text) { "Here is some `inline code without closing" }
it { is_expected.to eq(false) }
end
context "with unbalanced bold or italic markdown" do
let(:text) { "This is *italic without closing" }
it { is_expected.to eq(false) }
end
context "with incomplete image markdown" do
let(:text) { "Image ![alt text](https://example.com/image.png" } # missing closing )
it { is_expected.to eq(false) }
end
context "with unbalanced quote blocks" do
let(:text) { "[quote]Unclosed quote block" }
it { is_expected.to eq(false) }
end
context "with unclosed triple backticks" do
let(:text) { "```code block without closing" }
it { is_expected.to eq(false) }
end
context "with partial emoji" do
let(:text) { "A partial emoji :smile" }
it { is_expected.to eq(false) }
end
end
end
end