FEATURE: anthropic function calling (#654)

Adds support for native tool calling (both streaming and non streaming) for Anthropic.

This improves general tool support on the Anthropic models.
This commit is contained in:
Sam 2024-06-06 08:34:23 +10:00 committed by GitHub
parent 564d2de534
commit 3993c685e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 520 additions and 245 deletions

View File

@ -0,0 +1,98 @@
# frozen_string_literal: true
class DiscourseAi::Completions::AnthropicMessageProcessor
class AnthropicToolCall
attr_reader :name, :raw_json, :id
def initialize(name, id)
@name = name
@id = id
@raw_json = +""
end
def append(json)
@raw_json << json
end
end
attr_reader :tool_calls, :input_tokens, :output_tokens
def initialize(streaming_mode:)
@streaming_mode = streaming_mode
@tool_calls = []
end
def to_xml_tool_calls(function_buffer)
return function_buffer if @tool_calls.blank?
function_buffer = Nokogiri::HTML5.fragment(<<~TEXT)
<function_calls>
</function_calls>
TEXT
@tool_calls.each do |tool_call|
node =
function_buffer.at("function_calls").add_child(
Nokogiri::HTML5::DocumentFragment.parse(
DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n",
),
)
params = JSON.parse(tool_call.raw_json, symbolize_names: true)
xml = params.map { |name, value| "<#{name}>#{value}</#{name}>" }.join("\n")
node.at("tool_name").content = tool_call.name
node.at("tool_id").content = tool_call.id
node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present?
end
function_buffer
end
def process_message(payload)
result = ""
parsed = JSON.parse(payload, symbolize_names: true)
if @streaming_mode
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
tool_name = parsed.dig(:content_block, :name)
tool_id = parsed.dig(:content_block, :id)
@tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
if @tool_calls.present?
result = parsed.dig(:delta, :partial_json).to_s
@tool_calls.last.append(result)
else
result = parsed.dig(:delta, :text).to_s
end
elsif parsed[:type] == "message_start"
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
elsif parsed[:type] == "message_delta"
@output_tokens =
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
elsif parsed[:type] == "message_stop"
# bedrock has this ...
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
end
end
else
content = parsed.dig(:content)
if content.is_a?(Array)
tool_call = content.find { |c| c[:type] == "tool_use" }
if tool_call
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
@tool_calls.last.append(tool_call[:input].to_json)
else
result = parsed.dig(:content, 0, :text).to_s
end
end
@input_tokens = parsed.dig(:usage, :input_tokens)
@output_tokens = parsed.dig(:usage, :output_tokens)
end
result
end
end

View File

@ -15,10 +15,12 @@ module DiscourseAi
class ClaudePrompt
attr_reader :system_prompt
attr_reader :messages
attr_reader :tools
def initialize(system_prompt, messages)
def initialize(system_prompt, messages, tools)
@system_prompt = system_prompt
@messages = messages
@tools = tools
end
end
@ -46,7 +48,11 @@ module DiscourseAi
previous_message = message
end
ClaudePrompt.new(system_prompt.presence, interleving_messages)
ClaudePrompt.new(
system_prompt.presence,
interleving_messages,
tools_dialect.translated_tools,
)
end
def max_prompt_tokens
@ -58,6 +64,18 @@ module DiscourseAi
private
def tools_dialect
@tools_dialect ||= DiscourseAi::Completions::Dialects::ClaudeTools.new(prompt.tools)
end
def tool_call_msg(msg)
tools_dialect.from_raw_tool_call(msg)
end
def tool_msg(msg)
tools_dialect.from_raw_tool(msg)
end
def model_msg(msg)
{ role: "assistant", content: msg[:content] }
end

View File

@ -0,0 +1,81 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class ClaudeTools
def initialize(tools)
@raw_tools = tools
end
def translated_tools
# Transform the raw tools into the required Anthropic Claude API format
raw_tools.map do |t|
properties = {}
required = []
if t[:parameters]
properties =
t[:parameters].each_with_object({}) do |param, h|
h[param[:name]] = {
type: param[:type],
description: param[:description],
}.tap { |hash| hash[:items] = { type: param[:item_type] } if param[:item_type] }
end
required =
t[:parameters].select { |param| param[:required] }.map { |param| param[:name] }
end
{
name: t[:name],
description: t[:description],
input_schema: {
type: "object",
properties: properties,
required: required,
},
}
end
end
def instructions
"" # Noop. Tools are listed separate.
end
def from_raw_tool_call(raw_message)
call_details = JSON.parse(raw_message[:content], symbolize_names: true)
tool_call_id = raw_message[:id]
{
role: "assistant",
content: [
{
type: "tool_use",
id: tool_call_id,
name: raw_message[:name],
input: call_details[:arguments],
},
],
}
end
def from_raw_tool(raw_message)
{
role: "user",
content: [
{
type: "tool_result",
tool_use_id: raw_message[:id],
content: raw_message[:content],
},
],
}
end
private
attr_reader :raw_tools
end
end
end
end

View File

@ -45,10 +45,7 @@ module DiscourseAi
raise "Unsupported model: #{model}"
end
options = { model: mapped_model, max_tokens: 3_000 }
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
options
{ model: mapped_model, max_tokens: 3_000 }
end
def provider_id
@ -73,6 +70,7 @@ module DiscourseAi
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
payload[:stream] = true if @streaming_mode
payload[:tools] = prompt.tools if prompt.tools.present?
payload
end
@ -87,30 +85,30 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
def final_log_update(log)
log.request_tokens = @input_tokens if @input_tokens
log.response_tokens = @output_tokens if @output_tokens
def processor
@processor ||=
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
processor.to_xml_tool_calls(function_buffer) if !partial
end
def extract_completion_from(response_raw)
result = ""
parsed = JSON.parse(response_raw, symbolize_names: true)
processor.process_message(response_raw)
end
if @streaming_mode
if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
result = parsed.dig(:delta, :text).to_s
elsif parsed[:type] == "message_start"
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
elsif parsed[:type] == "message_delta"
@output_tokens = parsed.dig(:delta, :usage, :output_tokens)
end
else
result = parsed.dig(:content, 0, :text).to_s
@input_tokens = parsed.dig(:usage, :input_tokens)
@output_tokens = parsed.dig(:usage, :output_tokens)
end
def has_tool?(_response_data)
processor.tool_calls.present?
end
result
def final_log_update(log)
log.request_tokens = processor.input_tokens if processor.input_tokens
log.response_tokens = processor.output_tokens if processor.output_tokens
end
def native_tool_support?
true
end
def partials_from(decoded_chunk)

View File

@ -36,7 +36,6 @@ module DiscourseAi
def default_options(dialect)
options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" }
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
options
end
@ -82,6 +81,8 @@ module DiscourseAi
def prepare_payload(prompt, model_params, dialect)
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
payload[:tools] = prompt.tools if prompt.tools.present?
payload
end
@ -142,35 +143,35 @@ module DiscourseAi
end
def final_log_update(log)
log.request_tokens = @input_tokens if @input_tokens
log.response_tokens = @output_tokens if @output_tokens
log.request_tokens = processor.input_tokens if processor.input_tokens
log.response_tokens = processor.output_tokens if processor.output_tokens
end
def processor
@processor ||=
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
processor.to_xml_tool_calls(function_buffer) if !partial
end
def extract_completion_from(response_raw)
result = ""
parsed = JSON.parse(response_raw, symbolize_names: true)
processor.process_message(response_raw)
end
if @streaming_mode
if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
result = parsed.dig(:delta, :text).to_s
elsif parsed[:type] == "message_start"
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
elsif parsed[:type] == "message_delta"
@output_tokens = parsed.dig(:delta, :usage, :output_tokens)
end
else
result = parsed.dig(:content, 0, :text).to_s
@input_tokens = parsed.dig(:usage, :input_tokens)
@output_tokens = parsed.dig(:usage, :output_tokens)
end
result
def has_tool?(_response_data)
processor.tool_calls.present?
end
def partials_from(decoded_chunks)
decoded_chunks
end
def native_tool_support?
true
end
def chunk_to_string(chunk)
joined = +chunk.join("\n")
joined << "\n" if joined.length > 0

View File

@ -242,12 +242,14 @@ module DiscourseAi
else
leftover = ""
end
prev_processed_partials = 0 if leftover.blank?
end
rescue IOError, StandardError
raise if !cancelled
end
has_tool ||= has_tool?(partials_raw)
# Once we have the full response, try to return the tool as a XML doc.
if has_tool && native_tool_support?
function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw)
@ -345,7 +347,7 @@ module DiscourseAi
TEXT
end
def noop_function_call_text
def self.noop_function_call_text
(<<~TEXT).strip
<invoke>
<tool_name></tool_name>
@ -356,6 +358,10 @@ module DiscourseAi
TEXT
end
def noop_function_call_text
self.class.noop_function_call_text
end
def has_tool?(response)
response.include?("<function_calls>")
end

