DEV: Improve diff streaming accuracy with safety checker (#1338)

This update adds a safety checker which scans the streamed updates. It ensures that incomplete segments of text are not sent yet over message bus as this will cause breakage with the diff streamer. It also updates the diff streamer to handle a thinking state for when we are waiting for message bus updates.
This commit is contained in:
Keegan George 2025-05-15 11:38:46 -07:00 committed by GitHub
parent ff2e18f9ca
commit dfea784fc4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 240 additions and 14 deletions

View File

@ -22,6 +22,7 @@ export default class ModalDiffModal extends Component {
@service messageBus; @service messageBus;
@tracked loading = false; @tracked loading = false;
@tracked finalResult = "";
@tracked diffStreamer = new DiffStreamer(this.args.model.selectedText); @tracked diffStreamer = new DiffStreamer(this.args.model.selectedText);
@tracked suggestion = ""; @tracked suggestion = "";
@tracked @tracked
@ -65,6 +66,10 @@ export default class ModalDiffModal extends Component {
async updateResult(result) { async updateResult(result) {
this.loading = false; this.loading = false;
if (result.done) {
this.finalResult = result.result;
}
if (this.args.model.showResultAsDiff) { if (this.args.model.showResultAsDiff) {
this.diffStreamer.updateResult(result, "result"); this.diffStreamer.updateResult(result, "result");
} else { } else {
@ -105,10 +110,14 @@ export default class ModalDiffModal extends Component {
); );
} }
if (this.args.model.showResultAsDiff && this.diffStreamer.suggestion) { const finalResult =
this.finalResult?.length > 0
? this.finalResult
: this.diffStreamer.suggestion;
if (this.args.model.showResultAsDiff && finalResult) {
this.args.model.toolbarEvent.replaceText( this.args.model.toolbarEvent.replaceText(
this.args.model.selectedText, this.args.model.selectedText,
this.diffStreamer.suggestion finalResult
); );
} }
} }
@ -131,6 +140,7 @@ export default class ModalDiffModal extends Component {
"composer-ai-helper-modal__suggestion" "composer-ai-helper-modal__suggestion"
"streamable-content" "streamable-content"
(if this.isStreaming "streaming") (if this.isStreaming "streaming")
(if this.diffStreamer.isThinking "thinking")
(if @model.showResultAsDiff "inline-diff") (if @model.showResultAsDiff "inline-diff")
}} }}
> >

View File

