diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb index c3718574..8a023259 100644 --- a/app/models/ai_api_audit_log.rb +++ b/app/models/ai_api_audit_log.rb @@ -3,5 +3,21 @@ class AiApiAuditLog < ActiveRecord::Base module Provider OpenAI = 1 + Anthropic = 2 end end + +# == Schema Information +# +# Table name: ai_api_audit_logs +# +# id :bigint not null, primary key +# provider_id :integer not null +# user_id :integer +# request_tokens :integer +# response_tokens :integer +# raw_request_payload :string +# raw_response_payload :string +# created_at :datetime not null +# updated_at :datetime not null +# diff --git a/app/models/completion_prompt.rb b/app/models/completion_prompt.rb index 135408d7..167f9cbc 100644 --- a/app/models/completion_prompt.rb +++ b/app/models/completion_prompt.rb @@ -4,52 +4,11 @@ class CompletionPrompt < ActiveRecord::Base # TODO(roman): Remove sept 2023. self.ignored_columns = ["value"] - # GPT 3.5 allows 4000 tokens - MAX_PROMPT_TOKENS = 3500 - enum :prompt_type, { text: 0, list: 1, diff: 2 } validates :messages, length: { maximum: 20 } validate :each_message_length - def self.bot_prompt_with_topic_context(post) - messages = [] - conversation = - post - .topic - .posts - .includes(:user) - .where("post_number <= ?", post.post_number) - .order("post_number desc") - .pluck(:raw, :username) - - total_prompt_tokens = 0 - messages = - conversation.reduce([]) do |memo, (raw, username)| - break(memo) if total_prompt_tokens >= MAX_PROMPT_TOKENS - - tokens = DiscourseAi::Tokenizer.tokenize(raw) - - if tokens.length + total_prompt_tokens > MAX_PROMPT_TOKENS - tokens = tokens[0...(MAX_PROMPT_TOKENS - total_prompt_tokens)] - raw = tokens.join(" ") - end - - total_prompt_tokens += tokens.length - role = username == Discourse.gpt_bot.username ? "system" : "user" - - memo.unshift({ role: role, content: raw }) - end - - messages.unshift({ role: "system", content: <<~TEXT }) - You are gpt-bot. You answer questions and generate text. - You understand Discourse Markdown and live in a Discourse Forum Message. - You are provided you with context of previous discussions. - TEXT - - messages - end - def messages_with_user_input(user_input) if ::DiscourseAi::AiHelper::LlmPrompt.new.enabled_provider == "openai" self.messages << { role: "user", content: user_input } diff --git a/assets/javascripts/initializers/ai-bot-replies.js b/assets/javascripts/initializers/ai-bot-replies.js index fdd4e9b1..a056f014 100644 --- a/assets/javascripts/initializers/ai-bot-replies.js +++ b/assets/javascripts/initializers/ai-bot-replies.js @@ -5,7 +5,7 @@ import { popupAjaxError } from "discourse/lib/ajax-error"; import loadScript from "discourse/lib/load-script"; function isGPTBot(user) { - return user && user.id === -110; + return user && [-110, -111, -112].includes(user.id); } function initializeAIBotReplies(api) { @@ -74,7 +74,7 @@ function initializeAIBotReplies(api) { if ( this.model.isPrivateMessage && this.model.details.allowed_users && - this.model.details.allowed_users.filter(isGPTBot).length === 1 + this.model.details.allowed_users.filter(isGPTBot).length >= 1 ) { this.messageBus.subscribe( `discourse-ai/ai-bot/topic/${this.model.id}`, @@ -83,7 +83,7 @@ function initializeAIBotReplies(api) { } }, unsubscribe: function () { - this.messageBus.unsubscribe("discourse-ai/ai-bot/topic/"); + this.messageBus.unsubscribe("discourse-ai/ai-bot/topic/*"); this._super(); }, }); diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index c877400a..4bed6676 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -32,6 +32,7 @@ en: ai_nsfw_models: "Models to use for NSFW inference." ai_openai_api_key: "API key for OpenAI API" + ai_anthropic_api_key: "API key for Anthropic API" composer_ai_helper_enabled: "Enable the Composer's AI helper." ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer." @@ -58,6 +59,7 @@ en: ai_bot_enabled: "Enable the AI Bot module." ai_bot_allowed_groups: "When the GPT Bot has access to the PM, it will reply to members of these groups." + ai_bot_enabled_chat_bots: "Available models to act as an AI Bot" reviewables: diff --git a/config/settings.yml b/config/settings.yml index 0d3ed860..a2251581 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -179,4 +179,13 @@ plugins: client: true type: group_list list_type: compact - default: "3|14" # 3: @staff, 14: @trust_level_4 \ No newline at end of file + default: "3|14" # 3: @staff, 14: @trust_level_4 + # Adding a new bot? Make sure to create a user for it on the seed file. + ai_bot_enabled_chat_bots: + type: list + default: "gpt-3.5-turbo" + client: true + choices: + - gpt-3.5-turbo + - gpt-4 + - claude-v1 \ No newline at end of file diff --git a/db/fixtures/ai_bot/602_bot_users.rb b/db/fixtures/ai_bot/602_bot_users.rb index 7beb41b2..a7ec1f12 100644 --- a/db/fixtures/ai_bot/602_bot_users.rb +++ b/db/fixtures/ai_bot/602_bot_users.rb @@ -1,20 +1,22 @@ # frozen_string_literal: true -UserEmail.seed do |ue| - ue.id = -110 - ue.email = "no_email_gpt_bot" - ue.primary = true - ue.user_id = -110 -end +DiscourseAi::AiBot::EntryPoint::BOTS.each do |id, bot_username| + UserEmail.seed do |ue| + ue.id = id + ue.email = "no_email_#{bot_username}" + ue.primary = true + ue.user_id = id + end -User.seed do |u| - u.id = -110 - u.name = "GPT Bot" - u.username = UserNameSuggester.suggest("gpt_bot") - u.password = SecureRandom.hex - u.active = true - u.admin = true - u.moderator = true - u.approved = true - u.trust_level = TrustLevel[4] + User.seed do |u| + u.id = id + u.name = bot_username.titleize + u.username = UserNameSuggester.suggest(bot_username) + u.password = SecureRandom.hex + u.active = true + u.admin = true + u.moderator = true + u.approved = true + u.trust_level = TrustLevel[4] + end end diff --git a/lib/modules/ai_bot/anthropic_bot.rb b/lib/modules/ai_bot/anthropic_bot.rb new file mode 100644 index 00000000..954963dd --- /dev/null +++ b/lib/modules/ai_bot/anthropic_bot.rb @@ -0,0 +1,45 @@ +# frozen_string_literal: true + +module DiscourseAi + module AiBot + class AnthropicBot < Bot + def self.can_reply_as?(bot_user) + bot_user.id == DiscourseAi::AiBot::EntryPoint::CLAUDE_V1_ID + end + + def bot_prompt_with_topic_context(post) + super(post).join("\n\n") + end + + def prompt_limit + 7500 # https://console.anthropic.com/docs/prompt-design#what-is-a-prompt + end + + private + + def build_message(poster_username, content) + role = poster_username == bot_user.username ? "Assistant" : "Human" + + "#{role}: #{content}" + end + + def model_for + "claude-v1" + end + + def update_with_delta(_, partial) + partial[:completion] + end + + def submit_prompt_and_stream_reply(prompt, &blk) + DiscourseAi::Inference::AnthropicCompletions.perform!( + prompt, + model_for, + temperature: 0.4, + max_tokens: 3000, + &blk + ) + end + end + end +end diff --git a/lib/modules/ai_bot/bot.rb b/lib/modules/ai_bot/bot.rb new file mode 100644 index 00000000..e9af6a80 --- /dev/null +++ b/lib/modules/ai_bot/bot.rb @@ -0,0 +1,144 @@ +# frozen_string_literal: true + +module DiscourseAi + module AiBot + class Bot + BOT_NOT_FOUND = Class.new(StandardError) + + def self.as(bot_user) + available_bots = [DiscourseAi::AiBot::OpenAiBot, DiscourseAi::AiBot::AnthropicBot] + + bot = + available_bots.detect(-> { raise BOT_NOT_FOUND }) do |bot_klass| + bot_klass.can_reply_as?(bot_user) + end + + bot.new(bot_user) + end + + def initialize(bot_user) + @bot_user = bot_user + end + + def reply_to(post) + prompt = bot_prompt_with_topic_context(post) + + redis_stream_key = nil + reply = +"" + bot_reply_post = nil + start = Time.now + + submit_prompt_and_stream_reply(prompt) do |partial, cancel| + reply = update_with_delta(reply, partial) + + if redis_stream_key && !Discourse.redis.get(redis_stream_key) + cancel&.call + + bot_reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) if bot_reply_post + end + + next if reply.length < SiteSetting.min_personal_message_post_length + # Minor hack to skip the delay during tests. + next if (Time.now - start < 0.5) && !Rails.env.test? + + if bot_reply_post + Discourse.redis.expire(redis_stream_key, 60) + start = Time.now + + publish_update(bot_reply_post, raw: reply.dup) + else + bot_reply_post = + PostCreator.create!( + bot_user, + topic_id: post.topic_id, + raw: reply, + skip_validations: false, + ) + redis_stream_key = "gpt_cancel:#{bot_reply_post.id}" + Discourse.redis.setex(redis_stream_key, 60, 1) + end + end + + if bot_reply_post + publish_update(bot_reply_post, done: true) + bot_reply_post.revise( + bot_user, + { raw: reply }, + skip_validations: true, + skip_revision: true, + ) + end + rescue => e + Discourse.warn_exception(e, message: "ai-bot: Reply failed") + end + + def bot_prompt_with_topic_context(post) + messages = [] + conversation = conversation_context(post) + + total_prompt_tokens = 0 + messages = + conversation.reduce([]) do |memo, (raw, username)| + break(memo) if total_prompt_tokens >= prompt_limit + + tokens = DiscourseAi::Tokenizer.tokenize(raw) + + if tokens.length + total_prompt_tokens > prompt_limit + tokens = tokens[0...(prompt_limit - total_prompt_tokens)] + raw = tokens.join(" ") + end + + total_prompt_tokens += tokens.length + + memo.unshift(build_message(username, raw)) + end + + messages.unshift(build_message(bot_user.username, <<~TEXT)) + You are gpt-bot. You answer questions and generate text. + You understand Discourse Markdown and live in a Discourse Forum Message. + You are provided you with context of previous discussions. + TEXT + + messages + end + + def prompt_limit + raise NotImplemented + end + + protected + + attr_reader :bot_user + + def model_for(bot) + raise NotImplemented + end + + def get_delta_from(partial) + raise NotImplemented + end + + def submit_prompt_and_stream_reply(prompt, &blk) + raise NotImplemented + end + + def conversation_context(post) + post + .topic + .posts + .includes(:user) + .where("post_number <= ?", post.post_number) + .order("post_number desc") + .pluck(:raw, :username) + end + + def publish_update(bot_reply_post, payload) + MessageBus.publish( + "discourse-ai/ai-bot/topic/#{bot_reply_post.topic_id}", + payload.merge(post_id: bot_reply_post.id, post_number: bot_reply_post.post_number), + user_ids: bot_reply_post.topic.allowed_user_ids, + ) + end + end + end +end diff --git a/lib/modules/ai_bot/entry_point.rb b/lib/modules/ai_bot/entry_point.rb index af934f65..e61bc022 100644 --- a/lib/modules/ai_bot/entry_point.rb +++ b/lib/modules/ai_bot/entry_point.rb @@ -3,10 +3,20 @@ module DiscourseAi module AiBot class EntryPoint - AI_BOT_ID = -110 + GPT4_ID = -110 + GPT3_5_TURBO_ID = -111 + CLAUDE_V1_ID = -112 + BOTS = [ + [GPT4_ID, "gpt4_bot"], + [GPT3_5_TURBO_ID, "gpt3.5_bot"], + [CLAUDE_V1_ID, "claude_v1_bot"], + ] def load_files require_relative "jobs/regular/create_ai_reply" + require_relative "bot" + require_relative "anthropic_bot" + require_relative "open_ai_bot" end def inject_into(plugin) @@ -14,21 +24,15 @@ module DiscourseAi Rails.root.join("plugins", "discourse-ai", "db", "fixtures", "ai_bot"), ) - plugin.add_class_method(Discourse, :gpt_bot) do - @ai_bots ||= {} - current_db = RailsMultisite::ConnectionManagement.current_db - @ai_bots[current_db] ||= User.find(AI_BOT_ID) - end - plugin.on(:post_created) do |post| - if post.topic.private_message? && post.user_id != AI_BOT_ID && - post.topic.topic_allowed_users.exists?(user_id: Discourse.gpt_bot.id) - in_allowed_group = - SiteSetting.ai_bot_allowed_groups_map.any? do |group_id| - post.user.group_ids.include?(group_id) - end + bot_ids = BOTS.map(&:first) - Jobs.enqueue(:create_ai_reply, post_id: post.id) if in_allowed_group + if post.topic.private_message? && !bot_ids.include?(post.user_id) + if (SiteSetting.ai_bot_allowed_groups_map & post.user.group_ids).present? + bot_id = post.topic.topic_allowed_users.where(user_id: bot_ids).first&.user_id + + Jobs.enqueue(:create_ai_reply, post_id: post.id, bot_user_id: bot_id) if bot_id + end end end end diff --git a/lib/modules/ai_bot/jobs/regular/create_ai_reply.rb b/lib/modules/ai_bot/jobs/regular/create_ai_reply.rb index e2f1a6f6..d32879dc 100644 --- a/lib/modules/ai_bot/jobs/regular/create_ai_reply.rb +++ b/lib/modules/ai_bot/jobs/regular/create_ai_reply.rb @@ -5,72 +5,11 @@ module ::Jobs sidekiq_options retry: false def execute(args) + return unless bot_user = User.find_by(id: args[:bot_user_id]) + return unless bot = DiscourseAi::AiBot::Bot.as(bot_user) return unless post = Post.includes(:topic).find_by(id: args[:post_id]) - prompt = CompletionPrompt.bot_prompt_with_topic_context(post) - - redis_stream_key = nil - reply = +"" - bot_reply_post = nil - start = Time.now - - DiscourseAi::Inference::OpenAiCompletions.perform!( - prompt, - temperature: 0.4, - top_p: 0.9, - max_tokens: 3000, - ) do |partial, cancel| - content_delta = partial.dig(:choices, 0, :delta, :content) - reply << content_delta if content_delta - - if redis_stream_key && !Discourse.redis.get(redis_stream_key) - cancel&.call - - bot_reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) if bot_reply_post - end - - next if reply.length < SiteSetting.min_personal_message_post_length - # Minor hack to skip the delay during tests. - next if (Time.now - start < 0.5) && !Rails.env.test? - - if bot_reply_post - Discourse.redis.expire(redis_stream_key, 60) - start = Time.now - - MessageBus.publish( - "discourse-ai/ai-bot/topic/#{post.topic_id}", - { raw: reply.dup, post_id: bot_reply_post.id, post_number: bot_reply_post.post_number }, - user_ids: post.topic.allowed_user_ids, - ) - else - bot_reply_post = - PostCreator.create!( - Discourse.gpt_bot, - topic_id: post.topic_id, - raw: reply, - skip_validations: false, - ) - redis_stream_key = "gpt_cancel:#{bot_reply_post.id}" - Discourse.redis.setex(redis_stream_key, 60, 1) - end - end - - MessageBus.publish( - "discourse-ai/ai-bot/topic/#{post.topic_id}", - { done: true, post_id: bot_reply_post.id, post_number: bot_reply_post.post_number }, - user_ids: post.topic.allowed_user_ids, - ) - - if bot_reply_post - bot_reply_post.revise( - Discourse.gpt_bot, - { raw: reply }, - skip_validations: true, - skip_revision: true, - ) - end - rescue => e - Discourse.warn_exception(e, message: "ai-bot: Reply failed") + bot.reply_to(post) end end end diff --git a/lib/modules/ai_bot/open_ai_bot.rb b/lib/modules/ai_bot/open_ai_bot.rb new file mode 100644 index 00000000..855e17b2 --- /dev/null +++ b/lib/modules/ai_bot/open_ai_bot.rb @@ -0,0 +1,48 @@ +# frozen_string_literal: true + +module DiscourseAi + module AiBot + class OpenAiBot < Bot + def self.can_reply_as?(bot_user) + open_ai_bot_ids = [ + DiscourseAi::AiBot::EntryPoint::GPT4_ID, + DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID, + ] + + open_ai_bot_ids.include?(bot_user.id) + end + + def prompt_limit + 3500 + end + + private + + def build_message(poster_username, content) + role = poster_username == bot_user.username ? "system" : "user" + + { role: role, content: content } + end + + def model_for + return "gpt-4" if bot_user.id == DiscourseAi::AiBot::EntryPoint::GPT4_ID + "gpt-3.5-turbo" + end + + def update_with_delta(current_delta, partial) + current_delta + partial.dig(:choices, 0, :delta, :content).to_s + end + + def submit_prompt_and_stream_reply(prompt, &blk) + DiscourseAi::Inference::OpenAiCompletions.perform!( + prompt, + model_for, + temperature: 0.4, + top_p: 0.9, + max_tokens: 3000, + &blk + ) + end + end + end +end diff --git a/lib/modules/ai_helper/llm_prompt.rb b/lib/modules/ai_helper/llm_prompt.rb index 76ccbdab..1f09e459 100644 --- a/lib/modules/ai_helper/llm_prompt.rb +++ b/lib/modules/ai_helper/llm_prompt.rb @@ -64,7 +64,7 @@ module DiscourseAi messages = prompt.messages_with_user_input(text) result[:suggestions] = DiscourseAi::Inference::OpenAiCompletions - .perform!(messages) + .perform!(messages, SiteSetting.ai_helper_model) .dig(:choices) .to_a .flat_map { |choice| parse_content(prompt, choice.dig(:message, :content).to_s) } diff --git a/lib/shared/inference/anthropic_completions.rb b/lib/shared/inference/anthropic_completions.rb index 6346452f..c009214f 100644 --- a/lib/shared/inference/anthropic_completions.rb +++ b/lib/shared/inference/anthropic_completions.rb @@ -4,32 +4,106 @@ module ::DiscourseAi module Inference class AnthropicCompletions CompletionFailed = Class.new(StandardError) + TIMEOUT = 60 - def self.perform!(prompt) + def self.perform!( + prompt, + model = "claude-v1", + temperature: nil, + top_p: nil, + max_tokens: nil, + user_id: nil + ) + url = URI("https://api.anthropic.com/v1/complete") headers = { "x-api-key" => SiteSetting.ai_anthropic_api_key, "Content-Type" => "application/json", } - model = "claude-v1" + payload = { model: model, prompt: prompt } - connection_opts = { request: { write_timeout: 60, read_timeout: 60, open_timeout: 60 } } + payload[:temperature] = temperature if temperature + payload[:top_p] = top_p if top_p + payload[:max_tokens_to_sample] = max_tokens || 300 + payload[:stream] = true if block_given? - response = - Faraday.new(nil, connection_opts).post( - "https://api.anthropic.com/v1/complete", - { model: model, prompt: prompt, max_tokens_to_sample: 300 }.to_json, - headers, - ) + Net::HTTP.start( + url.host, + url.port, + use_ssl: true, + read_timeout: TIMEOUT, + open_timeout: TIMEOUT, + write_timeout: TIMEOUT, + ) do |http| + request = Net::HTTP::Post.new(url, headers) + request_body = payload.to_json + request.body = request_body - if response.status != 200 - Rails.logger.error( - "AnthropicCompletions: status: #{response.status} - body: #{response.body}", - ) - raise CompletionFailed + http.request(request) do |response| + if response.code.to_i != 200 + Rails.logger.error( + "AnthropicCompletions: status: #{response.code.to_i} - body: #{response.body}", + ) + raise CompletionFailed + end + + log = + AiApiAuditLog.create!( + provider_id: AiApiAuditLog::Provider::Anthropic, + raw_request_payload: request_body, + user_id: user_id, + ) + + if !block_given? + response_body = response.read_body + parsed_response = JSON.parse(response_body, symbolize_names: true) + + log.update!( + raw_response_payload: response_body, + request_tokens: DiscourseAi::Tokenizer.size(prompt), + response_tokens: DiscourseAi::Tokenizer.size(parsed_response[:completion]), + ) + return parsed_response + end + + begin + cancelled = false + cancel = lambda { cancelled = true } + response_data = +"" + response_raw = +"" + + response.read_body do |chunk| + if cancelled + http.finish + return + end + + response_raw << chunk + + chunk + .split("\n") + .each do |line| + data = line.split("data: ", 2)[1] + next if !data || data.squish == "[DONE]" + + if !cancelled && partial = JSON.parse(data, symbolize_names: true) + response_data << partial[:completion].to_s + + yield partial, cancel + end + end + rescue IOError + raise if !cancelled + ensure + log.update!( + raw_response_payload: response_raw, + request_tokens: DiscourseAi::Tokenizer.size(prompt), + response_tokens: DiscourseAi::Tokenizer.size(response_data), + ) + end + end + end end - - JSON.parse(response.body, symbolize_names: true) end end end diff --git a/lib/shared/inference/openai_completions.rb b/lib/shared/inference/openai_completions.rb index 7ef49e03..04f560c9 100644 --- a/lib/shared/inference/openai_completions.rb +++ b/lib/shared/inference/openai_completions.rb @@ -9,7 +9,7 @@ module ::DiscourseAi def self.perform!( messages, - model = SiteSetting.ai_helper_model, + model, temperature: nil, top_p: nil, max_tokens: nil, diff --git a/spec/lib/modules/ai_bot/entry_point_spec.rb b/spec/lib/modules/ai_bot/entry_point_spec.rb index f6446232..2347059a 100644 --- a/spec/lib/modules/ai_bot/entry_point_spec.rb +++ b/spec/lib/modules/ai_bot/entry_point_spec.rb @@ -4,7 +4,7 @@ RSpec.describe DiscourseAi::AiBot::EntryPoint do describe "#inject_into" do describe "subscribes to the post_created event" do fab!(:admin) { Fabricate(:admin) } - let(:gpt_bot) { Discourse.gpt_bot } + let(:gpt_bot) { User.find(described_class::GPT4_ID) } fab!(:bot_allowed_group) { Fabricate(:group) } let(:post_args) do @@ -13,7 +13,6 @@ RSpec.describe DiscourseAi::AiBot::EntryPoint do raw: "Hello, Can you please tell me a story?", archetype: Archetype.private_message, target_usernames: [gpt_bot.username].join(","), - category: 1, } end @@ -29,6 +28,19 @@ RSpec.describe DiscourseAi::AiBot::EntryPoint do ).by(1) end + it "includes the bot's user_id" do + claude_bot = User.find(described_class::CLAUDE_V1_ID) + claude_post_attrs = post_args.merge(target_usernames: [claude_bot.username].join(",")) + + expect { PostCreator.create!(admin, claude_post_attrs) }.to change( + Jobs::CreateAiReply.jobs, + :size, + ).by(1) + + job_args = Jobs::CreateAiReply.jobs.last["args"].first + expect(job_args["bot_user_id"]).to eq(claude_bot.id) + end + context "when the post is not from a PM" do it "does nothing" do expect { diff --git a/spec/lib/modules/ai_bot/jobs/regular/create_ai_reply_spec.rb b/spec/lib/modules/ai_bot/jobs/regular/create_ai_reply_spec.rb index bf305e02..50e7cbd0 100644 --- a/spec/lib/modules/ai_bot/jobs/regular/create_ai_reply_spec.rb +++ b/spec/lib/modules/ai_bot/jobs/regular/create_ai_reply_spec.rb @@ -1,6 +1,7 @@ # frozen_string_literal: true require_relative "../../../../../support/openai_completions_inference_stubs" +require_relative "../../../../../support/anthropic_completion_stubs" RSpec.describe Jobs::CreateAiReply do describe "#execute" do @@ -10,44 +11,82 @@ RSpec.describe Jobs::CreateAiReply do let(:expected_response) do "Hello this is a bot and what you just said is an interesting question" end - let(:deltas) { expected_response.split(" ").map { |w| { content: "#{w} " } } } - before do - SiteSetting.min_personal_message_post_length = 5 + before { SiteSetting.min_personal_message_post_length = 5 } - OpenAiCompletionsInferenceStubs.stub_streamed_response( - CompletionPrompt.bot_prompt_with_topic_context(post), - deltas, - req_opts: { - temperature: 0.4, - top_p: 0.9, - max_tokens: 3000, - stream: true, - }, - ) - end + context "when chatting with the Open AI bot" do + let(:deltas) { expected_response.split(" ").map { |w| { content: "#{w} " } } } - it "adds a reply from the GPT bot" do - subject.execute(post_id: topic.first_post.id) + before do + bot_user = User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) - expect(topic.posts.last.raw).to eq(expected_response) - end - - it "streams the reply on the fly to the client through MB" do - messages = - MessageBus.track_publish("discourse-ai/ai-bot/topic/#{topic.id}") do - subject.execute(post_id: topic.first_post.id) - end - - done_signal = messages.pop - - expect(messages.length).to eq(deltas.length) - - messages.each_with_index do |m, idx| - expect(m.data[:raw]).to eq(deltas[0..(idx + 1)].map { |d| d[:content] }.join) + OpenAiCompletionsInferenceStubs.stub_streamed_response( + DiscourseAi::AiBot::OpenAiBot.new(bot_user).bot_prompt_with_topic_context(post), + deltas, + req_opts: { + temperature: 0.4, + top_p: 0.9, + max_tokens: 3000, + stream: true, + }, + ) end - expect(done_signal.data[:done]).to eq(true) + it "adds a reply from the GPT bot" do + subject.execute( + post_id: topic.first_post.id, + bot_user_id: DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID, + ) + + expect(topic.posts.last.raw).to eq(expected_response) + end + + it "streams the reply on the fly to the client through MB" do + messages = + MessageBus.track_publish("discourse-ai/ai-bot/topic/#{topic.id}") do + subject.execute( + post_id: topic.first_post.id, + bot_user_id: DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID, + ) + end + + done_signal = messages.pop + + expect(messages.length).to eq(deltas.length) + + messages.each_with_index do |m, idx| + expect(m.data[:raw]).to eq(deltas[0..(idx + 1)].map { |d| d[:content] }.join) + end + + expect(done_signal.data[:done]).to eq(true) + end + end + + context "when chatting with Claude from Anthropic" do + let(:deltas) { expected_response.split(" ").map { |w| "#{w} " } } + + before do + bot_user = User.find(DiscourseAi::AiBot::EntryPoint::CLAUDE_V1_ID) + + AnthropicCompletionStubs.stub_streamed_response( + DiscourseAi::AiBot::AnthropicBot.new(bot_user).bot_prompt_with_topic_context(post), + deltas, + req_opts: { + temperature: 0.4, + max_tokens_to_sample: 3000, + stream: true, + }, + ) + end + + it "adds a reply from the Claude bot" do + subject.execute( + post_id: topic.first_post.id, + bot_user_id: DiscourseAi::AiBot::EntryPoint::CLAUDE_V1_ID, + ) + + expect(topic.posts.last.raw).to eq(expected_response) + end end end end diff --git a/spec/lib/modules/ai_bot/open_ai_bot_spec.rb b/spec/lib/modules/ai_bot/open_ai_bot_spec.rb new file mode 100644 index 00000000..f2731a74 --- /dev/null +++ b/spec/lib/modules/ai_bot/open_ai_bot_spec.rb @@ -0,0 +1,64 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::AiBot::OpenAiBot do + describe "#bot_prompt_with_topic_context" do + fab!(:topic) { Fabricate(:topic) } + + def post_body(post_number) + "This is post #{post_number}" + end + + def bot_user + User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID) + end + + subject { described_class.new(bot_user) } + + context "when the topic has one post" do + fab!(:post_1) { Fabricate(:post, topic: topic, raw: post_body(1), post_number: 1) } + + it "includes it in the prompt" do + prompt_messages = subject.bot_prompt_with_topic_context(post_1) + + post_1_message = prompt_messages[1] + + expect(post_1_message[:role]).to eq("user") + expect(post_1_message[:content]).to eq(post_body(1)) + end + end + + context "when prompt gets very long" do + fab!(:post_1) { Fabricate(:post, topic: topic, raw: "test " * 6000, post_number: 1) } + + it "trims the prompt" do + prompt_messages = subject.bot_prompt_with_topic_context(post_1) + + expect(prompt_messages[0][:role]).to eq("system") + expect(prompt_messages[1][:role]).to eq("user") + expected_length = ("test " * (subject.prompt_limit)).length + expect(prompt_messages[1][:content].length).to eq(expected_length) + end + end + + context "when the topic has multiple posts" do + fab!(:post_1) { Fabricate(:post, topic: topic, raw: post_body(1), post_number: 1) } + fab!(:post_2) do + Fabricate(:post, topic: topic, user: bot_user, raw: post_body(2), post_number: 2) + end + fab!(:post_3) { Fabricate(:post, topic: topic, raw: post_body(3), post_number: 3) } + + it "includes them in the prompt respecting the post number order" do + prompt_messages = subject.bot_prompt_with_topic_context(post_3) + + expect(prompt_messages[1][:role]).to eq("user") + expect(prompt_messages[1][:content]).to eq(post_body(1)) + + expect(prompt_messages[2][:role]).to eq("system") + expect(prompt_messages[2][:content]).to eq(post_body(2)) + + expect(prompt_messages[3][:role]).to eq("user") + expect(prompt_messages[3][:content]).to eq(post_body(3)) + end + end + end +end diff --git a/spec/models/completion_prompt_spec.rb b/spec/models/completion_prompt_spec.rb index 9fe29457..1eb356df 100644 --- a/spec/models/completion_prompt_spec.rb +++ b/spec/models/completion_prompt_spec.rb @@ -18,59 +18,4 @@ RSpec.describe CompletionPrompt do end end end - - describe ".bot_prompt_with_topic_context" do - fab!(:topic) { Fabricate(:topic) } - - def post_body(post_number) - "This is post #{post_number}" - end - - context "when the topic has one post" do - fab!(:post_1) { Fabricate(:post, topic: topic, raw: post_body(1), post_number: 1) } - - it "includes it in the prompt" do - prompt_messages = described_class.bot_prompt_with_topic_context(post_1) - - post_1_message = prompt_messages[1] - - expect(post_1_message[:role]).to eq("user") - expect(post_1_message[:content]).to eq(post_body(1)) - end - end - - context "when prompt gets very long" do - fab!(:post_1) { Fabricate(:post, topic: topic, raw: "test " * 6000, post_number: 1) } - - it "trims the prompt" do - prompt_messages = described_class.bot_prompt_with_topic_context(post_1) - - expect(prompt_messages[0][:role]).to eq("system") - expect(prompt_messages[1][:role]).to eq("user") - expected_length = ("test " * (CompletionPrompt::MAX_PROMPT_TOKENS)).length - expect(prompt_messages[1][:content].length).to eq(expected_length) - end - end - - context "when the topic has multiple posts" do - fab!(:post_1) { Fabricate(:post, topic: topic, raw: post_body(1), post_number: 1) } - fab!(:post_2) do - Fabricate(:post, topic: topic, user: Discourse.gpt_bot, raw: post_body(2), post_number: 2) - end - fab!(:post_3) { Fabricate(:post, topic: topic, raw: post_body(3), post_number: 3) } - - it "includes them in the prompt respecting the post number order" do - prompt_messages = described_class.bot_prompt_with_topic_context(post_3) - - expect(prompt_messages[1][:role]).to eq("user") - expect(prompt_messages[1][:content]).to eq(post_body(1)) - - expect(prompt_messages[2][:role]).to eq("system") - expect(prompt_messages[2][:content]).to eq(post_body(2)) - - expect(prompt_messages[3][:role]).to eq("user") - expect(prompt_messages[3][:content]).to eq(post_body(3)) - end - end - end end diff --git a/spec/shared/inference/anthropic_completions_spec.rb b/spec/shared/inference/anthropic_completions_spec.rb new file mode 100644 index 00000000..2da9f066 --- /dev/null +++ b/spec/shared/inference/anthropic_completions_spec.rb @@ -0,0 +1,67 @@ +# frozen_string_literal: true + +require_relative "../../support/anthropic_completion_stubs" + +RSpec.describe DiscourseAi::Inference::AnthropicCompletions do + before { SiteSetting.ai_anthropic_api_key = "abc-123" } + + it "can complete a trivial prompt" do + response_text = "1. Serenity\\n2. Laughter\\n3. Adventure" + prompt = "Human: write 3 words\n\n" + user_id = 183 + req_opts = { temperature: 0.5, max_tokens_to_sample: 700 } + + AnthropicCompletionStubs.stub_response(prompt, response_text, req_opts: req_opts) + + completions = + DiscourseAi::Inference::AnthropicCompletions.perform!( + prompt, + "claude-v1", + temperature: req_opts[:temperature], + max_tokens: req_opts[:max_tokens_to_sample], + user_id: user_id, + ) + + expect(completions[:completion]).to eq(response_text) + + expect(AiApiAuditLog.count).to eq(1) + log = AiApiAuditLog.first + + request_body = { model: "claude-v1", prompt: prompt }.merge(req_opts).to_json + response_body = AnthropicCompletionStubs.response(response_text).to_json + + expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic) + expect(log.request_tokens).to eq(7) + expect(log.response_tokens).to eq(16) + expect(log.raw_request_payload).to eq(request_body) + expect(log.raw_response_payload).to eq(response_body) + end + + it "supports streaming mode" do + deltas = ["Mount", "ain", " ", "Tree ", "Frog"] + prompt = "Human: write 3 words\n\n" + req_opts = { max_tokens_to_sample: 300, stream: true } + content = +"" + + AnthropicCompletionStubs.stub_streamed_response(prompt, deltas, req_opts: req_opts) + + DiscourseAi::Inference::AnthropicCompletions.perform!(prompt, "claude-v1") do |partial, cancel| + data = partial[:completion] + content = data if data + cancel.call if content.split(" ").length == 2 + end + + expect(content).to eq("Mountain Tree ") + + expect(AiApiAuditLog.count).to eq(1) + log = AiApiAuditLog.first + + request_body = { model: "claude-v1", prompt: prompt }.merge(req_opts).to_json + + expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic) + expect(log.request_tokens).to eq(7) + expect(log.response_tokens).to eq(9) + expect(log.raw_request_payload).to eq(request_body) + expect(log.raw_response_payload).to be_present + end +end diff --git a/spec/support/anthropic_completion_stubs.rb b/spec/support/anthropic_completion_stubs.rb new file mode 100644 index 00000000..01891573 --- /dev/null +++ b/spec/support/anthropic_completion_stubs.rb @@ -0,0 +1,55 @@ +# frozen_string_literal: true + +class AnthropicCompletionStubs + class << self + def response(content) + { + completion: content, + stop: "\n\nHuman:", + stop_reason: "stop_sequence", + truncated: false, + log_id: "12dcc7feafbee4a394e0de9dffde3ac5", + model: "claude-v1", + exception: nil, + } + end + + def stub_response(prompt, response_text, req_opts: {}) + WebMock + .stub_request(:post, "https://api.anthropic.com/v1/complete") + .with(body: { model: "claude-v1", prompt: prompt }.merge(req_opts).to_json) + .to_return(status: 200, body: JSON.dump(response(response_text))) + end + + def stream_line(delta, finish_reason: nil) + +"data: " << { + completion: delta, + stop: finish_reason ? "\n\nHuman:" : nil, + stop_reason: finish_reason, + truncated: false, + log_id: "12b029451c6d18094d868bc04ce83f63", + model: "claude-v1", + exception: nil, + }.to_json + end + + def stub_streamed_response(prompt, deltas, req_opts: {}) + chunks = + deltas.each_with_index.map do |_, index| + if index == (deltas.length - 1) + stream_line(deltas.join(""), finish_reason: "stop_sequence") + else + stream_line(deltas[0..index].join("")) + end + end + + chunks << "[DONE]" + chunks = chunks.join("\n\n") + + WebMock + .stub_request(:post, "https://api.anthropic.com/v1/complete") + .with(body: { model: "claude-v1", prompt: prompt }.merge(req_opts).to_json) + .to_return(status: 200, body: chunks) + end + end +end