View File

@ -46,7 +46,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
messages = [
{ type: :user, id: "user1", content: "echo something" },
{ type: :tool_call, content: tool_call_prompt.to_json },
{ type: :tool_call, name: "echo", id: "tool_id", content: tool_call_prompt.to_json },
{ type: :tool, id: "tool_id", content: "something".to_json },
{ type: :model, content: "I did it" },
{ type: :user, id: "user1", content: "echo something else" },
@ -63,24 +63,22 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
translated = dialect.translate
expect(translated.system_prompt).to start_with("You are a helpful bot")
expect(translated.system_prompt).to include("echo a string")
expected = [
{ role: "user", content: "user1: echo something" },
{
role: "assistant",
content:
"<function_calls>\n<invoke>\n<tool_name>echo</tool_name>\n<parameters>\n<text>something</text>\n</parameters>\n</invoke>\n</function_calls>",
content: [
{ type: "tool_use", id: "tool_id", name: "echo", input: { text: "something" } },
],
},
{
role: "user",
content:
"<function_results>\n<result>\n<tool_name>tool_id</tool_name>\n<json>\n\"something\"\n</json>\n</result>\n</function_results>",
content: [{ type: "tool_result", tool_use_id: "tool_id", content: "\"something\"" }],
},
{ role: "assistant", content: "I did it" },
{ role: "user", content: "user1: echo something else" },
]
expect(translated.messages).to eq(expected)
end
end

View File

@ -49,152 +49,59 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
it "does not eat spaces with tool calls" do
body = <<~STRING
event: message_start
data: {"type":"message_start","message":{"id":"msg_019kmW9Q3GqfWmuFJbePJTBR","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":347,"output_tokens":1}}}
event: message_start
data: {"type":"message_start","message":{"id":"msg_01Ju4j2MiGQb9KV9EEQ522Y3","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":1293,"output_tokens":1}} }
event: content_block_start
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
event: content_block_start
data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01DjrShFRRHp9SnHYRFRc53F","name":"search","input":{}} }
event: ping
data: {"type": "ping"}
event: ping
data: {"type": "ping"}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<function"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""} }
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\\"searc"} }
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"calls"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"h_qu"} }
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"er"} }
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"y\\": \\"s"} }
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<invoke"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"am"} }
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":" "} }
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"sam\\""} }
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<tool"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":", \\"cate"} }
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"gory"} }
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"name"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"\\": \\"gene"} }
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"ral\\"}"} }
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"google"}}
event: content_block_stop
data: {"type":"content_block_stop","index":0 }
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</tool"}}
event: message_delta
data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":70} }
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"name"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<parameters"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<query"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"top"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"10"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"things"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" do"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" in"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" japan"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" for"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" tourists"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</query"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</parameters"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</invoke"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
event: content_block_stop
data: {"type":"content_block_stop","index":0}
event: message_delta
data: {"type":"message_delta","delta":{"stop_reason":"stop_sequence","stop_sequence":"</function_calls>"},"usage":{"output_tokens":57}}
event: message_stop
data: {"type":"message_stop"}
event: message_stop
data: {"type":"message_stop"}
STRING
result = +""
@ -213,11 +120,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
expected = (<<~TEXT).strip
<function_calls>
<invoke>
<tool_name>google</tool_name>
<parameters>
<query>top 10 things to do in japan for tourists</query>
</parameters>
<tool_id>tool_0</tool_id>
<tool_name>search</tool_name>
<parameters><search_query>sam sam</search_query>
<category>general</category></parameters>
<tool_id>toolu_01DjrShFRRHp9SnHYRFRc53F</tool_id>
</invoke>
</function_calls>
TEXT
@ -285,71 +191,71 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
log = AiApiAuditLog.order(:id).last
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
expect(log.request_tokens).to eq(25)
expect(log.response_tokens).to eq(15)
expect(log.feature_name).to eq("testing")
expect(log.response_tokens).to eq(15)
expect(log.request_tokens).to eq(25)
end
it "can return multiple function calls" do
functions = <<~FUNCTIONS
it "supports non streaming tool calls" do
tool = {
name: "calculate",
description: "calculate something",
parameters: [
{
name: "expression",
type: "string",
description: "expression to calculate",
required: true,
},
],
}
prompt =
DiscourseAi::Completions::Prompt.new(
"You a calculator",
messages: [{ type: :user, id: "user1", content: "calculate 2758975 + 21.11" }],
tools: [tool],
)
proxy = DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-haiku")
body = {
id: "msg_01RdJkxCbsEj9VFyFYAkfy2S",
type: "message",
role: "assistant",
model: "claude-3-haiku-20240307",
content: [
{ type: "text", text: "Here is the calculation:" },
{
type: "tool_use",
id: "toolu_012kBdhG4eHaV68W56p4N94h",
name: "calculate",
input: {
expression: "2758975 + 21.11",
},
},
],
stop_reason: "tool_use",
stop_sequence: nil,
usage: {
input_tokens: 345,
output_tokens: 65,
},
}.to_json
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(body: body)
result = proxy.generate(prompt, user: Discourse.system_user)
expected = <<~TEXT.strip
<function_calls>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>something</text>
</parameters>
</invoke>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>something else</text>
</parameters>
</invoke>
FUNCTIONS
body = <<~STRING
{
"content": [
{
"text": "Hello!\n\n#{functions}\njunk",
"type": "text"
}
],
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
"model": "claude-3-opus-20240229",
"role": "assistant",
"stop_reason": "end_turn",
"stop_sequence": null,
"type": "message",
"usage": {
"input_tokens": 10,
"output_tokens": 25
}
}
STRING
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
result = llm.generate(prompt_with_echo_tool, user: Discourse.system_user)
expected = (<<~EXPECTED).strip
<function_calls>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>something</text>
</parameters>
<tool_id>tool_0</tool_id>
</invoke>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>something else</text>
</parameters>
<tool_id>tool_1</tool_id>
<tool_name>calculate</tool_name>
<parameters><expression>2758975 + 21.11</expression></parameters>
<tool_id>toolu_012kBdhG4eHaV68W56p4N94h</tool_id>
</invoke>
</function_calls>
EXPECTED
TEXT
expect(result.strip).to eq(expected)
end