@ -12,6 +12,8 @@ export default class DiffStreamer {
@tracked lastResultText = ""; @tracked lastResultText = "";
@tracked diff = ""; @tracked diff = "";
@tracked suggestion = ""; @tracked suggestion = "";
@tracked isDone = false;
@tracked isThinking = false;
typingTimer = null; typingTimer = null;
currentWordIndex = 0; currentWordIndex = 0;
@ -35,6 +37,7 @@ export default class DiffStreamer {
const newText = result[newTextKey]; const newText = result[newTextKey];
const diffText = newText.slice(this.lastResultText.length).trim(); const diffText = newText.slice(this.lastResultText.length).trim();
const newWords = diffText.split(/\s+/).filter(Boolean); const newWords = diffText.split(/\s+/).filter(Boolean);
this.isDone = result?.done;
if (newWords.length > 0) { if (newWords.length > 0) {
this.isStreaming = true; this.isStreaming = true;
@ -64,7 +67,12 @@ export default class DiffStreamer {
* Highlights the current word if streaming is ongoing. * Highlights the current word if streaming is ongoing.
*/ */
#streamNextWord() { #streamNextWord() {
if (this.currentWordIndex === this.words.length) { if (this.currentWordIndex === this.words.length && !this.isDone) {
this.isThinking = true;
}
if (this.currentWordIndex === this.words.length && this.isDone) {
this.isThinking = false;
this.diff = this.#compareText(this.selectedText, this.suggestion, { this.diff = this.#compareText(this.selectedText, this.suggestion, {
markLastWord: false, markLastWord: false,
}); });
@ -72,6 +80,7 @@ export default class DiffStreamer {
} }
if (this.currentWordIndex < this.words.length) { if (this.currentWordIndex < this.words.length) {
this.isThinking = false;
this.suggestion += this.words[this.currentWordIndex] + " "; this.suggestion += this.words[this.currentWordIndex] + " ";
this.diff = this.#compareText(this.selectedText, this.suggestion, { this.diff = this.#compareText(this.selectedText, this.suggestion, {
markLastWord: true, markLastWord: true,
@ -99,22 +108,36 @@ export default class DiffStreamer {
const oldWords = oldText.trim().split(/\s+/); const oldWords = oldText.trim().split(/\s+/);
const newWords = newText.trim().split(/\s+/); const newWords = newText.trim().split(/\s+/);
// Track where the line breaks are in the original oldText
const lineBreakMap = (() => {
const lines = oldText.trim().split("\n");
const map = new Set();
let wordIndex = 0;
for (const line of lines) {
const wordsInLine = line.trim().split(/\s+/);
wordIndex += wordsInLine.length;
map.add(wordIndex - 1); // Mark the last word in each line
}
return map;
})();
const diff = []; const diff = [];
let i = 0; let i = 0;
while (i < oldWords.length) { while (i < oldWords.length || i < newWords.length) {
const oldWord = oldWords[i]; const oldWord = oldWords[i];
const newWord = newWords[i]; const newWord = newWords[i];
let wordHTML = ""; let wordHTML = "";
let originalWordHTML = `<span class="ghost">${oldWord}</span>`;
if (newWord === undefined) { if (newWord === undefined) {
wordHTML = originalWordHTML; wordHTML = `<span class="ghost">${oldWord}</span>`;
} else if (oldWord === newWord) { } else if (oldWord === newWord) {
wordHTML = `<span class="same-word">${newWord}</span>`; wordHTML = `<span class="same-word">${newWord}</span>`;
} else if (oldWord !== newWord) { } else if (oldWord !== newWord) {
wordHTML = `<del>${oldWord}</del> <ins>${newWord}</ins>`; wordHTML = `<del>${oldWord ?? ""}</del> <ins>${newWord ?? ""}</ins>`;
} }
if (i === newWords.length - 1 && opts.markLastWord) { if (i === newWords.length - 1 && opts.markLastWord) {
@ -122,6 +145,12 @@ export default class DiffStreamer {
} }
diff.push(wordHTML); diff.push(wordHTML);
// Add a line break after this word if it ended a line in the original text
if (lineBreakMap.has(i)) {
diff.push("<br>");
}
i++; i++;
} }

View File

@ -79,3 +79,18 @@ article.streaming .cooked {
} }
} }
} }
@keyframes mark-blink {
0%,
100% {
border-color: var(--highlight-high);
}
50% {
border-color: transparent;
}
}
.composer-ai-helper-modal__suggestion.thinking mark.highlight {
animation: mark-blink 1s step-start 0s infinite;
}

View File

@ -181,16 +181,17 @@ module DiscourseAi
streamed_diff = parse_diff(input, partial_response) if completion_prompt.diff? streamed_diff = parse_diff(input, partial_response) if completion_prompt.diff?
# Throttle the updates and # Throttle updates and check for safe stream points
# checking length prevents partial tags
# that aren't sanitized correctly yet (i.e. '<output')
# from being sent in the stream
if (streamed_result.length > 10 && (Time.now - start > 0.3)) || Rails.env.test? if (streamed_result.length > 10 && (Time.now - start > 0.3)) || Rails.env.test?
payload = { result: sanitize_result(streamed_result), diff: streamed_diff, done: false } sanitized = sanitize_result(streamed_result)
if DiscourseAi::Utils::DiffUtils::SafetyChecker.safe_to_stream?(sanitized)
payload = { result: sanitized, diff: streamed_diff, done: false }
publish_update(channel, payload, user) publish_update(channel, payload, user)
start = Time.now start = Time.now
end end
end end
end
final_diff = parse_diff(input, streamed_result) if completion_prompt.diff? final_diff = parse_diff(input, streamed_result) if completion_prompt.diff?

View File

@ -0,0 +1,91 @@
# frozen_string_literal: true
require "cgi"
module DiscourseAi
module Utils
module DiffUtils
class SafetyChecker
def self.safe_to_stream?(html_text)
new(html_text).safe?
end
def initialize(html_text)
@original_html = html_text
@text = sanitize(html_text)
end
def safe?
return false if unclosed_markdown_links?
return false if unclosed_raw_html_tag?
return false if trailing_incomplete_url?
return false if unclosed_backticks?
return false if unbalanced_bold_or_italic?
return false if incomplete_image_markdown?
return false if unbalanced_quote_blocks?
return false if unclosed_triple_backticks?
return false if partial_emoji?
true
end
private
def sanitize(html)
text = html.gsub(%r{</?[^>]+>}, "") # remove tags like <span>, <del>, etc.
CGI.unescapeHTML(text)
end
def unclosed_markdown_links?
open_brackets = @text.count("[")
close_brackets = @text.count("]")
open_parens = @text.count("(")
close_parens = @text.count(")")
open_brackets != close_brackets || open_parens != close_parens
end
def unclosed_raw_html_tag?
last_lt = @text.rindex("<")
last_gt = @text.rindex(">")
last_lt && (!last_gt || last_gt < last_lt)
end
def trailing_incomplete_url?
last_word = @text.split(/\s/).last
last_word =~ %r{\Ahttps?://[^\s]*\z} && last_word !~ /[)\].,!?:;'"]\z/
end
def unclosed_backticks?
@text.count("`").odd?
end
def unbalanced_bold_or_italic?
@text.scan(/\*\*/).count.odd? || @text.scan(/\*(?!\*)/).count.odd? ||
@text.scan(/_/).count.odd?
end
def incomplete_image_markdown?
last_image = @text[/!\[.*?\]\(.*?$/, 0]
last_image && last_image[-1] != ")"
end
def unbalanced_quote_blocks?
opens = @text.scan(/\[quote(=.*?)?\]/i).count
closes = @text.scan(%r{\[/quote\]}i).count
opens > closes
end
def unclosed_triple_backticks?
@text.scan(/```/).count.odd?
end
def partial_emoji?
text = @text.gsub(/!\[.*?\]\(.*?\)/, "").gsub(%r{https?://[^\s]+}, "")
tokens = text.scan(/:[a-z0-9_+\-\.]+:?/i)
tokens.any? { |token| token.start_with?(":") && !token.end_with?(":") }
end
end
end
end
end

View File

@ -0,0 +1,80 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Utils::DiffUtils::SafetyChecker do
describe "#safe?" do
subject { described_class.new(text).safe? }
context "with safe text" do
let(:text) { "This is a simple safe text without issues." }
it { is_expected.to eq(true) }
context "with normal HTML tags" do
let(:text) { "Here is <strong>bold</strong> and <em>italic</em> text." }
it { is_expected.to eq(true) }
end
context "with balanced markdown and no partial emoji" do
let(:text) { "This is **bold**, *italic*, and a smiley :smile:!" }
it { is_expected.to eq(true) }
end
context "with balanced quote blocks" do
let(:text) { "[quote]Quoted text[/quote]" }
it { is_expected.to eq(true) }
end
context "with complete image markdown" do
let(:text) { "![alt text](https://example.com/image.png)" }
it { is_expected.to eq(true) }
end
end
context "with unsafe text" do
context "with unclosed markdown link" do
let(:text) { "This is a [link(https://example.com)" }
it { is_expected.to eq(false) }
end
context "with unclosed raw HTML tag" do
let(:text) { "Text with <div unclosed tag" }
it { is_expected.to eq(false) }
end
context "with trailing incomplete URL" do
let(:text) { "Check this out https://example.com/something" } # no closing punctuation
it { is_expected.to eq(false) }
end
context "with unclosed backticks" do
let(:text) { "Here is some `inline code without closing" }
it { is_expected.to eq(false) }
end
context "with unbalanced bold or italic markdown" do
let(:text) { "This is *italic without closing" }
it { is_expected.to eq(false) }
end
context "with incomplete image markdown" do
let(:text) { "Image ![alt text](https://example.com/image.png" } # missing closing )
it { is_expected.to eq(false) }
end
context "with unbalanced quote blocks" do
let(:text) { "[quote]Unclosed quote block" }
it { is_expected.to eq(false) }
end
context "with unclosed triple backticks" do
let(:text) { "```code block without closing" }
it { is_expected.to eq(false) }
end
context "with partial emoji" do
let(:text) { "A partial emoji :smile" }
it { is_expected.to eq(false) }
end
end
end
end