FEATURE: add Claude 3 sonnet/haiku support for Amazon Bedrock (#534)

This PR consolidates the  implements new Anthropic Messages interface for Bedrock Claude endpoints and adds support for the new Claude 3 models (haiku, opus, sonnet).

Key changes:
- Renamed `AnthropicMessages` and `Anthropic` endpoint classes into a single `Anthropic` class (ditto for ClaudeMessages -> Claude)
- Updated `AwsBedrock` endpoints to use the new `/messages` API format for all Claude models
- Added `claude-3-haiku`, `claude-3-opus` and `claude-3-sonnet` model support in both Anthropic and AWS Bedrock endpoints
- Updated specs for the new consolidated endpoints and Claude 3 model support

This refactor removes support for old non messages API which has been deprecated by anthropic
This commit is contained in:
Sam 2024-03-19 06:48:46 +11:00 committed by GitHub
parent d7ed8180af
commit f62703760f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 721 additions and 1001 deletions

View File

@ -178,9 +178,16 @@ module DiscourseAi
when DiscourseAi::AiBot::EntryPoint::FAKE_ID when DiscourseAi::AiBot::EntryPoint::FAKE_ID
"fake:fake" "fake:fake"
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID
# no bedrock support yet 18-03
"anthropic:claude-3-opus" "anthropic:claude-3-opus"
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID
"anthropic:claude-3-sonnet" if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(
"claude-3-sonnet",
)
"aws_bedrock:claude-3-sonnet"
else
"anthropic:claude-3-sonnet"
end
else else
nil nil
end end

View File

@ -216,13 +216,16 @@ Follow the provided writing composition instructions carefully and precisely ste
def translate_model(model) def translate_model(model)
return "google:gemini-pro" if model == "gemini-pro" return "google:gemini-pro" if model == "gemini-pro"
return "open_ai:#{model}" if model.start_with? "gpt" return "open_ai:#{model}" if model.start_with? "gpt"
return "anthropic:#{model}" if model.start_with? "claude-3"
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") if model.start_with? "claude"
"aws_bedrock:#{model}" if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(model)
else return "aws_bedrock:#{model}"
"anthropic:#{model}" else
return "anthropic:#{model}"
end
end end
raise "Unknown model #{model}"
end end
private private

View File

@ -6,7 +6,9 @@ module DiscourseAi
class Claude < Dialect class Claude < Dialect
class << self class << self
def can_translate?(model_name) def can_translate?(model_name)
%w[claude-instant-1 claude-2].include?(model_name) %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus].include?(
model_name,
)
end end
def tokenizer def tokenizer
@ -14,53 +16,69 @@ module DiscourseAi
end end
end end
class ClaudePrompt
attr_reader :system_prompt
attr_reader :messages
def initialize(system_prompt, messages)
@system_prompt = system_prompt
@messages = messages
end
end
def translate def translate
messages = prompt.messages messages = prompt.messages
system_prompt = +""
trimmed_messages = trim_messages(messages) messages =
trim_messages(messages)
.map do |msg|
case msg[:type]
when :system
system_prompt << msg[:content]
nil
when :tool_call
{ role: "assistant", content: tool_call_to_xml(msg) }
when :tool
{ role: "user", content: tool_result_to_xml(msg) }
when :model
{ role: "assistant", content: msg[:content] }
when :user
content = +""
content << "#{msg[:id]}: " if msg[:id]
content << msg[:content]
# Need to include this differently { role: "user", content: content }
last_message = trimmed_messages.last[:type] == :assistant ? trimmed_messages.pop : nil
claude_prompt =
trimmed_messages.reduce(+"") do |memo, msg|
if msg[:type] == :tool_call
memo << "\n\nAssistant: #{tool_call_to_xml(msg)}"
elsif msg[:type] == :system
memo << "Human: " unless uses_system_message?
memo << msg[:content]
if prompt.tools.present?
memo << "\n"
memo << build_tools_prompt
end end
elsif msg[:type] == :model
memo << "\n\nAssistant: #{msg[:content]}"
elsif msg[:type] == :tool
memo << "\n\nHuman:\n"
memo << tool_result_to_xml(msg)
else
memo << "\n\nHuman: "
memo << "#{msg[:id]}: " if msg[:id]
memo << msg[:content]
end end
.compact
memo if prompt.tools.present?
system_prompt << "\n\n"
system_prompt << build_tools_prompt
end
interleving_messages = []
previous_message = nil
messages.each do |message|
if previous_message
if previous_message[:role] == "user" && message[:role] == "user"
interleving_messages << { role: "assistant", content: "OK" }
elsif previous_message[:role] == "assistant" && message[:role] == "assistant"
interleving_messages << { role: "user", content: "OK" }
end
end end
interleving_messages << message
previous_message = message
end
claude_prompt << "\n\nAssistant:" ClaudePrompt.new(system_prompt.presence, interleving_messages)
claude_prompt << " #{last_message[:content]}:" if last_message
claude_prompt
end end
def max_prompt_tokens def max_prompt_tokens
100_000 # Claude-2.1 has a 200k context window. # Longer term it will have over 1 million
end 200_000 # Claude-3 has a 200k context window for now
private
def uses_system_message?
model_name == "claude-2"
end end
end end
end end

View File

