diff --git a/assets/javascripts/discourse/lib/ai-streamer.js b/assets/javascripts/discourse/lib/ai-streamer.js index 04ba9f63..df34d1a9 100644 --- a/assets/javascripts/discourse/lib/ai-streamer.js +++ b/assets/javascripts/discourse/lib/ai-streamer.js @@ -3,8 +3,9 @@ import loadScript from "discourse/lib/load-script"; import { cook } from "discourse/lib/text"; const PROGRESS_INTERVAL = 40; -const GIVE_UP_INTERVAL = 10000; -const LETTERS_PER_INTERVAL = 6; +const GIVE_UP_INTERVAL = 60000; +export const MIN_LETTERS_PER_INTERVAL = 6; +const MAX_FLUSH_TIME = 800; let progressTimer = null; @@ -41,69 +42,146 @@ export function addProgressDot(element) { lastBlock.appendChild(dotElement); } -async function applyProgress(postStatus, postStream) { - postStatus.startTime = postStatus.startTime || Date.now(); - let post = postStream.findLoadedPost(postStatus.post_id); +// this is the interface we need to implement +// for a streaming updater +class StreamUpdater { + set streaming(value) { + throw "not implemented"; + } - const postElement = document.querySelector(`#post_${postStatus.post_number}`); + async setCooked() { + throw "not implemented"; + } - if (Date.now() - postStatus.startTime > GIVE_UP_INTERVAL) { - if (postElement) { - postElement.classList.remove("streaming"); + async setRaw() { + throw "not implemented"; + } + + get element() { + throw "not implemented"; + } + + get raw() { + throw "not implemented"; + } +} + +class PostUpdater extends StreamUpdater { + constructor(postStream, postId) { + super(); + this.postStream = postStream; + this.postId = postId; + this.post = postStream.findLoadedPost(postId); + + if (this.post) { + this.postElement = document.querySelector( + `#post_${this.post.post_number}` + ); } + } + + get element() { + return this.postElement; + } + + set streaming(value) { + if (this.postElement) { + if (value) { + this.postElement.classList.add("streaming"); + } else { + this.postElement.classList.remove("streaming"); + } + } + } + + async setRaw(value, done) { + this.post.set("raw", value); + const cooked = await cook(value); + + // resets animation + this.element.classList.remove("streaming"); + void this.element.offsetWidth; + this.element.classList.add("streaming"); + + const cookedElement = document.createElement("div"); + cookedElement.innerHTML = cooked; + + if (!done) { + addProgressDot(cookedElement); + } + + await this.setCooked(cookedElement.innerHTML); + } + + async setCooked(value) { + this.post.set("cooked", value); + + const oldElement = this.postElement.querySelector(".cooked"); + + await loadScript("/javascripts/diffhtml.min.js"); + window.diff.innerHTML(oldElement, value); + } + + get raw() { + return this.post.get("raw") || ""; + } +} + +export async function applyProgress(status, updater) { + status.startTime = status.startTime || Date.now(); + + if (Date.now() - status.startTime > GIVE_UP_INTERVAL) { + updater.streaming = false; return true; } - if (!post) { + if (!updater.element) { // wait till later return false; } - const oldRaw = post.get("raw") || ""; + const oldRaw = updater.raw; - if (postStatus.raw === oldRaw && !postStatus.done) { - const hasProgressDot = - postElement && postElement.querySelector(".progress-dot"); + if (status.raw === oldRaw && !status.done) { + const hasProgressDot = updater.element.querySelector(".progress-dot"); if (hasProgressDot) { return false; } } - if (postStatus.raw) { - const newRaw = postStatus.raw.substring( - 0, - oldRaw.length + LETTERS_PER_INTERVAL - ); - const cooked = await cook(newRaw); + if (status.raw !== undefined) { + let newRaw = status.raw; - post.set("raw", newRaw); - post.set("cooked", cooked); + if (!status.done) { + // rush update if we have a tag (function call) + if (oldRaw.length === 0 && newRaw.indexOf("") !== -1) { + newRaw = status.raw; + } else { + const diff = newRaw.length - oldRaw.length; - // resets animation - postElement.classList.remove("streaming"); - void postElement.offsetWidth; - postElement.classList.add("streaming"); + // progress interval is 40ms + // by default we add 6 letters per interval + // but ... we want to be done in MAX_FLUSH_TIME + let letters = Math.floor(diff / (MAX_FLUSH_TIME / PROGRESS_INTERVAL)); + if (letters < MIN_LETTERS_PER_INTERVAL) { + letters = MIN_LETTERS_PER_INTERVAL; + } - const cookedElement = document.createElement("div"); - cookedElement.innerHTML = cooked; - - addProgressDot(cookedElement); - - const element = document.querySelector( - `#post_${postStatus.post_number} .cooked` - ); - - await loadScript("/javascripts/diffhtml.min.js"); - window.diff.innerHTML(element, cookedElement.innerHTML); - } - - if (postStatus.done) { - if (postElement) { - postElement.classList.remove("streaming"); + newRaw = status.raw.substring(0, oldRaw.length + letters); + } } + + await updater.setRaw(newRaw, status.done); } - return postStatus.done; + if (status.done) { + if (status.cooked) { + await updater.setCooked(status.cooked); + } + updater.streaming = false; + } + + return status.done; } async function handleProgress(postStream) { @@ -114,7 +192,8 @@ async function handleProgress(postStream) { const promises = Object.keys(status).map(async (postId) => { let postStatus = status[postId]; - const done = await applyProgress(postStatus, postStream); + const postUpdater = new PostUpdater(postStream, postStatus.post_id); + const done = await applyProgress(postStatus, postUpdater); if (done) { delete status[postId]; @@ -142,6 +221,10 @@ function ensureProgress(postStream) { } export default function streamText(postStream, data) { + if (data.noop) { + return; + } + let status = (postStream.aiStreamingStatus = postStream.aiStreamingStatus || {}); status[data.post_id] = data; diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 073902b4..8582fc7e 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -59,7 +59,7 @@ en: description: "Enable debug mode to see the raw input and output of the LLM" priority_group: label: "Priority Group" - description: "Priotize content from this group in the report" + description: "Prioritize content from this group in the report" llm_triage: fields: diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index d29cbe30..a4116439 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -112,9 +112,10 @@ module DiscourseAi topic_id: post.topic_id, raw: "", skip_validations: true, + skip_jobs: true, ) - publish_update(reply_post, raw: "

") + publish_update(reply_post, { raw: reply_post.cooked }) redis_stream_key = "gpt_cancel:#{reply_post.id}" Discourse.redis.setex(redis_stream_key, 60, 1) @@ -139,12 +140,14 @@ module DiscourseAi Discourse.redis.expire(redis_stream_key, 60) - publish_update(reply_post, raw: raw) + publish_update(reply_post, { raw: raw }) end return if reply.blank? - publish_update(reply_post, done: true) + # land the final message prior to saving so we don't clash + reply_post.cooked = PrettyText.cook(reply) + publish_final_update(reply_post) reply_post.revise(bot.bot_user, { raw: reply }, skip_validations: true, skip_revision: true) @@ -157,10 +160,25 @@ module DiscourseAi end reply_post + ensure + publish_final_update(reply_post) end private + def publish_final_update(reply_post) + return if @published_final_update + if reply_post + publish_update(reply_post, { cooked: reply_post.cooked, done: true }) + # we subscribe at position -2 so we will always get this message + # moving all cooked on every page load is wasteful ... this means + # we have a benign message at the end, 2 is set to ensure last message + # is delivered + publish_update(reply_post, { noop: true }) + @published_final_update = true + end + end + attr_reader :bot def can_attach?(post) @@ -201,10 +219,15 @@ module DiscourseAi end def publish_update(bot_reply_post, payload) + payload = { post_id: bot_reply_post.id, post_number: bot_reply_post.post_number }.merge( + payload, + ) MessageBus.publish( "discourse-ai/ai-bot/topic/#{bot_reply_post.topic_id}", - payload.merge(post_id: bot_reply_post.id, post_number: bot_reply_post.post_number), + payload, user_ids: bot_reply_post.topic.allowed_user_ids, + max_backlog_size: 2, + max_backlog_age: 60, ) end diff --git a/lib/ai_bot/tools/google.rb b/lib/ai_bot/tools/google.rb index 64930ebe..bcb0aff7 100644 --- a/lib/ai_bot/tools/google.rb +++ b/lib/ai_bot/tools/google.rb @@ -24,11 +24,11 @@ module DiscourseAi end def query - parameters[:query].to_s + parameters[:query].to_s.strip end def invoke(bot_user, llm) - yield("") # Triggers placeholder update + yield(query) api_key = SiteSetting.ai_google_custom_search_api_key cx = SiteSetting.ai_google_custom_search_cx diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index afea753f..d3d1030b 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -25,7 +25,10 @@ module DiscourseAi messages = prompt.messages # ChatGPT doesn't use an assistant msg to improve long-context responses. - messages.pop if messages.last[:type] == :model + if messages.last[:type] == :model + messages = messages.dup + messages.pop + end trimmed_messages = trim_messages(messages) diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 10b68fbc..1d968746 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -45,10 +45,10 @@ module DiscourseAi - + if a parameter type is an array, return a JSON array of values. For example: [1,"two",3.0] - + Here are the tools available: TEXT end @@ -115,36 +115,49 @@ module DiscourseAi current_token_count = 0 message_step_size = (max_prompt_tokens / 25).to_i * -1 - reversed_trimmed_msgs = - messages - .reverse - .reduce([]) do |acc, msg| - message_tokens = calculate_message_token(msg) + trimmed_messages = [] - dupped_msg = msg.dup + range = (0..-1) + if messages.dig(0, :type) == :system + system_message = messages[0] + trimmed_messages << system_message + current_token_count += calculate_message_token(system_message) + range = (1..-1) + end - # Don't trim tool call metadata. - if msg[:type] == :tool_call - current_token_count += message_tokens + per_message_overhead - acc << dupped_msg - next(acc) - end + reversed_trimmed_msgs = [] - # Trimming content to make sure we respect token limit. - while dupped_msg[:content].present? && - message_tokens + current_token_count + per_message_overhead > prompt_limit - dupped_msg[:content] = dupped_msg[:content][0..message_step_size] || "" - message_tokens = calculate_message_token(dupped_msg) - end + messages[range].reverse.each do |msg| + break if current_token_count >= prompt_limit - next(acc) if dupped_msg[:content].blank? + message_tokens = calculate_message_token(msg) - current_token_count += message_tokens + per_message_overhead + dupped_msg = msg.dup - acc << dupped_msg - end + # Don't trim tool call metadata. + if msg[:type] == :tool_call + break if current_token_count + message_tokens + per_message_overhead > prompt_limit - reversed_trimmed_msgs.reverse + current_token_count += message_tokens + per_message_overhead + reversed_trimmed_msgs << dupped_msg + next + end + + # Trimming content to make sure we respect token limit. + while dupped_msg[:content].present? && + message_tokens + current_token_count + per_message_overhead > prompt_limit + dupped_msg[:content] = dupped_msg[:content][0..message_step_size] || "" + message_tokens = calculate_message_token(dupped_msg) + end + + next if dupped_msg[:content].blank? + + current_token_count += message_tokens + per_message_overhead + + reversed_trimmed_msgs << dupped_msg + end + + trimmed_messages.concat(reversed_trimmed_msgs.reverse) end def per_message_overhead diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index a9b17f31..481275b8 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -83,6 +83,16 @@ module DiscourseAi stop_sequences: stop_sequences, } + if prompt.is_a?(String) + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a helpful bot", + messages: [{ type: :user, content: prompt }], + ) + elsif prompt.is_a?(Array) + prompt = DiscourseAi::Completions::Prompt.new(messages: prompt) + end + model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? } dialect = dialect_klass.new(prompt, model_name, opts: model_params) diff --git a/lib/completions/prompt.rb b/lib/completions/prompt.rb index 7e852400..95ef3fe8 100644 --- a/lib/completions/prompt.rb +++ b/lib/completions/prompt.rb @@ -5,16 +5,22 @@ module DiscourseAi class Prompt INVALID_TURN = Class.new(StandardError) - attr_reader :system_message, :messages + attr_reader :messages attr_accessor :tools - def initialize(system_msg, messages: [], tools: []) + def initialize(system_message_text = nil, messages: [], tools: []) raise ArgumentError, "messages must be an array" if !messages.is_a?(Array) raise ArgumentError, "tools must be an array" if !tools.is_a?(Array) - system_message = { type: :system, content: system_msg } + @messages = [] + + if system_message_text + system_message = { type: :system, content: system_message_text } + @messages << system_message + end + + @messages.concat(messages) - @messages = [system_message].concat(messages) @messages.each { |message| validate_message(message) } @messages.each_cons(2) { |last_turn, new_turn| validate_turn(last_turn, new_turn) } diff --git a/spec/lib/completions/dialects/chat_gpt_spec.rb b/spec/lib/completions/dialects/chat_gpt_spec.rb index 07acc23a..d26c6fe1 100644 --- a/spec/lib/completions/dialects/chat_gpt_spec.rb +++ b/spec/lib/completions/dialects/chat_gpt_spec.rb @@ -50,6 +50,21 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do expect(translated.last[:role]).to eq("user") expect(translated.last[:content].length).to be < context.long_message_text.length end + + it "always preserves system message when trimming" do + # gpt-4 is 8k tokens so last message totally blows everything + prompt = DiscourseAi::Completions::Prompt.new("You are a bot") + prompt.push(type: :user, content: "a " * 100) + prompt.push(type: :model, content: "b " * 100) + prompt.push(type: :user, content: "zjk " * 10_000) + + translated = context.dialect(prompt).translate + + expect(translated.length).to eq(2) + expect(translated.first).to eq(content: "You are a bot", role: "system") + expect(translated.last[:role]).to eq("user") + expect(translated.last[:content].length).to be < (8000 * 4) + end end describe "#tools" do diff --git a/spec/lib/completions/llm_spec.rb b/spec/lib/completions/llm_spec.rb index b76c6e6f..f141bfd6 100644 --- a/spec/lib/completions/llm_spec.rb +++ b/spec/lib/completions/llm_spec.rb @@ -52,6 +52,26 @@ RSpec.describe DiscourseAi::Completions::Llm do end end + describe "#generate with various style prompts" do + let :canned_response do + DiscourseAi::Completions::Endpoints::CannedResponse.new(["world"]) + end + + it "can generate a response to a simple string" do + response = llm.generate("hello", user: user) + expect(response).to eq("world") + end + + it "can generate a response from an array" do + response = + llm.generate( + [{ type: :system, content: "you are a bot" }, { type: :user, content: "hello" }], + user: user, + ) + expect(response).to eq("world") + end + end + describe "#generate" do let(:prompt) do system_insts = (<<~TEXT).strip diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index 27468801..63a3c16e 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -64,16 +64,21 @@ RSpec.describe DiscourseAi::AiBot::Playground do playground.reply_to(third_post) end + reply = pm.reload.posts.last + + noop_signal = messages.pop + expect(noop_signal.data[:noop]).to eq(true) + done_signal = messages.pop expect(done_signal.data[:done]).to eq(true) + expect(done_signal.data[:cooked]).to eq(reply.cooked) - # we need this for styling - expect(messages.first.data[:raw]).to eq("

") + expect(messages.first.data[:raw]).to eq("") messages[1..-1].each_with_index do |m, idx| expect(m.data[:raw]).to eq(expected_bot_response[0..idx]) end - expect(pm.reload.posts.last.cooked).to eq(PrettyText.cook(expected_bot_response)) + expect(reply.cooked).to eq(PrettyText.cook(expected_bot_response)) end end diff --git a/test/javascripts/unit/lib/ai-streamer-test.js b/test/javascripts/unit/lib/ai-streamer-test.js index 3870868e..eb851772 100644 --- a/test/javascripts/unit/lib/ai-streamer-test.js +++ b/test/javascripts/unit/lib/ai-streamer-test.js @@ -1,5 +1,49 @@ import { module, test } from "qunit"; -import { addProgressDot } from "discourse/plugins/discourse-ai/discourse/lib/ai-streamer"; +import { + addProgressDot, + applyProgress, + MIN_LETTERS_PER_INTERVAL, +} from "discourse/plugins/discourse-ai/discourse/lib/ai-streamer"; + +class FakeStreamUpdater { + constructor() { + this._streaming = true; + this._raw = ""; + this._cooked = ""; + this._element = document.createElement("div"); + } + + get streaming() { + return this._streaming; + } + set streaming(value) { + this._streaming = value; + } + + get cooked() { + return this._cooked; + } + + get raw() { + return this._raw; + } + + async setRaw(value) { + this._raw = value; + // just fake it, calling cook is tricky + const cooked = `

${value}

`; + await this.setCooked(cooked); + } + + async setCooked(value) { + this._cooked = value; + this._element.innerHTML = value; + } + + get element() { + return this._element; + } +} module("Discourse AI | Unit | Lib | ai-streamer", function () { function confirmPlaceholder(html, expected, assert) { @@ -69,4 +113,55 @@ module("Discourse AI | Unit | Lib | ai-streamer", function () { confirmPlaceholder(html, expected, assert); }); + + test("can perform delta updates", async function (assert) { + const status = { + startTime: Date.now(), + raw: "some raw content", + done: false, + }; + + const streamUpdater = new FakeStreamUpdater(); + + let done = await applyProgress(status, streamUpdater); + + assert.notOk(done, "The update should not be done."); + + assert.equal( + streamUpdater.raw, + status.raw.substring(0, MIN_LETTERS_PER_INTERVAL), + "The raw content should delta update." + ); + + done = await applyProgress(status, streamUpdater); + + assert.notOk(done, "The update should not be done."); + + assert.equal( + streamUpdater.raw, + status.raw.substring(0, MIN_LETTERS_PER_INTERVAL * 2), + "The raw content should delta update." + ); + + // last chunk + await applyProgress(status, streamUpdater); + + const innerHtml = streamUpdater.element.innerHTML; + assert.equal( + innerHtml, + "

some raw content

", + "The cooked content should be updated." + ); + + status.done = true; + status.cooked = "

updated cooked

"; + + await applyProgress(status, streamUpdater); + + assert.equal( + streamUpdater.element.innerHTML, + "

updated cooked

", + "The cooked content should be updated." + ); + }); });