View File

@ -24,6 +24,175 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
SiteSetting.ai_bedrock_region = "us-east-1"
end
describe "function calling" do
it "supports streaming function calls" do
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
request = nil
messages =
[
{
type: "message_start",
message: {
id: "msg_bdrk_01WYxeNMk6EKn9s98r6XXrAB",
type: "message",
role: "assistant",
model: "claude-3-haiku-20240307",
stop_sequence: nil,
usage: {
input_tokens: 840,
output_tokens: 1,
},
content: [],
stop_reason: nil,
},
},
{
type: "content_block_start",
index: 0,
content_block: {
type: "tool_use",
id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
name: "google",
input: {
},
},
},
{
type: "content_block_delta",
index: 0,
delta: {
type: "input_json_delta",
partial_json: "",
},
},
{
type: "content_block_delta",
index: 0,
delta: {
type: "input_json_delta",
partial_json: "{\"query\": \"s",
},
},
{
type: "content_block_delta",
index: 0,
delta: {
type: "input_json_delta",
partial_json: "ydney weat",
},
},
{
type: "content_block_delta",
index: 0,
delta: {
type: "input_json_delta",
partial_json: "her today\"}",
},
},
{ type: "content_block_stop", index: 0 },
{
type: "message_delta",
delta: {
stop_reason: "tool_use",
stop_sequence: nil,
},
usage: {
output_tokens: 53,
},
},
{
type: "message_stop",
"amazon-bedrock-invocationMetrics": {
inputTokenCount: 846,
outputTokenCount: 39,
invocationLatency: 880,
firstByteLatency: 402,
},
},
].map do |message|
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
io = StringIO.new(wrapped)
aws_message = Aws::EventStream::Message.new(payload: io)
Aws::EventStream::Encoder.new.encode(aws_message)
end
messages = messages.join("").split
bedrock_mock.with_chunk_array_support do
stub_request(
:post,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
)
.with do |inner_request|
request = inner_request
true
end
.to_return(status: 200, body: messages)
prompt =
DiscourseAi::Completions::Prompt.new(
messages: [{ type: :user, content: "what is the weather in sydney" }],
)
tool = {
name: "google",
description: "Will search using Google",
parameters: [
{ name: "query", description: "The search query", type: "string", required: true },
],
}
prompt.tools = [tool]
response = +""
proxy.generate(prompt, user: user) { |partial| response << partial }
expect(request.headers["Authorization"]).to be_present
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
expected_response = (<<~RESPONSE).strip
<function_calls>
<invoke>
<tool_name>google</tool_name>
<parameters><query>sydney weather today</query></parameters>
<tool_id>toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7</tool_id>
</invoke>
</function_calls>
RESPONSE
expect(response.strip).to eq(expected_response)
expected = {
"max_tokens" => 3000,
"anthropic_version" => "bedrock-2023-05-31",
"messages" => [{ "role" => "user", "content" => "what is the weather in sydney" }],
"tools" => [
{
"name" => "google",
"description" => "Will search using Google",
"input_schema" => {
"type" => "object",
"properties" => {
"query" => {
"type" => "string",
"description" => "The search query",
},
},
"required" => ["query"],
},
},
],
}
expect(JSON.parse(request.body)).to eq(expected)
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(846)
expect(log.response_tokens).to eq(39)
end
end
end
describe "Claude 3 Sonnet support" do
it "supports the sonnet model" do
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")