@ -1,85 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class ClaudeMessages < Dialect
class << self
def can_translate?(model_name)
# TODO: add haiku not released yet as of 2024-03-05
%w[claude-3-sonnet claude-3-opus].include?(model_name)
end
def tokenizer
DiscourseAi::Tokenizer::AnthropicTokenizer
end
end
class ClaudePrompt
attr_reader :system_prompt
attr_reader :messages
def initialize(system_prompt, messages)
@system_prompt = system_prompt
@messages = messages
end
end
def translate
messages = prompt.messages
system_prompt = +""
messages =
trim_messages(messages)
.map do |msg|
case msg[:type]
when :system
system_prompt << msg[:content]
nil
when :tool_call
{ role: "assistant", content: tool_call_to_xml(msg) }
when :tool
{ role: "user", content: tool_result_to_xml(msg) }
when :model
{ role: "assistant", content: msg[:content] }
when :user
content = +""
content << "#{msg[:id]}: " if msg[:id]
content << msg[:content]
{ role: "user", content: content }
end
end
.compact
if prompt.tools.present?
system_prompt << "\n\n"
system_prompt << build_tools_prompt
end
interleving_messages = []
previous_message = nil
messages.each do |message|
if previous_message
if previous_message[:role] == "user" && message[:role] == "user"
interleving_messages << { role: "assistant", content: "OK" }
elsif previous_message[:role] == "assistant" && message[:role] == "assistant"
interleving_messages << { role: "user", content: "OK" }
end
end
interleving_messages << message
previous_message = message
end
ClaudePrompt.new(system_prompt.presence, interleving_messages)
end
def max_prompt_tokens
# Longer term it will have over 1 million
200_000 # Claude-3 has a 200k context window for now
end
end
end
end
end

View File

@ -11,13 +11,12 @@ module DiscourseAi
def dialect_for(model_name) def dialect_for(model_name)
dialects = [ dialects = [
DiscourseAi::Completions::Dialects::Claude,
DiscourseAi::Completions::Dialects::Llama2Classic, DiscourseAi::Completions::Dialects::Llama2Classic,
DiscourseAi::Completions::Dialects::ChatGpt, DiscourseAi::Completions::Dialects::ChatGpt,
DiscourseAi::Completions::Dialects::OrcaStyle, DiscourseAi::Completions::Dialects::OrcaStyle,
DiscourseAi::Completions::Dialects::Gemini, DiscourseAi::Completions::Dialects::Gemini,
DiscourseAi::Completions::Dialects::Mixtral, DiscourseAi::Completions::Dialects::Mixtral,
DiscourseAi::Completions::Dialects::ClaudeMessages, DiscourseAi::Completions::Dialects::Claude,
] ]
if Rails.env.test? || Rails.env.development? if Rails.env.test? || Rails.env.development?

View File

@ -6,7 +6,10 @@ module DiscourseAi
class Anthropic < Base class Anthropic < Base
class << self class << self
def can_contact?(endpoint_name, model_name) def can_contact?(endpoint_name, model_name)
endpoint_name == "anthropic" && %w[claude-instant-1 claude-2].include?(model_name) endpoint_name == "anthropic" &&
%w[claude-instant-1 claude-2 claude-3-haiku claude-3-opus claude-3-sonnet].include?(
model_name,
)
end end
def dependant_setting_names def dependant_setting_names
@ -23,23 +26,32 @@ module DiscourseAi
end end
def normalize_model_params(model_params) def normalize_model_params(model_params)
model_params = model_params.dup # max_tokens, temperature, stop_sequences are already supported
# temperature, stop_sequences are already supported
#
if model_params[:max_tokens]
model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens)
end
model_params model_params
end end
def default_options def default_options(dialect)
{ # skipping 2.0 support for now, since other models are better
model: model == "claude-2" ? "claude-2.1" : model, mapped_model =
max_tokens_to_sample: 3_000, case model
stop_sequences: ["\n\nHuman:", "</function_calls>"], when "claude-2"
} "claude-2.1"
when "claude-instant-1"
"claude-instant-1.2"
when "claude-3-haiku"
"claude-3-haiku-20240307"
when "claude-3-sonnet"
"claude-3-sonnet-20240229"
when "claude-3-opus"
"claude-3-opus-20240229"
else
raise "Unsupported model: #{model}"
end
options = { model: mapped_model, max_tokens: 3_000 }
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
options
end end
def provider_id def provider_id
@ -48,15 +60,22 @@ module DiscourseAi
private private
def model_uri # this is an approximation, we will update it later if request goes through
@uri ||= URI("https://api.anthropic.com/v1/complete") def prompt_size(prompt)
super(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
end end
def prepare_payload(prompt, model_params, _dialect) def model_uri
default_options @uri ||= URI("https://api.anthropic.com/v1/messages")
.merge(model_params) end
.merge(prompt: prompt)
.tap { |payload| payload[:stream] = true if @streaming_mode } def prepare_payload(prompt, model_params, dialect)
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
payload[:stream] = true if @streaming_mode
payload
end end
def prepare_request(payload) def prepare_request(payload)
@ -69,8 +88,30 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end end
def final_log_update(log)
log.request_tokens = @input_tokens if @input_tokens
log.response_tokens = @output_tokens if @output_tokens
end
def extract_completion_from(response_raw) def extract_completion_from(response_raw)
JSON.parse(response_raw, symbolize_names: true)[:completion].to_s result = ""
parsed = JSON.parse(response_raw, symbolize_names: true)
if @streaming_mode
if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
result = parsed.dig(:delta, :text).to_s
elsif parsed[:type] == "message_start"
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
elsif parsed[:type] == "message_delta"
@output_tokens = parsed.dig(:delta, :usage, :output_tokens)
end
else
result = parsed.dig(:content, 0, :text).to_s
@input_tokens = parsed.dig(:usage, :input_tokens)
@output_tokens = parsed.dig(:usage, :output_tokens)
end
result
end end
def partials_from(decoded_chunk) def partials_from(decoded_chunk)

View File

@ -1,103 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Endpoints
class AnthropicMessages < Base
class << self
def can_contact?(endpoint_name, model_name)
endpoint_name == "anthropic" && %w[claude-3-opus claude-3-sonnet].include?(model_name)
end
def dependant_setting_names
%w[ai_anthropic_api_key]
end
def correctly_configured?(_model_name)
SiteSetting.ai_anthropic_api_key.present?
end
def endpoint_name(model_name)
"Anthropic - #{model_name}"
end
end
def normalize_model_params(model_params)
# max_tokens, temperature, stop_sequences are already supported
model_params
end
def default_options(dialect)
options = { model: model + "-20240229", max_tokens: 3_000 }
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
options
end
def provider_id
AiApiAuditLog::Provider::Anthropic
end
private
# this is an approximation, we will update it later if request goes through
def prompt_size(prompt)
super(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
end
def model_uri
@uri ||= URI("https://api.anthropic.com/v1/messages")
end
def prepare_payload(prompt, model_params, dialect)
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
payload[:stream] = true if @streaming_mode
payload
end
def prepare_request(payload)
headers = {
"anthropic-version" => "2023-06-01",
"x-api-key" => SiteSetting.ai_anthropic_api_key,
"content-type" => "application/json",
}
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
def final_log_update(log)
log.request_tokens = @input_tokens if @input_tokens
log.response_tokens = @output_tokens if @output_tokens
end
def extract_completion_from(response_raw)
result = ""
parsed = JSON.parse(response_raw, symbolize_names: true)
if @streaming_mode
if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
result = parsed.dig(:delta, :text).to_s
elsif parsed[:type] == "message_start"
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
elsif parsed[:type] == "message_delta"
@output_tokens = parsed.dig(:delta, :usage, :output_tokens)
end
else
result = parsed.dig(:content, 0, :text).to_s
@input_tokens = parsed.dig(:usage, :input_tokens)
@output_tokens = parsed.dig(:usage, :output_tokens)
end
result
end
def partials_from(decoded_chunk)
decoded_chunk.split("\n").map { |line| line.split("data: ", 2)[1] }.compact
end
end
end
end
end

View File

@ -8,17 +8,18 @@ module DiscourseAi
class AwsBedrock < Base class AwsBedrock < Base
class << self class << self
def can_contact?(endpoint_name, model_name) def can_contact?(endpoint_name, model_name)
endpoint_name == "aws_bedrock" && %w[claude-instant-1 claude-2].include?(model_name) endpoint_name == "aws_bedrock" &&
%w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet].include?(model_name)
end end
def dependant_setting_names def dependant_setting_names
%w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region] %w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region]
end end
def correctly_configured?(_model_name) def correctly_configured?(model)
SiteSetting.ai_bedrock_access_key_id.present? && SiteSetting.ai_bedrock_access_key_id.present? &&
SiteSetting.ai_bedrock_secret_access_key.present? && SiteSetting.ai_bedrock_secret_access_key.present? &&
SiteSetting.ai_bedrock_region.present? SiteSetting.ai_bedrock_region.present? && can_contact?("aws_bedrock", model)
end end
def endpoint_name(model_name) def endpoint_name(model_name)
@ -29,17 +30,15 @@ module DiscourseAi
def normalize_model_params(model_params) def normalize_model_params(model_params)
model_params = model_params.dup model_params = model_params.dup
# temperature, stop_sequences are already supported # max_tokens, temperature, stop_sequences, top_p are already supported
#
if model_params[:max_tokens]
model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens)
end
model_params model_params
end end
def default_options def default_options(dialect)
{ max_tokens_to_sample: 3_000, stop_sequences: ["\n\nHuman:", "</function_calls>"] } options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" }
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
options
end end
def provider_id def provider_id
@ -48,25 +47,40 @@ module DiscourseAi
private private
def model_uri def prompt_size(prompt)
# Bedrock uses slightly different names # approximation
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html super(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
bedrock_model_id = model.split("-") end
bedrock_model_id[-1] = "v#{bedrock_model_id.last}"
bedrock_model_id = +(bedrock_model_id.join("-"))
bedrock_model_id << ":1" if model == "claude-2" # For claude-2.1 def model_uri
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
#
# FYI there is a 2.0 version of Claude, very little need to support it given
# haiku/sonnet are better fits anyway, we map to claude-2.1
bedrock_model_id =
case model
when "claude-2"
"anthropic.claude-v2:1"
when "claude-3-haiku"
"anthropic.claude-3-haiku-20240307-v1:0"
when "claude-3-sonnet"
"anthropic.claude-3-sonnet-20240229-v1:0"
when "claude-instant-1"
"anthropic.claude-instant-v1"
end
api_url = api_url =
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.#{bedrock_model_id}/invoke" "https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url
URI(api_url) URI(api_url)
end end
def prepare_payload(prompt, model_params, _dialect) def prepare_payload(prompt, model_params, dialect)
default_options.merge(prompt: prompt).merge(model_params) payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
payload
end end
def prepare_request(payload) def prepare_request(payload)
@ -117,8 +131,30 @@ module DiscourseAi
nil nil
end end
def final_log_update(log)
log.request_tokens = @input_tokens if @input_tokens
log.response_tokens = @output_tokens if @output_tokens
end
def extract_completion_from(response_raw) def extract_completion_from(response_raw)
JSON.parse(response_raw, symbolize_names: true)[:completion].to_s result = ""
parsed = JSON.parse(response_raw, symbolize_names: true)
if @streaming_mode
if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
result = parsed.dig(:delta, :text).to_s
elsif parsed[:type] == "message_start"
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
elsif parsed[:type] == "message_delta"
@output_tokens = parsed.dig(:delta, :usage, :output_tokens)
end
else
result = parsed.dig(:content, 0, :text).to_s
@input_tokens = parsed.dig(:usage, :input_tokens)
@output_tokens = parsed.dig(:usage, :output_tokens)
end
result
end end
def partials_from(decoded_chunk) def partials_from(decoded_chunk)

View File

@ -11,12 +11,11 @@ module DiscourseAi
def endpoint_for(provider_name, model_name) def endpoint_for(provider_name, model_name)
endpoints = [ endpoints = [
DiscourseAi::Completions::Endpoints::AwsBedrock, DiscourseAi::Completions::Endpoints::AwsBedrock,
DiscourseAi::Completions::Endpoints::Anthropic,
DiscourseAi::Completions::Endpoints::OpenAi, DiscourseAi::Completions::Endpoints::OpenAi,
DiscourseAi::Completions::Endpoints::HuggingFace, DiscourseAi::Completions::Endpoints::HuggingFace,
DiscourseAi::Completions::Endpoints::Gemini, DiscourseAi::Completions::Endpoints::Gemini,
DiscourseAi::Completions::Endpoints::Vllm, DiscourseAi::Completions::Endpoints::Vllm,
DiscourseAi::Completions::Endpoints::AnthropicMessages, DiscourseAi::Completions::Endpoints::Anthropic,
] ]
if Rails.env.test? || Rails.env.development? if Rails.env.test? || Rails.env.development?

View File

@ -23,8 +23,8 @@ module DiscourseAi
# However, since they use the same URL/key settings, there's no reason to duplicate them. # However, since they use the same URL/key settings, there's no reason to duplicate them.
@models_by_provider ||= @models_by_provider ||=
{ {
aws_bedrock: %w[claude-instant-1 claude-2], aws_bedrock: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet],
anthropic: %w[claude-instant-1 claude-2 claude-3-sonnet claude-3-opus], anthropic: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus],
vllm: %w[ vllm: %w[
mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2 mistralai/Mistral-7B-Instruct-v0.2

View File

@ -1,87 +0,0 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::ClaudeMessages do
describe "#translate" do
it "can insert OKs to make stuff interleve properly" do
messages = [
{ type: :user, id: "user1", content: "1" },
{ type: :model, content: "2" },
{ type: :user, id: "user1", content: "4" },
{ type: :user, id: "user1", content: "5" },
{ type: :model, content: "6" },
]
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
dialectKlass = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
dialect = dialectKlass.new(prompt, "claude-3-opus")
translated = dialect.translate
expected_messages = [
{ role: "user", content: "user1: 1" },
{ role: "assistant", content: "2" },
{ role: "user", content: "user1: 4" },
{ role: "assistant", content: "OK" },
{ role: "user", content: "user1: 5" },
{ role: "assistant", content: "6" },
]
expect(translated.messages).to eq(expected_messages)
end
it "can properly translate a prompt" do
dialect = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
tools = [
{
name: "echo",
description: "echo a string",
parameters: [
{ name: "text", type: "string", description: "string to echo", required: true },
],
},
]
tool_call_prompt = { name: "echo", arguments: { text: "something" } }
messages = [
{ type: :user, id: "user1", content: "echo something" },
{ type: :tool_call, content: tool_call_prompt.to_json },
{ type: :tool, id: "tool_id", content: "something".to_json },
{ type: :model, content: "I did it" },
{ type: :user, id: "user1", content: "echo something else" },
]
prompt =
DiscourseAi::Completions::Prompt.new(
"You are a helpful bot",
messages: messages,
tools: tools,
)
dialect = dialect.new(prompt, "claude-3-opus")
translated = dialect.translate
expect(translated.system_prompt).to start_with("You are a helpful bot")
expect(translated.system_prompt).to include("echo a string")
expected = [
{ role: "user", content: "user1: echo something" },
{
role: "assistant",
content:
"<function_calls>\n<invoke>\n<tool_name>echo</tool_name>\n<parameters>\n<text>something</text>\n</parameters>\n</invoke>\n</function_calls>",
},
{
role: "user",
content:
"<function_results>\n<result>\n<tool_name>tool_id</tool_name>\n<json>\n\"something\"\n</json>\n</result>\n</function_results>",
},
{ role: "assistant", content: "I did it" },
{ role: "user", content: "user1: echo something else" },
]
expect(translated.messages).to eq(expected)
end
end
end

View File

@ -1,100 +1,87 @@
# frozen_string_literal: true # frozen_string_literal: true
require_relative "dialect_context"
RSpec.describe DiscourseAi::Completions::Dialects::Claude do RSpec.describe DiscourseAi::Completions::Dialects::Claude do
let(:model_name) { "claude-2" }
let(:context) { DialectContext.new(described_class, model_name) }
describe "#translate" do describe "#translate" do
it "translates a prompt written in our generic format to Claude's format" do it "can insert OKs to make stuff interleve properly" do
anthropic_version = (<<~TEXT).strip messages = [
#{context.system_insts} { type: :user, id: "user1", content: "1" },
#{described_class.tool_preamble} { type: :model, content: "2" },
<tools> { type: :user, id: "user1", content: "4" },
#{context.dialect_tools}</tools> { type: :user, id: "user1", content: "5" },
{ type: :model, content: "6" },
]
Human: #{context.simple_user_input} prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
Assistant: dialectKlass = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
TEXT dialect = dialectKlass.new(prompt, "claude-3-opus")
translated = dialect.translate
translated = context.system_user_scenario expected_messages = [
{ role: "user", content: "user1: 1" },
{ role: "assistant", content: "2" },
{ role: "user", content: "user1: 4" },
{ role: "assistant", content: "OK" },
{ role: "user", content: "user1: 5" },
{ role: "assistant", content: "6" },
]
expect(translated).to eq(anthropic_version) expect(translated.messages).to eq(expected_messages)
end end
it "translates tool messages" do it "can properly translate a prompt" do
expected = +(<<~TEXT).strip dialect = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
#{context.system_insts}
#{described_class.tool_preamble}
<tools>
#{context.dialect_tools}</tools>
Human: user1: This is a message by a user tools = [
{
name: "echo",
description: "echo a string",
parameters: [
{ name: "text", type: "string", description: "string to echo", required: true },
],
},
]
Assistant: I'm a previous bot reply, that's why there's no user tool_call_prompt = { name: "echo", arguments: { text: "something" } }
Human: user1: This is a new message by a user messages = [
{ type: :user, id: "user1", content: "echo something" },
{ type: :tool_call, content: tool_call_prompt.to_json },
{ type: :tool, id: "tool_id", content: "something".to_json },
{ type: :model, content: "I did it" },
{ type: :user, id: "user1", content: "echo something else" },
]
Assistant: <function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
Human:
<function_results>
<result>
<tool_name>get_weather</tool_name>
<json>
"I'm a tool result"
</json>
</result>
</function_results>
Assistant:
TEXT
expect(context.multi_turn_scenario).to eq(expected)
end
it "trims content if it's getting too long" do
length = 19_000
translated = context.long_user_input_scenario(length: length)
expect(translated.length).to be < context.long_message_text(length: length).length
end
it "retains usernames in generated prompt" do
prompt = prompt =
DiscourseAi::Completions::Prompt.new( DiscourseAi::Completions::Prompt.new(
"You are a bot", "You are a helpful bot",
messages: [ messages: messages,
{ id: "👻", type: :user, content: "Message1" }, tools: tools,
{ type: :model, content: "Ok" },
{ id: "joe", type: :user, content: "Message2" },
],
) )
translated = context.dialect(prompt).translate dialect = dialect.new(prompt, "claude-3-opus")
translated = dialect.translate
expect(translated).to eq(<<~TEXT.strip) expect(translated.system_prompt).to start_with("You are a helpful bot")
You are a bot expect(translated.system_prompt).to include("echo a string")
Human: 👻: Message1 expected = [
{ role: "user", content: "user1: echo something" },
{
role: "assistant",
content:
"<function_calls>\n<invoke>\n<tool_name>echo</tool_name>\n<parameters>\n<text>something</text>\n</parameters>\n</invoke>\n</function_calls>",
},
{
role: "user",
content:
"<function_results>\n<result>\n<tool_name>tool_id</tool_name>\n<json>\n\"something\"\n</json>\n</result>\n</function_results>",
},
{ role: "assistant", content: "I did it" },
{ role: "user", content: "user1: echo something else" },
]
Assistant: Ok expect(translated.messages).to eq(expected)
Human: joe: Message2
Assistant:
TEXT
end end
end end
end end

View File

@ -1,395 +0,0 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do
let(:llm) { DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-opus") }
let(:prompt) do
DiscourseAi::Completions::Prompt.new(
"You are hello bot",
messages: [type: :user, id: "user1", content: "hello"],
)
end
let(:echo_tool) do
{
name: "echo",
description: "echo something",
parameters: [{ name: "text", type: "string", description: "text to echo", required: true }],
}
end
let(:google_tool) do
{
name: "google",
description: "google something",
parameters: [
{ name: "query", type: "string", description: "text to google", required: true },
],
}
end
let(:prompt_with_echo_tool) do
prompt_with_tools = prompt
prompt.tools = [echo_tool]
prompt_with_tools
end
let(:prompt_with_google_tool) do
prompt_with_tools = prompt
prompt.tools = [echo_tool]
prompt_with_tools
end
before { SiteSetting.ai_anthropic_api_key = "123" }
it "does not eat spaces with tool calls" do
body = <<~STRING
event: message_start
data: {"type":"message_start","message":{"id":"msg_019kmW9Q3GqfWmuFJbePJTBR","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":347,"output_tokens":1}}}
event: content_block_start
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
event: ping
data: {"type": "ping"}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<function"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"calls"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<invoke"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<tool"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"name"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"google"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</tool"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"name"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<parameters"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<query"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"top"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"10"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"things"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" do"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" in"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" japan"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" for"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" tourists"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</query"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</parameters"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</invoke"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_stop
data: {"type":"content_block_stop","index":0}
event: message_delta
data: {"type":"message_delta","delta":{"stop_reason":"stop_sequence","stop_sequence":"</function_calls>"},"usage":{"output_tokens":57}}
event: message_stop
data: {"type":"message_stop"}
STRING
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
result = +""
llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial|
result << partial
end
expected = (<<~TEXT).strip
<function_calls>
<invoke>
<tool_name>google</tool_name>
<parameters>
<query>top 10 things to do in japan for tourists</query>
</parameters>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
TEXT
expect(result.strip).to eq(expected)
end
it "can stream a response" do
body = (<<~STRING).strip
event: message_start
data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}}
event: content_block_start
data: {"type": "content_block_start", "index":0, "content_block": {"type": "text", "text": ""}}
event: ping
data: {"type": "ping"}
event: content_block_delta
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}}
event: content_block_delta
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}}
event: content_block_stop
data: {"type": "content_block_stop", "index": 0}
event: message_delta
data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}}
event: message_stop
data: {"type": "message_stop"}
STRING
parsed_body = nil
stub_request(:post, "https://api.anthropic.com/v1/messages").with(
body:
proc do |req_body|
parsed_body = JSON.parse(req_body, symbolize_names: true)
true
end,
headers: {
"Content-Type" => "application/json",
"X-Api-Key" => "123",
"Anthropic-Version" => "2023-06-01",
},
).to_return(status: 200, body: body)
result = +""
llm.generate(prompt, user: Discourse.system_user) { |partial, cancel| result << partial }
expect(result).to eq("Hello!")
expected_body = {
model: "claude-3-opus-20240229",
max_tokens: 3000,
messages: [{ role: "user", content: "user1: hello" }],
system: "You are hello bot",
stream: true,
}
expect(parsed_body).to eq(expected_body)
log = AiApiAuditLog.order(:id).last
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
expect(log.request_tokens).to eq(25)
expect(log.response_tokens).to eq(15)
end
it "can return multiple function calls" do
functions = <<~FUNCTIONS
<function_calls>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>something</text>
</parameters>
</invoke>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>something else</text>
</parameters>
</invoke>
FUNCTIONS
body = <<~STRING
{
"content": [
{
"text": "Hello!\n\n#{functions}\njunk",
"type": "text"
}
],
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
"model": "claude-3-opus-20240229",
"role": "assistant",
"stop_reason": "end_turn",
"stop_sequence": null,
"type": "message",
"usage": {
"input_tokens": 10,
"output_tokens": 25
}
}
STRING
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
result = llm.generate(prompt_with_echo_tool, user: Discourse.system_user)
expected = (<<~EXPECTED).strip
<function_calls>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>something</text>
</parameters>
<tool_id>tool_0</tool_id>
</invoke>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>something else</text>
</parameters>
<tool_id>tool_1</tool_id>
</invoke>
</function_calls>
EXPECTED
expect(result.strip).to eq(expected)
end
it "can operate in regular mode" do
body = <<~STRING
{
"content": [
{
"text": "Hello!",
"type": "text"
}
],
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
"model": "claude-3-opus-20240229",
"role": "assistant",
"stop_reason": "end_turn",
"stop_sequence": null,
"type": "message",
"usage": {
"input_tokens": 10,
"output_tokens": 25
}
}
STRING
parsed_body = nil
stub_request(:post, "https://api.anthropic.com/v1/messages").with(
body:
proc do |req_body|
parsed_body = JSON.parse(req_body, symbolize_names: true)
true
end,
headers: {
"Content-Type" => "application/json",
"X-Api-Key" => "123",
"Anthropic-Version" => "2023-06-01",
},
).to_return(status: 200, body: body)
result = llm.generate(prompt, user: Discourse.system_user)
expect(result).to eq("Hello!")
expected_body = {
model: "claude-3-opus-20240229",
max_tokens: 3000,
messages: [{ role: "user", content: "user1: hello" }],
system: "You are hello bot",
}
expect(parsed_body).to eq(expected_body)
log = AiApiAuditLog.order(:id).last
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
expect(log.request_tokens).to eq(10)
expect(log.response_tokens).to eq(25)
end
end

