From 71b105a1bb58c48a45a34ab64d5284ab7a498b29 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Fri, 5 May 2023 15:28:31 -0300 Subject: [PATCH] FEATURE: Introduce the ai-bot module (#52) This module lets you chat with our GPT bot inside a PM. The bot only replies to members of the groups listed on the ai_bot_allowed_groups setting and only if you invite it to participate in the PM. --- .../discourse_ai/ai_bot/bot_controller.rb | 19 +++ app/models/completion_prompt.rb | 32 +++++ .../initializers/ai-bot-replies.js | 104 ++++++++++++++ .../modules/ai-bot/common/bot-replies.scss | 7 + config/locales/client.en.yml | 5 + config/locales/server.en.yml | 4 + config/routes.rb | 4 + config/settings.yml | 9 ++ db/fixtures/ai_bot/602_bot_users.rb | 21 +++ .../600_openai_completion_prompts.rb | 0 .../601_anthropic_completion_prompts.rb | 0 lib/modules/ai_bot/entry_point.rb | 37 +++++ .../ai_bot/jobs/regular/create_ai_reply.rb | 76 ++++++++++ lib/modules/ai_helper/entry_point.rb | 2 +- lib/shared/inference/openai_completions.rb | 136 ++++++++---------- plugin.rb | 3 + spec/lib/modules/ai_bot/entry_point_spec.rb | 73 ++++++++++ .../jobs/regular/create_ai_reply_spec.rb | 53 +++++++ spec/models/completion_prompt_spec.rb | 42 ++++++ spec/requests/ai_bot/bot_controller_spec.rb | 29 ++++ .../inference/openai_completions_spec.rb | 109 +++++--------- .../openai_completions_inference_stubs.rb | 32 ++++- 22 files changed, 644 insertions(+), 153 deletions(-) create mode 100644 app/controllers/discourse_ai/ai_bot/bot_controller.rb create mode 100644 assets/javascripts/initializers/ai-bot-replies.js create mode 100644 assets/stylesheets/modules/ai-bot/common/bot-replies.scss create mode 100644 db/fixtures/ai_bot/602_bot_users.rb rename db/fixtures/{ai-helper => ai_helper}/600_openai_completion_prompts.rb (100%) rename db/fixtures/{ai-helper => ai_helper}/601_anthropic_completion_prompts.rb (100%) create mode 100644 lib/modules/ai_bot/entry_point.rb create mode 100644 lib/modules/ai_bot/jobs/regular/create_ai_reply.rb create mode 100644 spec/lib/modules/ai_bot/entry_point_spec.rb create mode 100644 spec/lib/modules/ai_bot/jobs/regular/create_ai_reply_spec.rb create mode 100644 spec/requests/ai_bot/bot_controller_spec.rb diff --git a/app/controllers/discourse_ai/ai_bot/bot_controller.rb b/app/controllers/discourse_ai/ai_bot/bot_controller.rb new file mode 100644 index 00000000..d80df452 --- /dev/null +++ b/app/controllers/discourse_ai/ai_bot/bot_controller.rb @@ -0,0 +1,19 @@ +# frozen_string_literal: true + +module DiscourseAi + module AiBot + class BotController < ::ApplicationController + requires_plugin ::DiscourseAi::PLUGIN_NAME + requires_login + + def stop_streaming_response + post = Post.find(params[:post_id]) + guardian.ensure_can_see!(post) + + Discourse.redis.del("gpt_cancel:#{post.id}") + + render json: {}, status: 200 + end + end + end +end diff --git a/app/models/completion_prompt.rb b/app/models/completion_prompt.rb index 167f9cbc..c8227359 100644 --- a/app/models/completion_prompt.rb +++ b/app/models/completion_prompt.rb @@ -4,11 +4,43 @@ class CompletionPrompt < ActiveRecord::Base # TODO(roman): Remove sept 2023. self.ignored_columns = ["value"] + MAX_PROMPT_LENGTH = 3000 + enum :prompt_type, { text: 0, list: 1, diff: 2 } validates :messages, length: { maximum: 20 } validate :each_message_length + def self.bot_prompt_with_topic_context(post) + messages = [] + conversation = + post + .topic + .posts + .includes(:user) + .where("post_number <= ?", post.post_number) + .order("post_number desc") + .pluck(:raw, :username) + + total_prompt_length = 0 + messages = + conversation.reduce([]) do |memo, (raw, username)| + total_prompt_length += raw.length + break(memo) if total_prompt_length > MAX_PROMPT_LENGTH + role = username == Discourse.gpt_bot.username ? "system" : "user" + + memo.unshift({ role: role, content: raw }) + end + + messages.unshift({ role: "system", content: <<~TEXT }) + You are gpt-bot. You answer questions and generate text. + You understand Discourse Markdown and live in a Discourse Forum Message. + You are provided you with context of previous discussions. + TEXT + + messages + end + def messages_with_user_input(user_input) if ::DiscourseAi::AiHelper::LlmPrompt.new.enabled_provider == "openai" self.messages << { role: "user", content: user_input } diff --git a/assets/javascripts/initializers/ai-bot-replies.js b/assets/javascripts/initializers/ai-bot-replies.js new file mode 100644 index 00000000..fdd4e9b1 --- /dev/null +++ b/assets/javascripts/initializers/ai-bot-replies.js @@ -0,0 +1,104 @@ +import { withPluginApi } from "discourse/lib/plugin-api"; +import { cookAsync } from "discourse/lib/text"; +import { ajax } from "discourse/lib/ajax"; +import { popupAjaxError } from "discourse/lib/ajax-error"; +import loadScript from "discourse/lib/load-script"; + +function isGPTBot(user) { + return user && user.id === -110; +} + +function initializeAIBotReplies(api) { + api.addPostMenuButton("cancel-gpt", (post) => { + if (isGPTBot(post.user)) { + return { + icon: "pause", + action: "cancelStreaming", + title: "discourse_ai.ai_bot.cancel_streaming", + className: "btn btn-default cancel-streaming", + position: "first", + }; + } + }); + + api.attachWidgetAction("post", "cancelStreaming", function () { + ajax(`/discourse-ai/ai-bot/post/${this.model.id}/stop-streaming`, { + type: "POST", + }) + .then(() => { + document + .querySelector(`#post_${this.model.post_number}`) + .classList.remove("streaming"); + }) + .catch(popupAjaxError); + }); + + api.modifyClass("controller:topic", { + pluginId: "discourse-ai", + + onAIBotStreamedReply: function (data) { + const post = this.model.postStream.findLoadedPost(data.post_id); + + if (post) { + if (data.raw) { + cookAsync(data.raw).then((cooked) => { + post.set("raw", data.raw); + post.set("cooked", cooked); + + document + .querySelector(`#post_${data.post_number}`) + .classList.add("streaming"); + + const cookedElement = document.createElement("div"); + cookedElement.innerHTML = cooked; + + let element = document.querySelector( + `#post_${data.post_number} .cooked` + ); + + loadScript("/javascripts/diffhtml.min.js").then(() => { + window.diff.innerHTML(element, cookedElement.innerHTML); + }); + }); + } + if (data.done) { + document + .querySelector(`#post_${data.post_number}`) + .classList.remove("streaming"); + } + } + }, + subscribe: function () { + this._super(); + + if ( + this.model.isPrivateMessage && + this.model.details.allowed_users && + this.model.details.allowed_users.filter(isGPTBot).length === 1 + ) { + this.messageBus.subscribe( + `discourse-ai/ai-bot/topic/${this.model.id}`, + this.onAIBotStreamedReply.bind(this) + ); + } + }, + unsubscribe: function () { + this.messageBus.unsubscribe("discourse-ai/ai-bot/topic/"); + this._super(); + }, + }); +} + +export default { + name: "discourse-ai-bot-replies", + + initialize(container) { + const settings = container.lookup("service:site-settings"); + const aiBotEnaled = + settings.discourse_ai_enabled && settings.ai_bot_enabled; + + if (aiBotEnaled) { + withPluginApi("1.6.0", initializeAIBotReplies); + } + }, +}; diff --git a/assets/stylesheets/modules/ai-bot/common/bot-replies.scss b/assets/stylesheets/modules/ai-bot/common/bot-replies.scss new file mode 100644 index 00000000..de71dcde --- /dev/null +++ b/assets/stylesheets/modules/ai-bot/common/bot-replies.scss @@ -0,0 +1,7 @@ +nav.post-controls .actions button.cancel-streaming { + display: none; +} + +article.streaming nav.post-controls .actions button.cancel-streaming { + display: inline-block; +} diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 42f0e946..f2250b48 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -24,6 +24,11 @@ en: since: one: "Last hour" other: "Last %{count} hours" + + ai_bot: + cancel_streaming: Stop reply + + review: types: reviewable_ai_post: diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 01138091..c877400a 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -56,6 +56,10 @@ en: ai_summarization_model: "Model to use for summarization." ai_summarization_rate_limit_minutes: "Minutes to elapse after the summarization limit is reached (6 requests)." + ai_bot_enabled: "Enable the AI Bot module." + ai_bot_allowed_groups: "When the GPT Bot has access to the PM, it will reply to members of these groups." + + reviewables: reasons: flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic. diff --git a/config/routes.rb b/config/routes.rb index 097e659a..e3b24702 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -13,6 +13,10 @@ DiscourseAi::Engine.routes.draw do scope module: :summarization, path: "/summarization", defaults: { format: :json } do post "summary" => "summary#show" end + + scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do + post "post/:post_id/stop-streaming" => "bot#stop_streaming_response" + end end Discourse::Application.routes.append { mount ::DiscourseAi::Engine, at: "discourse-ai" } diff --git a/config/settings.yml b/config/settings.yml index c92cda0b..d2e16797 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -170,3 +170,12 @@ plugins: - gpt-4 - claude-v1 ai_summarization_rate_limit_minutes: 10 + + ai_bot_enabled: + default: false + client: true + ai_bot_allowed_groups: + client: true + type: group_list + list_type: compact + default: "3|14" # 3: @staff, 14: @trust_level_4 \ No newline at end of file diff --git a/db/fixtures/ai_bot/602_bot_users.rb b/db/fixtures/ai_bot/602_bot_users.rb new file mode 100644 index 00000000..8c0ff990 --- /dev/null +++ b/db/fixtures/ai_bot/602_bot_users.rb @@ -0,0 +1,21 @@ +# frozen_string_literal: true + +UserEmail.seed do |ue| + ue.id = -110 + ue.email = "no_email_gpt_bot" + ue.primary = true + ue.user_id = -110 +end + +User.seed do |u| + u.id = -110 + u.name = "GPT Bot" + u.username = "gpt_bot" + u.username_lower = "gpt_bot" + u.password = SecureRandom.hex + u.active = true + u.admin = true + u.moderator = true + u.approved = true + u.trust_level = TrustLevel[4] +end diff --git a/db/fixtures/ai-helper/600_openai_completion_prompts.rb b/db/fixtures/ai_helper/600_openai_completion_prompts.rb similarity index 100% rename from db/fixtures/ai-helper/600_openai_completion_prompts.rb rename to db/fixtures/ai_helper/600_openai_completion_prompts.rb diff --git a/db/fixtures/ai-helper/601_anthropic_completion_prompts.rb b/db/fixtures/ai_helper/601_anthropic_completion_prompts.rb similarity index 100% rename from db/fixtures/ai-helper/601_anthropic_completion_prompts.rb rename to db/fixtures/ai_helper/601_anthropic_completion_prompts.rb diff --git a/lib/modules/ai_bot/entry_point.rb b/lib/modules/ai_bot/entry_point.rb new file mode 100644 index 00000000..af934f65 --- /dev/null +++ b/lib/modules/ai_bot/entry_point.rb @@ -0,0 +1,37 @@ +# frozen_string_literal: true + +module DiscourseAi + module AiBot + class EntryPoint + AI_BOT_ID = -110 + + def load_files + require_relative "jobs/regular/create_ai_reply" + end + + def inject_into(plugin) + plugin.register_seedfu_fixtures( + Rails.root.join("plugins", "discourse-ai", "db", "fixtures", "ai_bot"), + ) + + plugin.add_class_method(Discourse, :gpt_bot) do + @ai_bots ||= {} + current_db = RailsMultisite::ConnectionManagement.current_db + @ai_bots[current_db] ||= User.find(AI_BOT_ID) + end + + plugin.on(:post_created) do |post| + if post.topic.private_message? && post.user_id != AI_BOT_ID && + post.topic.topic_allowed_users.exists?(user_id: Discourse.gpt_bot.id) + in_allowed_group = + SiteSetting.ai_bot_allowed_groups_map.any? do |group_id| + post.user.group_ids.include?(group_id) + end + + Jobs.enqueue(:create_ai_reply, post_id: post.id) if in_allowed_group + end + end + end + end + end +end diff --git a/lib/modules/ai_bot/jobs/regular/create_ai_reply.rb b/lib/modules/ai_bot/jobs/regular/create_ai_reply.rb new file mode 100644 index 00000000..e2f1a6f6 --- /dev/null +++ b/lib/modules/ai_bot/jobs/regular/create_ai_reply.rb @@ -0,0 +1,76 @@ +# frozen_string_literal: true + +module ::Jobs + class CreateAiReply < ::Jobs::Base + sidekiq_options retry: false + + def execute(args) + return unless post = Post.includes(:topic).find_by(id: args[:post_id]) + + prompt = CompletionPrompt.bot_prompt_with_topic_context(post) + + redis_stream_key = nil + reply = +"" + bot_reply_post = nil + start = Time.now + + DiscourseAi::Inference::OpenAiCompletions.perform!( + prompt, + temperature: 0.4, + top_p: 0.9, + max_tokens: 3000, + ) do |partial, cancel| + content_delta = partial.dig(:choices, 0, :delta, :content) + reply << content_delta if content_delta + + if redis_stream_key && !Discourse.redis.get(redis_stream_key) + cancel&.call + + bot_reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) if bot_reply_post + end + + next if reply.length < SiteSetting.min_personal_message_post_length + # Minor hack to skip the delay during tests. + next if (Time.now - start < 0.5) && !Rails.env.test? + + if bot_reply_post + Discourse.redis.expire(redis_stream_key, 60) + start = Time.now + + MessageBus.publish( + "discourse-ai/ai-bot/topic/#{post.topic_id}", + { raw: reply.dup, post_id: bot_reply_post.id, post_number: bot_reply_post.post_number }, + user_ids: post.topic.allowed_user_ids, + ) + else + bot_reply_post = + PostCreator.create!( + Discourse.gpt_bot, + topic_id: post.topic_id, + raw: reply, + skip_validations: false, + ) + redis_stream_key = "gpt_cancel:#{bot_reply_post.id}" + Discourse.redis.setex(redis_stream_key, 60, 1) + end + end + + MessageBus.publish( + "discourse-ai/ai-bot/topic/#{post.topic_id}", + { done: true, post_id: bot_reply_post.id, post_number: bot_reply_post.post_number }, + user_ids: post.topic.allowed_user_ids, + ) + + if bot_reply_post + bot_reply_post.revise( + Discourse.gpt_bot, + { raw: reply }, + skip_validations: true, + skip_revision: true, + ) + end + rescue => e + Discourse.warn_exception(e, message: "ai-bot: Reply failed") + end + end +end diff --git a/lib/modules/ai_helper/entry_point.rb b/lib/modules/ai_helper/entry_point.rb index 0c92f22e..cdafab20 100644 --- a/lib/modules/ai_helper/entry_point.rb +++ b/lib/modules/ai_helper/entry_point.rb @@ -8,7 +8,7 @@ module DiscourseAi def inject_into(plugin) plugin.register_seedfu_fixtures( - Rails.root.join("plugins", "discourse-ai", "db", "fixtures", "ai-helper"), + Rails.root.join("plugins", "discourse-ai", "db", "fixtures", "ai_helper"), ) plugin.register_svg_icon("discourse-sparkles") end diff --git a/lib/shared/inference/openai_completions.rb b/lib/shared/inference/openai_completions.rb index 63643e16..7ef49e03 100644 --- a/lib/shared/inference/openai_completions.rb +++ b/lib/shared/inference/openai_completions.rb @@ -13,23 +13,20 @@ module ::DiscourseAi temperature: nil, top_p: nil, max_tokens: nil, - stream: false, - user_id: nil, - &blk + user_id: nil ) - raise ArgumentError, "block must be supplied in streaming mode" if stream && !blk - url = URI("https://api.openai.com/v1/chat/completions") headers = { - "Content-Type": "application/json", - Authorization: "Bearer #{SiteSetting.ai_openai_api_key}", + "Content-Type" => "application/json", + "Authorization" => "Bearer #{SiteSetting.ai_openai_api_key}", } + payload = { model: model, messages: messages } payload[:temperature] = temperature if temperature payload[:top_p] = top_p if top_p payload[:max_tokens] = max_tokens if max_tokens - payload[:stream] = true if stream + payload[:stream] = true if block_given? Net::HTTP.start( url.host, @@ -43,84 +40,71 @@ module ::DiscourseAi request_body = payload.to_json request.body = request_body - response = http.request(request) + http.request(request) do |response| + if response.code.to_i != 200 + Rails.logger.error( + "OpenAiCompletions: status: #{response.code.to_i} - body: #{response.body}", + ) + raise CompletionFailed + end - if response.code.to_i != 200 - Rails.logger.error( - "OpenAiCompletions: status: #{response.code.to_i} - body: #{response.body}", - ) - raise CompletionFailed - end + log = + AiApiAuditLog.create!( + provider_id: AiApiAuditLog::Provider::OpenAI, + raw_request_payload: request_body, + user_id: user_id, + ) - log = - AiApiAuditLog.create!( - provider_id: AiApiAuditLog::Provider::OpenAI, - raw_request_payload: request_body, - user_id: user_id, - ) + if !block_given? + response_body = response.read_body + parsed_response = JSON.parse(response_body, symbolize_names: true) - if stream - stream(http, response, messages, log, &blk) - else - response_body = response.body - parsed = JSON.parse(response_body, symbolize_names: true) + log.update!( + raw_response_payload: response_body, + request_tokens: parsed_response.dig(:usage, :prompt_tokens), + response_tokens: parsed_response.dig(:usage, :completion_tokens), + ) + return parsed_response + end - log.update!( - raw_response_payload: response_body, - request_tokens: parsed.dig(:usage, :prompt_tokens), - response_tokens: parsed.dig(:usage, :completion_tokens), - ) - parsed - end - end - end + begin + cancelled = false + cancel = lambda { cancelled = true } + response_data = +"" + response_raw = +"" - def self.stream(http, response, messages, log) - cancelled = false - cancel = lambda { cancelled = true } - - response_data = +"" - response_raw = +"" - - response.read_body do |chunk| - if cancelled - http.finish - break - end - - response_raw << chunk - - chunk - .split("\n") - .each do |line| - data = line.split("data: ", 2)[1] - - next if data == "[DONE]" - - if data - json = JSON.parse(data, symbolize_names: true) - choices = json[:choices] - if choices && choices[0] - delta = choices[0].dig(:delta, :content) - response_data << delta if delta + response.read_body do |chunk| + if cancelled + http.finish + return end - yield json, cancel - end - if cancelled - http.finish - break + response_raw << chunk + + chunk + .split("\n") + .each do |line| + data = line.split("data: ", 2)[1] + next if !data || data == "[DONE]" + + if !cancelled && partial = JSON.parse(data, symbolize_names: true) + response_data << partial.dig(:choices, 0, :delta, :content).to_s + + yield partial, cancel + end + end + rescue IOError + raise if !cancelled + ensure + log.update!( + raw_response_payload: response_raw, + request_tokens: DiscourseAi::Tokenizer.size(extract_prompt(messages)), + response_tokens: DiscourseAi::Tokenizer.size(response_data), + ) end end + end end - rescue IOError - raise if !cancelled - ensure - log.update!( - raw_response_payload: response_raw, - request_tokens: DiscourseAi::Tokenizer.size(extract_prompt(messages)), - response_tokens: DiscourseAi::Tokenizer.size(response_data), - ) end def self.extract_prompt(messages) diff --git a/plugin.rb b/plugin.rb index 0d357a73..5aeffaee 100644 --- a/plugin.rb +++ b/plugin.rb @@ -13,6 +13,7 @@ enabled_site_setting :discourse_ai_enabled register_asset "stylesheets/modules/ai-helper/common/ai-helper.scss" register_asset "stylesheets/modules/summarization/common/summarization.scss" +register_asset "stylesheets/modules/ai-bot/common/bot-replies.scss" module ::DiscourseAi PLUGIN_NAME = "discourse-ai" @@ -41,6 +42,7 @@ after_initialize do require_relative "lib/modules/ai_helper/entry_point" require_relative "lib/modules/embeddings/entry_point" require_relative "lib/modules/summarization/entry_point" + require_relative "lib/modules/ai_bot/entry_point" [ DiscourseAi::Embeddings::EntryPoint.new, @@ -49,6 +51,7 @@ after_initialize do DiscourseAi::Sentiment::EntryPoint.new, DiscourseAi::AiHelper::EntryPoint.new, DiscourseAi::Summarization::EntryPoint.new, + DiscourseAi::AiBot::EntryPoint.new, ].each do |a_module| a_module.load_files a_module.inject_into(self) diff --git a/spec/lib/modules/ai_bot/entry_point_spec.rb b/spec/lib/modules/ai_bot/entry_point_spec.rb new file mode 100644 index 00000000..f6446232 --- /dev/null +++ b/spec/lib/modules/ai_bot/entry_point_spec.rb @@ -0,0 +1,73 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::AiBot::EntryPoint do + describe "#inject_into" do + describe "subscribes to the post_created event" do + fab!(:admin) { Fabricate(:admin) } + let(:gpt_bot) { Discourse.gpt_bot } + fab!(:bot_allowed_group) { Fabricate(:group) } + + let(:post_args) do + { + title: "Dear AI, I want to ask a question", + raw: "Hello, Can you please tell me a story?", + archetype: Archetype.private_message, + target_usernames: [gpt_bot.username].join(","), + category: 1, + } + end + + before do + SiteSetting.ai_bot_allowed_groups = bot_allowed_group.id + bot_allowed_group.add(admin) + end + + it "queues a job to generate a reply by the AI" do + expect { PostCreator.create!(admin, post_args) }.to change( + Jobs::CreateAiReply.jobs, + :size, + ).by(1) + end + + context "when the post is not from a PM" do + it "does nothing" do + expect { + PostCreator.create!(admin, post_args.merge(archetype: Archetype.default)) + }.not_to change(Jobs::CreateAiReply.jobs, :size) + end + end + + context "when the bot doesn't have access to the PM" do + it "does nothing" do + user_2 = Fabricate(:user) + expect { + PostCreator.create!(admin, post_args.merge(target_usernames: [user_2.username])) + }.not_to change(Jobs::CreateAiReply.jobs, :size) + end + end + + context "when the user is not allowed to interact with the bot" do + it "does nothing" do + bot_allowed_group.remove(admin) + expect { PostCreator.create!(admin, post_args) }.not_to change( + Jobs::CreateAiReply.jobs, + :size, + ) + end + end + + context "when the post was created by the bot" do + it "does nothing" do + gpt_topic_id = PostCreator.create!(admin, post_args).topic_id + reply_args = + post_args.except(:archetype, :target_usernames, :title).merge(topic_id: gpt_topic_id) + + expect { PostCreator.create!(gpt_bot, reply_args) }.not_to change( + Jobs::CreateAiReply.jobs, + :size, + ) + end + end + end + end +end diff --git a/spec/lib/modules/ai_bot/jobs/regular/create_ai_reply_spec.rb b/spec/lib/modules/ai_bot/jobs/regular/create_ai_reply_spec.rb new file mode 100644 index 00000000..bf305e02 --- /dev/null +++ b/spec/lib/modules/ai_bot/jobs/regular/create_ai_reply_spec.rb @@ -0,0 +1,53 @@ +# frozen_string_literal: true + +require_relative "../../../../../support/openai_completions_inference_stubs" + +RSpec.describe Jobs::CreateAiReply do + describe "#execute" do + fab!(:topic) { Fabricate(:topic) } + fab!(:post) { Fabricate(:post, topic: topic) } + + let(:expected_response) do + "Hello this is a bot and what you just said is an interesting question" + end + let(:deltas) { expected_response.split(" ").map { |w| { content: "#{w} " } } } + + before do + SiteSetting.min_personal_message_post_length = 5 + + OpenAiCompletionsInferenceStubs.stub_streamed_response( + CompletionPrompt.bot_prompt_with_topic_context(post), + deltas, + req_opts: { + temperature: 0.4, + top_p: 0.9, + max_tokens: 3000, + stream: true, + }, + ) + end + + it "adds a reply from the GPT bot" do + subject.execute(post_id: topic.first_post.id) + + expect(topic.posts.last.raw).to eq(expected_response) + end + + it "streams the reply on the fly to the client through MB" do + messages = + MessageBus.track_publish("discourse-ai/ai-bot/topic/#{topic.id}") do + subject.execute(post_id: topic.first_post.id) + end + + done_signal = messages.pop + + expect(messages.length).to eq(deltas.length) + + messages.each_with_index do |m, idx| + expect(m.data[:raw]).to eq(deltas[0..(idx + 1)].map { |d| d[:content] }.join) + end + + expect(done_signal.data[:done]).to eq(true) + end + end +end diff --git a/spec/models/completion_prompt_spec.rb b/spec/models/completion_prompt_spec.rb index 1eb356df..c4f6cb57 100644 --- a/spec/models/completion_prompt_spec.rb +++ b/spec/models/completion_prompt_spec.rb @@ -18,4 +18,46 @@ RSpec.describe CompletionPrompt do end end end + + describe ".bot_prompt_with_topic_context" do + fab!(:topic) { Fabricate(:topic) } + + def post_body(post_number) + "This is post #{post_number}" + end + + context "when the topic has one post" do + fab!(:post_1) { Fabricate(:post, topic: topic, raw: post_body(1), post_number: 1) } + + it "includes it in the prompt" do + prompt_messages = described_class.bot_prompt_with_topic_context(post_1) + + post_1_message = prompt_messages[1] + + expect(post_1_message[:role]).to eq("user") + expect(post_1_message[:content]).to eq(post_body(1)) + end + end + + context "when the topic has multiple posts" do + fab!(:post_1) { Fabricate(:post, topic: topic, raw: post_body(1), post_number: 1) } + fab!(:post_2) do + Fabricate(:post, topic: topic, user: Discourse.gpt_bot, raw: post_body(2), post_number: 2) + end + fab!(:post_3) { Fabricate(:post, topic: topic, raw: post_body(3), post_number: 3) } + + it "includes them in the prompt respecting the post number order" do + prompt_messages = described_class.bot_prompt_with_topic_context(post_3) + + expect(prompt_messages[1][:role]).to eq("user") + expect(prompt_messages[1][:content]).to eq(post_body(1)) + + expect(prompt_messages[2][:role]).to eq("system") + expect(prompt_messages[2][:content]).to eq(post_body(2)) + + expect(prompt_messages[3][:role]).to eq("user") + expect(prompt_messages[3][:content]).to eq(post_body(3)) + end + end + end end diff --git a/spec/requests/ai_bot/bot_controller_spec.rb b/spec/requests/ai_bot/bot_controller_spec.rb new file mode 100644 index 00000000..6efa385f --- /dev/null +++ b/spec/requests/ai_bot/bot_controller_spec.rb @@ -0,0 +1,29 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::AiBot::BotController do + describe "#stop_streaming_response" do + fab!(:pm_topic) { Fabricate(:private_message_topic) } + fab!(:pm_post) { Fabricate(:post, topic: pm_topic) } + + let(:redis_stream_key) { "gpt_cancel:#{pm_post.id}" } + + before { Discourse.redis.setex(redis_stream_key, 60, 1) } + + it "returns a 403 when the user cannot see the PM" do + sign_in(Fabricate(:user)) + + post "/discourse-ai/ai-bot/post/#{pm_post.id}/stop-streaming" + + expect(response.status).to eq(403) + end + + it "deletes the key using to track the streaming" do + sign_in(pm_topic.topic_allowed_users.first.user) + + post "/discourse-ai/ai-bot/post/#{pm_post.id}/stop-streaming" + + expect(response.status).to eq(200) + expect(Discourse.redis.get(redis_stream_key)).to be_nil + end + end +end diff --git a/spec/shared/inference/openai_completions_spec.rb b/spec/shared/inference/openai_completions_spec.rb index cac34c22..c6c1e6cc 100644 --- a/spec/shared/inference/openai_completions_spec.rb +++ b/spec/shared/inference/openai_completions_spec.rb @@ -1,26 +1,19 @@ # frozen_string_literal: true require "rails_helper" +require_relative "../../support/openai_completions_inference_stubs" + describe DiscourseAi::Inference::OpenAiCompletions do before { SiteSetting.ai_openai_api_key = "abc-123" } it "can complete a trivial prompt" do - body = <<~JSON - {"id":"chatcmpl-74OT0yKnvbmTkqyBINbHgAW0fpbxc","object":"chat.completion","created":1681281718,"model":"gpt-3.5-turbo-0301","usage":{"prompt_tokens":12,"completion_tokens":13,"total_tokens":25},"choices":[{"message":{"role":"assistant","content":"1. Serenity\\n2. Laughter\\n3. Adventure"},"finish_reason":"stop","index":0}]} - JSON - - stub_request(:post, "https://api.openai.com/v1/chat/completions").with( - body: - "{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"write 3 words\"}],\"temperature\":0.5,\"top_p\":0.8,\"max_tokens\":700}", - headers: { - "Authorization" => "Bearer #{SiteSetting.ai_openai_api_key}", - "Content-Type" => "application/json", - }, - ).to_return(status: 200, body: body, headers: {}) - - user_id = 183 - + response_text = "1. Serenity\\n2. Laughter\\n3. Adventure" prompt = [role: "user", content: "write 3 words"] + user_id = 183 + req_opts = { temperature: 0.5, top_p: 0.8, max_tokens: 700 } + + OpenAiCompletionsInferenceStubs.stub_response(prompt, response_text, req_opts: req_opts) + completions = DiscourseAi::Inference::OpenAiCompletions.perform!( prompt, @@ -30,75 +23,45 @@ describe DiscourseAi::Inference::OpenAiCompletions do max_tokens: 700, user_id: user_id, ) - expect(completions[:choices][0][:message][:content]).to eq( - "1. Serenity\n2. Laughter\n3. Adventure", - ) + + expect(completions.dig(:choices, 0, :message, :content)).to eq(response_text) expect(AiApiAuditLog.count).to eq(1) log = AiApiAuditLog.first - request_body = (<<~JSON).strip - {"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"write 3 words"}],"temperature":0.5,"top_p":0.8,"max_tokens":700} - JSON + body = { model: "gpt-3.5-turbo", messages: prompt }.merge(req_opts).to_json + request_body = OpenAiCompletionsInferenceStubs.response(response_text).to_json expect(log.provider_id).to eq(AiApiAuditLog::Provider::OpenAI) - expect(log.request_tokens).to eq(12) - expect(log.response_tokens).to eq(13) - expect(log.raw_request_payload).to eq(request_body) - expect(log.raw_response_payload).to eq(body) - end - - it "raises an error if attempting to stream without a block" do - expect do - DiscourseAi::Inference::OpenAiCompletions.perform!([], "gpt-3.5-turbo", stream: true) - end.to raise_error(ArgumentError) - end - - def stream_line(finish_reason: nil, delta: {}) - +"data: " << { - id: "chatcmpl-#{SecureRandom.hex}", - object: "chat.completion.chunk", - created: 1_681_283_881, - model: "gpt-3.5-turbo-0301", - choices: [{ delta: delta }], - finish_reason: finish_reason, - index: 0, - }.to_json + expect(log.request_tokens).to eq(337) + expect(log.response_tokens).to eq(162) + expect(log.raw_request_payload).to eq(body) + expect(log.raw_response_payload).to eq(request_body) end it "can operate in streaming mode" do - payload = [ - stream_line(delta: { role: "assistant" }), - stream_line(delta: { content: "Mount" }), - stream_line(delta: { content: "ain" }), - stream_line(delta: { content: " " }), - stream_line(delta: { content: "Tree " }), - stream_line(delta: { content: "Frog" }), - stream_line(finish_reason: "stop"), - "[DONE]", - ].join("\n\n") - - stub_request(:post, "https://api.openai.com/v1/chat/completions").with( - body: - "{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"write 3 words\"}],\"stream\":true}", - headers: { - "Accept" => "*/*", - "Authorization" => "Bearer abc-123", - "Content-Type" => "application/json", - "Host" => "api.openai.com", - }, - ).to_return(status: 200, body: payload, headers: {}) + deltas = [ + { role: "assistant" }, + { content: "Mount" }, + { content: "ain" }, + { content: " " }, + { content: "Tree " }, + { content: "Frog" }, + ] prompt = [role: "user", content: "write 3 words"] - content = +"" - DiscourseAi::Inference::OpenAiCompletions.perform!( + OpenAiCompletionsInferenceStubs.stub_streamed_response( prompt, - "gpt-3.5-turbo", - stream: true, - ) do |partial, cancel| - data = partial[:choices][0].dig(:delta, :content) + deltas, + req_opts: { + stream: true, + }, + ) + + DiscourseAi::Inference::OpenAiCompletions.perform!(prompt, "gpt-3.5-turbo") do |partial, cancel| + data = partial.dig(:choices, 0, :delta, :content) content << data if data cancel.call if content.split(" ").length == 2 end @@ -108,14 +71,12 @@ describe DiscourseAi::Inference::OpenAiCompletions do expect(AiApiAuditLog.count).to eq(1) log = AiApiAuditLog.first - request_body = (<<~JSON).strip - {"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"write 3 words"}],"stream":true} - JSON + request_body = { model: "gpt-3.5-turbo", messages: prompt, stream: true }.to_json expect(log.provider_id).to eq(AiApiAuditLog::Provider::OpenAI) expect(log.request_tokens).to eq(5) expect(log.response_tokens).to eq(4) expect(log.raw_request_payload).to eq(request_body) - expect(log.raw_response_payload).to eq(payload) + expect(log.raw_response_payload).to be_present end end diff --git a/spec/support/openai_completions_inference_stubs.rb b/spec/support/openai_completions_inference_stubs.rb index 369983fa..59288151 100644 --- a/spec/support/openai_completions_inference_stubs.rb +++ b/spec/support/openai_completions_inference_stubs.rb @@ -98,10 +98,38 @@ class OpenAiCompletionsInferenceStubs prompt_messages = CompletionPrompt.find_by(name: type).messages_with_user_input(text) + stub_response(prompt_messages, response_text_for(type)) + end + + def stub_response(messages, response_text, req_opts: {}) WebMock .stub_request(:post, "https://api.openai.com/v1/chat/completions") - .with(body: { model: "gpt-3.5-turbo", messages: prompt_messages }.to_json) - .to_return(status: 200, body: JSON.dump(response(response_text_for(type)))) + .with(body: { model: "gpt-3.5-turbo", messages: messages }.merge(req_opts).to_json) + .to_return(status: 200, body: JSON.dump(response(response_text))) + end + + def stream_line(finish_reason: nil, delta: {}) + +"data: " << { + id: "chatcmpl-#{SecureRandom.hex}", + object: "chat.completion.chunk", + created: 1_681_283_881, + model: "gpt-3.5-turbo-0301", + choices: [{ delta: delta }], + finish_reason: finish_reason, + index: 0, + }.to_json + end + + def stub_streamed_response(messages, deltas, req_opts: {}) + chunks = deltas.map { |d| stream_line(delta: d) } + chunks << stream_line(finish_reason: "stop") + chunks << "[DONE]" + chunks = chunks.join("\n\n") + + WebMock + .stub_request(:post, "https://api.openai.com/v1/chat/completions") + .with(body: { model: "gpt-3.5-turbo", messages: messages }.merge(req_opts).to_json) + .to_return(status: 200, body: chunks) end end end