diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index 71702469..40595b67 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -178,9 +178,16 @@ module DiscourseAi when DiscourseAi::AiBot::EntryPoint::FAKE_ID "fake:fake" when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID + # no bedrock support yet 18-03 "anthropic:claude-3-opus" when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID - "anthropic:claude-3-sonnet" + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?( + "claude-3-sonnet", + ) + "aws_bedrock:claude-3-sonnet" + else + "anthropic:claude-3-sonnet" + end else nil end diff --git a/lib/automation/report_runner.rb b/lib/automation/report_runner.rb index 029b22f2..1ab2eb9f 100644 --- a/lib/automation/report_runner.rb +++ b/lib/automation/report_runner.rb @@ -216,13 +216,16 @@ Follow the provided writing composition instructions carefully and precisely ste def translate_model(model) return "google:gemini-pro" if model == "gemini-pro" return "open_ai:#{model}" if model.start_with? "gpt" - return "anthropic:#{model}" if model.start_with? "claude-3" - if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") - "aws_bedrock:#{model}" - else - "anthropic:#{model}" + if model.start_with? "claude" + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(model) + return "aws_bedrock:#{model}" + else + return "anthropic:#{model}" + end end + + raise "Unknown model #{model}" end private diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index 6043a042..616fd168 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -6,7 +6,9 @@ module DiscourseAi class Claude < Dialect class << self def can_translate?(model_name) - %w[claude-instant-1 claude-2].include?(model_name) + %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus].include?( + model_name, + ) end def tokenizer @@ -14,53 +16,69 @@ module DiscourseAi end end + class ClaudePrompt + attr_reader :system_prompt + attr_reader :messages + + def initialize(system_prompt, messages) + @system_prompt = system_prompt + @messages = messages + end + end + def translate messages = prompt.messages + system_prompt = +"" - trimmed_messages = trim_messages(messages) + messages = + trim_messages(messages) + .map do |msg| + case msg[:type] + when :system + system_prompt << msg[:content] + nil + when :tool_call + { role: "assistant", content: tool_call_to_xml(msg) } + when :tool + { role: "user", content: tool_result_to_xml(msg) } + when :model + { role: "assistant", content: msg[:content] } + when :user + content = +"" + content << "#{msg[:id]}: " if msg[:id] + content << msg[:content] - # Need to include this differently - last_message = trimmed_messages.last[:type] == :assistant ? trimmed_messages.pop : nil - - claude_prompt = - trimmed_messages.reduce(+"") do |memo, msg| - if msg[:type] == :tool_call - memo << "\n\nAssistant: #{tool_call_to_xml(msg)}" - elsif msg[:type] == :system - memo << "Human: " unless uses_system_message? - memo << msg[:content] - if prompt.tools.present? - memo << "\n" - memo << build_tools_prompt + { role: "user", content: content } end - elsif msg[:type] == :model - memo << "\n\nAssistant: #{msg[:content]}" - elsif msg[:type] == :tool - memo << "\n\nHuman:\n" - memo << tool_result_to_xml(msg) - else - memo << "\n\nHuman: " - memo << "#{msg[:id]}: " if msg[:id] - memo << msg[:content] end + .compact - memo + if prompt.tools.present? + system_prompt << "\n\n" + system_prompt << build_tools_prompt + end + + interleving_messages = [] + + previous_message = nil + messages.each do |message| + if previous_message + if previous_message[:role] == "user" && message[:role] == "user" + interleving_messages << { role: "assistant", content: "OK" } + elsif previous_message[:role] == "assistant" && message[:role] == "assistant" + interleving_messages << { role: "user", content: "OK" } + end end + interleving_messages << message + previous_message = message + end - claude_prompt << "\n\nAssistant:" - claude_prompt << " #{last_message[:content]}:" if last_message - - claude_prompt + ClaudePrompt.new(system_prompt.presence, interleving_messages) end def max_prompt_tokens - 100_000 # Claude-2.1 has a 200k context window. - end - - private - - def uses_system_message? - model_name == "claude-2" + # Longer term it will have over 1 million + 200_000 # Claude-3 has a 200k context window for now end end end diff --git a/lib/completions/dialects/claude_messages.rb b/lib/completions/dialects/claude_messages.rb deleted file mode 100644 index c0e9feb0..00000000 --- a/lib/completions/dialects/claude_messages.rb +++ /dev/null @@ -1,85 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Completions - module Dialects - class ClaudeMessages < Dialect - class << self - def can_translate?(model_name) - # TODO: add haiku not released yet as of 2024-03-05 - %w[claude-3-sonnet claude-3-opus].include?(model_name) - end - - def tokenizer - DiscourseAi::Tokenizer::AnthropicTokenizer - end - end - - class ClaudePrompt - attr_reader :system_prompt - attr_reader :messages - - def initialize(system_prompt, messages) - @system_prompt = system_prompt - @messages = messages - end - end - - def translate - messages = prompt.messages - system_prompt = +"" - - messages = - trim_messages(messages) - .map do |msg| - case msg[:type] - when :system - system_prompt << msg[:content] - nil - when :tool_call - { role: "assistant", content: tool_call_to_xml(msg) } - when :tool - { role: "user", content: tool_result_to_xml(msg) } - when :model - { role: "assistant", content: msg[:content] } - when :user - content = +"" - content << "#{msg[:id]}: " if msg[:id] - content << msg[:content] - - { role: "user", content: content } - end - end - .compact - - if prompt.tools.present? - system_prompt << "\n\n" - system_prompt << build_tools_prompt - end - - interleving_messages = [] - - previous_message = nil - messages.each do |message| - if previous_message - if previous_message[:role] == "user" && message[:role] == "user" - interleving_messages << { role: "assistant", content: "OK" } - elsif previous_message[:role] == "assistant" && message[:role] == "assistant" - interleving_messages << { role: "user", content: "OK" } - end - end - interleving_messages << message - previous_message = message - end - - ClaudePrompt.new(system_prompt.presence, interleving_messages) - end - - def max_prompt_tokens - # Longer term it will have over 1 million - 200_000 # Claude-3 has a 200k context window for now - end - end - end - end -end diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index fb7b48da..3baaa68d 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -11,13 +11,12 @@ module DiscourseAi def dialect_for(model_name) dialects = [ - DiscourseAi::Completions::Dialects::Claude, DiscourseAi::Completions::Dialects::Llama2Classic, DiscourseAi::Completions::Dialects::ChatGpt, DiscourseAi::Completions::Dialects::OrcaStyle, DiscourseAi::Completions::Dialects::Gemini, DiscourseAi::Completions::Dialects::Mixtral, - DiscourseAi::Completions::Dialects::ClaudeMessages, + DiscourseAi::Completions::Dialects::Claude, ] if Rails.env.test? || Rails.env.development? diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index 0f766c4d..8c27a269 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -6,7 +6,10 @@ module DiscourseAi class Anthropic < Base class << self def can_contact?(endpoint_name, model_name) - endpoint_name == "anthropic" && %w[claude-instant-1 claude-2].include?(model_name) + endpoint_name == "anthropic" && + %w[claude-instant-1 claude-2 claude-3-haiku claude-3-opus claude-3-sonnet].include?( + model_name, + ) end def dependant_setting_names @@ -23,23 +26,32 @@ module DiscourseAi end def normalize_model_params(model_params) - model_params = model_params.dup - - # temperature, stop_sequences are already supported - # - if model_params[:max_tokens] - model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens) - end - + # max_tokens, temperature, stop_sequences are already supported model_params end - def default_options - { - model: model == "claude-2" ? "claude-2.1" : model, - max_tokens_to_sample: 3_000, - stop_sequences: ["\n\nHuman:", ""], - } + def default_options(dialect) + # skipping 2.0 support for now, since other models are better + mapped_model = + case model + when "claude-2" + "claude-2.1" + when "claude-instant-1" + "claude-instant-1.2" + when "claude-3-haiku" + "claude-3-haiku-20240307" + when "claude-3-sonnet" + "claude-3-sonnet-20240229" + when "claude-3-opus" + "claude-3-opus-20240229" + else + raise "Unsupported model: #{model}" + end + + options = { model: mapped_model, max_tokens: 3_000 } + + options[:stop_sequences] = [""] if dialect.prompt.has_tools? + options end def provider_id @@ -48,15 +60,22 @@ module DiscourseAi private - def model_uri - @uri ||= URI("https://api.anthropic.com/v1/complete") + # this is an approximation, we will update it later if request goes through + def prompt_size(prompt) + super(prompt.system_prompt.to_s + " " + prompt.messages.to_s) end - def prepare_payload(prompt, model_params, _dialect) - default_options - .merge(model_params) - .merge(prompt: prompt) - .tap { |payload| payload[:stream] = true if @streaming_mode } + def model_uri + @uri ||= URI("https://api.anthropic.com/v1/messages") + end + + def prepare_payload(prompt, model_params, dialect) + payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages) + + payload[:system] = prompt.system_prompt if prompt.system_prompt.present? + payload[:stream] = true if @streaming_mode + + payload end def prepare_request(payload) @@ -69,8 +88,30 @@ module DiscourseAi Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end + def final_log_update(log) + log.request_tokens = @input_tokens if @input_tokens + log.response_tokens = @output_tokens if @output_tokens + end + def extract_completion_from(response_raw) - JSON.parse(response_raw, symbolize_names: true)[:completion].to_s + result = "" + parsed = JSON.parse(response_raw, symbolize_names: true) + + if @streaming_mode + if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" + result = parsed.dig(:delta, :text).to_s + elsif parsed[:type] == "message_start" + @input_tokens = parsed.dig(:message, :usage, :input_tokens) + elsif parsed[:type] == "message_delta" + @output_tokens = parsed.dig(:delta, :usage, :output_tokens) + end + else + result = parsed.dig(:content, 0, :text).to_s + @input_tokens = parsed.dig(:usage, :input_tokens) + @output_tokens = parsed.dig(:usage, :output_tokens) + end + + result end def partials_from(decoded_chunk) diff --git a/lib/completions/endpoints/anthropic_messages.rb b/lib/completions/endpoints/anthropic_messages.rb deleted file mode 100644 index 7485e0e6..00000000 --- a/lib/completions/endpoints/anthropic_messages.rb +++ /dev/null @@ -1,103 +0,0 @@ -# frozen_string_literal: true - -module DiscourseAi - module Completions - module Endpoints - class AnthropicMessages < Base - class << self - def can_contact?(endpoint_name, model_name) - endpoint_name == "anthropic" && %w[claude-3-opus claude-3-sonnet].include?(model_name) - end - - def dependant_setting_names - %w[ai_anthropic_api_key] - end - - def correctly_configured?(_model_name) - SiteSetting.ai_anthropic_api_key.present? - end - - def endpoint_name(model_name) - "Anthropic - #{model_name}" - end - end - - def normalize_model_params(model_params) - # max_tokens, temperature, stop_sequences are already supported - model_params - end - - def default_options(dialect) - options = { model: model + "-20240229", max_tokens: 3_000 } - - options[:stop_sequences] = [""] if dialect.prompt.has_tools? - options - end - - def provider_id - AiApiAuditLog::Provider::Anthropic - end - - private - - # this is an approximation, we will update it later if request goes through - def prompt_size(prompt) - super(prompt.system_prompt.to_s + " " + prompt.messages.to_s) - end - - def model_uri - @uri ||= URI("https://api.anthropic.com/v1/messages") - end - - def prepare_payload(prompt, model_params, dialect) - payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages) - - payload[:system] = prompt.system_prompt if prompt.system_prompt.present? - payload[:stream] = true if @streaming_mode - - payload - end - - def prepare_request(payload) - headers = { - "anthropic-version" => "2023-06-01", - "x-api-key" => SiteSetting.ai_anthropic_api_key, - "content-type" => "application/json", - } - - Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } - end - - def final_log_update(log) - log.request_tokens = @input_tokens if @input_tokens - log.response_tokens = @output_tokens if @output_tokens - end - - def extract_completion_from(response_raw) - result = "" - parsed = JSON.parse(response_raw, symbolize_names: true) - - if @streaming_mode - if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" - result = parsed.dig(:delta, :text).to_s - elsif parsed[:type] == "message_start" - @input_tokens = parsed.dig(:message, :usage, :input_tokens) - elsif parsed[:type] == "message_delta" - @output_tokens = parsed.dig(:delta, :usage, :output_tokens) - end - else - result = parsed.dig(:content, 0, :text).to_s - @input_tokens = parsed.dig(:usage, :input_tokens) - @output_tokens = parsed.dig(:usage, :output_tokens) - end - - result - end - - def partials_from(decoded_chunk) - decoded_chunk.split("\n").map { |line| line.split("data: ", 2)[1] }.compact - end - end - end - end -end diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index 92c5e16f..d62cfe41 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -8,17 +8,18 @@ module DiscourseAi class AwsBedrock < Base class << self def can_contact?(endpoint_name, model_name) - endpoint_name == "aws_bedrock" && %w[claude-instant-1 claude-2].include?(model_name) + endpoint_name == "aws_bedrock" && + %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet].include?(model_name) end def dependant_setting_names %w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region] end - def correctly_configured?(_model_name) + def correctly_configured?(model) SiteSetting.ai_bedrock_access_key_id.present? && SiteSetting.ai_bedrock_secret_access_key.present? && - SiteSetting.ai_bedrock_region.present? + SiteSetting.ai_bedrock_region.present? && can_contact?("aws_bedrock", model) end def endpoint_name(model_name) @@ -29,17 +30,15 @@ module DiscourseAi def normalize_model_params(model_params) model_params = model_params.dup - # temperature, stop_sequences are already supported - # - if model_params[:max_tokens] - model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens) - end + # max_tokens, temperature, stop_sequences, top_p are already supported model_params end - def default_options - { max_tokens_to_sample: 3_000, stop_sequences: ["\n\nHuman:", ""] } + def default_options(dialect) + options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" } + options[:stop_sequences] = [""] if dialect.prompt.has_tools? + options end def provider_id @@ -48,25 +47,40 @@ module DiscourseAi private - def model_uri - # Bedrock uses slightly different names - # See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html - bedrock_model_id = model.split("-") - bedrock_model_id[-1] = "v#{bedrock_model_id.last}" - bedrock_model_id = +(bedrock_model_id.join("-")) + def prompt_size(prompt) + # approximation + super(prompt.system_prompt.to_s + " " + prompt.messages.to_s) + end - bedrock_model_id << ":1" if model == "claude-2" # For claude-2.1 + def model_uri + # See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html + # + # FYI there is a 2.0 version of Claude, very little need to support it given + # haiku/sonnet are better fits anyway, we map to claude-2.1 + bedrock_model_id = + case model + when "claude-2" + "anthropic.claude-v2:1" + when "claude-3-haiku" + "anthropic.claude-3-haiku-20240307-v1:0" + when "claude-3-sonnet" + "anthropic.claude-3-sonnet-20240229-v1:0" + when "claude-instant-1" + "anthropic.claude-instant-v1" + end api_url = - "https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.#{bedrock_model_id}/invoke" + "https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke" api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url URI(api_url) end - def prepare_payload(prompt, model_params, _dialect) - default_options.merge(prompt: prompt).merge(model_params) + def prepare_payload(prompt, model_params, dialect) + payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages) + payload[:system] = prompt.system_prompt if prompt.system_prompt.present? + payload end def prepare_request(payload) @@ -117,8 +131,30 @@ module DiscourseAi nil end + def final_log_update(log) + log.request_tokens = @input_tokens if @input_tokens + log.response_tokens = @output_tokens if @output_tokens + end + def extract_completion_from(response_raw) - JSON.parse(response_raw, symbolize_names: true)[:completion].to_s + result = "" + parsed = JSON.parse(response_raw, symbolize_names: true) + + if @streaming_mode + if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" + result = parsed.dig(:delta, :text).to_s + elsif parsed[:type] == "message_start" + @input_tokens = parsed.dig(:message, :usage, :input_tokens) + elsif parsed[:type] == "message_delta" + @output_tokens = parsed.dig(:delta, :usage, :output_tokens) + end + else + result = parsed.dig(:content, 0, :text).to_s + @input_tokens = parsed.dig(:usage, :input_tokens) + @output_tokens = parsed.dig(:usage, :output_tokens) + end + + result end def partials_from(decoded_chunk) diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index ec813cab..51c8c012 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -11,12 +11,11 @@ module DiscourseAi def endpoint_for(provider_name, model_name) endpoints = [ DiscourseAi::Completions::Endpoints::AwsBedrock, - DiscourseAi::Completions::Endpoints::Anthropic, DiscourseAi::Completions::Endpoints::OpenAi, DiscourseAi::Completions::Endpoints::HuggingFace, DiscourseAi::Completions::Endpoints::Gemini, DiscourseAi::Completions::Endpoints::Vllm, - DiscourseAi::Completions::Endpoints::AnthropicMessages, + DiscourseAi::Completions::Endpoints::Anthropic, ] if Rails.env.test? || Rails.env.development? diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index ba4bdf34..ed1b216b 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -23,8 +23,8 @@ module DiscourseAi # However, since they use the same URL/key settings, there's no reason to duplicate them. @models_by_provider ||= { - aws_bedrock: %w[claude-instant-1 claude-2], - anthropic: %w[claude-instant-1 claude-2 claude-3-sonnet claude-3-opus], + aws_bedrock: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet], + anthropic: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus], vllm: %w[ mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2 diff --git a/spec/lib/completions/dialects/claude_messages_spec.rb b/spec/lib/completions/dialects/claude_messages_spec.rb deleted file mode 100644 index 7e1f1edc..00000000 --- a/spec/lib/completions/dialects/claude_messages_spec.rb +++ /dev/null @@ -1,87 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe DiscourseAi::Completions::Dialects::ClaudeMessages do - describe "#translate" do - it "can insert OKs to make stuff interleve properly" do - messages = [ - { type: :user, id: "user1", content: "1" }, - { type: :model, content: "2" }, - { type: :user, id: "user1", content: "4" }, - { type: :user, id: "user1", content: "5" }, - { type: :model, content: "6" }, - ] - - prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages) - - dialectKlass = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus") - dialect = dialectKlass.new(prompt, "claude-3-opus") - translated = dialect.translate - - expected_messages = [ - { role: "user", content: "user1: 1" }, - { role: "assistant", content: "2" }, - { role: "user", content: "user1: 4" }, - { role: "assistant", content: "OK" }, - { role: "user", content: "user1: 5" }, - { role: "assistant", content: "6" }, - ] - - expect(translated.messages).to eq(expected_messages) - end - - it "can properly translate a prompt" do - dialect = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus") - - tools = [ - { - name: "echo", - description: "echo a string", - parameters: [ - { name: "text", type: "string", description: "string to echo", required: true }, - ], - }, - ] - - tool_call_prompt = { name: "echo", arguments: { text: "something" } } - - messages = [ - { type: :user, id: "user1", content: "echo something" }, - { type: :tool_call, content: tool_call_prompt.to_json }, - { type: :tool, id: "tool_id", content: "something".to_json }, - { type: :model, content: "I did it" }, - { type: :user, id: "user1", content: "echo something else" }, - ] - - prompt = - DiscourseAi::Completions::Prompt.new( - "You are a helpful bot", - messages: messages, - tools: tools, - ) - - dialect = dialect.new(prompt, "claude-3-opus") - translated = dialect.translate - - expect(translated.system_prompt).to start_with("You are a helpful bot") - expect(translated.system_prompt).to include("echo a string") - - expected = [ - { role: "user", content: "user1: echo something" }, - { - role: "assistant", - content: - "\n\necho\n\nsomething\n\n\n", - }, - { - role: "user", - content: - "\n\ntool_id\n\n\"something\"\n\n\n", - }, - { role: "assistant", content: "I did it" }, - { role: "user", content: "user1: echo something else" }, - ] - - expect(translated.messages).to eq(expected) - end - end -end diff --git a/spec/lib/completions/dialects/claude_spec.rb b/spec/lib/completions/dialects/claude_spec.rb index efb62bb3..af082c22 100644 --- a/spec/lib/completions/dialects/claude_spec.rb +++ b/spec/lib/completions/dialects/claude_spec.rb @@ -1,100 +1,87 @@ # frozen_string_literal: true -require_relative "dialect_context" - RSpec.describe DiscourseAi::Completions::Dialects::Claude do - let(:model_name) { "claude-2" } - let(:context) { DialectContext.new(described_class, model_name) } - describe "#translate" do - it "translates a prompt written in our generic format to Claude's format" do - anthropic_version = (<<~TEXT).strip - #{context.system_insts} - #{described_class.tool_preamble} - - #{context.dialect_tools} + it "can insert OKs to make stuff interleve properly" do + messages = [ + { type: :user, id: "user1", content: "1" }, + { type: :model, content: "2" }, + { type: :user, id: "user1", content: "4" }, + { type: :user, id: "user1", content: "5" }, + { type: :model, content: "6" }, + ] - Human: #{context.simple_user_input} + prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages) - Assistant: - TEXT + dialectKlass = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus") + dialect = dialectKlass.new(prompt, "claude-3-opus") + translated = dialect.translate - translated = context.system_user_scenario + expected_messages = [ + { role: "user", content: "user1: 1" }, + { role: "assistant", content: "2" }, + { role: "user", content: "user1: 4" }, + { role: "assistant", content: "OK" }, + { role: "user", content: "user1: 5" }, + { role: "assistant", content: "6" }, + ] - expect(translated).to eq(anthropic_version) + expect(translated.messages).to eq(expected_messages) end - it "translates tool messages" do - expected = +(<<~TEXT).strip - #{context.system_insts} - #{described_class.tool_preamble} - - #{context.dialect_tools} + it "can properly translate a prompt" do + dialect = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus") - Human: user1: This is a message by a user + tools = [ + { + name: "echo", + description: "echo a string", + parameters: [ + { name: "text", type: "string", description: "string to echo", required: true }, + ], + }, + ] - Assistant: I'm a previous bot reply, that's why there's no user + tool_call_prompt = { name: "echo", arguments: { text: "something" } } - Human: user1: This is a new message by a user + messages = [ + { type: :user, id: "user1", content: "echo something" }, + { type: :tool_call, content: tool_call_prompt.to_json }, + { type: :tool, id: "tool_id", content: "something".to_json }, + { type: :model, content: "I did it" }, + { type: :user, id: "user1", content: "echo something else" }, + ] - Assistant: - - get_weather - - Sydney - c - - - - - Human: - - - get_weather - - "I'm a tool result" - - - - - Assistant: - TEXT - - expect(context.multi_turn_scenario).to eq(expected) - end - - it "trims content if it's getting too long" do - length = 19_000 - - translated = context.long_user_input_scenario(length: length) - - expect(translated.length).to be < context.long_message_text(length: length).length - end - - it "retains usernames in generated prompt" do prompt = DiscourseAi::Completions::Prompt.new( - "You are a bot", - messages: [ - { id: "👻", type: :user, content: "Message1" }, - { type: :model, content: "Ok" }, - { id: "joe", type: :user, content: "Message2" }, - ], + "You are a helpful bot", + messages: messages, + tools: tools, ) - translated = context.dialect(prompt).translate + dialect = dialect.new(prompt, "claude-3-opus") + translated = dialect.translate - expect(translated).to eq(<<~TEXT.strip) - You are a bot + expect(translated.system_prompt).to start_with("You are a helpful bot") + expect(translated.system_prompt).to include("echo a string") - Human: 👻: Message1 + expected = [ + { role: "user", content: "user1: echo something" }, + { + role: "assistant", + content: + "\n\necho\n\nsomething\n\n\n", + }, + { + role: "user", + content: + "\n\ntool_id\n\n\"something\"\n\n\n", + }, + { role: "assistant", content: "I did it" }, + { role: "user", content: "user1: echo something else" }, + ] - Assistant: Ok - - Human: joe: Message2 - - Assistant: - TEXT + expect(translated.messages).to eq(expected) end end end diff --git a/spec/lib/completions/endpoints/anthropic_messages_spec.rb b/spec/lib/completions/endpoints/anthropic_messages_spec.rb deleted file mode 100644 index 1d52e46c..00000000 --- a/spec/lib/completions/endpoints/anthropic_messages_spec.rb +++ /dev/null @@ -1,395 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do - let(:llm) { DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-opus") } - - let(:prompt) do - DiscourseAi::Completions::Prompt.new( - "You are hello bot", - messages: [type: :user, id: "user1", content: "hello"], - ) - end - - let(:echo_tool) do - { - name: "echo", - description: "echo something", - parameters: [{ name: "text", type: "string", description: "text to echo", required: true }], - } - end - - let(:google_tool) do - { - name: "google", - description: "google something", - parameters: [ - { name: "query", type: "string", description: "text to google", required: true }, - ], - } - end - - let(:prompt_with_echo_tool) do - prompt_with_tools = prompt - prompt.tools = [echo_tool] - prompt_with_tools - end - - let(:prompt_with_google_tool) do - prompt_with_tools = prompt - prompt.tools = [echo_tool] - prompt_with_tools - end - - before { SiteSetting.ai_anthropic_api_key = "123" } - - it "does not eat spaces with tool calls" do - body = <<~STRING - event: message_start - data: {"type":"message_start","message":{"id":"msg_019kmW9Q3GqfWmuFJbePJTBR","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":347,"output_tokens":1}}} - - event: content_block_start - data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} - - event: ping - data: {"type": "ping"} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"google"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"top"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"10"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"things"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" do"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" in"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" japan"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" for"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" tourists"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} - - event: content_block_stop - data: {"type":"content_block_stop","index":0} - - event: message_delta - data: {"type":"message_delta","delta":{"stop_reason":"stop_sequence","stop_sequence":""},"usage":{"output_tokens":57}} - - event: message_stop - data: {"type":"message_stop"} - STRING - - stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body) - - result = +"" - llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial| - result << partial - end - - expected = (<<~TEXT).strip - - - google - - top 10 things to do in japan for tourists - - tool_0 - - - TEXT - - expect(result.strip).to eq(expected) - end - - it "can stream a response" do - body = (<<~STRING).strip - event: message_start - data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}} - - event: content_block_start - data: {"type": "content_block_start", "index":0, "content_block": {"type": "text", "text": ""}} - - event: ping - data: {"type": "ping"} - - event: content_block_delta - data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}} - - event: content_block_delta - data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}} - - event: content_block_stop - data: {"type": "content_block_stop", "index": 0} - - event: message_delta - data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}} - - event: message_stop - data: {"type": "message_stop"} - STRING - - parsed_body = nil - - stub_request(:post, "https://api.anthropic.com/v1/messages").with( - body: - proc do |req_body| - parsed_body = JSON.parse(req_body, symbolize_names: true) - true - end, - headers: { - "Content-Type" => "application/json", - "X-Api-Key" => "123", - "Anthropic-Version" => "2023-06-01", - }, - ).to_return(status: 200, body: body) - - result = +"" - llm.generate(prompt, user: Discourse.system_user) { |partial, cancel| result << partial } - - expect(result).to eq("Hello!") - - expected_body = { - model: "claude-3-opus-20240229", - max_tokens: 3000, - messages: [{ role: "user", content: "user1: hello" }], - system: "You are hello bot", - stream: true, - } - expect(parsed_body).to eq(expected_body) - - log = AiApiAuditLog.order(:id).last - expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic) - expect(log.request_tokens).to eq(25) - expect(log.response_tokens).to eq(15) - end - - it "can return multiple function calls" do - functions = <<~FUNCTIONS - - - echo - - something - - - - echo - - something else - - - FUNCTIONS - - body = <<~STRING - { - "content": [ - { - "text": "Hello!\n\n#{functions}\njunk", - "type": "text" - } - ], - "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF", - "model": "claude-3-opus-20240229", - "role": "assistant", - "stop_reason": "end_turn", - "stop_sequence": null, - "type": "message", - "usage": { - "input_tokens": 10, - "output_tokens": 25 - } - } - STRING - - stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body) - - result = llm.generate(prompt_with_echo_tool, user: Discourse.system_user) - - expected = (<<~EXPECTED).strip - - - echo - - something - - tool_0 - - - echo - - something else - - tool_1 - - - EXPECTED - - expect(result.strip).to eq(expected) - end - - it "can operate in regular mode" do - body = <<~STRING - { - "content": [ - { - "text": "Hello!", - "type": "text" - } - ], - "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF", - "model": "claude-3-opus-20240229", - "role": "assistant", - "stop_reason": "end_turn", - "stop_sequence": null, - "type": "message", - "usage": { - "input_tokens": 10, - "output_tokens": 25 - } - } - STRING - - parsed_body = nil - stub_request(:post, "https://api.anthropic.com/v1/messages").with( - body: - proc do |req_body| - parsed_body = JSON.parse(req_body, symbolize_names: true) - true - end, - headers: { - "Content-Type" => "application/json", - "X-Api-Key" => "123", - "Anthropic-Version" => "2023-06-01", - }, - ).to_return(status: 200, body: body) - - result = llm.generate(prompt, user: Discourse.system_user) - expect(result).to eq("Hello!") - - expected_body = { - model: "claude-3-opus-20240229", - max_tokens: 3000, - messages: [{ role: "user", content: "user1: hello" }], - system: "You are hello bot", - } - expect(parsed_body).to eq(expected_body) - - log = AiApiAuditLog.order(:id).last - expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic) - expect(log.request_tokens).to eq(10) - expect(log.response_tokens).to eq(25) - end -end diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index b05b79b0..d34acbf5 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -1,96 +1,395 @@ -# frozen_String_literal: true +# frozen_string_literal: true -require_relative "endpoint_compliance" +RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do + let(:llm) { DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-opus") } -class AnthropicMock < EndpointMock - def response(content) + let(:prompt) do + DiscourseAi::Completions::Prompt.new( + "You are hello bot", + messages: [type: :user, id: "user1", content: "hello"], + ) + end + + let(:echo_tool) do { - completion: content, - stop: "\n\nHuman:", - stop_reason: "stop_sequence", - truncated: false, - log_id: "12dcc7feafbee4a394e0de9dffde3ac5", - model: "claude-2", - exception: nil, + name: "echo", + description: "echo something", + parameters: [{ name: "text", type: "string", description: "text to echo", required: true }], } end - def stub_response(prompt, response_text, tool_call: false) - WebMock - .stub_request(:post, "https://api.anthropic.com/v1/complete") - .with(body: model.default_options.merge(prompt: prompt).to_json) - .to_return(status: 200, body: JSON.dump(response(response_text))) + let(:google_tool) do + { + name: "google", + description: "google something", + parameters: [ + { name: "query", type: "string", description: "text to google", required: true }, + ], + } end - def stream_line(delta, finish_reason: nil) - +"data: " << { - completion: delta, - stop: finish_reason ? "\n\nHuman:" : nil, - stop_reason: finish_reason, - truncated: false, - log_id: "12b029451c6d18094d868bc04ce83f63", - model: "claude-2", - exception: nil, - }.to_json + let(:prompt_with_echo_tool) do + prompt_with_tools = prompt + prompt.tools = [echo_tool] + prompt_with_tools end - def stub_streamed_response(prompt, deltas, tool_call: false) - chunks = - deltas.each_with_index.map do |_, index| - if index == (deltas.length - 1) - stream_line(deltas[index], finish_reason: "stop_sequence") - else - stream_line(deltas[index]) - end - end - - chunks = chunks.join("\n\n").split("") - - WebMock - .stub_request(:post, "https://api.anthropic.com/v1/complete") - .with(body: model.default_options.merge(prompt: prompt, stream: true).to_json) - .to_return(status: 200, body: chunks) - end -end - -RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do - subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) } - - fab!(:user) - - let(:anthropic_mock) { AnthropicMock.new(endpoint) } - - let(:compliance) do - EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user) + let(:prompt_with_google_tool) do + prompt_with_tools = prompt + prompt.tools = [echo_tool] + prompt_with_tools end - describe "#perform_completion!" do - context "when using regular mode" do - context "with simple prompts" do - it "completes a trivial prompt and logs the response" do - compliance.regular_mode_simple_prompt(anthropic_mock) - end - end + before { SiteSetting.ai_anthropic_api_key = "123" } - context "with tools" do - it "returns a function invocation" do - compliance.regular_mode_tools(anthropic_mock) - end - end + it "does not eat spaces with tool calls" do + body = <<~STRING + event: message_start + data: {"type":"message_start","message":{"id":"msg_019kmW9Q3GqfWmuFJbePJTBR","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":347,"output_tokens":1}}} + + event: content_block_start + data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} + + event: ping + data: {"type": "ping"} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"google"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"top"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"10"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"things"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" do"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" in"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" japan"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" for"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" tourists"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} + + event: content_block_stop + data: {"type":"content_block_stop","index":0} + + event: message_delta + data: {"type":"message_delta","delta":{"stop_reason":"stop_sequence","stop_sequence":""},"usage":{"output_tokens":57}} + + event: message_stop + data: {"type":"message_stop"} + STRING + + stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body) + + result = +"" + llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial| + result << partial end - describe "when using streaming mode" do - context "with simple prompts" do - it "completes a trivial prompt and logs the response" do - compliance.streaming_mode_simple_prompt(anthropic_mock) - end - end + expected = (<<~TEXT).strip + + + google + + top 10 things to do in japan for tourists + + tool_0 + + + TEXT - context "with tools" do - it "returns a function invocation" do - compliance.streaming_mode_tools(anthropic_mock) - end - end - end + expect(result.strip).to eq(expected) + end + + it "can stream a response" do + body = (<<~STRING).strip + event: message_start + data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}} + + event: content_block_start + data: {"type": "content_block_start", "index":0, "content_block": {"type": "text", "text": ""}} + + event: ping + data: {"type": "ping"} + + event: content_block_delta + data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}} + + event: content_block_delta + data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}} + + event: content_block_stop + data: {"type": "content_block_stop", "index": 0} + + event: message_delta + data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}} + + event: message_stop + data: {"type": "message_stop"} + STRING + + parsed_body = nil + + stub_request(:post, "https://api.anthropic.com/v1/messages").with( + body: + proc do |req_body| + parsed_body = JSON.parse(req_body, symbolize_names: true) + true + end, + headers: { + "Content-Type" => "application/json", + "X-Api-Key" => "123", + "Anthropic-Version" => "2023-06-01", + }, + ).to_return(status: 200, body: body) + + result = +"" + llm.generate(prompt, user: Discourse.system_user) { |partial, cancel| result << partial } + + expect(result).to eq("Hello!") + + expected_body = { + model: "claude-3-opus-20240229", + max_tokens: 3000, + messages: [{ role: "user", content: "user1: hello" }], + system: "You are hello bot", + stream: true, + } + expect(parsed_body).to eq(expected_body) + + log = AiApiAuditLog.order(:id).last + expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic) + expect(log.request_tokens).to eq(25) + expect(log.response_tokens).to eq(15) + end + + it "can return multiple function calls" do + functions = <<~FUNCTIONS + + + echo + + something + + + + echo + + something else + + + FUNCTIONS + + body = <<~STRING + { + "content": [ + { + "text": "Hello!\n\n#{functions}\njunk", + "type": "text" + } + ], + "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF", + "model": "claude-3-opus-20240229", + "role": "assistant", + "stop_reason": "end_turn", + "stop_sequence": null, + "type": "message", + "usage": { + "input_tokens": 10, + "output_tokens": 25 + } + } + STRING + + stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body) + + result = llm.generate(prompt_with_echo_tool, user: Discourse.system_user) + + expected = (<<~EXPECTED).strip + + + echo + + something + + tool_0 + + + echo + + something else + + tool_1 + + + EXPECTED + + expect(result.strip).to eq(expected) + end + + it "can operate in regular mode" do + body = <<~STRING + { + "content": [ + { + "text": "Hello!", + "type": "text" + } + ], + "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF", + "model": "claude-3-opus-20240229", + "role": "assistant", + "stop_reason": "end_turn", + "stop_sequence": null, + "type": "message", + "usage": { + "input_tokens": 10, + "output_tokens": 25 + } + } + STRING + + parsed_body = nil + stub_request(:post, "https://api.anthropic.com/v1/messages").with( + body: + proc do |req_body| + parsed_body = JSON.parse(req_body, symbolize_names: true) + true + end, + headers: { + "Content-Type" => "application/json", + "X-Api-Key" => "123", + "Anthropic-Version" => "2023-06-01", + }, + ).to_return(status: 200, body: body) + + result = llm.generate(prompt, user: Discourse.system_user) + expect(result).to eq("Hello!") + + expected_body = { + model: "claude-3-opus-20240229", + max_tokens: 3000, + messages: [{ role: "user", content: "user1: hello" }], + system: "You are hello bot", + } + expect(parsed_body).to eq(expected_body) + + log = AiApiAuditLog.order(:id).last + expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic) + expect(log.request_tokens).to eq(10) + expect(log.response_tokens).to eq(25) end end diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index 2461941b..7d3554cd 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -5,71 +5,6 @@ require "aws-eventstream" require "aws-sigv4" class BedrockMock < EndpointMock - def response(content) - { - completion: content, - stop: "\n\nHuman:", - stop_reason: "stop_sequence", - truncated: false, - log_id: "12dcc7feafbee4a394e0de9dffde3ac5", - model: "claude", - exception: nil, - } - end - - def stub_response(prompt, response_content, tool_call: false) - WebMock - .stub_request(:post, "#{base_url}/invoke") - .with(body: model.default_options.merge(prompt: prompt).to_json) - .to_return(status: 200, body: JSON.dump(response(response_content))) - end - - def stream_line(delta, finish_reason: nil) - encoder = Aws::EventStream::Encoder.new - - message = - Aws::EventStream::Message.new( - payload: - StringIO.new( - { - bytes: - Base64.encode64( - { - completion: delta, - stop: finish_reason ? "\n\nHuman:" : nil, - stop_reason: finish_reason, - truncated: false, - log_id: "12b029451c6d18094d868bc04ce83f63", - model: "claude-2.1", - exception: nil, - }.to_json, - ), - }.to_json, - ), - ) - - encoder.encode(message) - end - - def stub_streamed_response(prompt, deltas, tool_call: false) - chunks = - deltas.each_with_index.map do |_, index| - if index == (deltas.length - 1) - stream_line(deltas[index], finish_reason: "stop_sequence") - else - stream_line(deltas[index]) - end - end - - WebMock - .stub_request(:post, "#{base_url}/invoke-with-response-stream") - .with(body: model.default_options.merge(prompt: prompt).to_json) - .to_return(status: 200, body: chunks) - end - - def base_url - "https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.claude-v2:1" - end end RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do @@ -89,32 +24,98 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do SiteSetting.ai_bedrock_region = "us-east-1" end - describe "#perform_completion!" do - context "when using regular mode" do - context "with simple prompts" do - it "completes a trivial prompt and logs the response" do - compliance.regular_mode_simple_prompt(bedrock_mock) - end - end + describe "Claude 3 Sonnet support" do + it "supports the sonnet model" do + proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet") - context "with tools" do - it "returns a function invocation" do - compliance.regular_mode_tools(bedrock_mock) + request = nil + + content = { + content: [text: "hello sam"], + usage: { + input_tokens: 10, + output_tokens: 20, + }, + }.to_json + + stub_request( + :post, + "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke", + ) + .with do |inner_request| + request = inner_request + true end - end + .to_return(status: 200, body: content) + + response = proxy.generate("hello world", user: user) + + expect(request.headers["Authorization"]).to be_present + expect(request.headers["X-Amz-Content-Sha256"]).to be_present + + expected = { + "max_tokens" => 3000, + "anthropic_version" => "bedrock-2023-05-31", + "messages" => [{ "role" => "user", "content" => "hello world" }], + "system" => "You are a helpful bot", + } + expect(JSON.parse(request.body)).to eq(expected) + + expect(response).to eq("hello sam") + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(10) + expect(log.response_tokens).to eq(20) end - describe "when using streaming mode" do - context "with simple prompts" do - it "completes a trivial prompt and logs the response" do - compliance.streaming_mode_simple_prompt(bedrock_mock) - end - end + it "supports claude 3 sonnet streaming" do + proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet") - context "with tools" do - it "returns a function invocation" do - compliance.streaming_mode_tools(bedrock_mock) + request = nil + + messages = + [ + { type: "message_start", message: { usage: { input_tokens: 9 } } }, + { type: "content_block_delta", delta: { text: "hello " } }, + { type: "content_block_delta", delta: { text: "sam" } }, + { type: "message_delta", delta: { usage: { output_tokens: 25 } } }, + ].map do |message| + wrapped = { bytes: Base64.encode64(message.to_json) }.to_json + io = StringIO.new(wrapped) + aws_message = Aws::EventStream::Message.new(payload: io) + Aws::EventStream::Encoder.new.encode(aws_message) end + + bedrock_mock.with_chunk_array_support do + stub_request( + :post, + "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream", + ) + .with do |inner_request| + request = inner_request + true + end + .to_return(status: 200, body: messages) + + response = +"" + proxy.generate("hello world", user: user) { |partial| response << partial } + + expect(request.headers["Authorization"]).to be_present + expect(request.headers["X-Amz-Content-Sha256"]).to be_present + + expected = { + "max_tokens" => 3000, + "anthropic_version" => "bedrock-2023-05-31", + "messages" => [{ "role" => "user", "content" => "hello world" }], + "system" => "You are a helpful bot", + } + expect(JSON.parse(request.body)).to eq(expected) + + expect(response).to eq("hello sam") + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(9) + expect(log.response_tokens).to eq(25) end end end