diff --git a/lib/completions/anthropic_message_processor.rb b/lib/completions/anthropic_message_processor.rb
new file mode 100644
index 00000000..96a9a169
--- /dev/null
+++ b/lib/completions/anthropic_message_processor.rb
@@ -0,0 +1,98 @@
+# frozen_string_literal: true
+
+class DiscourseAi::Completions::AnthropicMessageProcessor
+ class AnthropicToolCall
+ attr_reader :name, :raw_json, :id
+
+ def initialize(name, id)
+ @name = name
+ @id = id
+ @raw_json = +""
+ end
+
+ def append(json)
+ @raw_json << json
+ end
+ end
+
+ attr_reader :tool_calls, :input_tokens, :output_tokens
+
+ def initialize(streaming_mode:)
+ @streaming_mode = streaming_mode
+ @tool_calls = []
+ end
+
+ def to_xml_tool_calls(function_buffer)
+ return function_buffer if @tool_calls.blank?
+
+ function_buffer = Nokogiri::HTML5.fragment(<<~TEXT)
+
+
+ TEXT
+
+ @tool_calls.each do |tool_call|
+ node =
+ function_buffer.at("function_calls").add_child(
+ Nokogiri::HTML5::DocumentFragment.parse(
+ DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n",
+ ),
+ )
+
+ params = JSON.parse(tool_call.raw_json, symbolize_names: true)
+ xml = params.map { |name, value| "<#{name}>#{value}#{name}>" }.join("\n")
+
+ node.at("tool_name").content = tool_call.name
+ node.at("tool_id").content = tool_call.id
+ node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present?
+ end
+
+ function_buffer
+ end
+
+ def process_message(payload)
+ result = ""
+ parsed = JSON.parse(payload, symbolize_names: true)
+
+ if @streaming_mode
+ if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
+ tool_name = parsed.dig(:content_block, :name)
+ tool_id = parsed.dig(:content_block, :id)
+ @tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name
+ elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
+ if @tool_calls.present?
+ result = parsed.dig(:delta, :partial_json).to_s
+ @tool_calls.last.append(result)
+ else
+ result = parsed.dig(:delta, :text).to_s
+ end
+ elsif parsed[:type] == "message_start"
+ @input_tokens = parsed.dig(:message, :usage, :input_tokens)
+ elsif parsed[:type] == "message_delta"
+ @output_tokens =
+ parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
+ elsif parsed[:type] == "message_stop"
+ # bedrock has this ...
+ if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
+ @input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
+ @output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
+ end
+ end
+ else
+ content = parsed.dig(:content)
+ if content.is_a?(Array)
+ tool_call = content.find { |c| c[:type] == "tool_use" }
+ if tool_call
+ @tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
+ @tool_calls.last.append(tool_call[:input].to_json)
+ else
+ result = parsed.dig(:content, 0, :text).to_s
+ end
+ end
+
+ @input_tokens = parsed.dig(:usage, :input_tokens)
+ @output_tokens = parsed.dig(:usage, :output_tokens)
+ end
+
+ result
+ end
+end
diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb
index d8c13374..6a8a9543 100644
--- a/lib/completions/dialects/claude.rb
+++ b/lib/completions/dialects/claude.rb
@@ -15,10 +15,12 @@ module DiscourseAi
class ClaudePrompt
attr_reader :system_prompt
attr_reader :messages
+ attr_reader :tools
- def initialize(system_prompt, messages)
+ def initialize(system_prompt, messages, tools)
@system_prompt = system_prompt
@messages = messages
+ @tools = tools
end
end
@@ -46,7 +48,11 @@ module DiscourseAi
previous_message = message
end
- ClaudePrompt.new(system_prompt.presence, interleving_messages)
+ ClaudePrompt.new(
+ system_prompt.presence,
+ interleving_messages,
+ tools_dialect.translated_tools,
+ )
end
def max_prompt_tokens
@@ -58,6 +64,18 @@ module DiscourseAi
private
+ def tools_dialect
+ @tools_dialect ||= DiscourseAi::Completions::Dialects::ClaudeTools.new(prompt.tools)
+ end
+
+ def tool_call_msg(msg)
+ tools_dialect.from_raw_tool_call(msg)
+ end
+
+ def tool_msg(msg)
+ tools_dialect.from_raw_tool(msg)
+ end
+
def model_msg(msg)
{ role: "assistant", content: msg[:content] }
end
diff --git a/lib/completions/dialects/claude_tools.rb b/lib/completions/dialects/claude_tools.rb
new file mode 100644
index 00000000..8708497f
--- /dev/null
+++ b/lib/completions/dialects/claude_tools.rb
@@ -0,0 +1,81 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Completions
+ module Dialects
+ class ClaudeTools
+ def initialize(tools)
+ @raw_tools = tools
+ end
+
+ def translated_tools
+ # Transform the raw tools into the required Anthropic Claude API format
+ raw_tools.map do |t|
+ properties = {}
+ required = []
+
+ if t[:parameters]
+ properties =
+ t[:parameters].each_with_object({}) do |param, h|
+ h[param[:name]] = {
+ type: param[:type],
+ description: param[:description],
+ }.tap { |hash| hash[:items] = { type: param[:item_type] } if param[:item_type] }
+ end
+ required =
+ t[:parameters].select { |param| param[:required] }.map { |param| param[:name] }
+ end
+
+ {
+ name: t[:name],
+ description: t[:description],
+ input_schema: {
+ type: "object",
+ properties: properties,
+ required: required,
+ },
+ }
+ end
+ end
+
+ def instructions
+ "" # Noop. Tools are listed separate.
+ end
+
+ def from_raw_tool_call(raw_message)
+ call_details = JSON.parse(raw_message[:content], symbolize_names: true)
+ tool_call_id = raw_message[:id]
+
+ {
+ role: "assistant",
+ content: [
+ {
+ type: "tool_use",
+ id: tool_call_id,
+ name: raw_message[:name],
+ input: call_details[:arguments],
+ },
+ ],
+ }
+ end
+
+ def from_raw_tool(raw_message)
+ {
+ role: "user",
+ content: [
+ {
+ type: "tool_result",
+ tool_use_id: raw_message[:id],
+ content: raw_message[:content],
+ },
+ ],
+ }
+ end
+
+ private
+
+ attr_reader :raw_tools
+ end
+ end
+ end
+end
diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb
index 6253d020..c1ec288f 100644
--- a/lib/completions/endpoints/anthropic.rb
+++ b/lib/completions/endpoints/anthropic.rb
@@ -45,10 +45,7 @@ module DiscourseAi
raise "Unsupported model: #{model}"
end
- options = { model: mapped_model, max_tokens: 3_000 }
-
- options[:stop_sequences] = [""] if dialect.prompt.has_tools?
- options
+ { model: mapped_model, max_tokens: 3_000 }
end
def provider_id
@@ -73,6 +70,7 @@ module DiscourseAi
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
payload[:stream] = true if @streaming_mode
+ payload[:tools] = prompt.tools if prompt.tools.present?
payload
end
@@ -87,30 +85,30 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
- def final_log_update(log)
- log.request_tokens = @input_tokens if @input_tokens
- log.response_tokens = @output_tokens if @output_tokens
+ def processor
+ @processor ||=
+ DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
+ end
+
+ def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
+ processor.to_xml_tool_calls(function_buffer) if !partial
end
def extract_completion_from(response_raw)
- result = ""
- parsed = JSON.parse(response_raw, symbolize_names: true)
+ processor.process_message(response_raw)
+ end
- if @streaming_mode
- if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
- result = parsed.dig(:delta, :text).to_s
- elsif parsed[:type] == "message_start"
- @input_tokens = parsed.dig(:message, :usage, :input_tokens)
- elsif parsed[:type] == "message_delta"
- @output_tokens = parsed.dig(:delta, :usage, :output_tokens)
- end
- else
- result = parsed.dig(:content, 0, :text).to_s
- @input_tokens = parsed.dig(:usage, :input_tokens)
- @output_tokens = parsed.dig(:usage, :output_tokens)
- end
+ def has_tool?(_response_data)
+ processor.tool_calls.present?
+ end
- result
+ def final_log_update(log)
+ log.request_tokens = processor.input_tokens if processor.input_tokens
+ log.response_tokens = processor.output_tokens if processor.output_tokens
+ end
+
+ def native_tool_support?
+ true
end
def partials_from(decoded_chunk)
diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb
index 2a346d95..bbd87749 100644
--- a/lib/completions/endpoints/aws_bedrock.rb
+++ b/lib/completions/endpoints/aws_bedrock.rb
@@ -36,7 +36,6 @@ module DiscourseAi
def default_options(dialect)
options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" }
- options[:stop_sequences] = [""] if dialect.prompt.has_tools?
options
end
@@ -82,6 +81,8 @@ module DiscourseAi
def prepare_payload(prompt, model_params, dialect)
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
+ payload[:tools] = prompt.tools if prompt.tools.present?
+
payload
end
@@ -142,35 +143,35 @@ module DiscourseAi
end
def final_log_update(log)
- log.request_tokens = @input_tokens if @input_tokens
- log.response_tokens = @output_tokens if @output_tokens
+ log.request_tokens = processor.input_tokens if processor.input_tokens
+ log.response_tokens = processor.output_tokens if processor.output_tokens
+ end
+
+ def processor
+ @processor ||=
+ DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
+ end
+
+ def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
+ processor.to_xml_tool_calls(function_buffer) if !partial
end
def extract_completion_from(response_raw)
- result = ""
- parsed = JSON.parse(response_raw, symbolize_names: true)
+ processor.process_message(response_raw)
+ end
- if @streaming_mode
- if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
- result = parsed.dig(:delta, :text).to_s
- elsif parsed[:type] == "message_start"
- @input_tokens = parsed.dig(:message, :usage, :input_tokens)
- elsif parsed[:type] == "message_delta"
- @output_tokens = parsed.dig(:delta, :usage, :output_tokens)
- end
- else
- result = parsed.dig(:content, 0, :text).to_s
- @input_tokens = parsed.dig(:usage, :input_tokens)
- @output_tokens = parsed.dig(:usage, :output_tokens)
- end
-
- result
+ def has_tool?(_response_data)
+ processor.tool_calls.present?
end
def partials_from(decoded_chunks)
decoded_chunks
end
+ def native_tool_support?
+ true
+ end
+
def chunk_to_string(chunk)
joined = +chunk.join("\n")
joined << "\n" if joined.length > 0
diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb
index bd558c7d..8d9d5f68 100644
--- a/lib/completions/endpoints/base.rb
+++ b/lib/completions/endpoints/base.rb
@@ -242,12 +242,14 @@ module DiscourseAi
else
leftover = ""
end
+
prev_processed_partials = 0 if leftover.blank?
end
rescue IOError, StandardError
raise if !cancelled
end
+ has_tool ||= has_tool?(partials_raw)
# Once we have the full response, try to return the tool as a XML doc.
if has_tool && native_tool_support?
function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw)
@@ -345,7 +347,7 @@ module DiscourseAi
TEXT
end
- def noop_function_call_text
+ def self.noop_function_call_text
(<<~TEXT).strip
@@ -356,6 +358,10 @@ module DiscourseAi
TEXT
end
+ def noop_function_call_text
+ self.class.noop_function_call_text
+ end
+
def has_tool?(response)
response.include?("")
end
diff --git a/spec/lib/completions/dialects/claude_spec.rb b/spec/lib/completions/dialects/claude_spec.rb
index af082c22..8eb00256 100644
--- a/spec/lib/completions/dialects/claude_spec.rb
+++ b/spec/lib/completions/dialects/claude_spec.rb
@@ -46,7 +46,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
messages = [
{ type: :user, id: "user1", content: "echo something" },
- { type: :tool_call, content: tool_call_prompt.to_json },
+ { type: :tool_call, name: "echo", id: "tool_id", content: tool_call_prompt.to_json },
{ type: :tool, id: "tool_id", content: "something".to_json },
{ type: :model, content: "I did it" },
{ type: :user, id: "user1", content: "echo something else" },
@@ -63,24 +63,22 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
translated = dialect.translate
expect(translated.system_prompt).to start_with("You are a helpful bot")
- expect(translated.system_prompt).to include("echo a string")
expected = [
{ role: "user", content: "user1: echo something" },
{
role: "assistant",
- content:
- "\n\necho\n\nsomething\n\n\n",
+ content: [
+ { type: "tool_use", id: "tool_id", name: "echo", input: { text: "something" } },
+ ],
},
{
role: "user",
- content:
- "\n\ntool_id\n\n\"something\"\n\n\n",
+ content: [{ type: "tool_result", tool_use_id: "tool_id", content: "\"something\"" }],
},
{ role: "assistant", content: "I did it" },
{ role: "user", content: "user1: echo something else" },
]
-
expect(translated.messages).to eq(expected)
end
end
diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb
index 9faae853..ce185dcf 100644
--- a/spec/lib/completions/endpoints/anthropic_spec.rb
+++ b/spec/lib/completions/endpoints/anthropic_spec.rb
@@ -49,152 +49,59 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
it "does not eat spaces with tool calls" do
body = <<~STRING
- event: message_start
- data: {"type":"message_start","message":{"id":"msg_019kmW9Q3GqfWmuFJbePJTBR","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":347,"output_tokens":1}}}
+ event: message_start
+ data: {"type":"message_start","message":{"id":"msg_01Ju4j2MiGQb9KV9EEQ522Y3","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":1293,"output_tokens":1}} }
- event: content_block_start
- data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
+ event: content_block_start
+ data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01DjrShFRRHp9SnHYRFRc53F","name":"search","input":{}} }
- event: ping
- data: {"type": "ping"}
+ event: ping
+ data: {"type": "ping"}
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"er"} }
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"y\\": \\"s"} }
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":" "} }
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"sam\\""} }
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
+ event: content_block_delta
+ data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"ral\\"}"} }
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"google"}}
+ event: content_block_stop
+ data: {"type":"content_block_stop","index":0 }
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"top"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"10"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"things"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" do"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" in"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" japan"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" for"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" tourists"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}
-
- event: content_block_delta
- data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
-
- event: content_block_stop
- data: {"type":"content_block_stop","index":0}
-
- event: message_delta
- data: {"type":"message_delta","delta":{"stop_reason":"stop_sequence","stop_sequence":""},"usage":{"output_tokens":57}}
-
- event: message_stop
- data: {"type":"message_stop"}
+ event: message_stop
+ data: {"type":"message_stop"}
STRING
result = +""
@@ -213,11 +120,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
expected = (<<~TEXT).strip
- google
-
- top 10 things to do in japan for tourists
-
- tool_0
+ search
+ sam sam
+ general
+ toolu_01DjrShFRRHp9SnHYRFRc53F
TEXT
@@ -285,71 +191,71 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
log = AiApiAuditLog.order(:id).last
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
- expect(log.request_tokens).to eq(25)
- expect(log.response_tokens).to eq(15)
expect(log.feature_name).to eq("testing")
+ expect(log.response_tokens).to eq(15)
+ expect(log.request_tokens).to eq(25)
end
- it "can return multiple function calls" do
- functions = <<~FUNCTIONS
+ it "supports non streaming tool calls" do
+ tool = {
+ name: "calculate",
+ description: "calculate something",
+ parameters: [
+ {
+ name: "expression",
+ type: "string",
+ description: "expression to calculate",
+ required: true,
+ },
+ ],
+ }
+
+ prompt =
+ DiscourseAi::Completions::Prompt.new(
+ "You a calculator",
+ messages: [{ type: :user, id: "user1", content: "calculate 2758975 + 21.11" }],
+ tools: [tool],
+ )
+
+ proxy = DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-haiku")
+
+ body = {
+ id: "msg_01RdJkxCbsEj9VFyFYAkfy2S",
+ type: "message",
+ role: "assistant",
+ model: "claude-3-haiku-20240307",
+ content: [
+ { type: "text", text: "Here is the calculation:" },
+ {
+ type: "tool_use",
+ id: "toolu_012kBdhG4eHaV68W56p4N94h",
+ name: "calculate",
+ input: {
+ expression: "2758975 + 21.11",
+ },
+ },
+ ],
+ stop_reason: "tool_use",
+ stop_sequence: nil,
+ usage: {
+ input_tokens: 345,
+ output_tokens: 65,
+ },
+ }.to_json
+
+ stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(body: body)
+
+ result = proxy.generate(prompt, user: Discourse.system_user)
+
+ expected = <<~TEXT.strip
- echo
-
- something
-
-
-
- echo
-
- something else
-
-
- FUNCTIONS
-
- body = <<~STRING
- {
- "content": [
- {
- "text": "Hello!\n\n#{functions}\njunk",
- "type": "text"
- }
- ],
- "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
- "model": "claude-3-opus-20240229",
- "role": "assistant",
- "stop_reason": "end_turn",
- "stop_sequence": null,
- "type": "message",
- "usage": {
- "input_tokens": 10,
- "output_tokens": 25
- }
- }
- STRING
-
- stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
-
- result = llm.generate(prompt_with_echo_tool, user: Discourse.system_user)
-
- expected = (<<~EXPECTED).strip
-
-
- echo
-
- something
-
- tool_0
-
-
- echo
-
- something else
-
- tool_1
+ calculate
+ 2758975 + 21.11
+ toolu_012kBdhG4eHaV68W56p4N94h
- EXPECTED
+ TEXT
expect(result.strip).to eq(expected)
end
diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb
index 5b87c095..c6c2399b 100644
--- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb
+++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb
@@ -24,6 +24,175 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
SiteSetting.ai_bedrock_region = "us-east-1"
end
+ describe "function calling" do
+ it "supports streaming function calls" do
+ proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
+
+ request = nil
+
+ messages =
+ [
+ {
+ type: "message_start",
+ message: {
+ id: "msg_bdrk_01WYxeNMk6EKn9s98r6XXrAB",
+ type: "message",
+ role: "assistant",
+ model: "claude-3-haiku-20240307",
+ stop_sequence: nil,
+ usage: {
+ input_tokens: 840,
+ output_tokens: 1,
+ },
+ content: [],
+ stop_reason: nil,
+ },
+ },
+ {
+ type: "content_block_start",
+ index: 0,
+ content_block: {
+ type: "tool_use",
+ id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
+ name: "google",
+ input: {
+ },
+ },
+ },
+ {
+ type: "content_block_delta",
+ index: 0,
+ delta: {
+ type: "input_json_delta",
+ partial_json: "",
+ },
+ },
+ {
+ type: "content_block_delta",
+ index: 0,
+ delta: {
+ type: "input_json_delta",
+ partial_json: "{\"query\": \"s",
+ },
+ },
+ {
+ type: "content_block_delta",
+ index: 0,
+ delta: {
+ type: "input_json_delta",
+ partial_json: "ydney weat",
+ },
+ },
+ {
+ type: "content_block_delta",
+ index: 0,
+ delta: {
+ type: "input_json_delta",
+ partial_json: "her today\"}",
+ },
+ },
+ { type: "content_block_stop", index: 0 },
+ {
+ type: "message_delta",
+ delta: {
+ stop_reason: "tool_use",
+ stop_sequence: nil,
+ },
+ usage: {
+ output_tokens: 53,
+ },
+ },
+ {
+ type: "message_stop",
+ "amazon-bedrock-invocationMetrics": {
+ inputTokenCount: 846,
+ outputTokenCount: 39,
+ invocationLatency: 880,
+ firstByteLatency: 402,
+ },
+ },
+ ].map do |message|
+ wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
+ io = StringIO.new(wrapped)
+ aws_message = Aws::EventStream::Message.new(payload: io)
+ Aws::EventStream::Encoder.new.encode(aws_message)
+ end
+
+ messages = messages.join("").split
+
+ bedrock_mock.with_chunk_array_support do
+ stub_request(
+ :post,
+ "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
+ )
+ .with do |inner_request|
+ request = inner_request
+ true
+ end
+ .to_return(status: 200, body: messages)
+
+ prompt =
+ DiscourseAi::Completions::Prompt.new(
+ messages: [{ type: :user, content: "what is the weather in sydney" }],
+ )
+
+ tool = {
+ name: "google",
+ description: "Will search using Google",
+ parameters: [
+ { name: "query", description: "The search query", type: "string", required: true },
+ ],
+ }
+
+ prompt.tools = [tool]
+ response = +""
+ proxy.generate(prompt, user: user) { |partial| response << partial }
+
+ expect(request.headers["Authorization"]).to be_present
+ expect(request.headers["X-Amz-Content-Sha256"]).to be_present
+
+ expected_response = (<<~RESPONSE).strip
+
+
+ google
+ sydney weather today
+ toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7
+
+
+ RESPONSE
+
+ expect(response.strip).to eq(expected_response)
+
+ expected = {
+ "max_tokens" => 3000,
+ "anthropic_version" => "bedrock-2023-05-31",
+ "messages" => [{ "role" => "user", "content" => "what is the weather in sydney" }],
+ "tools" => [
+ {
+ "name" => "google",
+ "description" => "Will search using Google",
+ "input_schema" => {
+ "type" => "object",
+ "properties" => {
+ "query" => {
+ "type" => "string",
+ "description" => "The search query",
+ },
+ },
+ "required" => ["query"],
+ },
+ },
+ ],
+ }
+ expect(JSON.parse(request.body)).to eq(expected)
+
+ log = AiApiAuditLog.order(:id).last
+ expect(log.request_tokens).to eq(846)
+ expect(log.response_tokens).to eq(39)
+ end
+ end
+ end
+
describe "Claude 3 Sonnet support" do
it "supports the sonnet model" do
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")