View File

@ -1,96 +1,395 @@
# frozen_String_literal: true # frozen_string_literal: true
require_relative "endpoint_compliance" RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
let(:llm) { DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-opus") }
class AnthropicMock < EndpointMock let(:prompt) do
def response(content) DiscourseAi::Completions::Prompt.new(
"You are hello bot",
messages: [type: :user, id: "user1", content: "hello"],
)
end
let(:echo_tool) do
{ {
completion: content, name: "echo",
stop: "\n\nHuman:", description: "echo something",
stop_reason: "stop_sequence", parameters: [{ name: "text", type: "string", description: "text to echo", required: true }],
truncated: false,
log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
model: "claude-2",
exception: nil,
} }
end end
def stub_response(prompt, response_text, tool_call: false) let(:google_tool) do
WebMock {
.stub_request(:post, "https://api.anthropic.com/v1/complete") name: "google",
.with(body: model.default_options.merge(prompt: prompt).to_json) description: "google something",
.to_return(status: 200, body: JSON.dump(response(response_text))) parameters: [
{ name: "query", type: "string", description: "text to google", required: true },
],
}
end end
def stream_line(delta, finish_reason: nil) let(:prompt_with_echo_tool) do
+"data: " << { prompt_with_tools = prompt
completion: delta, prompt.tools = [echo_tool]
stop: finish_reason ? "\n\nHuman:" : nil, prompt_with_tools
stop_reason: finish_reason,
truncated: false,
log_id: "12b029451c6d18094d868bc04ce83f63",
model: "claude-2",
exception: nil,
}.to_json
end end
def stub_streamed_response(prompt, deltas, tool_call: false) let(:prompt_with_google_tool) do
chunks = prompt_with_tools = prompt
deltas.each_with_index.map do |_, index| prompt.tools = [echo_tool]
if index == (deltas.length - 1) prompt_with_tools
stream_line(deltas[index], finish_reason: "stop_sequence")
else
stream_line(deltas[index])
end
end
chunks = chunks.join("\n\n").split("")
WebMock
.stub_request(:post, "https://api.anthropic.com/v1/complete")
.with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
.to_return(status: 200, body: chunks)
end
end
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) }
fab!(:user)
let(:anthropic_mock) { AnthropicMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user)
end end
describe "#perform_completion!" do before { SiteSetting.ai_anthropic_api_key = "123" }
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(anthropic_mock)
end
end
context "with tools" do it "does not eat spaces with tool calls" do
it "returns a function invocation" do body = <<~STRING
compliance.regular_mode_tools(anthropic_mock) event: message_start
end data: {"type":"message_start","message":{"id":"msg_019kmW9Q3GqfWmuFJbePJTBR","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":347,"output_tokens":1}}}
end
event: content_block_start
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
event: ping
data: {"type": "ping"}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<function"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"calls"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<invoke"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<tool"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"name"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"google"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</tool"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"name"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<parameters"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<query"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"top"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"10"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"things"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" do"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" in"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" japan"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" for"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" tourists"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</query"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</parameters"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</invoke"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_stop
data: {"type":"content_block_stop","index":0}
event: message_delta
data: {"type":"message_delta","delta":{"stop_reason":"stop_sequence","stop_sequence":"</function_calls>"},"usage":{"output_tokens":57}}
event: message_stop
data: {"type":"message_stop"}
STRING
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
result = +""
llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial|
result << partial
end end
describe "when using streaming mode" do expected = (<<~TEXT).strip
context "with simple prompts" do <function_calls>
it "completes a trivial prompt and logs the response" do <invoke>
compliance.streaming_mode_simple_prompt(anthropic_mock) <tool_name>google</tool_name>
end <parameters>
end <query>top 10 things to do in japan for tourists</query>
</parameters>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
TEXT
context "with tools" do expect(result.strip).to eq(expected)
it "returns a function invocation" do end
compliance.streaming_mode_tools(anthropic_mock)
end it "can stream a response" do
end body = (<<~STRING).strip
end event: message_start
data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}}
event: content_block_start
data: {"type": "content_block_start", "index":0, "content_block": {"type": "text", "text": ""}}
event: ping
data: {"type": "ping"}
event: content_block_delta
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}}
event: content_block_delta
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}}
event: content_block_stop
data: {"type": "content_block_stop", "index": 0}
event: message_delta
data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}}
event: message_stop
data: {"type": "message_stop"}
STRING
parsed_body = nil
stub_request(:post, "https://api.anthropic.com/v1/messages").with(
body:
proc do |req_body|
parsed_body = JSON.parse(req_body, symbolize_names: true)
true
end,
headers: {
"Content-Type" => "application/json",
"X-Api-Key" => "123",
"Anthropic-Version" => "2023-06-01",
},
).to_return(status: 200, body: body)
result = +""
llm.generate(prompt, user: Discourse.system_user) { |partial, cancel| result << partial }
expect(result).to eq("Hello!")
expected_body = {
model: "claude-3-opus-20240229",
max_tokens: 3000,
messages: [{ role: "user", content: "user1: hello" }],
system: "You are hello bot",
stream: true,
}
expect(parsed_body).to eq(expected_body)
log = AiApiAuditLog.order(:id).last
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
expect(log.request_tokens).to eq(25)
expect(log.response_tokens).to eq(15)
end
it "can return multiple function calls" do
functions = <<~FUNCTIONS
<function_calls>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>something</text>
</parameters>
</invoke>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>something else</text>
</parameters>
</invoke>
FUNCTIONS
body = <<~STRING
{
"content": [
{
"text": "Hello!\n\n#{functions}\njunk",
"type": "text"
}
],
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
"model": "claude-3-opus-20240229",
"role": "assistant",
"stop_reason": "end_turn",
"stop_sequence": null,
"type": "message",
"usage": {
"input_tokens": 10,
"output_tokens": 25
}
}
STRING
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
result = llm.generate(prompt_with_echo_tool, user: Discourse.system_user)
expected = (<<~EXPECTED).strip
<function_calls>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>something</text>
</parameters>
<tool_id>tool_0</tool_id>
</invoke>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>something else</text>
</parameters>
<tool_id>tool_1</tool_id>
</invoke>
</function_calls>
EXPECTED
expect(result.strip).to eq(expected)
end
it "can operate in regular mode" do
body = <<~STRING
{
"content": [
{
"text": "Hello!",
"type": "text"
}
],
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
"model": "claude-3-opus-20240229",
"role": "assistant",
"stop_reason": "end_turn",
"stop_sequence": null,
"type": "message",
"usage": {
"input_tokens": 10,
"output_tokens": 25
}
}
STRING
parsed_body = nil
stub_request(:post, "https://api.anthropic.com/v1/messages").with(
body:
proc do |req_body|
parsed_body = JSON.parse(req_body, symbolize_names: true)
true
end,
headers: {
"Content-Type" => "application/json",
"X-Api-Key" => "123",
"Anthropic-Version" => "2023-06-01",
},
).to_return(status: 200, body: body)
result = llm.generate(prompt, user: Discourse.system_user)
expect(result).to eq("Hello!")
expected_body = {
model: "claude-3-opus-20240229",
max_tokens: 3000,
messages: [{ role: "user", content: "user1: hello" }],
system: "You are hello bot",
}
expect(parsed_body).to eq(expected_body)
log = AiApiAuditLog.order(:id).last
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
expect(log.request_tokens).to eq(10)
expect(log.response_tokens).to eq(25)
end end
end end

