FEATURE: even smoother streaming (#420)
Account properly for function calls, don't stream through <details> blocks - Rush cooked content back to client - Wait longer (up to 60 seconds) before giving up on streaming - Clean up message bus channels so we don't have leftover data - Make ai streamer much more reusable and much easier to read - If buffer grows quickly, rush update so you are not artificially waiting - Refine prompt interface - Fix lost system message when prompt gets long
This commit is contained in:
parent
6b8a57d957
commit
825f01cfb2
|
@ -3,8 +3,9 @@ import loadScript from "discourse/lib/load-script";
|
|||
import { cook } from "discourse/lib/text";
|
||||
|
||||
const PROGRESS_INTERVAL = 40;
|
||||
const GIVE_UP_INTERVAL = 10000;
|
||||
const LETTERS_PER_INTERVAL = 6;
|
||||
const GIVE_UP_INTERVAL = 60000;
|
||||
export const MIN_LETTERS_PER_INTERVAL = 6;
|
||||
const MAX_FLUSH_TIME = 800;
|
||||
|
||||
let progressTimer = null;
|
||||
|
||||
|
@ -41,69 +42,146 @@ export function addProgressDot(element) {
|
|||
lastBlock.appendChild(dotElement);
|
||||
}
|
||||
|
||||
async function applyProgress(postStatus, postStream) {
|
||||
postStatus.startTime = postStatus.startTime || Date.now();
|
||||
let post = postStream.findLoadedPost(postStatus.post_id);
|
||||
// this is the interface we need to implement
|
||||
// for a streaming updater
|
||||
class StreamUpdater {
|
||||
set streaming(value) {
|
||||
throw "not implemented";
|
||||
}
|
||||
|
||||
const postElement = document.querySelector(`#post_${postStatus.post_number}`);
|
||||
async setCooked() {
|
||||
throw "not implemented";
|
||||
}
|
||||
|
||||
if (Date.now() - postStatus.startTime > GIVE_UP_INTERVAL) {
|
||||
if (postElement) {
|
||||
postElement.classList.remove("streaming");
|
||||
async setRaw() {
|
||||
throw "not implemented";
|
||||
}
|
||||
|
||||
get element() {
|
||||
throw "not implemented";
|
||||
}
|
||||
|
||||
get raw() {
|
||||
throw "not implemented";
|
||||
}
|
||||
}
|
||||
|
||||
class PostUpdater extends StreamUpdater {
|
||||
constructor(postStream, postId) {
|
||||
super();
|
||||
this.postStream = postStream;
|
||||
this.postId = postId;
|
||||
this.post = postStream.findLoadedPost(postId);
|
||||
|
||||
if (this.post) {
|
||||
this.postElement = document.querySelector(
|
||||
`#post_${this.post.post_number}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
get element() {
|
||||
return this.postElement;
|
||||
}
|
||||
|
||||
set streaming(value) {
|
||||
if (this.postElement) {
|
||||
if (value) {
|
||||
this.postElement.classList.add("streaming");
|
||||
} else {
|
||||
this.postElement.classList.remove("streaming");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async setRaw(value, done) {
|
||||
this.post.set("raw", value);
|
||||
const cooked = await cook(value);
|
||||
|
||||
// resets animation
|
||||
this.element.classList.remove("streaming");
|
||||
void this.element.offsetWidth;
|
||||
this.element.classList.add("streaming");
|
||||
|
||||
const cookedElement = document.createElement("div");
|
||||
cookedElement.innerHTML = cooked;
|
||||
|
||||
if (!done) {
|
||||
addProgressDot(cookedElement);
|
||||
}
|
||||
|
||||
await this.setCooked(cookedElement.innerHTML);
|
||||
}
|
||||
|
||||
async setCooked(value) {
|
||||
this.post.set("cooked", value);
|
||||
|
||||
const oldElement = this.postElement.querySelector(".cooked");
|
||||
|
||||
await loadScript("/javascripts/diffhtml.min.js");
|
||||
window.diff.innerHTML(oldElement, value);
|
||||
}
|
||||
|
||||
get raw() {
|
||||
return this.post.get("raw") || "";
|
||||
}
|
||||
}
|
||||
|
||||
export async function applyProgress(status, updater) {
|
||||
status.startTime = status.startTime || Date.now();
|
||||
|
||||
if (Date.now() - status.startTime > GIVE_UP_INTERVAL) {
|
||||
updater.streaming = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!post) {
|
||||
if (!updater.element) {
|
||||
// wait till later
|
||||
return false;
|
||||
}
|
||||
|
||||
const oldRaw = post.get("raw") || "";
|
||||
const oldRaw = updater.raw;
|
||||
|
||||
if (postStatus.raw === oldRaw && !postStatus.done) {
|
||||
const hasProgressDot =
|
||||
postElement && postElement.querySelector(".progress-dot");
|
||||
if (status.raw === oldRaw && !status.done) {
|
||||
const hasProgressDot = updater.element.querySelector(".progress-dot");
|
||||
if (hasProgressDot) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (postStatus.raw) {
|
||||
const newRaw = postStatus.raw.substring(
|
||||
0,
|
||||
oldRaw.length + LETTERS_PER_INTERVAL
|
||||
);
|
||||
const cooked = await cook(newRaw);
|
||||
if (status.raw !== undefined) {
|
||||
let newRaw = status.raw;
|
||||
|
||||
post.set("raw", newRaw);
|
||||
post.set("cooked", cooked);
|
||||
if (!status.done) {
|
||||
// rush update if we have a </details> tag (function call)
|
||||
if (oldRaw.length === 0 && newRaw.indexOf("</details>") !== -1) {
|
||||
newRaw = status.raw;
|
||||
} else {
|
||||
const diff = newRaw.length - oldRaw.length;
|
||||
|
||||
// resets animation
|
||||
postElement.classList.remove("streaming");
|
||||
void postElement.offsetWidth;
|
||||
postElement.classList.add("streaming");
|
||||
// progress interval is 40ms
|
||||
// by default we add 6 letters per interval
|
||||
// but ... we want to be done in MAX_FLUSH_TIME
|
||||
let letters = Math.floor(diff / (MAX_FLUSH_TIME / PROGRESS_INTERVAL));
|
||||
if (letters < MIN_LETTERS_PER_INTERVAL) {
|
||||
letters = MIN_LETTERS_PER_INTERVAL;
|
||||
}
|
||||
|
||||
const cookedElement = document.createElement("div");
|
||||
cookedElement.innerHTML = cooked;
|
||||
|
||||
addProgressDot(cookedElement);
|
||||
|
||||
const element = document.querySelector(
|
||||
`#post_${postStatus.post_number} .cooked`
|
||||
);
|
||||
|
||||
await loadScript("/javascripts/diffhtml.min.js");
|
||||
window.diff.innerHTML(element, cookedElement.innerHTML);
|
||||
}
|
||||
|
||||
if (postStatus.done) {
|
||||
if (postElement) {
|
||||
postElement.classList.remove("streaming");
|
||||
newRaw = status.raw.substring(0, oldRaw.length + letters);
|
||||
}
|
||||
}
|
||||
|
||||
await updater.setRaw(newRaw, status.done);
|
||||
}
|
||||
|
||||
return postStatus.done;
|
||||
if (status.done) {
|
||||
if (status.cooked) {
|
||||
await updater.setCooked(status.cooked);
|
||||
}
|
||||
updater.streaming = false;
|
||||
}
|
||||
|
||||
return status.done;
|
||||
}
|
||||
|
||||
async function handleProgress(postStream) {
|
||||
|
@ -114,7 +192,8 @@ async function handleProgress(postStream) {
|
|||
const promises = Object.keys(status).map(async (postId) => {
|
||||
let postStatus = status[postId];
|
||||
|
||||
const done = await applyProgress(postStatus, postStream);
|
||||
const postUpdater = new PostUpdater(postStream, postStatus.post_id);
|
||||
const done = await applyProgress(postStatus, postUpdater);
|
||||
|
||||
if (done) {
|
||||
delete status[postId];
|
||||
|
@ -142,6 +221,10 @@ function ensureProgress(postStream) {
|
|||
}
|
||||
|
||||
export default function streamText(postStream, data) {
|
||||
if (data.noop) {
|
||||
return;
|
||||
}
|
||||
|
||||
let status = (postStream.aiStreamingStatus =
|
||||
postStream.aiStreamingStatus || {});
|
||||
status[data.post_id] = data;
|
||||
|
|
|
@ -59,7 +59,7 @@ en:
|
|||
description: "Enable debug mode to see the raw input and output of the LLM"
|
||||
priority_group:
|
||||
label: "Priority Group"
|
||||
description: "Priotize content from this group in the report"
|
||||
description: "Prioritize content from this group in the report"
|
||||
|
||||
llm_triage:
|
||||
fields:
|
||||
|
|
|
@ -112,9 +112,10 @@ module DiscourseAi
|
|||
topic_id: post.topic_id,
|
||||
raw: "",
|
||||
skip_validations: true,
|
||||
skip_jobs: true,
|
||||
)
|
||||
|
||||
publish_update(reply_post, raw: "<p></p>")
|
||||
publish_update(reply_post, { raw: reply_post.cooked })
|
||||
|
||||
redis_stream_key = "gpt_cancel:#{reply_post.id}"
|
||||
Discourse.redis.setex(redis_stream_key, 60, 1)
|
||||
|
@ -139,12 +140,14 @@ module DiscourseAi
|
|||
|
||||
Discourse.redis.expire(redis_stream_key, 60)
|
||||
|
||||
publish_update(reply_post, raw: raw)
|
||||
publish_update(reply_post, { raw: raw })
|
||||
end
|
||||
|
||||
return if reply.blank?
|
||||
|
||||
publish_update(reply_post, done: true)
|
||||
# land the final message prior to saving so we don't clash
|
||||
reply_post.cooked = PrettyText.cook(reply)
|
||||
publish_final_update(reply_post)
|
||||
|
||||
reply_post.revise(bot.bot_user, { raw: reply }, skip_validations: true, skip_revision: true)
|
||||
|
||||
|
@ -157,10 +160,25 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
reply_post
|
||||
ensure
|
||||
publish_final_update(reply_post)
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def publish_final_update(reply_post)
|
||||
return if @published_final_update
|
||||
if reply_post
|
||||
publish_update(reply_post, { cooked: reply_post.cooked, done: true })
|
||||
# we subscribe at position -2 so we will always get this message
|
||||
# moving all cooked on every page load is wasteful ... this means
|
||||
# we have a benign message at the end, 2 is set to ensure last message
|
||||
# is delivered
|
||||
publish_update(reply_post, { noop: true })
|
||||
@published_final_update = true
|
||||
end
|
||||
end
|
||||
|
||||
attr_reader :bot
|
||||
|
||||
def can_attach?(post)
|
||||
|
@ -201,10 +219,15 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def publish_update(bot_reply_post, payload)
|
||||
payload = { post_id: bot_reply_post.id, post_number: bot_reply_post.post_number }.merge(
|
||||
payload,
|
||||
)
|
||||
MessageBus.publish(
|
||||
"discourse-ai/ai-bot/topic/#{bot_reply_post.topic_id}",
|
||||
payload.merge(post_id: bot_reply_post.id, post_number: bot_reply_post.post_number),
|
||||
payload,
|
||||
user_ids: bot_reply_post.topic.allowed_user_ids,
|
||||
max_backlog_size: 2,
|
||||
max_backlog_age: 60,
|
||||
)
|
||||
end
|
||||
|
||||
|
|
|
@ -24,11 +24,11 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def query
|
||||
parameters[:query].to_s
|
||||
parameters[:query].to_s.strip
|
||||
end
|
||||
|
||||
def invoke(bot_user, llm)
|
||||
yield("") # Triggers placeholder update
|
||||
yield(query)
|
||||
|
||||
api_key = SiteSetting.ai_google_custom_search_api_key
|
||||
cx = SiteSetting.ai_google_custom_search_cx
|
||||
|
|
|
@ -25,7 +25,10 @@ module DiscourseAi
|
|||
messages = prompt.messages
|
||||
|
||||
# ChatGPT doesn't use an assistant msg to improve long-context responses.
|
||||
messages.pop if messages.last[:type] == :model
|
||||
if messages.last[:type] == :model
|
||||
messages = messages.dup
|
||||
messages.pop
|
||||
end
|
||||
|
||||
trimmed_messages = trim_messages(messages)
|
||||
|
||||
|
|
|
@ -115,36 +115,49 @@ module DiscourseAi
|
|||
current_token_count = 0
|
||||
message_step_size = (max_prompt_tokens / 25).to_i * -1
|
||||
|
||||
reversed_trimmed_msgs =
|
||||
messages
|
||||
.reverse
|
||||
.reduce([]) do |acc, msg|
|
||||
message_tokens = calculate_message_token(msg)
|
||||
trimmed_messages = []
|
||||
|
||||
dupped_msg = msg.dup
|
||||
range = (0..-1)
|
||||
if messages.dig(0, :type) == :system
|
||||
system_message = messages[0]
|
||||
trimmed_messages << system_message
|
||||
current_token_count += calculate_message_token(system_message)
|
||||
range = (1..-1)
|
||||
end
|
||||
|
||||
# Don't trim tool call metadata.
|
||||
if msg[:type] == :tool_call
|
||||
current_token_count += message_tokens + per_message_overhead
|
||||
acc << dupped_msg
|
||||
next(acc)
|
||||
end
|
||||
reversed_trimmed_msgs = []
|
||||
|
||||
# Trimming content to make sure we respect token limit.
|
||||
while dupped_msg[:content].present? &&
|
||||
message_tokens + current_token_count + per_message_overhead > prompt_limit
|
||||
dupped_msg[:content] = dupped_msg[:content][0..message_step_size] || ""
|
||||
message_tokens = calculate_message_token(dupped_msg)
|
||||
end
|
||||
messages[range].reverse.each do |msg|
|
||||
break if current_token_count >= prompt_limit
|
||||
|
||||
next(acc) if dupped_msg[:content].blank?
|
||||
message_tokens = calculate_message_token(msg)
|
||||
|
||||
current_token_count += message_tokens + per_message_overhead
|
||||
dupped_msg = msg.dup
|
||||
|
||||
acc << dupped_msg
|
||||
end
|
||||
# Don't trim tool call metadata.
|
||||
if msg[:type] == :tool_call
|
||||
break if current_token_count + message_tokens + per_message_overhead > prompt_limit
|
||||
|
||||
reversed_trimmed_msgs.reverse
|
||||
current_token_count += message_tokens + per_message_overhead
|
||||
reversed_trimmed_msgs << dupped_msg
|
||||
next
|
||||
end
|
||||
|
||||
# Trimming content to make sure we respect token limit.
|
||||
while dupped_msg[:content].present? &&
|
||||
message_tokens + current_token_count + per_message_overhead > prompt_limit
|
||||
dupped_msg[:content] = dupped_msg[:content][0..message_step_size] || ""
|
||||
message_tokens = calculate_message_token(dupped_msg)
|
||||
end
|
||||
|
||||
next if dupped_msg[:content].blank?
|
||||
|
||||
current_token_count += message_tokens + per_message_overhead
|
||||
|
||||
reversed_trimmed_msgs << dupped_msg
|
||||
end
|
||||
|
||||
trimmed_messages.concat(reversed_trimmed_msgs.reverse)
|
||||
end
|
||||
|
||||
def per_message_overhead
|
||||
|
|
|
@ -83,6 +83,16 @@ module DiscourseAi
|
|||
stop_sequences: stop_sequences,
|
||||
}
|
||||
|
||||
if prompt.is_a?(String)
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are a helpful bot",
|
||||
messages: [{ type: :user, content: prompt }],
|
||||
)
|
||||
elsif prompt.is_a?(Array)
|
||||
prompt = DiscourseAi::Completions::Prompt.new(messages: prompt)
|
||||
end
|
||||
|
||||
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
|
||||
|
||||
dialect = dialect_klass.new(prompt, model_name, opts: model_params)
|
||||
|
|
|
@ -5,16 +5,22 @@ module DiscourseAi
|
|||
class Prompt
|
||||
INVALID_TURN = Class.new(StandardError)
|
||||
|
||||
attr_reader :system_message, :messages
|
||||
attr_reader :messages
|
||||
attr_accessor :tools
|
||||
|
||||
def initialize(system_msg, messages: [], tools: [])
|
||||
def initialize(system_message_text = nil, messages: [], tools: [])
|
||||
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
|
||||
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
|
||||
|
||||
system_message = { type: :system, content: system_msg }
|
||||
@messages = []
|
||||
|
||||
if system_message_text
|
||||
system_message = { type: :system, content: system_message_text }
|
||||
@messages << system_message
|
||||
end
|
||||
|
||||
@messages.concat(messages)
|
||||
|
||||
@messages = [system_message].concat(messages)
|
||||
@messages.each { |message| validate_message(message) }
|
||||
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }
|
||||
|
||||
|
|
|
@ -50,6 +50,21 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
|
|||
expect(translated.last[:role]).to eq("user")
|
||||
expect(translated.last[:content].length).to be < context.long_message_text.length
|
||||
end
|
||||
|
||||
it "always preserves system message when trimming" do
|
||||
# gpt-4 is 8k tokens so last message totally blows everything
|
||||
prompt = DiscourseAi::Completions::Prompt.new("You are a bot")
|
||||
prompt.push(type: :user, content: "a " * 100)
|
||||
prompt.push(type: :model, content: "b " * 100)
|
||||
prompt.push(type: :user, content: "zjk " * 10_000)
|
||||
|
||||
translated = context.dialect(prompt).translate
|
||||
|
||||
expect(translated.length).to eq(2)
|
||||
expect(translated.first).to eq(content: "You are a bot", role: "system")
|
||||
expect(translated.last[:role]).to eq("user")
|
||||
expect(translated.last[:content].length).to be < (8000 * 4)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#tools" do
|
||||
|
|
|
@ -52,6 +52,26 @@ RSpec.describe DiscourseAi::Completions::Llm do
|
|||
end
|
||||
end
|
||||
|
||||
describe "#generate with various style prompts" do
|
||||
let :canned_response do
|
||||
DiscourseAi::Completions::Endpoints::CannedResponse.new(["world"])
|
||||
end
|
||||
|
||||
it "can generate a response to a simple string" do
|
||||
response = llm.generate("hello", user: user)
|
||||
expect(response).to eq("world")
|
||||
end
|
||||
|
||||
it "can generate a response from an array" do
|
||||
response =
|
||||
llm.generate(
|
||||
[{ type: :system, content: "you are a bot" }, { type: :user, content: "hello" }],
|
||||
user: user,
|
||||
)
|
||||
expect(response).to eq("world")
|
||||
end
|
||||
end
|
||||
|
||||
describe "#generate" do
|
||||
let(:prompt) do
|
||||
system_insts = (<<~TEXT).strip
|
||||
|
|
|
@ -64,16 +64,21 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
playground.reply_to(third_post)
|
||||
end
|
||||
|
||||
reply = pm.reload.posts.last
|
||||
|
||||
noop_signal = messages.pop
|
||||
expect(noop_signal.data[:noop]).to eq(true)
|
||||
|
||||
done_signal = messages.pop
|
||||
expect(done_signal.data[:done]).to eq(true)
|
||||
expect(done_signal.data[:cooked]).to eq(reply.cooked)
|
||||
|
||||
# we need this for styling
|
||||
expect(messages.first.data[:raw]).to eq("<p></p>")
|
||||
expect(messages.first.data[:raw]).to eq("")
|
||||
messages[1..-1].each_with_index do |m, idx|
|
||||
expect(m.data[:raw]).to eq(expected_bot_response[0..idx])
|
||||
end
|
||||
|
||||
expect(pm.reload.posts.last.cooked).to eq(PrettyText.cook(expected_bot_response))
|
||||
expect(reply.cooked).to eq(PrettyText.cook(expected_bot_response))
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -1,5 +1,49 @@
|
|||
import { module, test } from "qunit";
|
||||
import { addProgressDot } from "discourse/plugins/discourse-ai/discourse/lib/ai-streamer";
|
||||
import {
|
||||
addProgressDot,
|
||||
applyProgress,
|
||||
MIN_LETTERS_PER_INTERVAL,
|
||||
} from "discourse/plugins/discourse-ai/discourse/lib/ai-streamer";
|
||||
|
||||
class FakeStreamUpdater {
|
||||
constructor() {
|
||||
this._streaming = true;
|
||||
this._raw = "";
|
||||
this._cooked = "";
|
||||
this._element = document.createElement("div");
|
||||
}
|
||||
|
||||
get streaming() {
|
||||
return this._streaming;
|
||||
}
|
||||
set streaming(value) {
|
||||
this._streaming = value;
|
||||
}
|
||||
|
||||
get cooked() {
|
||||
return this._cooked;
|
||||
}
|
||||
|
||||
get raw() {
|
||||
return this._raw;
|
||||
}
|
||||
|
||||
async setRaw(value) {
|
||||
this._raw = value;
|
||||
// just fake it, calling cook is tricky
|
||||
const cooked = `<p>${value}</p>`;
|
||||
await this.setCooked(cooked);
|
||||
}
|
||||
|
||||
async setCooked(value) {
|
||||
this._cooked = value;
|
||||
this._element.innerHTML = value;
|
||||
}
|
||||
|
||||
get element() {
|
||||
return this._element;
|
||||
}
|
||||
}
|
||||
|
||||
module("Discourse AI | Unit | Lib | ai-streamer", function () {
|
||||
function confirmPlaceholder(html, expected, assert) {
|
||||
|
@ -69,4 +113,55 @@ module("Discourse AI | Unit | Lib | ai-streamer", function () {
|
|||
|
||||
confirmPlaceholder(html, expected, assert);
|
||||
});
|
||||
|
||||
test("can perform delta updates", async function (assert) {
|
||||
const status = {
|
||||
startTime: Date.now(),
|
||||
raw: "some raw content",
|
||||
done: false,
|
||||
};
|
||||
|
||||
const streamUpdater = new FakeStreamUpdater();
|
||||
|
||||
let done = await applyProgress(status, streamUpdater);
|
||||
|
||||
assert.notOk(done, "The update should not be done.");
|
||||
|
||||
assert.equal(
|
||||
streamUpdater.raw,
|
||||
status.raw.substring(0, MIN_LETTERS_PER_INTERVAL),
|
||||
"The raw content should delta update."
|
||||
);
|
||||
|
||||
done = await applyProgress(status, streamUpdater);
|
||||
|
||||
assert.notOk(done, "The update should not be done.");
|
||||
|
||||
assert.equal(
|
||||
streamUpdater.raw,
|
||||
status.raw.substring(0, MIN_LETTERS_PER_INTERVAL * 2),
|
||||
"The raw content should delta update."
|
||||
);
|
||||
|
||||
// last chunk
|
||||
await applyProgress(status, streamUpdater);
|
||||
|
||||
const innerHtml = streamUpdater.element.innerHTML;
|
||||
assert.equal(
|
||||
innerHtml,
|
||||
"<p>some raw content</p>",
|
||||
"The cooked content should be updated."
|
||||
);
|
||||
|
||||
status.done = true;
|
||||
status.cooked = "<p>updated cooked</p>";
|
||||
|
||||
await applyProgress(status, streamUpdater);
|
||||
|
||||
assert.equal(
|
||||
streamUpdater.element.innerHTML,
|
||||
"<p>updated cooked</p>",
|
||||
"The cooked content should be updated."
|
||||
);
|
||||
});
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue