From 3993c685e1cc02ae9f774f5dd28b800ad2f37967 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 6 Jun 2024 08:34:23 +1000 Subject: [PATCH] FEATURE: anthropic function calling (#654) Adds support for native tool calling (both streaming and non streaming) for Anthropic. This improves general tool support on the Anthropic models. --- .../anthropic_message_processor.rb | 98 ++++++ lib/completions/dialects/claude.rb | 22 +- lib/completions/dialects/claude_tools.rb | 81 +++++ lib/completions/endpoints/anthropic.rb | 44 ++- lib/completions/endpoints/aws_bedrock.rb | 41 +-- lib/completions/endpoints/base.rb | 8 +- spec/lib/completions/dialects/claude_spec.rb | 12 +- .../completions/endpoints/anthropic_spec.rb | 290 ++++++------------ .../completions/endpoints/aws_bedrock_spec.rb | 169 ++++++++++ 9 files changed, 520 insertions(+), 245 deletions(-) create mode 100644 lib/completions/anthropic_message_processor.rb create mode 100644 lib/completions/dialects/claude_tools.rb diff --git a/lib/completions/anthropic_message_processor.rb b/lib/completions/anthropic_message_processor.rb new file mode 100644 index 00000000..96a9a169 --- /dev/null +++ b/lib/completions/anthropic_message_processor.rb @@ -0,0 +1,98 @@ +# frozen_string_literal: true + +class DiscourseAi::Completions::AnthropicMessageProcessor + class AnthropicToolCall + attr_reader :name, :raw_json, :id + + def initialize(name, id) + @name = name + @id = id + @raw_json = +"" + end + + def append(json) + @raw_json << json + end + end + + attr_reader :tool_calls, :input_tokens, :output_tokens + + def initialize(streaming_mode:) + @streaming_mode = streaming_mode + @tool_calls = [] + end + + def to_xml_tool_calls(function_buffer) + return function_buffer if @tool_calls.blank? + + function_buffer = Nokogiri::HTML5.fragment(<<~TEXT) + + + TEXT + + @tool_calls.each do |tool_call| + node = + function_buffer.at("function_calls").add_child( + Nokogiri::HTML5::DocumentFragment.parse( + DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n", + ), + ) + + params = JSON.parse(tool_call.raw_json, symbolize_names: true) + xml = params.map { |name, value| "<#{name}>#{value}" }.join("\n") + + node.at("tool_name").content = tool_call.name + node.at("tool_id").content = tool_call.id + node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present? + end + + function_buffer + end + + def process_message(payload) + result = "" + parsed = JSON.parse(payload, symbolize_names: true) + + if @streaming_mode + if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use" + tool_name = parsed.dig(:content_block, :name) + tool_id = parsed.dig(:content_block, :id) + @tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name + elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" + if @tool_calls.present? + result = parsed.dig(:delta, :partial_json).to_s + @tool_calls.last.append(result) + else + result = parsed.dig(:delta, :text).to_s + end + elsif parsed[:type] == "message_start" + @input_tokens = parsed.dig(:message, :usage, :input_tokens) + elsif parsed[:type] == "message_delta" + @output_tokens = + parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens) + elsif parsed[:type] == "message_stop" + # bedrock has this ... + if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym) + @input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens + @output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens + end + end + else + content = parsed.dig(:content) + if content.is_a?(Array) + tool_call = content.find { |c| c[:type] == "tool_use" } + if tool_call + @tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id]) + @tool_calls.last.append(tool_call[:input].to_json) + else + result = parsed.dig(:content, 0, :text).to_s + end + end + + @input_tokens = parsed.dig(:usage, :input_tokens) + @output_tokens = parsed.dig(:usage, :output_tokens) + end + + result + end +end diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index d8c13374..6a8a9543 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -15,10 +15,12 @@ module DiscourseAi class ClaudePrompt attr_reader :system_prompt attr_reader :messages + attr_reader :tools - def initialize(system_prompt, messages) + def initialize(system_prompt, messages, tools) @system_prompt = system_prompt @messages = messages + @tools = tools end end @@ -46,7 +48,11 @@ module DiscourseAi previous_message = message end - ClaudePrompt.new(system_prompt.presence, interleving_messages) + ClaudePrompt.new( + system_prompt.presence, + interleving_messages, + tools_dialect.translated_tools, + ) end def max_prompt_tokens @@ -58,6 +64,18 @@ module DiscourseAi private + def tools_dialect + @tools_dialect ||= DiscourseAi::Completions::Dialects::ClaudeTools.new(prompt.tools) + end + + def tool_call_msg(msg) + tools_dialect.from_raw_tool_call(msg) + end + + def tool_msg(msg) + tools_dialect.from_raw_tool(msg) + end + def model_msg(msg) { role: "assistant", content: msg[:content] } end diff --git a/lib/completions/dialects/claude_tools.rb b/lib/completions/dialects/claude_tools.rb new file mode 100644 index 00000000..8708497f --- /dev/null +++ b/lib/completions/dialects/claude_tools.rb @@ -0,0 +1,81 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Dialects + class ClaudeTools + def initialize(tools) + @raw_tools = tools + end + + def translated_tools + # Transform the raw tools into the required Anthropic Claude API format + raw_tools.map do |t| + properties = {} + required = [] + + if t[:parameters] + properties = + t[:parameters].each_with_object({}) do |param, h| + h[param[:name]] = { + type: param[:type], + description: param[:description], + }.tap { |hash| hash[:items] = { type: param[:item_type] } if param[:item_type] } + end + required = + t[:parameters].select { |param| param[:required] }.map { |param| param[:name] } + end + + { + name: t[:name], + description: t[:description], + input_schema: { + type: "object", + properties: properties, + required: required, + }, + } + end + end + + def instructions + "" # Noop. Tools are listed separate. + end + + def from_raw_tool_call(raw_message) + call_details = JSON.parse(raw_message[:content], symbolize_names: true) + tool_call_id = raw_message[:id] + + { + role: "assistant", + content: [ + { + type: "tool_use", + id: tool_call_id, + name: raw_message[:name], + input: call_details[:arguments], + }, + ], + } + end + + def from_raw_tool(raw_message) + { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: raw_message[:id], + content: raw_message[:content], + }, + ], + } + end + + private + + attr_reader :raw_tools + end + end + end +end diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index 6253d020..c1ec288f 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -45,10 +45,7 @@ module DiscourseAi raise "Unsupported model: #{model}" end - options = { model: mapped_model, max_tokens: 3_000 } - - options[:stop_sequences] = [""] if dialect.prompt.has_tools? - options + { model: mapped_model, max_tokens: 3_000 } end def provider_id @@ -73,6 +70,7 @@ module DiscourseAi payload[:system] = prompt.system_prompt if prompt.system_prompt.present? payload[:stream] = true if @streaming_mode + payload[:tools] = prompt.tools if prompt.tools.present? payload end @@ -87,30 +85,30 @@ module DiscourseAi Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end - def final_log_update(log) - log.request_tokens = @input_tokens if @input_tokens - log.response_tokens = @output_tokens if @output_tokens + def processor + @processor ||= + DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) + end + + def add_to_function_buffer(function_buffer, partial: nil, payload: nil) + processor.to_xml_tool_calls(function_buffer) if !partial end def extract_completion_from(response_raw) - result = "" - parsed = JSON.parse(response_raw, symbolize_names: true) + processor.process_message(response_raw) + end - if @streaming_mode - if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" - result = parsed.dig(:delta, :text).to_s - elsif parsed[:type] == "message_start" - @input_tokens = parsed.dig(:message, :usage, :input_tokens) - elsif parsed[:type] == "message_delta" - @output_tokens = parsed.dig(:delta, :usage, :output_tokens) - end - else - result = parsed.dig(:content, 0, :text).to_s - @input_tokens = parsed.dig(:usage, :input_tokens) - @output_tokens = parsed.dig(:usage, :output_tokens) - end + def has_tool?(_response_data) + processor.tool_calls.present? + end - result + def final_log_update(log) + log.request_tokens = processor.input_tokens if processor.input_tokens + log.response_tokens = processor.output_tokens if processor.output_tokens + end + + def native_tool_support? + true end def partials_from(decoded_chunk) diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index 2a346d95..bbd87749 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -36,7 +36,6 @@ module DiscourseAi def default_options(dialect) options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" } - options[:stop_sequences] = [""] if dialect.prompt.has_tools? options end @@ -82,6 +81,8 @@ module DiscourseAi def prepare_payload(prompt, model_params, dialect) payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages) payload[:system] = prompt.system_prompt if prompt.system_prompt.present? + payload[:tools] = prompt.tools if prompt.tools.present? + payload end @@ -142,35 +143,35 @@ module DiscourseAi end def final_log_update(log) - log.request_tokens = @input_tokens if @input_tokens - log.response_tokens = @output_tokens if @output_tokens + log.request_tokens = processor.input_tokens if processor.input_tokens + log.response_tokens = processor.output_tokens if processor.output_tokens + end + + def processor + @processor ||= + DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) + end + + def add_to_function_buffer(function_buffer, partial: nil, payload: nil) + processor.to_xml_tool_calls(function_buffer) if !partial end def extract_completion_from(response_raw) - result = "" - parsed = JSON.parse(response_raw, symbolize_names: true) + processor.process_message(response_raw) + end - if @streaming_mode - if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" - result = parsed.dig(:delta, :text).to_s - elsif parsed[:type] == "message_start" - @input_tokens = parsed.dig(:message, :usage, :input_tokens) - elsif parsed[:type] == "message_delta" - @output_tokens = parsed.dig(:delta, :usage, :output_tokens) - end - else - result = parsed.dig(:content, 0, :text).to_s - @input_tokens = parsed.dig(:usage, :input_tokens) - @output_tokens = parsed.dig(:usage, :output_tokens) - end - - result + def has_tool?(_response_data) + processor.tool_calls.present? end def partials_from(decoded_chunks) decoded_chunks end + def native_tool_support? + true + end + def chunk_to_string(chunk) joined = +chunk.join("\n") joined << "\n" if joined.length > 0 diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index bd558c7d..8d9d5f68 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -242,12 +242,14 @@ module DiscourseAi else leftover = "" end + prev_processed_partials = 0 if leftover.blank? end rescue IOError, StandardError raise if !cancelled end + has_tool ||= has_tool?(partials_raw) # Once we have the full response, try to return the tool as a XML doc. if has_tool && native_tool_support? function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw) @@ -345,7 +347,7 @@ module DiscourseAi TEXT end - def noop_function_call_text + def self.noop_function_call_text (<<~TEXT).strip @@ -356,6 +358,10 @@ module DiscourseAi TEXT end + def noop_function_call_text + self.class.noop_function_call_text + end + def has_tool?(response) response.include?("") end diff --git a/spec/lib/completions/dialects/claude_spec.rb b/spec/lib/completions/dialects/claude_spec.rb index af082c22..8eb00256 100644 --- a/spec/lib/completions/dialects/claude_spec.rb +++ b/spec/lib/completions/dialects/claude_spec.rb @@ -46,7 +46,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do messages = [ { type: :user, id: "user1", content: "echo something" }, - { type: :tool_call, content: tool_call_prompt.to_json }, + { type: :tool_call, name: "echo", id: "tool_id", content: tool_call_prompt.to_json }, { type: :tool, id: "tool_id", content: "something".to_json }, { type: :model, content: "I did it" }, { type: :user, id: "user1", content: "echo something else" }, @@ -63,24 +63,22 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do translated = dialect.translate expect(translated.system_prompt).to start_with("You are a helpful bot") - expect(translated.system_prompt).to include("echo a string") expected = [ { role: "user", content: "user1: echo something" }, { role: "assistant", - content: - "\n\necho\n\nsomething\n\n\n", + content: [ + { type: "tool_use", id: "tool_id", name: "echo", input: { text: "something" } }, + ], }, { role: "user", - content: - "\n\ntool_id\n\n\"something\"\n\n\n", + content: [{ type: "tool_result", tool_use_id: "tool_id", content: "\"something\"" }], }, { role: "assistant", content: "I did it" }, { role: "user", content: "user1: echo something else" }, ] - expect(translated.messages).to eq(expected) end end diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index 9faae853..ce185dcf 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -49,152 +49,59 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do it "does not eat spaces with tool calls" do body = <<~STRING - event: message_start - data: {"type":"message_start","message":{"id":"msg_019kmW9Q3GqfWmuFJbePJTBR","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":347,"output_tokens":1}}} + event: message_start + data: {"type":"message_start","message":{"id":"msg_01Ju4j2MiGQb9KV9EEQ522Y3","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":1293,"output_tokens":1}} } - event: content_block_start - data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} + event: content_block_start + data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01DjrShFRRHp9SnHYRFRc53F","name":"search","input":{}} } - event: ping - data: {"type": "ping"} + event: ping + data: {"type": "ping"} - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"er"} } - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"y\\": \\"s"} } - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":" "} } - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"sam\\""} } - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"ral\\"}"} } - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"google"}} + event: content_block_stop + data: {"type":"content_block_stop","index":0 } - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"top"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"10"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"things"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" do"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" in"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" japan"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" for"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" tourists"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}} - - event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}} - - event: content_block_stop - data: {"type":"content_block_stop","index":0} - - event: message_delta - data: {"type":"message_delta","delta":{"stop_reason":"stop_sequence","stop_sequence":""},"usage":{"output_tokens":57}} - - event: message_stop - data: {"type":"message_stop"} + event: message_stop + data: {"type":"message_stop"} STRING result = +"" @@ -213,11 +120,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do expected = (<<~TEXT).strip - google - - top 10 things to do in japan for tourists - - tool_0 + search + sam sam + general + toolu_01DjrShFRRHp9SnHYRFRc53F TEXT @@ -285,71 +191,71 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do log = AiApiAuditLog.order(:id).last expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic) - expect(log.request_tokens).to eq(25) - expect(log.response_tokens).to eq(15) expect(log.feature_name).to eq("testing") + expect(log.response_tokens).to eq(15) + expect(log.request_tokens).to eq(25) end - it "can return multiple function calls" do - functions = <<~FUNCTIONS + it "supports non streaming tool calls" do + tool = { + name: "calculate", + description: "calculate something", + parameters: [ + { + name: "expression", + type: "string", + description: "expression to calculate", + required: true, + }, + ], + } + + prompt = + DiscourseAi::Completions::Prompt.new( + "You a calculator", + messages: [{ type: :user, id: "user1", content: "calculate 2758975 + 21.11" }], + tools: [tool], + ) + + proxy = DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-haiku") + + body = { + id: "msg_01RdJkxCbsEj9VFyFYAkfy2S", + type: "message", + role: "assistant", + model: "claude-3-haiku-20240307", + content: [ + { type: "text", text: "Here is the calculation:" }, + { + type: "tool_use", + id: "toolu_012kBdhG4eHaV68W56p4N94h", + name: "calculate", + input: { + expression: "2758975 + 21.11", + }, + }, + ], + stop_reason: "tool_use", + stop_sequence: nil, + usage: { + input_tokens: 345, + output_tokens: 65, + }, + }.to_json + + stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(body: body) + + result = proxy.generate(prompt, user: Discourse.system_user) + + expected = <<~TEXT.strip - echo - - something - - - - echo - - something else - - - FUNCTIONS - - body = <<~STRING - { - "content": [ - { - "text": "Hello!\n\n#{functions}\njunk", - "type": "text" - } - ], - "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF", - "model": "claude-3-opus-20240229", - "role": "assistant", - "stop_reason": "end_turn", - "stop_sequence": null, - "type": "message", - "usage": { - "input_tokens": 10, - "output_tokens": 25 - } - } - STRING - - stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body) - - result = llm.generate(prompt_with_echo_tool, user: Discourse.system_user) - - expected = (<<~EXPECTED).strip - - - echo - - something - - tool_0 - - - echo - - something else - - tool_1 + calculate + 2758975 + 21.11 + toolu_012kBdhG4eHaV68W56p4N94h - EXPECTED + TEXT expect(result.strip).to eq(expected) end diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index 5b87c095..c6c2399b 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -24,6 +24,175 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do SiteSetting.ai_bedrock_region = "us-east-1" end + describe "function calling" do + it "supports streaming function calls" do + proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet") + + request = nil + + messages = + [ + { + type: "message_start", + message: { + id: "msg_bdrk_01WYxeNMk6EKn9s98r6XXrAB", + type: "message", + role: "assistant", + model: "claude-3-haiku-20240307", + stop_sequence: nil, + usage: { + input_tokens: 840, + output_tokens: 1, + }, + content: [], + stop_reason: nil, + }, + }, + { + type: "content_block_start", + index: 0, + content_block: { + type: "tool_use", + id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7", + name: "google", + input: { + }, + }, + }, + { + type: "content_block_delta", + index: 0, + delta: { + type: "input_json_delta", + partial_json: "", + }, + }, + { + type: "content_block_delta", + index: 0, + delta: { + type: "input_json_delta", + partial_json: "{\"query\": \"s", + }, + }, + { + type: "content_block_delta", + index: 0, + delta: { + type: "input_json_delta", + partial_json: "ydney weat", + }, + }, + { + type: "content_block_delta", + index: 0, + delta: { + type: "input_json_delta", + partial_json: "her today\"}", + }, + }, + { type: "content_block_stop", index: 0 }, + { + type: "message_delta", + delta: { + stop_reason: "tool_use", + stop_sequence: nil, + }, + usage: { + output_tokens: 53, + }, + }, + { + type: "message_stop", + "amazon-bedrock-invocationMetrics": { + inputTokenCount: 846, + outputTokenCount: 39, + invocationLatency: 880, + firstByteLatency: 402, + }, + }, + ].map do |message| + wrapped = { bytes: Base64.encode64(message.to_json) }.to_json + io = StringIO.new(wrapped) + aws_message = Aws::EventStream::Message.new(payload: io) + Aws::EventStream::Encoder.new.encode(aws_message) + end + + messages = messages.join("").split + + bedrock_mock.with_chunk_array_support do + stub_request( + :post, + "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream", + ) + .with do |inner_request| + request = inner_request + true + end + .to_return(status: 200, body: messages) + + prompt = + DiscourseAi::Completions::Prompt.new( + messages: [{ type: :user, content: "what is the weather in sydney" }], + ) + + tool = { + name: "google", + description: "Will search using Google", + parameters: [ + { name: "query", description: "The search query", type: "string", required: true }, + ], + } + + prompt.tools = [tool] + response = +"" + proxy.generate(prompt, user: user) { |partial| response << partial } + + expect(request.headers["Authorization"]).to be_present + expect(request.headers["X-Amz-Content-Sha256"]).to be_present + + expected_response = (<<~RESPONSE).strip + + + google + sydney weather today + toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7 + + + RESPONSE + + expect(response.strip).to eq(expected_response) + + expected = { + "max_tokens" => 3000, + "anthropic_version" => "bedrock-2023-05-31", + "messages" => [{ "role" => "user", "content" => "what is the weather in sydney" }], + "tools" => [ + { + "name" => "google", + "description" => "Will search using Google", + "input_schema" => { + "type" => "object", + "properties" => { + "query" => { + "type" => "string", + "description" => "The search query", + }, + }, + "required" => ["query"], + }, + }, + ], + } + expect(JSON.parse(request.body)).to eq(expected) + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(846) + expect(log.response_tokens).to eq(39) + end + end + end + describe "Claude 3 Sonnet support" do it "supports the sonnet model" do proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")