mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-08-06 13:13:28 +00:00
Native tools do not work well on Opus. Chain of Thought prompting means it consumes enormous amounts of tokens and has poor latency. This commit introduce and XML stripper to remove various chain of thought XML islands from anthropic prompts when tools are involved. This mean Opus native tools is now functions (albeit slowly) From local testing XML just works better now. Also fixes enum support in Anthropic native tools
379 lines
12 KiB
Ruby
379 lines
12 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
require_relative "endpoint_compliance"
|
|
require "aws-eventstream"
|
|
require "aws-sigv4"
|
|
|
|
class BedrockMock < EndpointMock
|
|
end
|
|
|
|
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|
subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) }
|
|
|
|
fab!(:user)
|
|
|
|
let(:bedrock_mock) { BedrockMock.new(endpoint) }
|
|
|
|
let(:compliance) do
|
|
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user)
|
|
end
|
|
|
|
def encode_message(message)
|
|
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
|
|
io = StringIO.new(wrapped)
|
|
aws_message = Aws::EventStream::Message.new(payload: io)
|
|
Aws::EventStream::Encoder.new.encode(aws_message)
|
|
end
|
|
|
|
before do
|
|
SiteSetting.ai_bedrock_access_key_id = "123456"
|
|
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
|
|
SiteSetting.ai_bedrock_region = "us-east-1"
|
|
end
|
|
|
|
describe "function calling" do
|
|
it "supports old school xml function calls" do
|
|
SiteSetting.ai_anthropic_native_tool_call_models = ""
|
|
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
|
|
|
incomplete_tool_call = <<~XML.strip
|
|
<thinking>I should be ignored</thinking>
|
|
<search_quality_reflection>also ignored</search_quality_reflection>
|
|
<search_quality_score>0</search_quality_score>
|
|
<function_calls>
|
|
<invoke>
|
|
<tool_name>google</tool_name>
|
|
<parameters><query>sydney weather today</query></parameters>
|
|
</invoke>
|
|
</function_calls>
|
|
XML
|
|
|
|
messages =
|
|
[
|
|
{ type: "message_start", message: { usage: { input_tokens: 9 } } },
|
|
{ type: "content_block_delta", delta: { text: "hello\n" } },
|
|
{ type: "content_block_delta", delta: { text: incomplete_tool_call } },
|
|
{ type: "message_delta", delta: { usage: { output_tokens: 25 } } },
|
|
].map { |message| encode_message(message) }
|
|
|
|
request = nil
|
|
bedrock_mock.with_chunk_array_support do
|
|
stub_request(
|
|
:post,
|
|
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
|
|
)
|
|
.with do |inner_request|
|
|
request = inner_request
|
|
true
|
|
end
|
|
.to_return(status: 200, body: messages)
|
|
|
|
prompt =
|
|
DiscourseAi::Completions::Prompt.new(
|
|
messages: [{ type: :user, content: "what is the weather in sydney" }],
|
|
)
|
|
|
|
tool = {
|
|
name: "google",
|
|
description: "Will search using Google",
|
|
parameters: [
|
|
{ name: "query", description: "The search query", type: "string", required: true },
|
|
],
|
|
}
|
|
|
|
prompt.tools = [tool]
|
|
response = +""
|
|
proxy.generate(prompt, user: user) { |partial| response << partial }
|
|
|
|
expect(request.headers["Authorization"]).to be_present
|
|
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
|
|
|
parsed_body = JSON.parse(request.body)
|
|
expect(parsed_body["system"]).to include("<function_calls>")
|
|
expect(parsed_body["tools"]).to eq(nil)
|
|
expect(parsed_body["stop_sequences"]).to eq(["</function_calls>"])
|
|
|
|
# note we now have a tool_id cause we were normalized
|
|
function_call = <<~XML.strip
|
|
hello
|
|
|
|
|
|
<function_calls>
|
|
<invoke>
|
|
<tool_name>google</tool_name>
|
|
<parameters><query>sydney weather today</query></parameters>
|
|
<tool_id>tool_0</tool_id>
|
|
</invoke>
|
|
</function_calls>
|
|
XML
|
|
|
|
expect(response.strip).to eq(function_call)
|
|
end
|
|
end
|
|
|
|
it "supports streaming function calls" do
|
|
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
|
|
|
request = nil
|
|
|
|
messages =
|
|
[
|
|
{
|
|
type: "message_start",
|
|
message: {
|
|
id: "msg_bdrk_01WYxeNMk6EKn9s98r6XXrAB",
|
|
type: "message",
|
|
role: "assistant",
|
|
model: "claude-3-haiku-20240307",
|
|
stop_sequence: nil,
|
|
usage: {
|
|
input_tokens: 840,
|
|
output_tokens: 1,
|
|
},
|
|
content: [],
|
|
stop_reason: nil,
|
|
},
|
|
},
|
|
{
|
|
type: "content_block_start",
|
|
index: 0,
|
|
delta: {
|
|
text: "<thinking>I should be ignored</thinking>",
|
|
},
|
|
},
|
|
{
|
|
type: "content_block_start",
|
|
index: 0,
|
|
content_block: {
|
|
type: "tool_use",
|
|
id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
|
|
name: "google",
|
|
input: {
|
|
},
|
|
},
|
|
},
|
|
{
|
|
type: "content_block_delta",
|
|
index: 0,
|
|
delta: {
|
|
type: "input_json_delta",
|
|
partial_json: "",
|
|
},
|
|
},
|
|
{
|
|
type: "content_block_delta",
|
|
index: 0,
|
|
delta: {
|
|
type: "input_json_delta",
|
|
partial_json: "{\"query\": \"s",
|
|
},
|
|
},
|
|
{
|
|
type: "content_block_delta",
|
|
index: 0,
|
|
delta: {
|
|
type: "input_json_delta",
|
|
partial_json: "ydney weat",
|
|
},
|
|
},
|
|
{
|
|
type: "content_block_delta",
|
|
index: 0,
|
|
delta: {
|
|
type: "input_json_delta",
|
|
partial_json: "her today\"}",
|
|
},
|
|
},
|
|
{ type: "content_block_stop", index: 0 },
|
|
{
|
|
type: "message_delta",
|
|
delta: {
|
|
stop_reason: "tool_use",
|
|
stop_sequence: nil,
|
|
},
|
|
usage: {
|
|
output_tokens: 53,
|
|
},
|
|
},
|
|
{
|
|
type: "message_stop",
|
|
"amazon-bedrock-invocationMetrics": {
|
|
inputTokenCount: 846,
|
|
outputTokenCount: 39,
|
|
invocationLatency: 880,
|
|
firstByteLatency: 402,
|
|
},
|
|
},
|
|
].map { |message| encode_message(message) }
|
|
|
|
messages = messages.join("").split
|
|
|
|
bedrock_mock.with_chunk_array_support do
|
|
stub_request(
|
|
:post,
|
|
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
|
|
)
|
|
.with do |inner_request|
|
|
request = inner_request
|
|
true
|
|
end
|
|
.to_return(status: 200, body: messages)
|
|
|
|
prompt =
|
|
DiscourseAi::Completions::Prompt.new(
|
|
messages: [{ type: :user, content: "what is the weather in sydney" }],
|
|
)
|
|
|
|
tool = {
|
|
name: "google",
|
|
description: "Will search using Google",
|
|
parameters: [
|
|
{ name: "query", description: "The search query", type: "string", required: true },
|
|
],
|
|
}
|
|
|
|
prompt.tools = [tool]
|
|
response = +""
|
|
proxy.generate(prompt, user: user) { |partial| response << partial }
|
|
|
|
expect(request.headers["Authorization"]).to be_present
|
|
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
|
|
|
expected_response = (<<~RESPONSE).strip
|
|
<function_calls>
|
|
<invoke>
|
|
<tool_name>google</tool_name>
|
|
<parameters><query>sydney weather today</query></parameters>
|
|
<tool_id>toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7</tool_id>
|
|
</invoke>
|
|
</function_calls>
|
|
RESPONSE
|
|
|
|
expect(response.strip).to eq(expected_response)
|
|
|
|
expected = {
|
|
"max_tokens" => 3000,
|
|
"anthropic_version" => "bedrock-2023-05-31",
|
|
"messages" => [{ "role" => "user", "content" => "what is the weather in sydney" }],
|
|
"tools" => [
|
|
{
|
|
"name" => "google",
|
|
"description" => "Will search using Google",
|
|
"input_schema" => {
|
|
"type" => "object",
|
|
"properties" => {
|
|
"query" => {
|
|
"type" => "string",
|
|
"description" => "The search query",
|
|
},
|
|
},
|
|
"required" => ["query"],
|
|
},
|
|
},
|
|
],
|
|
}
|
|
expect(JSON.parse(request.body)).to eq(expected)
|
|
|
|
log = AiApiAuditLog.order(:id).last
|
|
expect(log.request_tokens).to eq(846)
|
|
expect(log.response_tokens).to eq(39)
|
|
end
|
|
end
|
|
end
|
|
|
|
describe "Claude 3 Sonnet support" do
|
|
it "supports the sonnet model" do
|
|
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
|
|
|
request = nil
|
|
|
|
content = {
|
|
content: [text: "hello sam"],
|
|
usage: {
|
|
input_tokens: 10,
|
|
output_tokens: 20,
|
|
},
|
|
}.to_json
|
|
|
|
stub_request(
|
|
:post,
|
|
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke",
|
|
)
|
|
.with do |inner_request|
|
|
request = inner_request
|
|
true
|
|
end
|
|
.to_return(status: 200, body: content)
|
|
|
|
response = proxy.generate("hello world", user: user)
|
|
|
|
expect(request.headers["Authorization"]).to be_present
|
|
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
|
|
|
expected = {
|
|
"max_tokens" => 3000,
|
|
"anthropic_version" => "bedrock-2023-05-31",
|
|
"messages" => [{ "role" => "user", "content" => "hello world" }],
|
|
"system" => "You are a helpful bot",
|
|
}
|
|
expect(JSON.parse(request.body)).to eq(expected)
|
|
|
|
expect(response).to eq("hello sam")
|
|
|
|
log = AiApiAuditLog.order(:id).last
|
|
expect(log.request_tokens).to eq(10)
|
|
expect(log.response_tokens).to eq(20)
|
|
end
|
|
|
|
it "supports claude 3 sonnet streaming" do
|
|
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
|
|
|
request = nil
|
|
|
|
messages =
|
|
[
|
|
{ type: "message_start", message: { usage: { input_tokens: 9 } } },
|
|
{ type: "content_block_delta", delta: { text: "hello " } },
|
|
{ type: "content_block_delta", delta: { text: "sam" } },
|
|
{ type: "message_delta", delta: { usage: { output_tokens: 25 } } },
|
|
].map { |message| encode_message(message) }
|
|
|
|
# stream 1 letter at a time
|
|
# cause we need to handle this case
|
|
messages = messages.join("").split
|
|
|
|
bedrock_mock.with_chunk_array_support do
|
|
stub_request(
|
|
:post,
|
|
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
|
|
)
|
|
.with do |inner_request|
|
|
request = inner_request
|
|
true
|
|
end
|
|
.to_return(status: 200, body: messages)
|
|
|
|
response = +""
|
|
proxy.generate("hello world", user: user) { |partial| response << partial }
|
|
|
|
expect(request.headers["Authorization"]).to be_present
|
|
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
|
|
|
expected = {
|
|
"max_tokens" => 3000,
|
|
"anthropic_version" => "bedrock-2023-05-31",
|
|
"messages" => [{ "role" => "user", "content" => "hello world" }],
|
|
"system" => "You are a helpful bot",
|
|
}
|
|
expect(JSON.parse(request.body)).to eq(expected)
|
|
|
|
expect(response).to eq("hello sam")
|
|
|
|
log = AiApiAuditLog.order(:id).last
|
|
expect(log.request_tokens).to eq(9)
|
|
expect(log.response_tokens).to eq(25)
|
|
end
|
|
end
|
|
end
|
|
end
|