diff --git a/assets/javascripts/discourse/components/modal/diff-modal.gjs b/assets/javascripts/discourse/components/modal/diff-modal.gjs index 280bb1d6..68d16f5a 100644 --- a/assets/javascripts/discourse/components/modal/diff-modal.gjs +++ b/assets/javascripts/discourse/components/modal/diff-modal.gjs @@ -22,6 +22,7 @@ export default class ModalDiffModal extends Component { @service messageBus; @tracked loading = false; + @tracked finalResult = ""; @tracked diffStreamer = new DiffStreamer(this.args.model.selectedText); @tracked suggestion = ""; @tracked @@ -65,6 +66,10 @@ export default class ModalDiffModal extends Component { async updateResult(result) { this.loading = false; + if (result.done) { + this.finalResult = result.result; + } + if (this.args.model.showResultAsDiff) { this.diffStreamer.updateResult(result, "result"); } else { @@ -105,10 +110,14 @@ export default class ModalDiffModal extends Component { ); } - if (this.args.model.showResultAsDiff && this.diffStreamer.suggestion) { + const finalResult = + this.finalResult?.length > 0 + ? this.finalResult + : this.diffStreamer.suggestion; + if (this.args.model.showResultAsDiff && finalResult) { this.args.model.toolbarEvent.replaceText( this.args.model.selectedText, - this.diffStreamer.suggestion + finalResult ); } } @@ -131,6 +140,7 @@ export default class ModalDiffModal extends Component { "composer-ai-helper-modal__suggestion" "streamable-content" (if this.isStreaming "streaming") + (if this.diffStreamer.isThinking "thinking") (if @model.showResultAsDiff "inline-diff") }} > diff --git a/assets/javascripts/discourse/lib/diff-streamer.gjs b/assets/javascripts/discourse/lib/diff-streamer.gjs index 82f76856..e915ec7d 100644 --- a/assets/javascripts/discourse/lib/diff-streamer.gjs +++ b/assets/javascripts/discourse/lib/diff-streamer.gjs @@ -12,6 +12,8 @@ export default class DiffStreamer { @tracked lastResultText = ""; @tracked diff = ""; @tracked suggestion = ""; + @tracked isDone = false; + @tracked isThinking = false; typingTimer = null; currentWordIndex = 0; @@ -35,6 +37,7 @@ export default class DiffStreamer { const newText = result[newTextKey]; const diffText = newText.slice(this.lastResultText.length).trim(); const newWords = diffText.split(/\s+/).filter(Boolean); + this.isDone = result?.done; if (newWords.length > 0) { this.isStreaming = true; @@ -64,7 +67,12 @@ export default class DiffStreamer { * Highlights the current word if streaming is ongoing. */ #streamNextWord() { - if (this.currentWordIndex === this.words.length) { + if (this.currentWordIndex === this.words.length && !this.isDone) { + this.isThinking = true; + } + + if (this.currentWordIndex === this.words.length && this.isDone) { + this.isThinking = false; this.diff = this.#compareText(this.selectedText, this.suggestion, { markLastWord: false, }); @@ -72,6 +80,7 @@ export default class DiffStreamer { } if (this.currentWordIndex < this.words.length) { + this.isThinking = false; this.suggestion += this.words[this.currentWordIndex] + " "; this.diff = this.#compareText(this.selectedText, this.suggestion, { markLastWord: true, @@ -99,22 +108,36 @@ export default class DiffStreamer { const oldWords = oldText.trim().split(/\s+/); const newWords = newText.trim().split(/\s+/); + // Track where the line breaks are in the original oldText + const lineBreakMap = (() => { + const lines = oldText.trim().split("\n"); + const map = new Set(); + let wordIndex = 0; + + for (const line of lines) { + const wordsInLine = line.trim().split(/\s+/); + wordIndex += wordsInLine.length; + map.add(wordIndex - 1); // Mark the last word in each line + } + + return map; + })(); + const diff = []; let i = 0; - while (i < oldWords.length) { + while (i < oldWords.length || i < newWords.length) { const oldWord = oldWords[i]; const newWord = newWords[i]; let wordHTML = ""; - let originalWordHTML = `${oldWord}`; if (newWord === undefined) { - wordHTML = originalWordHTML; + wordHTML = `${oldWord}`; } else if (oldWord === newWord) { wordHTML = `${newWord}`; } else if (oldWord !== newWord) { - wordHTML = `${oldWord} ${newWord}`; + wordHTML = `${oldWord ?? ""} ${newWord ?? ""}`; } if (i === newWords.length - 1 && opts.markLastWord) { @@ -122,6 +145,12 @@ export default class DiffStreamer { } diff.push(wordHTML); + + // Add a line break after this word if it ended a line in the original text + if (lineBreakMap.has(i)) { + diff.push("
"); + } + i++; } diff --git a/assets/stylesheets/common/streaming.scss b/assets/stylesheets/common/streaming.scss index e6561e32..b26c914e 100644 --- a/assets/stylesheets/common/streaming.scss +++ b/assets/stylesheets/common/streaming.scss @@ -79,3 +79,18 @@ article.streaming .cooked { } } } + +@keyframes mark-blink { + 0%, + 100% { + border-color: var(--highlight-high); + } + + 50% { + border-color: transparent; + } +} + +.composer-ai-helper-modal__suggestion.thinking mark.highlight { + animation: mark-blink 1s step-start 0s infinite; +} diff --git a/lib/ai_helper/assistant.rb b/lib/ai_helper/assistant.rb index c15bd358..95fd5ab7 100644 --- a/lib/ai_helper/assistant.rb +++ b/lib/ai_helper/assistant.rb @@ -181,14 +181,15 @@ module DiscourseAi streamed_diff = parse_diff(input, partial_response) if completion_prompt.diff? - # Throttle the updates and - # checking length prevents partial tags - # that aren't sanitized correctly yet (i.e. ' 10 && (Time.now - start > 0.3)) || Rails.env.test? - payload = { result: sanitize_result(streamed_result), diff: streamed_diff, done: false } - publish_update(channel, payload, user) - start = Time.now + sanitized = sanitize_result(streamed_result) + + if DiscourseAi::Utils::DiffUtils::SafetyChecker.safe_to_stream?(sanitized) + payload = { result: sanitized, diff: streamed_diff, done: false } + publish_update(channel, payload, user) + start = Time.now + end end end diff --git a/lib/utils/diff_utils/safety_checker.rb b/lib/utils/diff_utils/safety_checker.rb new file mode 100644 index 00000000..16adb3ab --- /dev/null +++ b/lib/utils/diff_utils/safety_checker.rb @@ -0,0 +1,91 @@ +# frozen_string_literal: true + +require "cgi" + +module DiscourseAi + module Utils + module DiffUtils + class SafetyChecker + def self.safe_to_stream?(html_text) + new(html_text).safe? + end + + def initialize(html_text) + @original_html = html_text + @text = sanitize(html_text) + end + + def safe? + return false if unclosed_markdown_links? + return false if unclosed_raw_html_tag? + return false if trailing_incomplete_url? + return false if unclosed_backticks? + return false if unbalanced_bold_or_italic? + return false if incomplete_image_markdown? + return false if unbalanced_quote_blocks? + return false if unclosed_triple_backticks? + return false if partial_emoji? + + true + end + + private + + def sanitize(html) + text = html.gsub(%r{]+>}, "") # remove tags like , , etc. + CGI.unescapeHTML(text) + end + + def unclosed_markdown_links? + open_brackets = @text.count("[") + close_brackets = @text.count("]") + open_parens = @text.count("(") + close_parens = @text.count(")") + + open_brackets != close_brackets || open_parens != close_parens + end + + def unclosed_raw_html_tag? + last_lt = @text.rindex("<") + last_gt = @text.rindex(">") + last_lt && (!last_gt || last_gt < last_lt) + end + + def trailing_incomplete_url? + last_word = @text.split(/\s/).last + last_word =~ %r{\Ahttps?://[^\s]*\z} && last_word !~ /[)\].,!?:;'"]\z/ + end + + def unclosed_backticks? + @text.count("`").odd? + end + + def unbalanced_bold_or_italic? + @text.scan(/\*\*/).count.odd? || @text.scan(/\*(?!\*)/).count.odd? || + @text.scan(/_/).count.odd? + end + + def incomplete_image_markdown? + last_image = @text[/!\[.*?\]\(.*?$/, 0] + last_image && last_image[-1] != ")" + end + + def unbalanced_quote_blocks? + opens = @text.scan(/\[quote(=.*?)?\]/i).count + closes = @text.scan(%r{\[/quote\]}i).count + opens > closes + end + + def unclosed_triple_backticks? + @text.scan(/```/).count.odd? + end + + def partial_emoji? + text = @text.gsub(/!\[.*?\]\(.*?\)/, "").gsub(%r{https?://[^\s]+}, "") + tokens = text.scan(/:[a-z0-9_+\-\.]+:?/i) + tokens.any? { |token| token.start_with?(":") && !token.end_with?(":") } + end + end + end + end +end diff --git a/spec/lib/utils/diff_utils/safety_checker_spec.rb b/spec/lib/utils/diff_utils/safety_checker_spec.rb new file mode 100644 index 00000000..19a8a861 --- /dev/null +++ b/spec/lib/utils/diff_utils/safety_checker_spec.rb @@ -0,0 +1,80 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Utils::DiffUtils::SafetyChecker do + describe "#safe?" do + subject { described_class.new(text).safe? } + + context "with safe text" do + let(:text) { "This is a simple safe text without issues." } + + it { is_expected.to eq(true) } + + context "with normal HTML tags" do + let(:text) { "Here is bold and italic text." } + it { is_expected.to eq(true) } + end + + context "with balanced markdown and no partial emoji" do + let(:text) { "This is **bold**, *italic*, and a smiley :smile:!" } + it { is_expected.to eq(true) } + end + + context "with balanced quote blocks" do + let(:text) { "[quote]Quoted text[/quote]" } + it { is_expected.to eq(true) } + end + + context "with complete image markdown" do + let(:text) { "![alt text](https://example.com/image.png)" } + it { is_expected.to eq(true) } + end + end + + context "with unsafe text" do + context "with unclosed markdown link" do + let(:text) { "This is a [link(https://example.com)" } + it { is_expected.to eq(false) } + end + + context "with unclosed raw HTML tag" do + let(:text) { "Text with