FEATURE: smooth streaming of AI responses on the client (#413)

This PR introduces 3 things:

1. Fake bot that can be used on local so you can test LLMs, to enable on dev use:

SiteSetting.ai_bot_enabled_chat_bots = "fake"

2. More elegant smooth streaming of progress on LLM completion

This leans on JavaScript to buffer and trickle llm results through. It also amends it so the progress dot is much 
more consistently rendered

3. It fixes the Claude dialect 

Claude needs newlines **exactly** at the right spot, amended so it is happy 

---------

Co-authored-by: Martin Brennan <martin@discourse.org>
This commit is contained in:
Sam 2024-01-11 15:56:40 +11:00 committed by GitHub
parent 37b957dbbb
commit 8df966e9c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 493 additions and 103 deletions

View File

@ -0,0 +1,149 @@
import { later } from "@ember/runloop";
import loadScript from "discourse/lib/load-script";
import { cook } from "discourse/lib/text";
const PROGRESS_INTERVAL = 40;
const GIVE_UP_INTERVAL = 10000;
const LETTERS_PER_INTERVAL = 6;
let progressTimer = null;
function lastNonEmptyChild(element) {
let lastChild = element.lastChild;
while (
lastChild &&
lastChild.nodeType === Node.TEXT_NODE &&
!/\S/.test(lastChild.textContent)
) {
lastChild = lastChild.previousSibling;
}
return lastChild;
}
export function addProgressDot(element) {
let lastBlock = element;
while (true) {
let lastChild = lastNonEmptyChild(lastBlock);
if (!lastChild) {
break;
}
if (lastChild.nodeType === Node.ELEMENT_NODE) {
lastBlock = lastChild;
} else {
break;
}
}
const dotElement = document.createElement("span");
dotElement.classList.add("progress-dot");
lastBlock.appendChild(dotElement);
}
async function applyProgress(postStatus, postStream) {
postStatus.startTime = postStatus.startTime || Date.now();
let post = postStream.findLoadedPost(postStatus.post_id);
const postElement = document.querySelector(`#post_${postStatus.post_number}`);
if (Date.now() - postStatus.startTime > GIVE_UP_INTERVAL) {
if (postElement) {
postElement.classList.remove("streaming");
}
return true;
}
if (!post) {
// wait till later
return false;
}
const oldRaw = post.get("raw") || "";
if (postStatus.raw === oldRaw && !postStatus.done) {
const hasProgressDot =
postElement && postElement.querySelector(".progress-dot");
if (hasProgressDot) {
return false;
}
}
if (postStatus.raw) {
const newRaw = postStatus.raw.substring(
0,
oldRaw.length + LETTERS_PER_INTERVAL
);
const cooked = await cook(newRaw);
post.set("raw", newRaw);
post.set("cooked", cooked);
// resets animation
postElement.classList.remove("streaming");
void postElement.offsetWidth;
postElement.classList.add("streaming");
const cookedElement = document.createElement("div");
cookedElement.innerHTML = cooked;
addProgressDot(cookedElement);
const element = document.querySelector(
`#post_${postStatus.post_number} .cooked`
);
await loadScript("/javascripts/diffhtml.min.js");
window.diff.innerHTML(element, cookedElement.innerHTML);
}
if (postStatus.done) {
if (postElement) {
postElement.classList.remove("streaming");
}
}
return postStatus.done;
}
async function handleProgress(postStream) {
const status = postStream.aiStreamingStatus;
let keepPolling = false;
const promises = Object.keys(status).map(async (postId) => {
let postStatus = status[postId];
const done = await applyProgress(postStatus, postStream);
if (done) {
delete status[postId];
} else {
keepPolling = true;
}
});
await Promise.all(promises);
return keepPolling;
}
function ensureProgress(postStream) {
if (!progressTimer) {
progressTimer = later(async () => {
const keepPolling = await handleProgress(postStream);
progressTimer = null;
if (keepPolling) {
ensureProgress(postStream);
}
}, PROGRESS_INTERVAL);
}
}
export default function streamText(postStream, data) {
let status = (postStream.aiStreamingStatus =
postStream.aiStreamingStatus || {});
status[data.post_id] = data;
ensureProgress(postStream);
}

View File

@ -1,19 +1,17 @@
import { later } from "@ember/runloop";
import { hbs } from "ember-cli-htmlbars";
import { ajax } from "discourse/lib/ajax";
import { popupAjaxError } from "discourse/lib/ajax-error";
import loadScript from "discourse/lib/load-script";
import { withPluginApi } from "discourse/lib/plugin-api";
import { cook } from "discourse/lib/text";
import { registerWidgetShim } from "discourse/widgets/render-glimmer";
import { composeAiBotMessage } from "discourse/plugins/discourse-ai/discourse/lib/ai-bot-helper";
import ShareModal from "../discourse/components/modal/share-modal";
import streamText from "../discourse/lib/ai-streamer";
import copyConversation from "../discourse/lib/copy-conversation";
const AUTO_COPY_THRESHOLD = 4;
function isGPTBot(user) {
return user && [-110, -111, -112, -113, -114, -115].includes(user.id);
return user && [-110, -111, -112, -113, -114, -115, -116].includes(user.id);
}
function attachHeaderIcon(api) {
@ -93,62 +91,7 @@ function initializeAIBotReplies(api) {
pluginId: "discourse-ai",
onAIBotStreamedReply: function (data) {
const post = this.model.postStream.findLoadedPost(data.post_id);
// it may take us a few seconds to load the post
// we need to requeue the event
if (!post && !data.done) {
const refresh = this.onAIBotStreamedReply.bind(this);
data.retries = data.retries || 5;
data.retries -= 1;
data.skipIfStreaming = true;
if (data.retries > 0) {
later(() => {
refresh(data);
}, 1000);
}
}
if (post) {
if (data.raw) {
const postElement = document.querySelector(
`#post_${data.post_number}`
);
if (
data.skipIfStreaming &&
postElement.classList.contains("streaming")
) {
return;
}
cook(data.raw).then((cooked) => {
post.set("raw", data.raw);
post.set("cooked", cooked);
// resets animation
postElement.classList.remove("streaming");
void postElement.offsetWidth;
postElement.classList.add("streaming");
const cookedElement = document.createElement("div");
cookedElement.innerHTML = cooked;
let element = document.querySelector(
`#post_${data.post_number} .cooked`
);
loadScript("/javascripts/diffhtml.min.js").then(() => {
window.diff.innerHTML(element, cookedElement.innerHTML);
});
});
}
if (data.done) {
document
.querySelector(`#post_${data.post_number}`)
.classList.remove("streaming");
}
}
streamText(this.model.postStream, data);
},
subscribe: function () {
this._super();

View File

@ -57,23 +57,18 @@ article.streaming nav.post-controls .actions button.cancel-streaming {
}
}
article.streaming .cooked > {
:not(ol):not(ul):not(pre):last-child::after,
ol:last-child li:last-child p:last-child::after,
ol:last-child li:last-child:not(:has(p))::after,
ul:last-child li:last-child p:last-child::after,
ul:last-child li:last-child:not(:has(p))::after,
pre:last-child code::after {
article.streaming .cooked .progress-dot::after {
content: "\25CF";
font-family: Söhne Circle, system-ui, -apple-system, Segoe UI, Roboto,
Ubuntu, Cantarell, Noto Sans, sans-serif;
font-family: Söhne Circle, system-ui, -apple-system, Segoe UI, Roboto, Ubuntu,
Cantarell, Noto Sans, sans-serif;
line-height: normal;
margin-left: 0.25rem;
vertical-align: baseline;
animation: flashing 1.5s 3s infinite;
display: inline-block;
}
font-size: 1rem;
color: var(--tertiary-medium);
}
.ai-bot-available-bot-options {

View File

@ -131,6 +131,8 @@ module DiscourseAi
"mistralai/Mixtral-8x7B-Instruct-v0.1"
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
"gemini-pro"
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
"fake"
else
nil
end

View File

@ -11,6 +11,8 @@ module DiscourseAi
GPT4_TURBO_ID = -113
MIXTRAL_ID = -114
GEMINI_ID = -115
FAKE_ID = -116 # only used for dev and test
BOTS = [
[GPT4_ID, "gpt4_bot", "gpt-4"],
[GPT3_5_TURBO_ID, "gpt3.5_bot", "gpt-3.5-turbo"],
@ -18,6 +20,7 @@ module DiscourseAi
[GPT4_TURBO_ID, "gpt4t_bot", "gpt-4-turbo"],
[MIXTRAL_ID, "mixtral_bot", "mixtral-8x7B-Instruct-V0.1"],
[GEMINI_ID, "gemini_bot", "gemini-pro"],
[FAKE_ID, "fake_bot", "fake"],
]
def self.map_bot_model_to_user_id(model_name)
@ -34,6 +37,8 @@ module DiscourseAi
MIXTRAL_ID
in "gemini-pro"
GEMINI_ID
in "fake"
FAKE_ID
else
nil
end

View File

@ -5,6 +5,10 @@ module DiscourseAi::AiBot::SiteSettingsExtension
enabled_bots = SiteSetting.ai_bot_enabled_chat_bots_map
enabled_bots = [] if !SiteSetting.ai_bot_enabled
DiscourseAi::AiBot::EntryPoint::BOTS.each do |id, bot_name, name|
if id == DiscourseAi::AiBot::EntryPoint::FAKE_ID
next if Rails.env.production?
end
active = enabled_bots.include?(name)
user = User.find_by(id: id)

View File

@ -14,6 +14,14 @@ module DiscourseAi
end
end
def pad_newlines!(prompt)
if prompt[-1..-1] != "\n"
prompt << "\n\n"
elsif prompt[-2..-1] != "\n\n"
prompt << "\n"
end
end
def translate
claude_prompt = uses_system_message? ? +"" : +"Human: "
claude_prompt << prompt[:insts] << "\n"
@ -22,8 +30,12 @@ module DiscourseAi
claude_prompt << build_examples(prompt[:examples]) if prompt[:examples]
pad_newlines!(claude_prompt)
claude_prompt << conversation_context if prompt[:conversation_context]
pad_newlines!(claude_prompt)
if uses_system_message? && (prompt[:input] || prompt[:post_insts])
claude_prompt << "Human: "
end
@ -31,10 +43,11 @@ module DiscourseAi
claude_prompt << "#{prompt[:post_insts]}\n" if prompt[:post_insts]
claude_prompt << "\n\n"
claude_prompt << "Assistant:"
pad_newlines!(claude_prompt)
claude_prompt << "Assistant: "
claude_prompt << " #{prompt[:final_insts]}:" if prompt[:final_insts]
claude_prompt << "\n"
claude_prompt
end
def max_prompt_tokens
@ -50,12 +63,12 @@ module DiscourseAi
trimmed_context
.reverse
.reduce(+"") do |memo, context|
memo << (context[:type] == "user" ? "Human:" : "Assistant:")
.map do |context|
row = context[:type] == "user" ? +"Human:" : +"Assistant:"
if context[:type] == "tool"
memo << <<~TEXT
row << "\n"
row << (<<~TEXT).strip
<function_results>
<result>
<tool_name>#{context[:name]}</tool_name>
@ -66,11 +79,11 @@ module DiscourseAi
</function_results>
TEXT
else
memo << " " << context[:content] << "\n"
row << " "
row << context[:content]
end
memo
end
.join("\n\n")
end
private

View File

@ -19,6 +19,10 @@ module DiscourseAi
DiscourseAi::Completions::Dialects::Mixtral,
]
if Rails.env.test? || Rails.env.development?
dialects << DiscourseAi::Completions::Dialects::Fake
end
dialect = dialects.find { |d| d.can_translate?(model_name) }
raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect
dialect

View File

@ -0,0 +1,19 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class Fake < Dialect
class << self
def can_translate?(model_name)
model_name == "fake"
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
end
end
end
end
end

View File

@ -10,14 +10,20 @@ module DiscourseAi
def self.endpoint_for(model_name)
# Order is important.
# Bedrock has priority over Anthropic if creadentials are present.
[
endpoints = [
DiscourseAi::Completions::Endpoints::AwsBedrock,
DiscourseAi::Completions::Endpoints::Anthropic,
DiscourseAi::Completions::Endpoints::OpenAi,
DiscourseAi::Completions::Endpoints::HuggingFace,
DiscourseAi::Completions::Endpoints::Gemini,
DiscourseAi::Completions::Endpoints::Vllm,
].detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
]
if Rails.env.test? || Rails.env.development?
endpoints << DiscourseAi::Completions::Endpoints::Fake
end
endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
ek.can_contact?(model_name)
end
end

View File

@ -0,0 +1,124 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Endpoints
class Fake < Base
def self.can_contact?(model_name)
model_name == "fake"
end
STOCK_CONTENT = <<~TEXT
# Discourse Markdown Styles Showcase
Welcome to the **Discourse Markdown Styles Showcase**! This _post_ is designed to demonstrate a wide range of Markdown capabilities available in Discourse.
## Lists and Emphasis
- **Bold Text**: To emphasize a point, you can use bold text.
- _Italic Text_: To subtly highlight text, italics are perfect.
- ~~Strikethrough~~: Sometimes, marking text as obsolete requires a strikethrough.
> **Note**: Combining these _styles_ can **_really_** make your text stand out!
1. First item
2. Second item
* Nested bullet
* Another nested bullet
3. Third item
## Links and Images
You can easily add [links](https://meta.discourse.org) to your posts. For adding images, use this syntax:
![Discourse Logo](https://meta.discourse.org/images/discourse-logo.svg)
## Code and Quotes
Inline `code` is used for mentioning small code snippets like `let x = 10;`. For larger blocks of code, fenced code blocks are used:
```javascript
function greet() {
console.log("Hello, Discourse Community!");
}
greet();
```
> Blockquotes can be very effective for highlighting user comments or important sections from cited sources. They stand out visually and offer great readability.
## Tables and Horizontal Rules
Creating tables in Markdown is straightforward:
| Header 1 | Header 2 | Header 3 |
| ---------|:--------:| --------:|
| Row 1, Col 1 | Centered | Right-aligned |
| Row 2, Col 1 | **Bold** | _Italic_ |
| Row 3, Col 1 | `Inline Code` | [Link](https://meta.discourse.org) |
To separate content sections:
---
## Final Thoughts
Congratulations, you've now seen a small sample of what Discourse's Markdown can do! For more intricate formatting, consider exploring the advanced styling options. Remember that the key to great formatting is not just the available tools, but also the **clarity** and **readability** it brings to your readers.
TEXT
def self.fake_content
@fake_content || STOCK_CONTENT
end
def self.delays
@delays ||= Array.new(10) { rand * 6 }
end
def self.delays=(delays)
@delays = delays
end
def self.chunk_count
@chunk_count ||= 10
end
def self.chunk_count=(chunk_count)
@chunk_count = chunk_count
end
def perform_completion!(dialect, user, model_params = {})
content = self.class.fake_content
if block_given?
split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort
indexes = [0, *split_indices, content.length]
original_content = content
content = +""
cancel = false
cancel_proc = -> { cancel = true }
i = 0
indexes
.each_cons(2)
.map { |start, finish| original_content[start...finish] }
.each do |chunk|
break if cancel
if self.class.delays.present? &&
(delay = self.class.delays[i % self.class.delays.length])
sleep(delay)
i += 1
end
break if cancel
content << chunk
yield(chunk, cancel_proc)
end
end
content
end
end
end
end
end

View File

@ -47,12 +47,11 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
describe "#translate" do
it "translates a prompt written in our generic format to Claude's format" do
anthropic_version = <<~TEXT
anthropic_version = (<<~TEXT).strip + " "
#{prompt[:insts]}
Human: #{prompt[:input]}
#{prompt[:post_insts]}
Assistant:
TEXT
@ -68,16 +67,16 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
"<ai>The solitary horse.,The horse etched in gold.,A horse's infinite journey.,A horse lost in time.,A horse's last ride.</ai>",
],
]
anthropic_version = <<~TEXT
anthropic_version = (<<~TEXT).strip + " "
#{prompt[:insts]}
<example>
H: #{prompt[:examples][0][0]}
A: #{prompt[:examples][0][1]}
</example>
Human: #{prompt[:input]}
#{prompt[:post_insts]}
Assistant:
TEXT
@ -89,15 +88,15 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
it "include tools inside the prompt" do
prompt[:tools] = [tool]
anthropic_version = <<~TEXT
anthropic_version = (<<~TEXT).strip + " "
#{prompt[:insts]}
#{DiscourseAi::Completions::Dialects::Claude.tool_preamble}
<tools>
#{dialect.tools}</tools>
Human: #{prompt[:input]}
#{prompt[:post_insts]}
Assistant:
TEXT
@ -105,6 +104,34 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
expect(translated).to eq(anthropic_version)
end
it "includes all the right newlines" do
prompt.clear
prompt.merge!(
{
insts: "You are an artist",
conversation_context: [
{ content: "draw another funny cat", type: "user", name: "sam" },
{ content: "ok", type: "assistant" },
{ content: "draw a funny cat", type: "user", name: "sam" },
],
},
)
expected = (<<~TEXT).strip + " "
You are an artist
Human: draw a funny cat
Assistant: ok
Human: draw another funny cat
Assistant:
TEXT
expect(dialect.translate).to eq(expected)
end
end
describe "#conversation_context" do
@ -119,7 +146,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
it "adds conversation in reverse order (first == newer)" do
prompt[:conversation_context] = context
expected = <<~TEXT
expected = (<<~TEXT).strip
Assistant:
<function_results>
<result>
@ -129,7 +156,9 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
</json>
</result>
</function_results>
Assistant: #{context.second[:content]}
Human: #{context.first[:content]}
TEXT

View File

@ -21,6 +21,31 @@ RSpec.describe DiscourseAi::Completions::Llm do
end
end
describe "#generate with fake model" do
before do
DiscourseAi::Completions::Endpoints::Fake.delays = []
DiscourseAi::Completions::Endpoints::Fake.chunk_count = 10
end
let(:llm) { described_class.proxy("fake") }
it "can generate a response" do
response = llm.generate({ input: "fake prompt" }, user: user)
expect(response).to be_present
end
it "can generate content via a block" do
partials = []
response =
llm.generate({ input: "fake prompt" }, user: user) { |partial| partials << partial }
expect(partials.length).to eq(10)
expect(response).to eq(DiscourseAi::Completions::Endpoints::Fake.fake_content)
expect(partials.join).to eq(response)
end
end
describe "#generate" do
let(:prompt) do
{

View File

@ -0,0 +1,72 @@
import { module, test } from "qunit";
import { addProgressDot } from "discourse/plugins/discourse-ai/discourse/lib/ai-streamer";
module("Discourse AI | Unit | Lib | ai-streamer", function () {
function confirmPlaceholder(html, expected, assert) {
const element = document.createElement("div");
element.innerHTML = html;
const expectedElement = document.createElement("div");
expectedElement.innerHTML = expected;
addProgressDot(element);
assert.equal(element.innerHTML, expectedElement.innerHTML);
}
test("inserts progress span in correct location for simple div", function (assert) {
const html = "<div>hello world<div>hello 2</div></div>";
const expected =
"<div>hello world<div>hello 2<span class='progress-dot'></span></div></div>";
confirmPlaceholder(html, expected, assert);
});
test("inserts progress span in correct location for lists", function (assert) {
const html = "<p>test</p><ul><li>hello world</li><li>hello world</li></ul>";
const expected =
"<p>test</p><ul><li>hello world</li><li>hello world<span class='progress-dot'></span></li></ul>";
confirmPlaceholder(html, expected, assert);
});
test("inserts correctly if list has blank html nodes", function (assert) {
const html = `<ul>
<li><strong>Bold Text</strong>: To</li>
</ul>`;
const expected = `<ul>
<li><strong>Bold Text</strong>: To<span class="progress-dot"></span></li>
</ul>`;
confirmPlaceholder(html, expected, assert);
});
test("inserts correctly for tables", function (assert) {
const html = `<table>
<tbody>
<tr>
<td>Bananas</td>
<td>20</td>
<td>$0.50</td>
</tr>
</tbody>
</table>
`;
const expected = `<table>
<tbody>
<tr>
<td>Bananas</td>
<td>20</td>
<td>$0.50<span class="progress-dot"></span></td>
</tr>
</tbody>
</table>
`;
confirmPlaceholder(html, expected, assert);
});
});