diff --git a/assets/javascripts/discourse/lib/ai-streamer.js b/assets/javascripts/discourse/lib/ai-streamer.js
new file mode 100644
index 00000000..04ba9f63
--- /dev/null
+++ b/assets/javascripts/discourse/lib/ai-streamer.js
@@ -0,0 +1,149 @@
+import { later } from "@ember/runloop";
+import loadScript from "discourse/lib/load-script";
+import { cook } from "discourse/lib/text";
+
+const PROGRESS_INTERVAL = 40;
+const GIVE_UP_INTERVAL = 10000;
+const LETTERS_PER_INTERVAL = 6;
+
+let progressTimer = null;
+
+function lastNonEmptyChild(element) {
+ let lastChild = element.lastChild;
+ while (
+ lastChild &&
+ lastChild.nodeType === Node.TEXT_NODE &&
+ !/\S/.test(lastChild.textContent)
+ ) {
+ lastChild = lastChild.previousSibling;
+ }
+ return lastChild;
+}
+
+export function addProgressDot(element) {
+ let lastBlock = element;
+
+ while (true) {
+ let lastChild = lastNonEmptyChild(lastBlock);
+ if (!lastChild) {
+ break;
+ }
+
+ if (lastChild.nodeType === Node.ELEMENT_NODE) {
+ lastBlock = lastChild;
+ } else {
+ break;
+ }
+ }
+
+ const dotElement = document.createElement("span");
+ dotElement.classList.add("progress-dot");
+ lastBlock.appendChild(dotElement);
+}
+
+async function applyProgress(postStatus, postStream) {
+ postStatus.startTime = postStatus.startTime || Date.now();
+ let post = postStream.findLoadedPost(postStatus.post_id);
+
+ const postElement = document.querySelector(`#post_${postStatus.post_number}`);
+
+ if (Date.now() - postStatus.startTime > GIVE_UP_INTERVAL) {
+ if (postElement) {
+ postElement.classList.remove("streaming");
+ }
+ return true;
+ }
+
+ if (!post) {
+ // wait till later
+ return false;
+ }
+
+ const oldRaw = post.get("raw") || "";
+
+ if (postStatus.raw === oldRaw && !postStatus.done) {
+ const hasProgressDot =
+ postElement && postElement.querySelector(".progress-dot");
+ if (hasProgressDot) {
+ return false;
+ }
+ }
+
+ if (postStatus.raw) {
+ const newRaw = postStatus.raw.substring(
+ 0,
+ oldRaw.length + LETTERS_PER_INTERVAL
+ );
+ const cooked = await cook(newRaw);
+
+ post.set("raw", newRaw);
+ post.set("cooked", cooked);
+
+ // resets animation
+ postElement.classList.remove("streaming");
+ void postElement.offsetWidth;
+ postElement.classList.add("streaming");
+
+ const cookedElement = document.createElement("div");
+ cookedElement.innerHTML = cooked;
+
+ addProgressDot(cookedElement);
+
+ const element = document.querySelector(
+ `#post_${postStatus.post_number} .cooked`
+ );
+
+ await loadScript("/javascripts/diffhtml.min.js");
+ window.diff.innerHTML(element, cookedElement.innerHTML);
+ }
+
+ if (postStatus.done) {
+ if (postElement) {
+ postElement.classList.remove("streaming");
+ }
+ }
+
+ return postStatus.done;
+}
+
+async function handleProgress(postStream) {
+ const status = postStream.aiStreamingStatus;
+
+ let keepPolling = false;
+
+ const promises = Object.keys(status).map(async (postId) => {
+ let postStatus = status[postId];
+
+ const done = await applyProgress(postStatus, postStream);
+
+ if (done) {
+ delete status[postId];
+ } else {
+ keepPolling = true;
+ }
+ });
+
+ await Promise.all(promises);
+ return keepPolling;
+}
+
+function ensureProgress(postStream) {
+ if (!progressTimer) {
+ progressTimer = later(async () => {
+ const keepPolling = await handleProgress(postStream);
+
+ progressTimer = null;
+
+ if (keepPolling) {
+ ensureProgress(postStream);
+ }
+ }, PROGRESS_INTERVAL);
+ }
+}
+
+export default function streamText(postStream, data) {
+ let status = (postStream.aiStreamingStatus =
+ postStream.aiStreamingStatus || {});
+ status[data.post_id] = data;
+ ensureProgress(postStream);
+}
diff --git a/assets/javascripts/initializers/ai-bot-replies.js b/assets/javascripts/initializers/ai-bot-replies.js
index 774fdf37..8e1f4312 100644
--- a/assets/javascripts/initializers/ai-bot-replies.js
+++ b/assets/javascripts/initializers/ai-bot-replies.js
@@ -1,19 +1,17 @@
-import { later } from "@ember/runloop";
import { hbs } from "ember-cli-htmlbars";
import { ajax } from "discourse/lib/ajax";
import { popupAjaxError } from "discourse/lib/ajax-error";
-import loadScript from "discourse/lib/load-script";
import { withPluginApi } from "discourse/lib/plugin-api";
-import { cook } from "discourse/lib/text";
import { registerWidgetShim } from "discourse/widgets/render-glimmer";
import { composeAiBotMessage } from "discourse/plugins/discourse-ai/discourse/lib/ai-bot-helper";
import ShareModal from "../discourse/components/modal/share-modal";
+import streamText from "../discourse/lib/ai-streamer";
import copyConversation from "../discourse/lib/copy-conversation";
const AUTO_COPY_THRESHOLD = 4;
function isGPTBot(user) {
- return user && [-110, -111, -112, -113, -114, -115].includes(user.id);
+ return user && [-110, -111, -112, -113, -114, -115, -116].includes(user.id);
}
function attachHeaderIcon(api) {
@@ -93,62 +91,7 @@ function initializeAIBotReplies(api) {
pluginId: "discourse-ai",
onAIBotStreamedReply: function (data) {
- const post = this.model.postStream.findLoadedPost(data.post_id);
-
- // it may take us a few seconds to load the post
- // we need to requeue the event
- if (!post && !data.done) {
- const refresh = this.onAIBotStreamedReply.bind(this);
- data.retries = data.retries || 5;
- data.retries -= 1;
- data.skipIfStreaming = true;
- if (data.retries > 0) {
- later(() => {
- refresh(data);
- }, 1000);
- }
- }
-
- if (post) {
- if (data.raw) {
- const postElement = document.querySelector(
- `#post_${data.post_number}`
- );
-
- if (
- data.skipIfStreaming &&
- postElement.classList.contains("streaming")
- ) {
- return;
- }
-
- cook(data.raw).then((cooked) => {
- post.set("raw", data.raw);
- post.set("cooked", cooked);
-
- // resets animation
- postElement.classList.remove("streaming");
- void postElement.offsetWidth;
- postElement.classList.add("streaming");
-
- const cookedElement = document.createElement("div");
- cookedElement.innerHTML = cooked;
-
- let element = document.querySelector(
- `#post_${data.post_number} .cooked`
- );
-
- loadScript("/javascripts/diffhtml.min.js").then(() => {
- window.diff.innerHTML(element, cookedElement.innerHTML);
- });
- });
- }
- if (data.done) {
- document
- .querySelector(`#post_${data.post_number}`)
- .classList.remove("streaming");
- }
- }
+ streamText(this.model.postStream, data);
},
subscribe: function () {
this._super();
diff --git a/assets/stylesheets/modules/ai-bot/common/bot-replies.scss b/assets/stylesheets/modules/ai-bot/common/bot-replies.scss
index eeefb0ea..c5be4ebf 100644
--- a/assets/stylesheets/modules/ai-bot/common/bot-replies.scss
+++ b/assets/stylesheets/modules/ai-bot/common/bot-replies.scss
@@ -57,23 +57,18 @@ article.streaming nav.post-controls .actions button.cancel-streaming {
}
}
-article.streaming .cooked > {
- :not(ol):not(ul):not(pre):last-child::after,
- ol:last-child li:last-child p:last-child::after,
- ol:last-child li:last-child:not(:has(p))::after,
- ul:last-child li:last-child p:last-child::after,
- ul:last-child li:last-child:not(:has(p))::after,
- pre:last-child code::after {
- content: "\25CF";
- font-family: Söhne Circle, system-ui, -apple-system, Segoe UI, Roboto,
- Ubuntu, Cantarell, Noto Sans, sans-serif;
- line-height: normal;
- margin-left: 0.25rem;
- vertical-align: baseline;
+article.streaming .cooked .progress-dot::after {
+ content: "\25CF";
+ font-family: Söhne Circle, system-ui, -apple-system, Segoe UI, Roboto, Ubuntu,
+ Cantarell, Noto Sans, sans-serif;
+ line-height: normal;
+ margin-left: 0.25rem;
+ vertical-align: baseline;
- animation: flashing 1.5s 3s infinite;
- display: inline-block;
- }
+ animation: flashing 1.5s 3s infinite;
+ display: inline-block;
+ font-size: 1rem;
+ color: var(--tertiary-medium);
}
.ai-bot-available-bot-options {
diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb
index 661ee8d4..30cdbf7a 100644
--- a/lib/ai_bot/bot.rb
+++ b/lib/ai_bot/bot.rb
@@ -131,6 +131,8 @@ module DiscourseAi
"mistralai/Mixtral-8x7B-Instruct-v0.1"
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
"gemini-pro"
+ when DiscourseAi::AiBot::EntryPoint::FAKE_ID
+ "fake"
else
nil
end
diff --git a/lib/ai_bot/entry_point.rb b/lib/ai_bot/entry_point.rb
index 50637240..3c9d9a75 100644
--- a/lib/ai_bot/entry_point.rb
+++ b/lib/ai_bot/entry_point.rb
@@ -11,6 +11,8 @@ module DiscourseAi
GPT4_TURBO_ID = -113
MIXTRAL_ID = -114
GEMINI_ID = -115
+ FAKE_ID = -116 # only used for dev and test
+
BOTS = [
[GPT4_ID, "gpt4_bot", "gpt-4"],
[GPT3_5_TURBO_ID, "gpt3.5_bot", "gpt-3.5-turbo"],
@@ -18,6 +20,7 @@ module DiscourseAi
[GPT4_TURBO_ID, "gpt4t_bot", "gpt-4-turbo"],
[MIXTRAL_ID, "mixtral_bot", "mixtral-8x7B-Instruct-V0.1"],
[GEMINI_ID, "gemini_bot", "gemini-pro"],
+ [FAKE_ID, "fake_bot", "fake"],
]
def self.map_bot_model_to_user_id(model_name)
@@ -34,6 +37,8 @@ module DiscourseAi
MIXTRAL_ID
in "gemini-pro"
GEMINI_ID
+ in "fake"
+ FAKE_ID
else
nil
end
diff --git a/lib/ai_bot/site_settings_extension.rb b/lib/ai_bot/site_settings_extension.rb
index d98873d4..4e8d11a2 100644
--- a/lib/ai_bot/site_settings_extension.rb
+++ b/lib/ai_bot/site_settings_extension.rb
@@ -5,6 +5,10 @@ module DiscourseAi::AiBot::SiteSettingsExtension
enabled_bots = SiteSetting.ai_bot_enabled_chat_bots_map
enabled_bots = [] if !SiteSetting.ai_bot_enabled
DiscourseAi::AiBot::EntryPoint::BOTS.each do |id, bot_name, name|
+ if id == DiscourseAi::AiBot::EntryPoint::FAKE_ID
+ next if Rails.env.production?
+ end
+
active = enabled_bots.include?(name)
user = User.find_by(id: id)
diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb
index 71997056..09b8ba85 100644
--- a/lib/completions/dialects/claude.rb
+++ b/lib/completions/dialects/claude.rb
@@ -14,6 +14,14 @@ module DiscourseAi
end
end
+ def pad_newlines!(prompt)
+ if prompt[-1..-1] != "\n"
+ prompt << "\n\n"
+ elsif prompt[-2..-1] != "\n\n"
+ prompt << "\n"
+ end
+ end
+
def translate
claude_prompt = uses_system_message? ? +"" : +"Human: "
claude_prompt << prompt[:insts] << "\n"
@@ -22,8 +30,12 @@ module DiscourseAi
claude_prompt << build_examples(prompt[:examples]) if prompt[:examples]
+ pad_newlines!(claude_prompt)
+
claude_prompt << conversation_context if prompt[:conversation_context]
+ pad_newlines!(claude_prompt)
+
if uses_system_message? && (prompt[:input] || prompt[:post_insts])
claude_prompt << "Human: "
end
@@ -31,10 +43,11 @@ module DiscourseAi
claude_prompt << "#{prompt[:post_insts]}\n" if prompt[:post_insts]
- claude_prompt << "\n\n"
- claude_prompt << "Assistant:"
+ pad_newlines!(claude_prompt)
+
+ claude_prompt << "Assistant: "
claude_prompt << " #{prompt[:final_insts]}:" if prompt[:final_insts]
- claude_prompt << "\n"
+ claude_prompt
end
def max_prompt_tokens
@@ -50,27 +63,27 @@ module DiscourseAi
trimmed_context
.reverse
- .reduce(+"") do |memo, context|
- memo << (context[:type] == "user" ? "Human:" : "Assistant:")
+ .map do |context|
+ row = context[:type] == "user" ? +"Human:" : +"Assistant:"
if context[:type] == "tool"
- memo << <<~TEXT
-
-
test
test
Bananas | +20 | +$0.50 | +
Bananas | +20 | +$0.50 | +