# frozen_string_literal: true
require_relative "endpoint_compliance"
require "aws-eventstream"
require "aws-sigv4"
class BedrockMock < EndpointMock
end
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
subject(:endpoint) { described_class.new(model) }
fab!(:user)
fab!(:model) { Fabricate(:bedrock_model) }
let(:bedrock_mock) { BedrockMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user)
end
def encode_message(message)
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
io = StringIO.new(wrapped)
aws_message = Aws::EventStream::Message.new(payload: io)
Aws::EventStream::Encoder.new.encode(aws_message)
end
it "should provide accurate max token count" do
prompt = DiscourseAi::Completions::Prompt.new("hello")
dialect = DiscourseAi::Completions::Dialects::Claude.new(prompt, model)
endpoint = DiscourseAi::Completions::Endpoints::AwsBedrock.new(model)
model.name = "claude-2"
expect(endpoint.default_options(dialect)[:max_tokens]).to eq(4096)
model.name = "claude-3-5-sonnet"
expect(endpoint.default_options(dialect)[:max_tokens]).to eq(8192)
model.name = "claude-3-5-haiku"
options = endpoint.default_options(dialect)
expect(options[:max_tokens]).to eq(8192)
end
describe "function calling" do
it "supports old school xml function calls" do
model.provider_params["disable_native_tools"] = true
model.save!
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
incomplete_tool_call = <<~XML.strip
I should be ignored
also ignored
0
google
sydney weather today
XML
messages =
[
{ type: "message_start", message: { usage: { input_tokens: 9 } } },
{ type: "content_block_delta", delta: { text: "hello\n" } },
{ type: "content_block_delta", delta: { text: incomplete_tool_call } },
{ type: "message_delta", delta: { usage: { output_tokens: 25 } } },
].map { |message| encode_message(message) }
request = nil
bedrock_mock.with_chunk_array_support do
stub_request(
:post,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
)
.with do |inner_request|
request = inner_request
true
end
.to_return(status: 200, body: messages)
prompt =
DiscourseAi::Completions::Prompt.new(
messages: [{ type: :user, content: "what is the weather in sydney" }],
)
tool = {
name: "google",
description: "Will search using Google",
parameters: [
{ name: "query", description: "The search query", type: "string", required: true },
],
}
prompt.tools = [tool]
response = []
proxy.generate(prompt, user: user) { |partial| response << partial }
expect(request.headers["Authorization"]).to be_present
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
parsed_body = JSON.parse(request.body)
expect(parsed_body["system"]).to include("")
expect(parsed_body["tools"]).to eq(nil)
expect(parsed_body["stop_sequences"]).to eq([""])
expected = [
"hello\n",
DiscourseAi::Completions::ToolCall.new(
id: "tool_0",
name: "google",
parameters: {
query: "sydney weather today",
},
),
]
expect(response).to eq(expected)
end
end
it "supports streaming function calls" do
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
request = nil
messages =
[
{
type: "message_start",
message: {
id: "msg_bdrk_01WYxeNMk6EKn9s98r6XXrAB",
type: "message",
role: "assistant",
model: "claude-3-sonnet-20240307",
stop_sequence: nil,
usage: {
input_tokens: 840,
output_tokens: 1,
},
content: [],
stop_reason: nil,
},
},
{
type: "content_block_start",
index: 0,
delta: {
text: "I should be ignored",
},
},
{
type: "content_block_start",
index: 0,
content_block: {
type: "tool_use",
id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
name: "google",
input: {
},
},
},
{
type: "content_block_delta",
index: 0,
delta: {
type: "input_json_delta",
partial_json: "",
},
},
{
type: "content_block_delta",
index: 0,
delta: {
type: "input_json_delta",
partial_json: "{\"query\": \"s",
},
},
{
type: "content_block_delta",
index: 0,
delta: {
type: "input_json_delta",
partial_json: "ydney weat",
},
},
{
type: "content_block_delta",
index: 0,
delta: {
type: "input_json_delta",
partial_json: "her today\"}",
},
},
{ type: "content_block_stop", index: 0 },
{
type: "message_delta",
delta: {
stop_reason: "tool_use",
stop_sequence: nil,
},
usage: {
output_tokens: 53,
},
},
{
type: "message_stop",
"amazon-bedrock-invocationMetrics": {
inputTokenCount: 846,
outputTokenCount: 39,
invocationLatency: 880,
firstByteLatency: 402,
},
},
].map { |message| encode_message(message) }
messages = messages.join("").split
bedrock_mock.with_chunk_array_support do
stub_request(
:post,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
)
.with do |inner_request|
request = inner_request
true
end
.to_return(status: 200, body: messages)
prompt =
DiscourseAi::Completions::Prompt.new(
messages: [{ type: :user, content: "what is the weather in sydney" }],
)
tool = {
name: "google",
description: "Will search using Google",
parameters: [
{ name: "query", description: "The search query", type: "string", required: true },
],
}
prompt.tools = [tool]
response = []
proxy.generate(prompt, user: user) { |partial| response << partial }
expect(request.headers["Authorization"]).to be_present
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
expected_response = [
DiscourseAi::Completions::ToolCall.new(
id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
name: "google",
parameters: {
query: "sydney weather today",
},
),
]
expect(response).to eq(expected_response)
expected = {
"max_tokens" => 4096,
"anthropic_version" => "bedrock-2023-05-31",
"messages" => [{ "role" => "user", "content" => "what is the weather in sydney" }],
"tools" => [
{
"name" => "google",
"description" => "Will search using Google",
"input_schema" => {
"type" => "object",
"properties" => {
"query" => {
"type" => "string",
"description" => "The search query",
},
},
"required" => ["query"],
},
},
],
}
expect(JSON.parse(request.body)).to eq(expected)
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(846)
expect(log.response_tokens).to eq(39)
end
end
end
describe "Claude 3 support" do
it "supports regular completions" do
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
request = nil
content = {
content: [text: "hello sam"],
usage: {
input_tokens: 10,
output_tokens: 20,
},
}.to_json
stub_request(
:post,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke",
)
.with do |inner_request|
request = inner_request
true
end
.to_return(status: 200, body: content)
response = proxy.generate("hello world", user: user)
expect(request.headers["Authorization"]).to be_present
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
expected = {
"max_tokens" => 4096,
"anthropic_version" => "bedrock-2023-05-31",
"messages" => [{ "role" => "user", "content" => "hello world" }],
"system" => "You are a helpful bot",
}
expect(JSON.parse(request.body)).to eq(expected)
expect(response).to eq("hello sam")
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(10)
expect(log.response_tokens).to eq(20)
end
it "supports claude 3 streaming" do
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
request = nil
messages =
[
{ type: "message_start", message: { usage: { input_tokens: 9 } } },
{ type: "content_block_delta", delta: { text: "hello " } },
{ type: "content_block_delta", delta: { text: "sam" } },
{ type: "message_delta", delta: { usage: { output_tokens: 25 } } },
].map { |message| encode_message(message) }
# stream 1 letter at a time
# cause we need to handle this case
messages = messages.join("").split
bedrock_mock.with_chunk_array_support do
stub_request(
:post,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
)
.with do |inner_request|
request = inner_request
true
end
.to_return(status: 200, body: messages)
response = +""
proxy.generate("hello world", user: user) { |partial| response << partial }
expect(request.headers["Authorization"]).to be_present
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
expected = {
"max_tokens" => 4096,
"anthropic_version" => "bedrock-2023-05-31",
"messages" => [{ "role" => "user", "content" => "hello world" }],
"system" => "You are a helpful bot",
}
expect(JSON.parse(request.body)).to eq(expected)
expect(response).to eq("hello sam")
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(9)
expect(log.response_tokens).to eq(25)
end
end
end
end