View File

@ -5,71 +5,6 @@ require "aws-eventstream"
require "aws-sigv4" require "aws-sigv4"
class BedrockMock < EndpointMock class BedrockMock < EndpointMock
def response(content)
{
completion: content,
stop: "\n\nHuman:",
stop_reason: "stop_sequence",
truncated: false,
log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
model: "claude",
exception: nil,
}
end
def stub_response(prompt, response_content, tool_call: false)
WebMock
.stub_request(:post, "#{base_url}/invoke")
.with(body: model.default_options.merge(prompt: prompt).to_json)
.to_return(status: 200, body: JSON.dump(response(response_content)))
end
def stream_line(delta, finish_reason: nil)
encoder = Aws::EventStream::Encoder.new
message =
Aws::EventStream::Message.new(
payload:
StringIO.new(
{
bytes:
Base64.encode64(
{
completion: delta,
stop: finish_reason ? "\n\nHuman:" : nil,
stop_reason: finish_reason,
truncated: false,
log_id: "12b029451c6d18094d868bc04ce83f63",
model: "claude-2.1",
exception: nil,
}.to_json,
),
}.to_json,
),
)
encoder.encode(message)
end
def stub_streamed_response(prompt, deltas, tool_call: false)
chunks =
deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1)
stream_line(deltas[index], finish_reason: "stop_sequence")
else
stream_line(deltas[index])
end
end
WebMock
.stub_request(:post, "#{base_url}/invoke-with-response-stream")
.with(body: model.default_options.merge(prompt: prompt).to_json)
.to_return(status: 200, body: chunks)
end
def base_url
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.claude-v2:1"
end
end end
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
@ -89,32 +24,98 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
SiteSetting.ai_bedrock_region = "us-east-1" SiteSetting.ai_bedrock_region = "us-east-1"
end end
describe "#perform_completion!" do describe "Claude 3 Sonnet support" do
context "when using regular mode" do it "supports the sonnet model" do
context "with simple prompts" do proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(bedrock_mock)
end
end
context "with tools" do request = nil
it "returns a function invocation" do
compliance.regular_mode_tools(bedrock_mock) content = {
content: [text: "hello sam"],
usage: {
input_tokens: 10,
output_tokens: 20,
},
}.to_json
stub_request(
:post,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke",
)
.with do |inner_request|
request = inner_request
true
end end
end .to_return(status: 200, body: content)
response = proxy.generate("hello world", user: user)
expect(request.headers["Authorization"]).to be_present
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
expected = {
"max_tokens" => 3000,
"anthropic_version" => "bedrock-2023-05-31",
"messages" => [{ "role" => "user", "content" => "hello world" }],
"system" => "You are a helpful bot",
}
expect(JSON.parse(request.body)).to eq(expected)
expect(response).to eq("hello sam")
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(10)
expect(log.response_tokens).to eq(20)
end end
describe "when using streaming mode" do it "supports claude 3 sonnet streaming" do
context "with simple prompts" do proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(bedrock_mock)
end
end
context "with tools" do request = nil
it "returns a function invocation" do
compliance.streaming_mode_tools(bedrock_mock) messages =
[
{ type: "message_start", message: { usage: { input_tokens: 9 } } },
{ type: "content_block_delta", delta: { text: "hello " } },
{ type: "content_block_delta", delta: { text: "sam" } },
{ type: "message_delta", delta: { usage: { output_tokens: 25 } } },
].map do |message|
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
io = StringIO.new(wrapped)
aws_message = Aws::EventStream::Message.new(payload: io)
Aws::EventStream::Encoder.new.encode(aws_message)
end end
bedrock_mock.with_chunk_array_support do
stub_request(
:post,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
)
.with do |inner_request|
request = inner_request
true
end
.to_return(status: 200, body: messages)
response = +""
proxy.generate("hello world", user: user) { |partial| response << partial }
expect(request.headers["Authorization"]).to be_present
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
expected = {
"max_tokens" => 3000,
"anthropic_version" => "bedrock-2023-05-31",
"messages" => [{ "role" => "user", "content" => "hello world" }],
"system" => "You are a helpful bot",
}
expect(JSON.parse(request.body)).to eq(expected)
expect(response).to eq("hello sam")
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(9)
expect(log.response_tokens).to eq(25)
end end
end end
end end