FEATURE: anthropic function calling (#654)
Adds support for native tool calling (both streaming and non streaming) for Anthropic. This improves general tool support on the Anthropic models.
This commit is contained in:
parent
564d2de534
commit
3993c685e1
|
@ -0,0 +1,98 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class DiscourseAi::Completions::AnthropicMessageProcessor
|
||||
class AnthropicToolCall
|
||||
attr_reader :name, :raw_json, :id
|
||||
|
||||
def initialize(name, id)
|
||||
@name = name
|
||||
@id = id
|
||||
@raw_json = +""
|
||||
end
|
||||
|
||||
def append(json)
|
||||
@raw_json << json
|
||||
end
|
||||
end
|
||||
|
||||
attr_reader :tool_calls, :input_tokens, :output_tokens
|
||||
|
||||
def initialize(streaming_mode:)
|
||||
@streaming_mode = streaming_mode
|
||||
@tool_calls = []
|
||||
end
|
||||
|
||||
def to_xml_tool_calls(function_buffer)
|
||||
return function_buffer if @tool_calls.blank?
|
||||
|
||||
function_buffer = Nokogiri::HTML5.fragment(<<~TEXT)
|
||||
<function_calls>
|
||||
</function_calls>
|
||||
TEXT
|
||||
|
||||
@tool_calls.each do |tool_call|
|
||||
node =
|
||||
function_buffer.at("function_calls").add_child(
|
||||
Nokogiri::HTML5::DocumentFragment.parse(
|
||||
DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n",
|
||||
),
|
||||
)
|
||||
|
||||
params = JSON.parse(tool_call.raw_json, symbolize_names: true)
|
||||
xml = params.map { |name, value| "<#{name}>#{value}</#{name}>" }.join("\n")
|
||||
|
||||
node.at("tool_name").content = tool_call.name
|
||||
node.at("tool_id").content = tool_call.id
|
||||
node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present?
|
||||
end
|
||||
|
||||
function_buffer
|
||||
end
|
||||
|
||||
def process_message(payload)
|
||||
result = ""
|
||||
parsed = JSON.parse(payload, symbolize_names: true)
|
||||
|
||||
if @streaming_mode
|
||||
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
|
||||
tool_name = parsed.dig(:content_block, :name)
|
||||
tool_id = parsed.dig(:content_block, :id)
|
||||
@tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name
|
||||
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
||||
if @tool_calls.present?
|
||||
result = parsed.dig(:delta, :partial_json).to_s
|
||||
@tool_calls.last.append(result)
|
||||
else
|
||||
result = parsed.dig(:delta, :text).to_s
|
||||
end
|
||||
elsif parsed[:type] == "message_start"
|
||||
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
|
||||
elsif parsed[:type] == "message_delta"
|
||||
@output_tokens =
|
||||
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
|
||||
elsif parsed[:type] == "message_stop"
|
||||
# bedrock has this ...
|
||||
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
|
||||
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
|
||||
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
|
||||
end
|
||||
end
|
||||
else
|
||||
content = parsed.dig(:content)
|
||||
if content.is_a?(Array)
|
||||
tool_call = content.find { |c| c[:type] == "tool_use" }
|
||||
if tool_call
|
||||
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
|
||||
@tool_calls.last.append(tool_call[:input].to_json)
|
||||
else
|
||||
result = parsed.dig(:content, 0, :text).to_s
|
||||
end
|
||||
end
|
||||
|
||||
@input_tokens = parsed.dig(:usage, :input_tokens)
|
||||
@output_tokens = parsed.dig(:usage, :output_tokens)
|
||||
end
|
||||
|
||||
result
|
||||
end
|
||||
end
|
|
@ -15,10 +15,12 @@ module DiscourseAi
|
|||
class ClaudePrompt
|
||||
attr_reader :system_prompt
|
||||
attr_reader :messages
|
||||
attr_reader :tools
|
||||
|
||||
def initialize(system_prompt, messages)
|
||||
def initialize(system_prompt, messages, tools)
|
||||
@system_prompt = system_prompt
|
||||
@messages = messages
|
||||
@tools = tools
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -46,7 +48,11 @@ module DiscourseAi
|
|||
previous_message = message
|
||||
end
|
||||
|
||||
ClaudePrompt.new(system_prompt.presence, interleving_messages)
|
||||
ClaudePrompt.new(
|
||||
system_prompt.presence,
|
||||
interleving_messages,
|
||||
tools_dialect.translated_tools,
|
||||
)
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
|
@ -58,6 +64,18 @@ module DiscourseAi
|
|||
|
||||
private
|
||||
|
||||
def tools_dialect
|
||||
@tools_dialect ||= DiscourseAi::Completions::Dialects::ClaudeTools.new(prompt.tools)
|
||||
end
|
||||
|
||||
def tool_call_msg(msg)
|
||||
tools_dialect.from_raw_tool_call(msg)
|
||||
end
|
||||
|
||||
def tool_msg(msg)
|
||||
tools_dialect.from_raw_tool(msg)
|
||||
end
|
||||
|
||||
def model_msg(msg)
|
||||
{ role: "assistant", content: msg[:content] }
|
||||
end
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
class ClaudeTools
|
||||
def initialize(tools)
|
||||
@raw_tools = tools
|
||||
end
|
||||
|
||||
def translated_tools
|
||||
# Transform the raw tools into the required Anthropic Claude API format
|
||||
raw_tools.map do |t|
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
if t[:parameters]
|
||||
properties =
|
||||
t[:parameters].each_with_object({}) do |param, h|
|
||||
h[param[:name]] = {
|
||||
type: param[:type],
|
||||
description: param[:description],
|
||||
}.tap { |hash| hash[:items] = { type: param[:item_type] } if param[:item_type] }
|
||||
end
|
||||
required =
|
||||
t[:parameters].select { |param| param[:required] }.map { |param| param[:name] }
|
||||
end
|
||||
|
||||
{
|
||||
name: t[:name],
|
||||
description: t[:description],
|
||||
input_schema: {
|
||||
type: "object",
|
||||
properties: properties,
|
||||
required: required,
|
||||
},
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
def instructions
|
||||
"" # Noop. Tools are listed separate.
|
||||
end
|
||||
|
||||
def from_raw_tool_call(raw_message)
|
||||
call_details = JSON.parse(raw_message[:content], symbolize_names: true)
|
||||
tool_call_id = raw_message[:id]
|
||||
|
||||
{
|
||||
role: "assistant",
|
||||
content: [
|
||||
{
|
||||
type: "tool_use",
|
||||
id: tool_call_id,
|
||||
name: raw_message[:name],
|
||||
input: call_details[:arguments],
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
def from_raw_tool(raw_message)
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "tool_result",
|
||||
tool_use_id: raw_message[:id],
|
||||
content: raw_message[:content],
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
attr_reader :raw_tools
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -45,10 +45,7 @@ module DiscourseAi
|
|||
raise "Unsupported model: #{model}"
|
||||
end
|
||||
|
||||
options = { model: mapped_model, max_tokens: 3_000 }
|
||||
|
||||
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
|
||||
options
|
||||
{ model: mapped_model, max_tokens: 3_000 }
|
||||
end
|
||||
|
||||
def provider_id
|
||||
|
@ -73,6 +70,7 @@ module DiscourseAi
|
|||
|
||||
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
|
||||
payload[:stream] = true if @streaming_mode
|
||||
payload[:tools] = prompt.tools if prompt.tools.present?
|
||||
|
||||
payload
|
||||
end
|
||||
|
@ -87,30 +85,30 @@ module DiscourseAi
|
|||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
end
|
||||
|
||||
def final_log_update(log)
|
||||
log.request_tokens = @input_tokens if @input_tokens
|
||||
log.response_tokens = @output_tokens if @output_tokens
|
||||
def processor
|
||||
@processor ||=
|
||||
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||
processor.to_xml_tool_calls(function_buffer) if !partial
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
result = ""
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||
processor.process_message(response_raw)
|
||||
end
|
||||
|
||||
if @streaming_mode
|
||||
if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
||||
result = parsed.dig(:delta, :text).to_s
|
||||
elsif parsed[:type] == "message_start"
|
||||
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
|
||||
elsif parsed[:type] == "message_delta"
|
||||
@output_tokens = parsed.dig(:delta, :usage, :output_tokens)
|
||||
end
|
||||
else
|
||||
result = parsed.dig(:content, 0, :text).to_s
|
||||
@input_tokens = parsed.dig(:usage, :input_tokens)
|
||||
@output_tokens = parsed.dig(:usage, :output_tokens)
|
||||
end
|
||||
def has_tool?(_response_data)
|
||||
processor.tool_calls.present?
|
||||
end
|
||||
|
||||
result
|
||||
def final_log_update(log)
|
||||
log.request_tokens = processor.input_tokens if processor.input_tokens
|
||||
log.response_tokens = processor.output_tokens if processor.output_tokens
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
true
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
|
|
|
@ -36,7 +36,6 @@ module DiscourseAi
|
|||
|
||||
def default_options(dialect)
|
||||
options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" }
|
||||
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
|
||||
options
|
||||
end
|
||||
|
||||
|
@ -82,6 +81,8 @@ module DiscourseAi
|
|||
def prepare_payload(prompt, model_params, dialect)
|
||||
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
|
||||
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
|
||||
payload[:tools] = prompt.tools if prompt.tools.present?
|
||||
|
||||
payload
|
||||
end
|
||||
|
||||
|
@ -142,35 +143,35 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def final_log_update(log)
|
||||
log.request_tokens = @input_tokens if @input_tokens
|
||||
log.response_tokens = @output_tokens if @output_tokens
|
||||
log.request_tokens = processor.input_tokens if processor.input_tokens
|
||||
log.response_tokens = processor.output_tokens if processor.output_tokens
|
||||
end
|
||||
|
||||
def processor
|
||||
@processor ||=
|
||||
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||
processor.to_xml_tool_calls(function_buffer) if !partial
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
result = ""
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||
processor.process_message(response_raw)
|
||||
end
|
||||
|
||||
if @streaming_mode
|
||||
if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
||||
result = parsed.dig(:delta, :text).to_s
|
||||
elsif parsed[:type] == "message_start"
|
||||
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
|
||||
elsif parsed[:type] == "message_delta"
|
||||
@output_tokens = parsed.dig(:delta, :usage, :output_tokens)
|
||||
end
|
||||
else
|
||||
result = parsed.dig(:content, 0, :text).to_s
|
||||
@input_tokens = parsed.dig(:usage, :input_tokens)
|
||||
@output_tokens = parsed.dig(:usage, :output_tokens)
|
||||
end
|
||||
|
||||
result
|
||||
def has_tool?(_response_data)
|
||||
processor.tool_calls.present?
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunks)
|
||||
decoded_chunks
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
true
|
||||
end
|
||||
|
||||
def chunk_to_string(chunk)
|
||||
joined = +chunk.join("\n")
|
||||
joined << "\n" if joined.length > 0
|
||||
|
|
|
@ -242,12 +242,14 @@ module DiscourseAi
|
|||
else
|
||||
leftover = ""
|
||||
end
|
||||
|
||||
prev_processed_partials = 0 if leftover.blank?
|
||||
end
|
||||
rescue IOError, StandardError
|
||||
raise if !cancelled
|
||||
end
|
||||
|
||||
has_tool ||= has_tool?(partials_raw)
|
||||
# Once we have the full response, try to return the tool as a XML doc.
|
||||
if has_tool && native_tool_support?
|
||||
function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw)
|
||||
|
@ -345,7 +347,7 @@ module DiscourseAi
|
|||
TEXT
|
||||
end
|
||||
|
||||
def noop_function_call_text
|
||||
def self.noop_function_call_text
|
||||
(<<~TEXT).strip
|
||||
<invoke>
|
||||
<tool_name></tool_name>
|
||||
|
@ -356,6 +358,10 @@ module DiscourseAi
|
|||
TEXT
|
||||
end
|
||||
|
||||
def noop_function_call_text
|
||||
self.class.noop_function_call_text
|
||||
end
|
||||
|
||||
def has_tool?(response)
|
||||
response.include?("<function_calls>")
|
||||
end
|
||||
|
|
|
@ -46,7 +46,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
|||
|
||||
messages = [
|
||||
{ type: :user, id: "user1", content: "echo something" },
|
||||
{ type: :tool_call, content: tool_call_prompt.to_json },
|
||||
{ type: :tool_call, name: "echo", id: "tool_id", content: tool_call_prompt.to_json },
|
||||
{ type: :tool, id: "tool_id", content: "something".to_json },
|
||||
{ type: :model, content: "I did it" },
|
||||
{ type: :user, id: "user1", content: "echo something else" },
|
||||
|
@ -63,24 +63,22 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
|||
translated = dialect.translate
|
||||
|
||||
expect(translated.system_prompt).to start_with("You are a helpful bot")
|
||||
expect(translated.system_prompt).to include("echo a string")
|
||||
|
||||
expected = [
|
||||
{ role: "user", content: "user1: echo something" },
|
||||
{
|
||||
role: "assistant",
|
||||
content:
|
||||
"<function_calls>\n<invoke>\n<tool_name>echo</tool_name>\n<parameters>\n<text>something</text>\n</parameters>\n</invoke>\n</function_calls>",
|
||||
content: [
|
||||
{ type: "tool_use", id: "tool_id", name: "echo", input: { text: "something" } },
|
||||
],
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content:
|
||||
"<function_results>\n<result>\n<tool_name>tool_id</tool_name>\n<json>\n\"something\"\n</json>\n</result>\n</function_results>",
|
||||
content: [{ type: "tool_result", tool_use_id: "tool_id", content: "\"something\"" }],
|
||||
},
|
||||
{ role: "assistant", content: "I did it" },
|
||||
{ role: "user", content: "user1: echo something else" },
|
||||
]
|
||||
|
||||
expect(translated.messages).to eq(expected)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -49,152 +49,59 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
|
||||
it "does not eat spaces with tool calls" do
|
||||
body = <<~STRING
|
||||
event: message_start
|
||||
data: {"type":"message_start","message":{"id":"msg_019kmW9Q3GqfWmuFJbePJTBR","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":347,"output_tokens":1}}}
|
||||
event: message_start
|
||||
data: {"type":"message_start","message":{"id":"msg_01Ju4j2MiGQb9KV9EEQ522Y3","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":1293,"output_tokens":1}} }
|
||||
|
||||
event: content_block_start
|
||||
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
|
||||
event: content_block_start
|
||||
data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01DjrShFRRHp9SnHYRFRc53F","name":"search","input":{}} }
|
||||
|
||||
event: ping
|
||||
data: {"type": "ping"}
|
||||
event: ping
|
||||
data: {"type": "ping"}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<function"}}
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""} }
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\\"searc"} }
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"calls"}}
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"h_qu"} }
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"er"} }
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"y\\": \\"s"} }
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<invoke"}}
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"am"} }
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":" "} }
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"sam\\""} }
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<tool"}}
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":", \\"cate"} }
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"gory"} }
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"name"}}
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"\\": \\"gene"} }
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"ral\\"}"} }
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"google"}}
|
||||
event: content_block_stop
|
||||
data: {"type":"content_block_stop","index":0 }
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</tool"}}
|
||||
event: message_delta
|
||||
data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":70} }
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"name"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<parameters"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<query"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"top"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"10"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"things"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" do"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" in"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" japan"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" for"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" tourists"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</query"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</parameters"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</invoke"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_stop
|
||||
data: {"type":"content_block_stop","index":0}
|
||||
|
||||
event: message_delta
|
||||
data: {"type":"message_delta","delta":{"stop_reason":"stop_sequence","stop_sequence":"</function_calls>"},"usage":{"output_tokens":57}}
|
||||
|
||||
event: message_stop
|
||||
data: {"type":"message_stop"}
|
||||
event: message_stop
|
||||
data: {"type":"message_stop"}
|
||||
STRING
|
||||
|
||||
result = +""
|
||||
|
@ -213,11 +120,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
expected = (<<~TEXT).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>google</tool_name>
|
||||
<parameters>
|
||||
<query>top 10 things to do in japan for tourists</query>
|
||||
</parameters>
|
||||
<tool_id>tool_0</tool_id>
|
||||
<tool_name>search</tool_name>
|
||||
<parameters><search_query>sam sam</search_query>
|
||||
<category>general</category></parameters>
|
||||
<tool_id>toolu_01DjrShFRRHp9SnHYRFRc53F</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
|
@ -285,71 +191,71 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
|
||||
expect(log.request_tokens).to eq(25)
|
||||
expect(log.response_tokens).to eq(15)
|
||||
expect(log.feature_name).to eq("testing")
|
||||
expect(log.response_tokens).to eq(15)
|
||||
expect(log.request_tokens).to eq(25)
|
||||
end
|
||||
|
||||
it "can return multiple function calls" do
|
||||
functions = <<~FUNCTIONS
|
||||
it "supports non streaming tool calls" do
|
||||
tool = {
|
||||
name: "calculate",
|
||||
description: "calculate something",
|
||||
parameters: [
|
||||
{
|
||||
name: "expression",
|
||||
type: "string",
|
||||
description: "expression to calculate",
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You a calculator",
|
||||
messages: [{ type: :user, id: "user1", content: "calculate 2758975 + 21.11" }],
|
||||
tools: [tool],
|
||||
)
|
||||
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-haiku")
|
||||
|
||||
body = {
|
||||
id: "msg_01RdJkxCbsEj9VFyFYAkfy2S",
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
model: "claude-3-haiku-20240307",
|
||||
content: [
|
||||
{ type: "text", text: "Here is the calculation:" },
|
||||
{
|
||||
type: "tool_use",
|
||||
id: "toolu_012kBdhG4eHaV68W56p4N94h",
|
||||
name: "calculate",
|
||||
input: {
|
||||
expression: "2758975 + 21.11",
|
||||
},
|
||||
},
|
||||
],
|
||||
stop_reason: "tool_use",
|
||||
stop_sequence: nil,
|
||||
usage: {
|
||||
input_tokens: 345,
|
||||
output_tokens: 65,
|
||||
},
|
||||
}.to_json
|
||||
|
||||
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(body: body)
|
||||
|
||||
result = proxy.generate(prompt, user: Discourse.system_user)
|
||||
|
||||
expected = <<~TEXT.strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>echo</tool_name>
|
||||
<parameters>
|
||||
<text>something</text>
|
||||
</parameters>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>echo</tool_name>
|
||||
<parameters>
|
||||
<text>something else</text>
|
||||
</parameters>
|
||||
</invoke>
|
||||
FUNCTIONS
|
||||
|
||||
body = <<~STRING
|
||||
{
|
||||
"content": [
|
||||
{
|
||||
"text": "Hello!\n\n#{functions}\njunk",
|
||||
"type": "text"
|
||||
}
|
||||
],
|
||||
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||
"model": "claude-3-opus-20240229",
|
||||
"role": "assistant",
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": null,
|
||||
"type": "message",
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 25
|
||||
}
|
||||
}
|
||||
STRING
|
||||
|
||||
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
|
||||
|
||||
result = llm.generate(prompt_with_echo_tool, user: Discourse.system_user)
|
||||
|
||||
expected = (<<~EXPECTED).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>echo</tool_name>
|
||||
<parameters>
|
||||
<text>something</text>
|
||||
</parameters>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>echo</tool_name>
|
||||
<parameters>
|
||||
<text>something else</text>
|
||||
</parameters>
|
||||
<tool_id>tool_1</tool_id>
|
||||
<tool_name>calculate</tool_name>
|
||||
<parameters><expression>2758975 + 21.11</expression></parameters>
|
||||
<tool_id>toolu_012kBdhG4eHaV68W56p4N94h</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
EXPECTED
|
||||
TEXT
|
||||
|
||||
expect(result.strip).to eq(expected)
|
||||
end
|
||||
|
|
|
@ -24,6 +24,175 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
SiteSetting.ai_bedrock_region = "us-east-1"
|
||||
end
|
||||
|
||||
describe "function calling" do
|
||||
it "supports streaming function calls" do
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
||||
|
||||
request = nil
|
||||
|
||||
messages =
|
||||
[
|
||||
{
|
||||
type: "message_start",
|
||||
message: {
|
||||
id: "msg_bdrk_01WYxeNMk6EKn9s98r6XXrAB",
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
model: "claude-3-haiku-20240307",
|
||||
stop_sequence: nil,
|
||||
usage: {
|
||||
input_tokens: 840,
|
||||
output_tokens: 1,
|
||||
},
|
||||
content: [],
|
||||
stop_reason: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: "tool_use",
|
||||
id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
|
||||
name: "google",
|
||||
input: {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "content_block_delta",
|
||||
index: 0,
|
||||
delta: {
|
||||
type: "input_json_delta",
|
||||
partial_json: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "content_block_delta",
|
||||
index: 0,
|
||||
delta: {
|
||||
type: "input_json_delta",
|
||||
partial_json: "{\"query\": \"s",
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "content_block_delta",
|
||||
index: 0,
|
||||
delta: {
|
||||
type: "input_json_delta",
|
||||
partial_json: "ydney weat",
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "content_block_delta",
|
||||
index: 0,
|
||||
delta: {
|
||||
type: "input_json_delta",
|
||||
partial_json: "her today\"}",
|
||||
},
|
||||
},
|
||||
{ type: "content_block_stop", index: 0 },
|
||||
{
|
||||
type: "message_delta",
|
||||
delta: {
|
||||
stop_reason: "tool_use",
|
||||
stop_sequence: nil,
|
||||
},
|
||||
usage: {
|
||||
output_tokens: 53,
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "message_stop",
|
||||
"amazon-bedrock-invocationMetrics": {
|
||||
inputTokenCount: 846,
|
||||
outputTokenCount: 39,
|
||||
invocationLatency: 880,
|
||||
firstByteLatency: 402,
|
||||
},
|
||||
},
|
||||
].map do |message|
|
||||
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
|
||||
io = StringIO.new(wrapped)
|
||||
aws_message = Aws::EventStream::Message.new(payload: io)
|
||||
Aws::EventStream::Encoder.new.encode(aws_message)
|
||||
end
|
||||
|
||||
messages = messages.join("").split
|
||||
|
||||
bedrock_mock.with_chunk_array_support do
|
||||
stub_request(
|
||||
:post,
|
||||
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
|
||||
)
|
||||
.with do |inner_request|
|
||||
request = inner_request
|
||||
true
|
||||
end
|
||||
.to_return(status: 200, body: messages)
|
||||
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
messages: [{ type: :user, content: "what is the weather in sydney" }],
|
||||
)
|
||||
|
||||
tool = {
|
||||
name: "google",
|
||||
description: "Will search using Google",
|
||||
parameters: [
|
||||
{ name: "query", description: "The search query", type: "string", required: true },
|
||||
],
|
||||
}
|
||||
|
||||
prompt.tools = [tool]
|
||||
response = +""
|
||||
proxy.generate(prompt, user: user) { |partial| response << partial }
|
||||
|
||||
expect(request.headers["Authorization"]).to be_present
|
||||
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
||||
|
||||
expected_response = (<<~RESPONSE).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>google</tool_name>
|
||||
<parameters><query>sydney weather today</query></parameters>
|
||||
<tool_id>toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
RESPONSE
|
||||
|
||||
expect(response.strip).to eq(expected_response)
|
||||
|
||||
expected = {
|
||||
"max_tokens" => 3000,
|
||||
"anthropic_version" => "bedrock-2023-05-31",
|
||||
"messages" => [{ "role" => "user", "content" => "what is the weather in sydney" }],
|
||||
"tools" => [
|
||||
{
|
||||
"name" => "google",
|
||||
"description" => "Will search using Google",
|
||||
"input_schema" => {
|
||||
"type" => "object",
|
||||
"properties" => {
|
||||
"query" => {
|
||||
"type" => "string",
|
||||
"description" => "The search query",
|
||||
},
|
||||
},
|
||||
"required" => ["query"],
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
expect(JSON.parse(request.body)).to eq(expected)
|
||||
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
expect(log.request_tokens).to eq(846)
|
||||
expect(log.response_tokens).to eq(39)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "Claude 3 Sonnet support" do
|
||||
it "supports the sonnet model" do
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
||||
|
|
Loading…
Reference in New Issue