FEATURE: even smoother streaming (#420)

Account properly for function calls, don't stream through <details> blocks
- Rush cooked content back to client
- Wait longer (up to 60 seconds) before giving up on streaming
- Clean up message bus channels so we don't have leftover data
- Make ai streamer much more reusable and much easier to read
- If buffer grows quickly, rush update so you are not artificially waiting
- Refine prompt interface
- Fix lost system message when prompt gets long
This commit is contained in:
Sam 2024-01-15 18:51:14 +11:00 committed by GitHub
parent 6b8a57d957
commit 825f01cfb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 358 additions and 85 deletions

View File

@ -3,8 +3,9 @@ import loadScript from "discourse/lib/load-script";
import { cook } from "discourse/lib/text";
const PROGRESS_INTERVAL = 40;
const GIVE_UP_INTERVAL = 10000;
const LETTERS_PER_INTERVAL = 6;
const GIVE_UP_INTERVAL = 60000;
export const MIN_LETTERS_PER_INTERVAL = 6;
const MAX_FLUSH_TIME = 800;
let progressTimer = null;
@ -41,69 +42,146 @@ export function addProgressDot(element) {
lastBlock.appendChild(dotElement);
}
async function applyProgress(postStatus, postStream) {
postStatus.startTime = postStatus.startTime || Date.now();
let post = postStream.findLoadedPost(postStatus.post_id);
// this is the interface we need to implement
// for a streaming updater
class StreamUpdater {
set streaming(value) {
throw "not implemented";
}
const postElement = document.querySelector(`#post_${postStatus.post_number}`);
async setCooked() {
throw "not implemented";
}
if (Date.now() - postStatus.startTime > GIVE_UP_INTERVAL) {
if (postElement) {
postElement.classList.remove("streaming");
async setRaw() {
throw "not implemented";
}
get element() {
throw "not implemented";
}
get raw() {
throw "not implemented";
}
}
class PostUpdater extends StreamUpdater {
constructor(postStream, postId) {
super();
this.postStream = postStream;
this.postId = postId;
this.post = postStream.findLoadedPost(postId);
if (this.post) {
this.postElement = document.querySelector(
`#post_${this.post.post_number}`
);
}
}
get element() {
return this.postElement;
}
set streaming(value) {
if (this.postElement) {
if (value) {
this.postElement.classList.add("streaming");
} else {
this.postElement.classList.remove("streaming");
}
}
}
async setRaw(value, done) {
this.post.set("raw", value);
const cooked = await cook(value);
// resets animation
this.element.classList.remove("streaming");
void this.element.offsetWidth;
this.element.classList.add("streaming");
const cookedElement = document.createElement("div");
cookedElement.innerHTML = cooked;
if (!done) {
addProgressDot(cookedElement);
}
await this.setCooked(cookedElement.innerHTML);
}
async setCooked(value) {
this.post.set("cooked", value);
const oldElement = this.postElement.querySelector(".cooked");
await loadScript("/javascripts/diffhtml.min.js");
window.diff.innerHTML(oldElement, value);
}
get raw() {
return this.post.get("raw") || "";
}
}
export async function applyProgress(status, updater) {
status.startTime = status.startTime || Date.now();
if (Date.now() - status.startTime > GIVE_UP_INTERVAL) {
updater.streaming = false;
return true;
}
if (!post) {
if (!updater.element) {
// wait till later
return false;
}
const oldRaw = post.get("raw") || "";
const oldRaw = updater.raw;
if (postStatus.raw === oldRaw && !postStatus.done) {
const hasProgressDot =
postElement && postElement.querySelector(".progress-dot");
if (status.raw === oldRaw && !status.done) {
const hasProgressDot = updater.element.querySelector(".progress-dot");
if (hasProgressDot) {
return false;
}
}
if (postStatus.raw) {
const newRaw = postStatus.raw.substring(
0,
oldRaw.length + LETTERS_PER_INTERVAL
);
const cooked = await cook(newRaw);
if (status.raw !== undefined) {
let newRaw = status.raw;
post.set("raw", newRaw);
post.set("cooked", cooked);
if (!status.done) {
// rush update if we have a </details> tag (function call)
if (oldRaw.length === 0 && newRaw.indexOf("</details>") !== -1) {
newRaw = status.raw;
} else {
const diff = newRaw.length - oldRaw.length;
// resets animation
postElement.classList.remove("streaming");
void postElement.offsetWidth;
postElement.classList.add("streaming");
// progress interval is 40ms
// by default we add 6 letters per interval
// but ... we want to be done in MAX_FLUSH_TIME
let letters = Math.floor(diff / (MAX_FLUSH_TIME / PROGRESS_INTERVAL));
if (letters < MIN_LETTERS_PER_INTERVAL) {
letters = MIN_LETTERS_PER_INTERVAL;
}
const cookedElement = document.createElement("div");
cookedElement.innerHTML = cooked;
addProgressDot(cookedElement);
const element = document.querySelector(
`#post_${postStatus.post_number} .cooked`
);
await loadScript("/javascripts/diffhtml.min.js");
window.diff.innerHTML(element, cookedElement.innerHTML);
}
if (postStatus.done) {
if (postElement) {
postElement.classList.remove("streaming");
newRaw = status.raw.substring(0, oldRaw.length + letters);
}
}
await updater.setRaw(newRaw, status.done);
}
return postStatus.done;
if (status.done) {
if (status.cooked) {
await updater.setCooked(status.cooked);
}
updater.streaming = false;
}
return status.done;
}
async function handleProgress(postStream) {
@ -114,7 +192,8 @@ async function handleProgress(postStream) {
const promises = Object.keys(status).map(async (postId) => {
let postStatus = status[postId];
const done = await applyProgress(postStatus, postStream);
const postUpdater = new PostUpdater(postStream, postStatus.post_id);
const done = await applyProgress(postStatus, postUpdater);
if (done) {
delete status[postId];
@ -142,6 +221,10 @@ function ensureProgress(postStream) {
}
export default function streamText(postStream, data) {
if (data.noop) {
return;
}
let status = (postStream.aiStreamingStatus =
postStream.aiStreamingStatus || {});
status[data.post_id] = data;

View File

@ -59,7 +59,7 @@ en:
description: "Enable debug mode to see the raw input and output of the LLM"
priority_group:
label: "Priority Group"
description: "Priotize content from this group in the report"
description: "Prioritize content from this group in the report"
llm_triage:
fields:

View File

@ -112,9 +112,10 @@ module DiscourseAi
topic_id: post.topic_id,
raw: "",
skip_validations: true,
skip_jobs: true,
)
publish_update(reply_post, raw: "<p></p>")
publish_update(reply_post, { raw: reply_post.cooked })
redis_stream_key = "gpt_cancel:#{reply_post.id}"
Discourse.redis.setex(redis_stream_key, 60, 1)
@ -139,12 +140,14 @@ module DiscourseAi
Discourse.redis.expire(redis_stream_key, 60)
publish_update(reply_post, raw: raw)
publish_update(reply_post, { raw: raw })
end
return if reply.blank?
publish_update(reply_post, done: true)
# land the final message prior to saving so we don't clash
reply_post.cooked = PrettyText.cook(reply)
publish_final_update(reply_post)
reply_post.revise(bot.bot_user, { raw: reply }, skip_validations: true, skip_revision: true)
@ -157,10 +160,25 @@ module DiscourseAi
end
reply_post
ensure
publish_final_update(reply_post)
end
private
def publish_final_update(reply_post)
return if @published_final_update
if reply_post
publish_update(reply_post, { cooked: reply_post.cooked, done: true })
# we subscribe at position -2 so we will always get this message
# moving all cooked on every page load is wasteful ... this means
# we have a benign message at the end, 2 is set to ensure last message
# is delivered
publish_update(reply_post, { noop: true })
@published_final_update = true
end
end
attr_reader :bot
def can_attach?(post)
@ -201,10 +219,15 @@ module DiscourseAi
end
def publish_update(bot_reply_post, payload)
payload = { post_id: bot_reply_post.id, post_number: bot_reply_post.post_number }.merge(
payload,
)
MessageBus.publish(
"discourse-ai/ai-bot/topic/#{bot_reply_post.topic_id}",
payload.merge(post_id: bot_reply_post.id, post_number: bot_reply_post.post_number),
payload,
user_ids: bot_reply_post.topic.allowed_user_ids,
max_backlog_size: 2,
max_backlog_age: 60,
)
end

View File

@ -24,11 +24,11 @@ module DiscourseAi
end
def query
parameters[:query].to_s
parameters[:query].to_s.strip
end
def invoke(bot_user, llm)
yield("") # Triggers placeholder update
yield(query)
api_key = SiteSetting.ai_google_custom_search_api_key
cx = SiteSetting.ai_google_custom_search_cx

View File

@ -25,7 +25,10 @@ module DiscourseAi
messages = prompt.messages
# ChatGPT doesn't use an assistant msg to improve long-context responses.
messages.pop if messages.last[:type] == :model
if messages.last[:type] == :model
messages = messages.dup
messages.pop
end
trimmed_messages = trim_messages(messages)

View File

@ -115,36 +115,49 @@ module DiscourseAi
current_token_count = 0
message_step_size = (max_prompt_tokens / 25).to_i * -1
reversed_trimmed_msgs =
messages
.reverse
.reduce([]) do |acc, msg|
message_tokens = calculate_message_token(msg)
trimmed_messages = []
dupped_msg = msg.dup
range = (0..-1)
if messages.dig(0, :type) == :system
system_message = messages[0]
trimmed_messages << system_message
current_token_count += calculate_message_token(system_message)
range = (1..-1)
end
# Don't trim tool call metadata.
if msg[:type] == :tool_call
current_token_count += message_tokens + per_message_overhead
acc << dupped_msg
next(acc)
end
reversed_trimmed_msgs = []
# Trimming content to make sure we respect token limit.
while dupped_msg[:content].present? &&
message_tokens + current_token_count + per_message_overhead > prompt_limit
dupped_msg[:content] = dupped_msg[:content][0..message_step_size] || ""
message_tokens = calculate_message_token(dupped_msg)
end
messages[range].reverse.each do |msg|
break if current_token_count >= prompt_limit
next(acc) if dupped_msg[:content].blank?
message_tokens = calculate_message_token(msg)
current_token_count += message_tokens + per_message_overhead
dupped_msg = msg.dup
acc << dupped_msg
end
# Don't trim tool call metadata.
if msg[:type] == :tool_call
break if current_token_count + message_tokens + per_message_overhead > prompt_limit
reversed_trimmed_msgs.reverse
current_token_count += message_tokens + per_message_overhead
reversed_trimmed_msgs << dupped_msg
next
end
# Trimming content to make sure we respect token limit.
while dupped_msg[:content].present? &&
message_tokens + current_token_count + per_message_overhead > prompt_limit
dupped_msg[:content] = dupped_msg[:content][0..message_step_size] || ""
message_tokens = calculate_message_token(dupped_msg)
end
next if dupped_msg[:content].blank?
current_token_count += message_tokens + per_message_overhead
reversed_trimmed_msgs << dupped_msg
end
trimmed_messages.concat(reversed_trimmed_msgs.reverse)
end
def per_message_overhead

View File

@ -83,6 +83,16 @@ module DiscourseAi
stop_sequences: stop_sequences,
}
if prompt.is_a?(String)
prompt =
DiscourseAi::Completions::Prompt.new(
"You are a helpful bot",
messages: [{ type: :user, content: prompt }],
)
elsif prompt.is_a?(Array)
prompt = DiscourseAi::Completions::Prompt.new(messages: prompt)
end
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
dialect = dialect_klass.new(prompt, model_name, opts: model_params)

View File

@ -5,16 +5,22 @@ module DiscourseAi
class Prompt
INVALID_TURN = Class.new(StandardError)
attr_reader :system_message, :messages
attr_reader :messages
attr_accessor :tools
def initialize(system_msg, messages: [], tools: [])
def initialize(system_message_text = nil, messages: [], tools: [])
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
system_message = { type: :system, content: system_msg }
@messages = []
if system_message_text
system_message = { type: :system, content: system_message_text }
@messages << system_message
end
@messages.concat(messages)
@messages = [system_message].concat(messages)
@messages.each { |message| validate_message(message) }
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }

View File

@ -50,6 +50,21 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
expect(translated.last[:role]).to eq("user")
expect(translated.last[:content].length).to be < context.long_message_text.length
end
it "always preserves system message when trimming" do
# gpt-4 is 8k tokens so last message totally blows everything
prompt = DiscourseAi::Completions::Prompt.new("You are a bot")
prompt.push(type: :user, content: "a " * 100)
prompt.push(type: :model, content: "b " * 100)
prompt.push(type: :user, content: "zjk " * 10_000)
translated = context.dialect(prompt).translate
expect(translated.length).to eq(2)
expect(translated.first).to eq(content: "You are a bot", role: "system")
expect(translated.last[:role]).to eq("user")
expect(translated.last[:content].length).to be < (8000 * 4)
end
end
describe "#tools" do

View File

@ -52,6 +52,26 @@ RSpec.describe DiscourseAi::Completions::Llm do
end
end
describe "#generate with various style prompts" do
let :canned_response do
DiscourseAi::Completions::Endpoints::CannedResponse.new(["world"])
end
it "can generate a response to a simple string" do
response = llm.generate("hello", user: user)
expect(response).to eq("world")
end
it "can generate a response from an array" do
response =
llm.generate(
[{ type: :system, content: "you are a bot" }, { type: :user, content: "hello" }],
user: user,
)
expect(response).to eq("world")
end
end
describe "#generate" do
let(:prompt) do
system_insts = (<<~TEXT).strip

View File

@ -64,16 +64,21 @@ RSpec.describe DiscourseAi::AiBot::Playground do
playground.reply_to(third_post)
end
reply = pm.reload.posts.last
noop_signal = messages.pop
expect(noop_signal.data[:noop]).to eq(true)
done_signal = messages.pop
expect(done_signal.data[:done]).to eq(true)
expect(done_signal.data[:cooked]).to eq(reply.cooked)
# we need this for styling
expect(messages.first.data[:raw]).to eq("<p></p>")
expect(messages.first.data[:raw]).to eq("")
messages[1..-1].each_with_index do |m, idx|
expect(m.data[:raw]).to eq(expected_bot_response[0..idx])
end
expect(pm.reload.posts.last.cooked).to eq(PrettyText.cook(expected_bot_response))
expect(reply.cooked).to eq(PrettyText.cook(expected_bot_response))
end
end

View File

@ -1,5 +1,49 @@
import { module, test } from "qunit";
import { addProgressDot } from "discourse/plugins/discourse-ai/discourse/lib/ai-streamer";
import {
addProgressDot,
applyProgress,
MIN_LETTERS_PER_INTERVAL,
} from "discourse/plugins/discourse-ai/discourse/lib/ai-streamer";
class FakeStreamUpdater {
constructor() {
this._streaming = true;
this._raw = "";
this._cooked = "";
this._element = document.createElement("div");
}
get streaming() {
return this._streaming;
}
set streaming(value) {
this._streaming = value;
}
get cooked() {
return this._cooked;
}
get raw() {
return this._raw;
}
async setRaw(value) {
this._raw = value;
// just fake it, calling cook is tricky
const cooked = `<p>${value}</p>`;
await this.setCooked(cooked);
}
async setCooked(value) {
this._cooked = value;
this._element.innerHTML = value;
}
get element() {
return this._element;
}
}
module("Discourse AI | Unit | Lib | ai-streamer", function () {
function confirmPlaceholder(html, expected, assert) {
@ -69,4 +113,55 @@ module("Discourse AI | Unit | Lib | ai-streamer", function () {
confirmPlaceholder(html, expected, assert);
});
test("can perform delta updates", async function (assert) {
const status = {
startTime: Date.now(),
raw: "some raw content",
done: false,
};
const streamUpdater = new FakeStreamUpdater();
let done = await applyProgress(status, streamUpdater);
assert.notOk(done, "The update should not be done.");
assert.equal(
streamUpdater.raw,
status.raw.substring(0, MIN_LETTERS_PER_INTERVAL),
"The raw content should delta update."
);
done = await applyProgress(status, streamUpdater);
assert.notOk(done, "The update should not be done.");
assert.equal(
streamUpdater.raw,
status.raw.substring(0, MIN_LETTERS_PER_INTERVAL * 2),
"The raw content should delta update."
);
// last chunk
await applyProgress(status, streamUpdater);
const innerHtml = streamUpdater.element.innerHTML;
assert.equal(
innerHtml,
"<p>some raw content</p>",
"The cooked content should be updated."
);
status.done = true;
status.cooked = "<p>updated cooked</p>";
await applyProgress(status, streamUpdater);
assert.equal(
streamUpdater.element.innerHTML,
"<p>updated cooked</p>",
"The cooked content should be updated."
);
});
});