mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-06-25 17:12:16 +00:00
DEV: Improve diff streaming accuracy with safety checker (#1338)
This update adds a safety checker which scans the streamed updates. It ensures that incomplete segments of text are not sent yet over message bus as this will cause breakage with the diff streamer. It also updates the diff streamer to handle a thinking state for when we are waiting for message bus updates.
This commit is contained in:
parent
ff2e18f9ca
commit
dfea784fc4
@ -22,6 +22,7 @@ export default class ModalDiffModal extends Component {
|
||||
@service messageBus;
|
||||
|
||||
@tracked loading = false;
|
||||
@tracked finalResult = "";
|
||||
@tracked diffStreamer = new DiffStreamer(this.args.model.selectedText);
|
||||
@tracked suggestion = "";
|
||||
@tracked
|
||||
@ -65,6 +66,10 @@ export default class ModalDiffModal extends Component {
|
||||
async updateResult(result) {
|
||||
this.loading = false;
|
||||
|
||||
if (result.done) {
|
||||
this.finalResult = result.result;
|
||||
}
|
||||
|
||||
if (this.args.model.showResultAsDiff) {
|
||||
this.diffStreamer.updateResult(result, "result");
|
||||
} else {
|
||||
@ -105,10 +110,14 @@ export default class ModalDiffModal extends Component {
|
||||
);
|
||||
}
|
||||
|
||||
if (this.args.model.showResultAsDiff && this.diffStreamer.suggestion) {
|
||||
const finalResult =
|
||||
this.finalResult?.length > 0
|
||||
? this.finalResult
|
||||
: this.diffStreamer.suggestion;
|
||||
if (this.args.model.showResultAsDiff && finalResult) {
|
||||
this.args.model.toolbarEvent.replaceText(
|
||||
this.args.model.selectedText,
|
||||
this.diffStreamer.suggestion
|
||||
finalResult
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -131,6 +140,7 @@ export default class ModalDiffModal extends Component {
|
||||
"composer-ai-helper-modal__suggestion"
|
||||
"streamable-content"
|
||||
(if this.isStreaming "streaming")
|
||||
(if this.diffStreamer.isThinking "thinking")
|
||||
(if @model.showResultAsDiff "inline-diff")
|
||||
}}
|
||||
>
|
||||
|
@ -12,6 +12,8 @@ export default class DiffStreamer {
|
||||
@tracked lastResultText = "";
|
||||
@tracked diff = "";
|
||||
@tracked suggestion = "";
|
||||
@tracked isDone = false;
|
||||
@tracked isThinking = false;
|
||||
typingTimer = null;
|
||||
currentWordIndex = 0;
|
||||
|
||||
@ -35,6 +37,7 @@ export default class DiffStreamer {
|
||||
const newText = result[newTextKey];
|
||||
const diffText = newText.slice(this.lastResultText.length).trim();
|
||||
const newWords = diffText.split(/\s+/).filter(Boolean);
|
||||
this.isDone = result?.done;
|
||||
|
||||
if (newWords.length > 0) {
|
||||
this.isStreaming = true;
|
||||
@ -64,7 +67,12 @@ export default class DiffStreamer {
|
||||
* Highlights the current word if streaming is ongoing.
|
||||
*/
|
||||
#streamNextWord() {
|
||||
if (this.currentWordIndex === this.words.length) {
|
||||
if (this.currentWordIndex === this.words.length && !this.isDone) {
|
||||
this.isThinking = true;
|
||||
}
|
||||
|
||||
if (this.currentWordIndex === this.words.length && this.isDone) {
|
||||
this.isThinking = false;
|
||||
this.diff = this.#compareText(this.selectedText, this.suggestion, {
|
||||
markLastWord: false,
|
||||
});
|
||||
@ -72,6 +80,7 @@ export default class DiffStreamer {
|
||||
}
|
||||
|
||||
if (this.currentWordIndex < this.words.length) {
|
||||
this.isThinking = false;
|
||||
this.suggestion += this.words[this.currentWordIndex] + " ";
|
||||
this.diff = this.#compareText(this.selectedText, this.suggestion, {
|
||||
markLastWord: true,
|
||||
@ -99,22 +108,36 @@ export default class DiffStreamer {
|
||||
const oldWords = oldText.trim().split(/\s+/);
|
||||
const newWords = newText.trim().split(/\s+/);
|
||||
|
||||
// Track where the line breaks are in the original oldText
|
||||
const lineBreakMap = (() => {
|
||||
const lines = oldText.trim().split("\n");
|
||||
const map = new Set();
|
||||
let wordIndex = 0;
|
||||
|
||||
for (const line of lines) {
|
||||
const wordsInLine = line.trim().split(/\s+/);
|
||||
wordIndex += wordsInLine.length;
|
||||
map.add(wordIndex - 1); // Mark the last word in each line
|
||||
}
|
||||
|
||||
return map;
|
||||
})();
|
||||
|
||||
const diff = [];
|
||||
let i = 0;
|
||||
|
||||
while (i < oldWords.length) {
|
||||
while (i < oldWords.length || i < newWords.length) {
|
||||
const oldWord = oldWords[i];
|
||||
const newWord = newWords[i];
|
||||
|
||||
let wordHTML = "";
|
||||
let originalWordHTML = `<span class="ghost">${oldWord}</span>`;
|
||||
|
||||
if (newWord === undefined) {
|
||||
wordHTML = originalWordHTML;
|
||||
wordHTML = `<span class="ghost">${oldWord}</span>`;
|
||||
} else if (oldWord === newWord) {
|
||||
wordHTML = `<span class="same-word">${newWord}</span>`;
|
||||
} else if (oldWord !== newWord) {
|
||||
wordHTML = `<del>${oldWord}</del> <ins>${newWord}</ins>`;
|
||||
wordHTML = `<del>${oldWord ?? ""}</del> <ins>${newWord ?? ""}</ins>`;
|
||||
}
|
||||
|
||||
if (i === newWords.length - 1 && opts.markLastWord) {
|
||||
@ -122,6 +145,12 @@ export default class DiffStreamer {
|
||||
}
|
||||
|
||||
diff.push(wordHTML);
|
||||
|
||||
// Add a line break after this word if it ended a line in the original text
|
||||
if (lineBreakMap.has(i)) {
|
||||
diff.push("<br>");
|
||||
}
|
||||
|
||||
i++;
|
||||
}
|
||||
|
||||
|
@ -79,3 +79,18 @@ article.streaming .cooked {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes mark-blink {
|
||||
0%,
|
||||
100% {
|
||||
border-color: var(--highlight-high);
|
||||
}
|
||||
|
||||
50% {
|
||||
border-color: transparent;
|
||||
}
|
||||
}
|
||||
|
||||
.composer-ai-helper-modal__suggestion.thinking mark.highlight {
|
||||
animation: mark-blink 1s step-start 0s infinite;
|
||||
}
|
||||
|
@ -181,14 +181,15 @@ module DiscourseAi
|
||||
|
||||
streamed_diff = parse_diff(input, partial_response) if completion_prompt.diff?
|
||||
|
||||
# Throttle the updates and
|
||||
# checking length prevents partial tags
|
||||
# that aren't sanitized correctly yet (i.e. '<output')
|
||||
# from being sent in the stream
|
||||
# Throttle updates and check for safe stream points
|
||||
if (streamed_result.length > 10 && (Time.now - start > 0.3)) || Rails.env.test?
|
||||
payload = { result: sanitize_result(streamed_result), diff: streamed_diff, done: false }
|
||||
publish_update(channel, payload, user)
|
||||
start = Time.now
|
||||
sanitized = sanitize_result(streamed_result)
|
||||
|
||||
if DiscourseAi::Utils::DiffUtils::SafetyChecker.safe_to_stream?(sanitized)
|
||||
payload = { result: sanitized, diff: streamed_diff, done: false }
|
||||
publish_update(channel, payload, user)
|
||||
start = Time.now
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
91
lib/utils/diff_utils/safety_checker.rb
Normal file
91
lib/utils/diff_utils/safety_checker.rb
Normal file
@ -0,0 +1,91 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
require "cgi"
|
||||
|
||||
module DiscourseAi
|
||||
module Utils
|
||||
module DiffUtils
|
||||
class SafetyChecker
|
||||
def self.safe_to_stream?(html_text)
|
||||
new(html_text).safe?
|
||||
end
|
||||
|
||||
def initialize(html_text)
|
||||
@original_html = html_text
|
||||
@text = sanitize(html_text)
|
||||
end
|
||||
|
||||
def safe?
|
||||
return false if unclosed_markdown_links?
|
||||
return false if unclosed_raw_html_tag?
|
||||
return false if trailing_incomplete_url?
|
||||
return false if unclosed_backticks?
|
||||
return false if unbalanced_bold_or_italic?
|
||||
return false if incomplete_image_markdown?
|
||||
return false if unbalanced_quote_blocks?
|
||||
return false if unclosed_triple_backticks?
|
||||
return false if partial_emoji?
|
||||
|
||||
true
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def sanitize(html)
|
||||
text = html.gsub(%r{</?[^>]+>}, "") # remove tags like <span>, <del>, etc.
|
||||
CGI.unescapeHTML(text)
|
||||
end
|
||||
|
||||
def unclosed_markdown_links?
|
||||
open_brackets = @text.count("[")
|
||||
close_brackets = @text.count("]")
|
||||
open_parens = @text.count("(")
|
||||
close_parens = @text.count(")")
|
||||
|
||||
open_brackets != close_brackets || open_parens != close_parens
|
||||
end
|
||||
|
||||
def unclosed_raw_html_tag?
|
||||
last_lt = @text.rindex("<")
|
||||
last_gt = @text.rindex(">")
|
||||
last_lt && (!last_gt || last_gt < last_lt)
|
||||
end
|
||||
|
||||
def trailing_incomplete_url?
|
||||
last_word = @text.split(/\s/).last
|
||||
last_word =~ %r{\Ahttps?://[^\s]*\z} && last_word !~ /[)\].,!?:;'"]\z/
|
||||
end
|
||||
|
||||
def unclosed_backticks?
|
||||
@text.count("`").odd?
|
||||
end
|
||||
|
||||
def unbalanced_bold_or_italic?
|
||||
@text.scan(/\*\*/).count.odd? || @text.scan(/\*(?!\*)/).count.odd? ||
|
||||
@text.scan(/_/).count.odd?
|
||||
end
|
||||
|
||||
def incomplete_image_markdown?
|
||||
last_image = @text[/!\[.*?\]\(.*?$/, 0]
|
||||
last_image && last_image[-1] != ")"
|
||||
end
|
||||
|
||||
def unbalanced_quote_blocks?
|
||||
opens = @text.scan(/\[quote(=.*?)?\]/i).count
|
||||
closes = @text.scan(%r{\[/quote\]}i).count
|
||||
opens > closes
|
||||
end
|
||||
|
||||
def unclosed_triple_backticks?
|
||||
@text.scan(/```/).count.odd?
|
||||
end
|
||||
|
||||
def partial_emoji?
|
||||
text = @text.gsub(/!\[.*?\]\(.*?\)/, "").gsub(%r{https?://[^\s]+}, "")
|
||||
tokens = text.scan(/:[a-z0-9_+\-\.]+:?/i)
|
||||
tokens.any? { |token| token.start_with?(":") && !token.end_with?(":") }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
80
spec/lib/utils/diff_utils/safety_checker_spec.rb
Normal file
80
spec/lib/utils/diff_utils/safety_checker_spec.rb
Normal file
@ -0,0 +1,80 @@
|
||||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Utils::DiffUtils::SafetyChecker do
|
||||
describe "#safe?" do
|
||||
subject { described_class.new(text).safe? }
|
||||
|
||||
context "with safe text" do
|
||||
let(:text) { "This is a simple safe text without issues." }
|
||||
|
||||
it { is_expected.to eq(true) }
|
||||
|
||||
context "with normal HTML tags" do
|
||||
let(:text) { "Here is <strong>bold</strong> and <em>italic</em> text." }
|
||||
it { is_expected.to eq(true) }
|
||||
end
|
||||
|
||||
context "with balanced markdown and no partial emoji" do
|
||||
let(:text) { "This is **bold**, *italic*, and a smiley :smile:!" }
|
||||
it { is_expected.to eq(true) }
|
||||
end
|
||||
|
||||
context "with balanced quote blocks" do
|
||||
let(:text) { "[quote]Quoted text[/quote]" }
|
||||
it { is_expected.to eq(true) }
|
||||
end
|
||||
|
||||
context "with complete image markdown" do
|
||||
let(:text) { "" }
|
||||
it { is_expected.to eq(true) }
|
||||
end
|
||||
end
|
||||
|
||||
context "with unsafe text" do
|
||||
context "with unclosed markdown link" do
|
||||
let(:text) { "This is a [link(https://example.com)" }
|
||||
it { is_expected.to eq(false) }
|
||||
end
|
||||
|
||||
context "with unclosed raw HTML tag" do
|
||||
let(:text) { "Text with <div unclosed tag" }
|
||||
it { is_expected.to eq(false) }
|
||||
end
|
||||
|
||||
context "with trailing incomplete URL" do
|
||||
let(:text) { "Check this out https://example.com/something" } # no closing punctuation
|
||||
it { is_expected.to eq(false) }
|
||||
end
|
||||
|
||||
context "with unclosed backticks" do
|
||||
let(:text) { "Here is some `inline code without closing" }
|
||||
it { is_expected.to eq(false) }
|
||||
end
|
||||
|
||||
context "with unbalanced bold or italic markdown" do
|
||||
let(:text) { "This is *italic without closing" }
|
||||
it { is_expected.to eq(false) }
|
||||
end
|
||||
|
||||
context "with incomplete image markdown" do
|
||||
let(:text) { "Image 
|
||||
it { is_expected.to eq(false) }
|
||||
end
|
||||
|
||||
context "with unbalanced quote blocks" do
|
||||
let(:text) { "[quote]Unclosed quote block" }
|
||||
it { is_expected.to eq(false) }
|
||||
end
|
||||
|
||||
context "with unclosed triple backticks" do
|
||||
let(:text) { "```code block without closing" }
|
||||
it { is_expected.to eq(false) }
|
||||
end
|
||||
|
||||
context "with partial emoji" do
|
||||
let(:text) { "A partial emoji :smile" }
|
||||
it { is_expected.to eq(false) }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
Loading…
x
Reference in New Issue
Block a user