From dfea784fc4b5dd88fc89ce23a34b7ee10f24de97 Mon Sep 17 00:00:00 2001 From: Keegan George Date: Thu, 15 May 2025 11:38:46 -0700 Subject: [PATCH] DEV: Improve diff streaming accuracy with safety checker (#1338) This update adds a safety checker which scans the streamed updates. It ensures that incomplete segments of text are not sent yet over message bus as this will cause breakage with the diff streamer. It also updates the diff streamer to handle a thinking state for when we are waiting for message bus updates. --- .../discourse/components/modal/diff-modal.gjs | 14 ++- .../discourse/lib/diff-streamer.gjs | 39 +++++++- assets/stylesheets/common/streaming.scss | 15 +++ lib/ai_helper/assistant.rb | 15 +-- lib/utils/diff_utils/safety_checker.rb | 91 +++++++++++++++++++ .../utils/diff_utils/safety_checker_spec.rb | 80 ++++++++++++++++ 6 files changed, 240 insertions(+), 14 deletions(-) create mode 100644 lib/utils/diff_utils/safety_checker.rb create mode 100644 spec/lib/utils/diff_utils/safety_checker_spec.rb diff --git a/assets/javascripts/discourse/components/modal/diff-modal.gjs b/assets/javascripts/discourse/components/modal/diff-modal.gjs index 280bb1d6..68d16f5a 100644 --- a/assets/javascripts/discourse/components/modal/diff-modal.gjs +++ b/assets/javascripts/discourse/components/modal/diff-modal.gjs @@ -22,6 +22,7 @@ export default class ModalDiffModal extends Component { @service messageBus; @tracked loading = false; + @tracked finalResult = ""; @tracked diffStreamer = new DiffStreamer(this.args.model.selectedText); @tracked suggestion = ""; @tracked @@ -65,6 +66,10 @@ export default class ModalDiffModal extends Component { async updateResult(result) { this.loading = false; + if (result.done) { + this.finalResult = result.result; + } + if (this.args.model.showResultAsDiff) { this.diffStreamer.updateResult(result, "result"); } else { @@ -105,10 +110,14 @@ export default class ModalDiffModal extends Component { ); } - if (this.args.model.showResultAsDiff && this.diffStreamer.suggestion) { + const finalResult = + this.finalResult?.length > 0 + ? this.finalResult + : this.diffStreamer.suggestion; + if (this.args.model.showResultAsDiff && finalResult) { this.args.model.toolbarEvent.replaceText( this.args.model.selectedText, - this.diffStreamer.suggestion + finalResult ); } } @@ -131,6 +140,7 @@ export default class ModalDiffModal extends Component { "composer-ai-helper-modal__suggestion" "streamable-content" (if this.isStreaming "streaming") + (if this.diffStreamer.isThinking "thinking") (if @model.showResultAsDiff "inline-diff") }} > diff --git a/assets/javascripts/discourse/lib/diff-streamer.gjs b/assets/javascripts/discourse/lib/diff-streamer.gjs index 82f76856..e915ec7d 100644 --- a/assets/javascripts/discourse/lib/diff-streamer.gjs +++ b/assets/javascripts/discourse/lib/diff-streamer.gjs @@ -12,6 +12,8 @@ export default class DiffStreamer { @tracked lastResultText = ""; @tracked diff = ""; @tracked suggestion = ""; + @tracked isDone = false; + @tracked isThinking = false; typingTimer = null; currentWordIndex = 0; @@ -35,6 +37,7 @@ export default class DiffStreamer { const newText = result[newTextKey]; const diffText = newText.slice(this.lastResultText.length).trim(); const newWords = diffText.split(/\s+/).filter(Boolean); + this.isDone = result?.done; if (newWords.length > 0) { this.isStreaming = true; @@ -64,7 +67,12 @@ export default class DiffStreamer { * Highlights the current word if streaming is ongoing. */ #streamNextWord() { - if (this.currentWordIndex === this.words.length) { + if (this.currentWordIndex === this.words.length && !this.isDone) { + this.isThinking = true; + } + + if (this.currentWordIndex === this.words.length && this.isDone) { + this.isThinking = false; this.diff = this.#compareText(this.selectedText, this.suggestion, { markLastWord: false, }); @@ -72,6 +80,7 @@ export default class DiffStreamer { } if (this.currentWordIndex < this.words.length) { + this.isThinking = false; this.suggestion += this.words[this.currentWordIndex] + " "; this.diff = this.#compareText(this.selectedText, this.suggestion, { markLastWord: true, @@ -99,22 +108,36 @@ export default class DiffStreamer { const oldWords = oldText.trim().split(/\s+/); const newWords = newText.trim().split(/\s+/); + // Track where the line breaks are in the original oldText + const lineBreakMap = (() => { + const lines = oldText.trim().split("\n"); + const map = new Set(); + let wordIndex = 0; + + for (const line of lines) { + const wordsInLine = line.trim().split(/\s+/); + wordIndex += wordsInLine.length; + map.add(wordIndex - 1); // Mark the last word in each line + } + + return map; + })(); + const diff = []; let i = 0; - while (i < oldWords.length) { + while (i < oldWords.length || i < newWords.length) { const oldWord = oldWords[i]; const newWord = newWords[i]; let wordHTML = ""; - let originalWordHTML = `${oldWord}`; if (newWord === undefined) { - wordHTML = originalWordHTML; + wordHTML = `${oldWord}`; } else if (oldWord === newWord) { wordHTML = `${newWord}`; } else if (oldWord !== newWord) { - wordHTML = `${oldWord} ${newWord}`; + wordHTML = `${oldWord ?? ""} ${newWord ?? ""}`; } if (i === newWords.length - 1 && opts.markLastWord) { @@ -122,6 +145,12 @@ export default class DiffStreamer { } diff.push(wordHTML); + + // Add a line break after this word if it ended a line in the original text + if (lineBreakMap.has(i)) { + diff.push("
"); + } + i++; } diff --git a/assets/stylesheets/common/streaming.scss b/assets/stylesheets/common/streaming.scss index e6561e32..b26c914e 100644 --- a/assets/stylesheets/common/streaming.scss +++ b/assets/stylesheets/common/streaming.scss @@ -79,3 +79,18 @@ article.streaming .cooked { } } } + +@keyframes mark-blink { + 0%, + 100% { + border-color: var(--highlight-high); + } + + 50% { + border-color: transparent; + } +} + +.composer-ai-helper-modal__suggestion.thinking mark.highlight { + animation: mark-blink 1s step-start 0s infinite; +} diff --git a/lib/ai_helper/assistant.rb b/lib/ai_helper/assistant.rb index c15bd358..95fd5ab7 100644 --- a/lib/ai_helper/assistant.rb +++ b/lib/ai_helper/assistant.rb @@ -181,14 +181,15 @@ module DiscourseAi streamed_diff = parse_diff(input, partial_response) if completion_prompt.diff? - # Throttle the updates and - # checking length prevents partial tags - # that aren't sanitized correctly yet (i.e. ' 10 && (Time.now - start > 0.3)) || Rails.env.test? - payload = { result: sanitize_result(streamed_result), diff: streamed_diff, done: false } - publish_update(channel, payload, user) - start = Time.now + sanitized = sanitize_result(streamed_result) + + if DiscourseAi::Utils::DiffUtils::SafetyChecker.safe_to_stream?(sanitized) + payload = { result: sanitized, diff: streamed_diff, done: false } + publish_update(channel, payload, user) + start = Time.now + end end end diff --git a/lib/utils/diff_utils/safety_checker.rb b/lib/utils/diff_utils/safety_checker.rb new file mode 100644 index 00000000..16adb3ab --- /dev/null +++ b/lib/utils/diff_utils/safety_checker.rb @@ -0,0 +1,91 @@ +# frozen_string_literal: true + +require "cgi" + +module DiscourseAi + module Utils + module DiffUtils + class SafetyChecker + def self.safe_to_stream?(html_text) + new(html_text).safe? + end + + def initialize(html_text) + @original_html = html_text + @text = sanitize(html_text) + end + + def safe? + return false if unclosed_markdown_links? + return false if unclosed_raw_html_tag? + return false if trailing_incomplete_url? + return false if unclosed_backticks? + return false if unbalanced_bold_or_italic? + return false if incomplete_image_markdown? + return false if unbalanced_quote_blocks? + return false if unclosed_triple_backticks? + return false if partial_emoji? + + true + end + + private + + def sanitize(html) + text = html.gsub(%r{]+>}, "") # remove tags like , , etc. + CGI.unescapeHTML(text) + end + + def unclosed_markdown_links? + open_brackets = @text.count("[") + close_brackets = @text.count("]") + open_parens = @text.count("(") + close_parens = @text.count(")") + + open_brackets != close_brackets || open_parens != close_parens + end + + def unclosed_raw_html_tag? + last_lt = @text.rindex("<") + last_gt = @text.rindex(">") + last_lt && (!last_gt || last_gt < last_lt) + end + + def trailing_incomplete_url? + last_word = @text.split(/\s/).last + last_word =~ %r{\Ahttps?://[^\s]*\z} && last_word !~ /[)\].,!?:;'"]\z/ + end + + def unclosed_backticks? + @text.count("`").odd? + end + + def unbalanced_bold_or_italic? + @text.scan(/\*\*/).count.odd? || @text.scan(/\*(?!\*)/).count.odd? || + @text.scan(/_/).count.odd? + end + + def incomplete_image_markdown? + last_image = @text[/!\[.*?\]\(.*?$/, 0] + last_image && last_image[-1] != ")" + end + + def unbalanced_quote_blocks? + opens = @text.scan(/\[quote(=.*?)?\]/i).count + closes = @text.scan(%r{\[/quote\]}i).count + opens > closes + end + + def unclosed_triple_backticks? + @text.scan(/```/).count.odd? + end + + def partial_emoji? + text = @text.gsub(/!\[.*?\]\(.*?\)/, "").gsub(%r{https?://[^\s]+}, "") + tokens = text.scan(/:[a-z0-9_+\-\.]+:?/i) + tokens.any? { |token| token.start_with?(":") && !token.end_with?(":") } + end + end + end + end +end diff --git a/spec/lib/utils/diff_utils/safety_checker_spec.rb b/spec/lib/utils/diff_utils/safety_checker_spec.rb new file mode 100644 index 00000000..19a8a861 --- /dev/null +++ b/spec/lib/utils/diff_utils/safety_checker_spec.rb @@ -0,0 +1,80 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Utils::DiffUtils::SafetyChecker do + describe "#safe?" do + subject { described_class.new(text).safe? } + + context "with safe text" do + let(:text) { "This is a simple safe text without issues." } + + it { is_expected.to eq(true) } + + context "with normal HTML tags" do + let(:text) { "Here is bold and italic text." } + it { is_expected.to eq(true) } + end + + context "with balanced markdown and no partial emoji" do + let(:text) { "This is **bold**, *italic*, and a smiley :smile:!" } + it { is_expected.to eq(true) } + end + + context "with balanced quote blocks" do + let(:text) { "[quote]Quoted text[/quote]" } + it { is_expected.to eq(true) } + end + + context "with complete image markdown" do + let(:text) { "![alt text](https://example.com/image.png)" } + it { is_expected.to eq(true) } + end + end + + context "with unsafe text" do + context "with unclosed markdown link" do + let(:text) { "This is a [link(https://example.com)" } + it { is_expected.to eq(false) } + end + + context "with unclosed raw HTML tag" do + let(:text) { "Text with