FEATURE: even smoother streaming (#420)
Account properly for function calls, don't stream through <details> blocks - Rush cooked content back to client - Wait longer (up to 60 seconds) before giving up on streaming - Clean up message bus channels so we don't have leftover data - Make ai streamer much more reusable and much easier to read - If buffer grows quickly, rush update so you are not artificially waiting - Refine prompt interface - Fix lost system message when prompt gets long
This commit is contained in:
parent
6b8a57d957
commit
825f01cfb2
|
@ -3,8 +3,9 @@ import loadScript from "discourse/lib/load-script";
|
||||||
import { cook } from "discourse/lib/text";
|
import { cook } from "discourse/lib/text";
|
||||||
|
|
||||||
const PROGRESS_INTERVAL = 40;
|
const PROGRESS_INTERVAL = 40;
|
||||||
const GIVE_UP_INTERVAL = 10000;
|
const GIVE_UP_INTERVAL = 60000;
|
||||||
const LETTERS_PER_INTERVAL = 6;
|
export const MIN_LETTERS_PER_INTERVAL = 6;
|
||||||
|
const MAX_FLUSH_TIME = 800;
|
||||||
|
|
||||||
let progressTimer = null;
|
let progressTimer = null;
|
||||||
|
|
||||||
|
@ -41,69 +42,146 @@ export function addProgressDot(element) {
|
||||||
lastBlock.appendChild(dotElement);
|
lastBlock.appendChild(dotElement);
|
||||||
}
|
}
|
||||||
|
|
||||||
async function applyProgress(postStatus, postStream) {
|
// this is the interface we need to implement
|
||||||
postStatus.startTime = postStatus.startTime || Date.now();
|
// for a streaming updater
|
||||||
let post = postStream.findLoadedPost(postStatus.post_id);
|
class StreamUpdater {
|
||||||
|
set streaming(value) {
|
||||||
|
throw "not implemented";
|
||||||
|
}
|
||||||
|
|
||||||
const postElement = document.querySelector(`#post_${postStatus.post_number}`);
|
async setCooked() {
|
||||||
|
throw "not implemented";
|
||||||
|
}
|
||||||
|
|
||||||
if (Date.now() - postStatus.startTime > GIVE_UP_INTERVAL) {
|
async setRaw() {
|
||||||
if (postElement) {
|
throw "not implemented";
|
||||||
postElement.classList.remove("streaming");
|
}
|
||||||
|
|
||||||
|
get element() {
|
||||||
|
throw "not implemented";
|
||||||
|
}
|
||||||
|
|
||||||
|
get raw() {
|
||||||
|
throw "not implemented";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class PostUpdater extends StreamUpdater {
|
||||||
|
constructor(postStream, postId) {
|
||||||
|
super();
|
||||||
|
this.postStream = postStream;
|
||||||
|
this.postId = postId;
|
||||||
|
this.post = postStream.findLoadedPost(postId);
|
||||||
|
|
||||||
|
if (this.post) {
|
||||||
|
this.postElement = document.querySelector(
|
||||||
|
`#post_${this.post.post_number}`
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
get element() {
|
||||||
|
return this.postElement;
|
||||||
|
}
|
||||||
|
|
||||||
|
set streaming(value) {
|
||||||
|
if (this.postElement) {
|
||||||
|
if (value) {
|
||||||
|
this.postElement.classList.add("streaming");
|
||||||
|
} else {
|
||||||
|
this.postElement.classList.remove("streaming");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async setRaw(value, done) {
|
||||||
|
this.post.set("raw", value);
|
||||||
|
const cooked = await cook(value);
|
||||||
|
|
||||||
|
// resets animation
|
||||||
|
this.element.classList.remove("streaming");
|
||||||
|
void this.element.offsetWidth;
|
||||||
|
this.element.classList.add("streaming");
|
||||||
|
|
||||||
|
const cookedElement = document.createElement("div");
|
||||||
|
cookedElement.innerHTML = cooked;
|
||||||
|
|
||||||
|
if (!done) {
|
||||||
|
addProgressDot(cookedElement);
|
||||||
|
}
|
||||||
|
|
||||||
|
await this.setCooked(cookedElement.innerHTML);
|
||||||
|
}
|
||||||
|
|
||||||
|
async setCooked(value) {
|
||||||
|
this.post.set("cooked", value);
|
||||||
|
|
||||||
|
const oldElement = this.postElement.querySelector(".cooked");
|
||||||
|
|
||||||
|
await loadScript("/javascripts/diffhtml.min.js");
|
||||||
|
window.diff.innerHTML(oldElement, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
get raw() {
|
||||||
|
return this.post.get("raw") || "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function applyProgress(status, updater) {
|
||||||
|
status.startTime = status.startTime || Date.now();
|
||||||
|
|
||||||
|
if (Date.now() - status.startTime > GIVE_UP_INTERVAL) {
|
||||||
|
updater.streaming = false;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!post) {
|
if (!updater.element) {
|
||||||
// wait till later
|
// wait till later
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const oldRaw = post.get("raw") || "";
|
const oldRaw = updater.raw;
|
||||||
|
|
||||||
if (postStatus.raw === oldRaw && !postStatus.done) {
|
if (status.raw === oldRaw && !status.done) {
|
||||||
const hasProgressDot =
|
const hasProgressDot = updater.element.querySelector(".progress-dot");
|
||||||
postElement && postElement.querySelector(".progress-dot");
|
|
||||||
if (hasProgressDot) {
|
if (hasProgressDot) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (postStatus.raw) {
|
if (status.raw !== undefined) {
|
||||||
const newRaw = postStatus.raw.substring(
|
let newRaw = status.raw;
|
||||||
0,
|
|
||||||
oldRaw.length + LETTERS_PER_INTERVAL
|
|
||||||
);
|
|
||||||
const cooked = await cook(newRaw);
|
|
||||||
|
|
||||||
post.set("raw", newRaw);
|
if (!status.done) {
|
||||||
post.set("cooked", cooked);
|
// rush update if we have a </details> tag (function call)
|
||||||
|
if (oldRaw.length === 0 && newRaw.indexOf("</details>") !== -1) {
|
||||||
|
newRaw = status.raw;
|
||||||
|
} else {
|
||||||
|
const diff = newRaw.length - oldRaw.length;
|
||||||
|
|
||||||
// resets animation
|
// progress interval is 40ms
|
||||||
postElement.classList.remove("streaming");
|
// by default we add 6 letters per interval
|
||||||
void postElement.offsetWidth;
|
// but ... we want to be done in MAX_FLUSH_TIME
|
||||||
postElement.classList.add("streaming");
|
let letters = Math.floor(diff / (MAX_FLUSH_TIME / PROGRESS_INTERVAL));
|
||||||
|
if (letters < MIN_LETTERS_PER_INTERVAL) {
|
||||||
|
letters = MIN_LETTERS_PER_INTERVAL;
|
||||||
|
}
|
||||||
|
|
||||||
const cookedElement = document.createElement("div");
|
newRaw = status.raw.substring(0, oldRaw.length + letters);
|
||||||
cookedElement.innerHTML = cooked;
|
}
|
||||||
|
|
||||||
addProgressDot(cookedElement);
|
|
||||||
|
|
||||||
const element = document.querySelector(
|
|
||||||
`#post_${postStatus.post_number} .cooked`
|
|
||||||
);
|
|
||||||
|
|
||||||
await loadScript("/javascripts/diffhtml.min.js");
|
|
||||||
window.diff.innerHTML(element, cookedElement.innerHTML);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (postStatus.done) {
|
|
||||||
if (postElement) {
|
|
||||||
postElement.classList.remove("streaming");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await updater.setRaw(newRaw, status.done);
|
||||||
}
|
}
|
||||||
|
|
||||||
return postStatus.done;
|
if (status.done) {
|
||||||
|
if (status.cooked) {
|
||||||
|
await updater.setCooked(status.cooked);
|
||||||
|
}
|
||||||
|
updater.streaming = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return status.done;
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleProgress(postStream) {
|
async function handleProgress(postStream) {
|
||||||
|
@ -114,7 +192,8 @@ async function handleProgress(postStream) {
|
||||||
const promises = Object.keys(status).map(async (postId) => {
|
const promises = Object.keys(status).map(async (postId) => {
|
||||||
let postStatus = status[postId];
|
let postStatus = status[postId];
|
||||||
|
|
||||||
const done = await applyProgress(postStatus, postStream);
|
const postUpdater = new PostUpdater(postStream, postStatus.post_id);
|
||||||
|
const done = await applyProgress(postStatus, postUpdater);
|
||||||
|
|
||||||
if (done) {
|
if (done) {
|
||||||
delete status[postId];
|
delete status[postId];
|
||||||
|
@ -142,6 +221,10 @@ function ensureProgress(postStream) {
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function streamText(postStream, data) {
|
export default function streamText(postStream, data) {
|
||||||
|
if (data.noop) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
let status = (postStream.aiStreamingStatus =
|
let status = (postStream.aiStreamingStatus =
|
||||||
postStream.aiStreamingStatus || {});
|
postStream.aiStreamingStatus || {});
|
||||||
status[data.post_id] = data;
|
status[data.post_id] = data;
|
||||||
|
|
|
@ -59,7 +59,7 @@ en:
|
||||||
description: "Enable debug mode to see the raw input and output of the LLM"
|
description: "Enable debug mode to see the raw input and output of the LLM"
|
||||||
priority_group:
|
priority_group:
|
||||||
label: "Priority Group"
|
label: "Priority Group"
|
||||||
description: "Priotize content from this group in the report"
|
description: "Prioritize content from this group in the report"
|
||||||
|
|
||||||
llm_triage:
|
llm_triage:
|
||||||
fields:
|
fields:
|
||||||
|
|
|
@ -112,9 +112,10 @@ module DiscourseAi
|
||||||
topic_id: post.topic_id,
|
topic_id: post.topic_id,
|
||||||
raw: "",
|
raw: "",
|
||||||
skip_validations: true,
|
skip_validations: true,
|
||||||
|
skip_jobs: true,
|
||||||
)
|
)
|
||||||
|
|
||||||
publish_update(reply_post, raw: "<p></p>")
|
publish_update(reply_post, { raw: reply_post.cooked })
|
||||||
|
|
||||||
redis_stream_key = "gpt_cancel:#{reply_post.id}"
|
redis_stream_key = "gpt_cancel:#{reply_post.id}"
|
||||||
Discourse.redis.setex(redis_stream_key, 60, 1)
|
Discourse.redis.setex(redis_stream_key, 60, 1)
|
||||||
|
@ -139,12 +140,14 @@ module DiscourseAi
|
||||||
|
|
||||||
Discourse.redis.expire(redis_stream_key, 60)
|
Discourse.redis.expire(redis_stream_key, 60)
|
||||||
|
|
||||||
publish_update(reply_post, raw: raw)
|
publish_update(reply_post, { raw: raw })
|
||||||
end
|
end
|
||||||
|
|
||||||
return if reply.blank?
|
return if reply.blank?
|
||||||
|
|
||||||
publish_update(reply_post, done: true)
|
# land the final message prior to saving so we don't clash
|
||||||
|
reply_post.cooked = PrettyText.cook(reply)
|
||||||
|
publish_final_update(reply_post)
|
||||||
|
|
||||||
reply_post.revise(bot.bot_user, { raw: reply }, skip_validations: true, skip_revision: true)
|
reply_post.revise(bot.bot_user, { raw: reply }, skip_validations: true, skip_revision: true)
|
||||||
|
|
||||||
|
@ -157,10 +160,25 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
reply_post
|
reply_post
|
||||||
|
ensure
|
||||||
|
publish_final_update(reply_post)
|
||||||
end
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
|
def publish_final_update(reply_post)
|
||||||
|
return if @published_final_update
|
||||||
|
if reply_post
|
||||||
|
publish_update(reply_post, { cooked: reply_post.cooked, done: true })
|
||||||
|
# we subscribe at position -2 so we will always get this message
|
||||||
|
# moving all cooked on every page load is wasteful ... this means
|
||||||
|
# we have a benign message at the end, 2 is set to ensure last message
|
||||||
|
# is delivered
|
||||||
|
publish_update(reply_post, { noop: true })
|
||||||
|
@published_final_update = true
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
attr_reader :bot
|
attr_reader :bot
|
||||||
|
|
||||||
def can_attach?(post)
|
def can_attach?(post)
|
||||||
|
@ -201,10 +219,15 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def publish_update(bot_reply_post, payload)
|
def publish_update(bot_reply_post, payload)
|
||||||
|
payload = { post_id: bot_reply_post.id, post_number: bot_reply_post.post_number }.merge(
|
||||||
|
payload,
|
||||||
|
)
|
||||||
MessageBus.publish(
|
MessageBus.publish(
|
||||||
"discourse-ai/ai-bot/topic/#{bot_reply_post.topic_id}",
|
"discourse-ai/ai-bot/topic/#{bot_reply_post.topic_id}",
|
||||||
payload.merge(post_id: bot_reply_post.id, post_number: bot_reply_post.post_number),
|
payload,
|
||||||
user_ids: bot_reply_post.topic.allowed_user_ids,
|
user_ids: bot_reply_post.topic.allowed_user_ids,
|
||||||
|
max_backlog_size: 2,
|
||||||
|
max_backlog_age: 60,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -24,11 +24,11 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def query
|
def query
|
||||||
parameters[:query].to_s
|
parameters[:query].to_s.strip
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke(bot_user, llm)
|
def invoke(bot_user, llm)
|
||||||
yield("") # Triggers placeholder update
|
yield(query)
|
||||||
|
|
||||||
api_key = SiteSetting.ai_google_custom_search_api_key
|
api_key = SiteSetting.ai_google_custom_search_api_key
|
||||||
cx = SiteSetting.ai_google_custom_search_cx
|
cx = SiteSetting.ai_google_custom_search_cx
|
||||||
|
|
|
@ -25,7 +25,10 @@ module DiscourseAi
|
||||||
messages = prompt.messages
|
messages = prompt.messages
|
||||||
|
|
||||||
# ChatGPT doesn't use an assistant msg to improve long-context responses.
|
# ChatGPT doesn't use an assistant msg to improve long-context responses.
|
||||||
messages.pop if messages.last[:type] == :model
|
if messages.last[:type] == :model
|
||||||
|
messages = messages.dup
|
||||||
|
messages.pop
|
||||||
|
end
|
||||||
|
|
||||||
trimmed_messages = trim_messages(messages)
|
trimmed_messages = trim_messages(messages)
|
||||||
|
|
||||||
|
|
|
@ -45,10 +45,10 @@ module DiscourseAi
|
||||||
</parameters>
|
</parameters>
|
||||||
</invoke>
|
</invoke>
|
||||||
</function_calls>
|
</function_calls>
|
||||||
|
|
||||||
if a parameter type is an array, return a JSON array of values. For example:
|
if a parameter type is an array, return a JSON array of values. For example:
|
||||||
[1,"two",3.0]
|
[1,"two",3.0]
|
||||||
|
|
||||||
Here are the tools available:
|
Here are the tools available:
|
||||||
TEXT
|
TEXT
|
||||||
end
|
end
|
||||||
|
@ -115,36 +115,49 @@ module DiscourseAi
|
||||||
current_token_count = 0
|
current_token_count = 0
|
||||||
message_step_size = (max_prompt_tokens / 25).to_i * -1
|
message_step_size = (max_prompt_tokens / 25).to_i * -1
|
||||||
|
|
||||||
reversed_trimmed_msgs =
|
trimmed_messages = []
|
||||||
messages
|
|
||||||
.reverse
|
|
||||||
.reduce([]) do |acc, msg|
|
|
||||||
message_tokens = calculate_message_token(msg)
|
|
||||||
|
|
||||||
dupped_msg = msg.dup
|
range = (0..-1)
|
||||||
|
if messages.dig(0, :type) == :system
|
||||||
|
system_message = messages[0]
|
||||||
|
trimmed_messages << system_message
|
||||||
|
current_token_count += calculate_message_token(system_message)
|
||||||
|
range = (1..-1)
|
||||||
|
end
|
||||||
|
|
||||||
# Don't trim tool call metadata.
|
reversed_trimmed_msgs = []
|
||||||
if msg[:type] == :tool_call
|
|
||||||
current_token_count += message_tokens + per_message_overhead
|
|
||||||
acc << dupped_msg
|
|
||||||
next(acc)
|
|
||||||
end
|
|
||||||
|
|
||||||
# Trimming content to make sure we respect token limit.
|
messages[range].reverse.each do |msg|
|
||||||
while dupped_msg[:content].present? &&
|
break if current_token_count >= prompt_limit
|
||||||
message_tokens + current_token_count + per_message_overhead > prompt_limit
|
|
||||||
dupped_msg[:content] = dupped_msg[:content][0..message_step_size] || ""
|
|
||||||
message_tokens = calculate_message_token(dupped_msg)
|
|
||||||
end
|
|
||||||
|
|
||||||
next(acc) if dupped_msg[:content].blank?
|
message_tokens = calculate_message_token(msg)
|
||||||
|
|
||||||
current_token_count += message_tokens + per_message_overhead
|
dupped_msg = msg.dup
|
||||||
|
|
||||||
acc << dupped_msg
|
# Don't trim tool call metadata.
|
||||||
end
|
if msg[:type] == :tool_call
|
||||||
|
break if current_token_count + message_tokens + per_message_overhead > prompt_limit
|
||||||
|
|
||||||
reversed_trimmed_msgs.reverse
|
current_token_count += message_tokens + per_message_overhead
|
||||||
|
reversed_trimmed_msgs << dupped_msg
|
||||||
|
next
|
||||||
|
end
|
||||||
|
|
||||||
|
# Trimming content to make sure we respect token limit.
|
||||||
|
while dupped_msg[:content].present? &&
|
||||||
|
message_tokens + current_token_count + per_message_overhead > prompt_limit
|
||||||
|
dupped_msg[:content] = dupped_msg[:content][0..message_step_size] || ""
|
||||||
|
message_tokens = calculate_message_token(dupped_msg)
|
||||||
|
end
|
||||||
|
|
||||||
|
next if dupped_msg[:content].blank?
|
||||||
|
|
||||||
|
current_token_count += message_tokens + per_message_overhead
|
||||||
|
|
||||||
|
reversed_trimmed_msgs << dupped_msg
|
||||||
|
end
|
||||||
|
|
||||||
|
trimmed_messages.concat(reversed_trimmed_msgs.reverse)
|
||||||
end
|
end
|
||||||
|
|
||||||
def per_message_overhead
|
def per_message_overhead
|
||||||
|
|
|
@ -83,6 +83,16 @@ module DiscourseAi
|
||||||
stop_sequences: stop_sequences,
|
stop_sequences: stop_sequences,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if prompt.is_a?(String)
|
||||||
|
prompt =
|
||||||
|
DiscourseAi::Completions::Prompt.new(
|
||||||
|
"You are a helpful bot",
|
||||||
|
messages: [{ type: :user, content: prompt }],
|
||||||
|
)
|
||||||
|
elsif prompt.is_a?(Array)
|
||||||
|
prompt = DiscourseAi::Completions::Prompt.new(messages: prompt)
|
||||||
|
end
|
||||||
|
|
||||||
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
|
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
|
||||||
|
|
||||||
dialect = dialect_klass.new(prompt, model_name, opts: model_params)
|
dialect = dialect_klass.new(prompt, model_name, opts: model_params)
|
||||||
|
|
|
@ -5,16 +5,22 @@ module DiscourseAi
|
||||||
class Prompt
|
class Prompt
|
||||||
INVALID_TURN = Class.new(StandardError)
|
INVALID_TURN = Class.new(StandardError)
|
||||||
|
|
||||||
attr_reader :system_message, :messages
|
attr_reader :messages
|
||||||
attr_accessor :tools
|
attr_accessor :tools
|
||||||
|
|
||||||
def initialize(system_msg, messages: [], tools: [])
|
def initialize(system_message_text = nil, messages: [], tools: [])
|
||||||
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
|
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
|
||||||
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
|
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
|
||||||
|
|
||||||
system_message = { type: :system, content: system_msg }
|
@messages = []
|
||||||
|
|
||||||
|
if system_message_text
|
||||||
|
system_message = { type: :system, content: system_message_text }
|
||||||
|
@messages << system_message
|
||||||
|
end
|
||||||
|
|
||||||
|
@messages.concat(messages)
|
||||||
|
|
||||||
@messages = [system_message].concat(messages)
|
|
||||||
@messages.each { |message| validate_message(message) }
|
@messages.each { |message| validate_message(message) }
|
||||||
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }
|
@messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) }
|
||||||
|
|
||||||
|
|
|
@ -50,6 +50,21 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
|
||||||
expect(translated.last[:role]).to eq("user")
|
expect(translated.last[:role]).to eq("user")
|
||||||
expect(translated.last[:content].length).to be < context.long_message_text.length
|
expect(translated.last[:content].length).to be < context.long_message_text.length
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "always preserves system message when trimming" do
|
||||||
|
# gpt-4 is 8k tokens so last message totally blows everything
|
||||||
|
prompt = DiscourseAi::Completions::Prompt.new("You are a bot")
|
||||||
|
prompt.push(type: :user, content: "a " * 100)
|
||||||
|
prompt.push(type: :model, content: "b " * 100)
|
||||||
|
prompt.push(type: :user, content: "zjk " * 10_000)
|
||||||
|
|
||||||
|
translated = context.dialect(prompt).translate
|
||||||
|
|
||||||
|
expect(translated.length).to eq(2)
|
||||||
|
expect(translated.first).to eq(content: "You are a bot", role: "system")
|
||||||
|
expect(translated.last[:role]).to eq("user")
|
||||||
|
expect(translated.last[:content].length).to be < (8000 * 4)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "#tools" do
|
describe "#tools" do
|
||||||
|
|
|
@ -52,6 +52,26 @@ RSpec.describe DiscourseAi::Completions::Llm do
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "#generate with various style prompts" do
|
||||||
|
let :canned_response do
|
||||||
|
DiscourseAi::Completions::Endpoints::CannedResponse.new(["world"])
|
||||||
|
end
|
||||||
|
|
||||||
|
it "can generate a response to a simple string" do
|
||||||
|
response = llm.generate("hello", user: user)
|
||||||
|
expect(response).to eq("world")
|
||||||
|
end
|
||||||
|
|
||||||
|
it "can generate a response from an array" do
|
||||||
|
response =
|
||||||
|
llm.generate(
|
||||||
|
[{ type: :system, content: "you are a bot" }, { type: :user, content: "hello" }],
|
||||||
|
user: user,
|
||||||
|
)
|
||||||
|
expect(response).to eq("world")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
describe "#generate" do
|
describe "#generate" do
|
||||||
let(:prompt) do
|
let(:prompt) do
|
||||||
system_insts = (<<~TEXT).strip
|
system_insts = (<<~TEXT).strip
|
||||||
|
|
|
@ -64,16 +64,21 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
playground.reply_to(third_post)
|
playground.reply_to(third_post)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
reply = pm.reload.posts.last
|
||||||
|
|
||||||
|
noop_signal = messages.pop
|
||||||
|
expect(noop_signal.data[:noop]).to eq(true)
|
||||||
|
|
||||||
done_signal = messages.pop
|
done_signal = messages.pop
|
||||||
expect(done_signal.data[:done]).to eq(true)
|
expect(done_signal.data[:done]).to eq(true)
|
||||||
|
expect(done_signal.data[:cooked]).to eq(reply.cooked)
|
||||||
|
|
||||||
# we need this for styling
|
expect(messages.first.data[:raw]).to eq("")
|
||||||
expect(messages.first.data[:raw]).to eq("<p></p>")
|
|
||||||
messages[1..-1].each_with_index do |m, idx|
|
messages[1..-1].each_with_index do |m, idx|
|
||||||
expect(m.data[:raw]).to eq(expected_bot_response[0..idx])
|
expect(m.data[:raw]).to eq(expected_bot_response[0..idx])
|
||||||
end
|
end
|
||||||
|
|
||||||
expect(pm.reload.posts.last.cooked).to eq(PrettyText.cook(expected_bot_response))
|
expect(reply.cooked).to eq(PrettyText.cook(expected_bot_response))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,49 @@
|
||||||
import { module, test } from "qunit";
|
import { module, test } from "qunit";
|
||||||
import { addProgressDot } from "discourse/plugins/discourse-ai/discourse/lib/ai-streamer";
|
import {
|
||||||
|
addProgressDot,
|
||||||
|
applyProgress,
|
||||||
|
MIN_LETTERS_PER_INTERVAL,
|
||||||
|
} from "discourse/plugins/discourse-ai/discourse/lib/ai-streamer";
|
||||||
|
|
||||||
|
class FakeStreamUpdater {
|
||||||
|
constructor() {
|
||||||
|
this._streaming = true;
|
||||||
|
this._raw = "";
|
||||||
|
this._cooked = "";
|
||||||
|
this._element = document.createElement("div");
|
||||||
|
}
|
||||||
|
|
||||||
|
get streaming() {
|
||||||
|
return this._streaming;
|
||||||
|
}
|
||||||
|
set streaming(value) {
|
||||||
|
this._streaming = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
get cooked() {
|
||||||
|
return this._cooked;
|
||||||
|
}
|
||||||
|
|
||||||
|
get raw() {
|
||||||
|
return this._raw;
|
||||||
|
}
|
||||||
|
|
||||||
|
async setRaw(value) {
|
||||||
|
this._raw = value;
|
||||||
|
// just fake it, calling cook is tricky
|
||||||
|
const cooked = `<p>${value}</p>`;
|
||||||
|
await this.setCooked(cooked);
|
||||||
|
}
|
||||||
|
|
||||||
|
async setCooked(value) {
|
||||||
|
this._cooked = value;
|
||||||
|
this._element.innerHTML = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
get element() {
|
||||||
|
return this._element;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
module("Discourse AI | Unit | Lib | ai-streamer", function () {
|
module("Discourse AI | Unit | Lib | ai-streamer", function () {
|
||||||
function confirmPlaceholder(html, expected, assert) {
|
function confirmPlaceholder(html, expected, assert) {
|
||||||
|
@ -69,4 +113,55 @@ module("Discourse AI | Unit | Lib | ai-streamer", function () {
|
||||||
|
|
||||||
confirmPlaceholder(html, expected, assert);
|
confirmPlaceholder(html, expected, assert);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test("can perform delta updates", async function (assert) {
|
||||||
|
const status = {
|
||||||
|
startTime: Date.now(),
|
||||||
|
raw: "some raw content",
|
||||||
|
done: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
const streamUpdater = new FakeStreamUpdater();
|
||||||
|
|
||||||
|
let done = await applyProgress(status, streamUpdater);
|
||||||
|
|
||||||
|
assert.notOk(done, "The update should not be done.");
|
||||||
|
|
||||||
|
assert.equal(
|
||||||
|
streamUpdater.raw,
|
||||||
|
status.raw.substring(0, MIN_LETTERS_PER_INTERVAL),
|
||||||
|
"The raw content should delta update."
|
||||||
|
);
|
||||||
|
|
||||||
|
done = await applyProgress(status, streamUpdater);
|
||||||
|
|
||||||
|
assert.notOk(done, "The update should not be done.");
|
||||||
|
|
||||||
|
assert.equal(
|
||||||
|
streamUpdater.raw,
|
||||||
|
status.raw.substring(0, MIN_LETTERS_PER_INTERVAL * 2),
|
||||||
|
"The raw content should delta update."
|
||||||
|
);
|
||||||
|
|
||||||
|
// last chunk
|
||||||
|
await applyProgress(status, streamUpdater);
|
||||||
|
|
||||||
|
const innerHtml = streamUpdater.element.innerHTML;
|
||||||
|
assert.equal(
|
||||||
|
innerHtml,
|
||||||
|
"<p>some raw content</p>",
|
||||||
|
"The cooked content should be updated."
|
||||||
|
);
|
||||||
|
|
||||||
|
status.done = true;
|
||||||
|
status.cooked = "<p>updated cooked</p>";
|
||||||
|
|
||||||
|
await applyProgress(status, streamUpdater);
|
||||||
|
|
||||||
|
assert.equal(
|
||||||
|
streamUpdater.element.innerHTML,
|
||||||
|
"<p>updated cooked</p>",
|
||||||
|
"The cooked content should be updated."
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
Loading…
Reference in New Issue