FEATURE: anthropic function calling (#654)
Adds support for native tool calling (both streaming and non streaming) for Anthropic. This improves general tool support on the Anthropic models.
This commit is contained in:
parent
564d2de534
commit
3993c685e1
|
@ -0,0 +1,98 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class DiscourseAi::Completions::AnthropicMessageProcessor
|
||||||
|
class AnthropicToolCall
|
||||||
|
attr_reader :name, :raw_json, :id
|
||||||
|
|
||||||
|
def initialize(name, id)
|
||||||
|
@name = name
|
||||||
|
@id = id
|
||||||
|
@raw_json = +""
|
||||||
|
end
|
||||||
|
|
||||||
|
def append(json)
|
||||||
|
@raw_json << json
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
attr_reader :tool_calls, :input_tokens, :output_tokens
|
||||||
|
|
||||||
|
def initialize(streaming_mode:)
|
||||||
|
@streaming_mode = streaming_mode
|
||||||
|
@tool_calls = []
|
||||||
|
end
|
||||||
|
|
||||||
|
def to_xml_tool_calls(function_buffer)
|
||||||
|
return function_buffer if @tool_calls.blank?
|
||||||
|
|
||||||
|
function_buffer = Nokogiri::HTML5.fragment(<<~TEXT)
|
||||||
|
<function_calls>
|
||||||
|
</function_calls>
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
@tool_calls.each do |tool_call|
|
||||||
|
node =
|
||||||
|
function_buffer.at("function_calls").add_child(
|
||||||
|
Nokogiri::HTML5::DocumentFragment.parse(
|
||||||
|
DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
params = JSON.parse(tool_call.raw_json, symbolize_names: true)
|
||||||
|
xml = params.map { |name, value| "<#{name}>#{value}</#{name}>" }.join("\n")
|
||||||
|
|
||||||
|
node.at("tool_name").content = tool_call.name
|
||||||
|
node.at("tool_id").content = tool_call.id
|
||||||
|
node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present?
|
||||||
|
end
|
||||||
|
|
||||||
|
function_buffer
|
||||||
|
end
|
||||||
|
|
||||||
|
def process_message(payload)
|
||||||
|
result = ""
|
||||||
|
parsed = JSON.parse(payload, symbolize_names: true)
|
||||||
|
|
||||||
|
if @streaming_mode
|
||||||
|
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
|
||||||
|
tool_name = parsed.dig(:content_block, :name)
|
||||||
|
tool_id = parsed.dig(:content_block, :id)
|
||||||
|
@tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name
|
||||||
|
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
||||||
|
if @tool_calls.present?
|
||||||
|
result = parsed.dig(:delta, :partial_json).to_s
|
||||||
|
@tool_calls.last.append(result)
|
||||||
|
else
|
||||||
|
result = parsed.dig(:delta, :text).to_s
|
||||||
|
end
|
||||||
|
elsif parsed[:type] == "message_start"
|
||||||
|
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
|
||||||
|
elsif parsed[:type] == "message_delta"
|
||||||
|
@output_tokens =
|
||||||
|
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
|
||||||
|
elsif parsed[:type] == "message_stop"
|
||||||
|
# bedrock has this ...
|
||||||
|
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
|
||||||
|
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
|
||||||
|
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
|
||||||
|
end
|
||||||
|
end
|
||||||
|
else
|
||||||
|
content = parsed.dig(:content)
|
||||||
|
if content.is_a?(Array)
|
||||||
|
tool_call = content.find { |c| c[:type] == "tool_use" }
|
||||||
|
if tool_call
|
||||||
|
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
|
||||||
|
@tool_calls.last.append(tool_call[:input].to_json)
|
||||||
|
else
|
||||||
|
result = parsed.dig(:content, 0, :text).to_s
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
@input_tokens = parsed.dig(:usage, :input_tokens)
|
||||||
|
@output_tokens = parsed.dig(:usage, :output_tokens)
|
||||||
|
end
|
||||||
|
|
||||||
|
result
|
||||||
|
end
|
||||||
|
end
|
|
@ -15,10 +15,12 @@ module DiscourseAi
|
||||||
class ClaudePrompt
|
class ClaudePrompt
|
||||||
attr_reader :system_prompt
|
attr_reader :system_prompt
|
||||||
attr_reader :messages
|
attr_reader :messages
|
||||||
|
attr_reader :tools
|
||||||
|
|
||||||
def initialize(system_prompt, messages)
|
def initialize(system_prompt, messages, tools)
|
||||||
@system_prompt = system_prompt
|
@system_prompt = system_prompt
|
||||||
@messages = messages
|
@messages = messages
|
||||||
|
@tools = tools
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -46,7 +48,11 @@ module DiscourseAi
|
||||||
previous_message = message
|
previous_message = message
|
||||||
end
|
end
|
||||||
|
|
||||||
ClaudePrompt.new(system_prompt.presence, interleving_messages)
|
ClaudePrompt.new(
|
||||||
|
system_prompt.presence,
|
||||||
|
interleving_messages,
|
||||||
|
tools_dialect.translated_tools,
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
def max_prompt_tokens
|
def max_prompt_tokens
|
||||||
|
@ -58,6 +64,18 @@ module DiscourseAi
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
|
def tools_dialect
|
||||||
|
@tools_dialect ||= DiscourseAi::Completions::Dialects::ClaudeTools.new(prompt.tools)
|
||||||
|
end
|
||||||
|
|
||||||
|
def tool_call_msg(msg)
|
||||||
|
tools_dialect.from_raw_tool_call(msg)
|
||||||
|
end
|
||||||
|
|
||||||
|
def tool_msg(msg)
|
||||||
|
tools_dialect.from_raw_tool(msg)
|
||||||
|
end
|
||||||
|
|
||||||
def model_msg(msg)
|
def model_msg(msg)
|
||||||
{ role: "assistant", content: msg[:content] }
|
{ role: "assistant", content: msg[:content] }
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
module Dialects
|
||||||
|
class ClaudeTools
|
||||||
|
def initialize(tools)
|
||||||
|
@raw_tools = tools
|
||||||
|
end
|
||||||
|
|
||||||
|
def translated_tools
|
||||||
|
# Transform the raw tools into the required Anthropic Claude API format
|
||||||
|
raw_tools.map do |t|
|
||||||
|
properties = {}
|
||||||
|
required = []
|
||||||
|
|
||||||
|
if t[:parameters]
|
||||||
|
properties =
|
||||||
|
t[:parameters].each_with_object({}) do |param, h|
|
||||||
|
h[param[:name]] = {
|
||||||
|
type: param[:type],
|
||||||
|
description: param[:description],
|
||||||
|
}.tap { |hash| hash[:items] = { type: param[:item_type] } if param[:item_type] }
|
||||||
|
end
|
||||||
|
required =
|
||||||
|
t[:parameters].select { |param| param[:required] }.map { |param| param[:name] }
|
||||||
|
end
|
||||||
|
|
||||||
|
{
|
||||||
|
name: t[:name],
|
||||||
|
description: t[:description],
|
||||||
|
input_schema: {
|
||||||
|
type: "object",
|
||||||
|
properties: properties,
|
||||||
|
required: required,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def instructions
|
||||||
|
"" # Noop. Tools are listed separate.
|
||||||
|
end
|
||||||
|
|
||||||
|
def from_raw_tool_call(raw_message)
|
||||||
|
call_details = JSON.parse(raw_message[:content], symbolize_names: true)
|
||||||
|
tool_call_id = raw_message[:id]
|
||||||
|
|
||||||
|
{
|
||||||
|
role: "assistant",
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: "tool_use",
|
||||||
|
id: tool_call_id,
|
||||||
|
name: raw_message[:name],
|
||||||
|
input: call_details[:arguments],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def from_raw_tool(raw_message)
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: "tool_result",
|
||||||
|
tool_use_id: raw_message[:id],
|
||||||
|
content: raw_message[:content],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
attr_reader :raw_tools
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -45,10 +45,7 @@ module DiscourseAi
|
||||||
raise "Unsupported model: #{model}"
|
raise "Unsupported model: #{model}"
|
||||||
end
|
end
|
||||||
|
|
||||||
options = { model: mapped_model, max_tokens: 3_000 }
|
{ model: mapped_model, max_tokens: 3_000 }
|
||||||
|
|
||||||
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
|
|
||||||
options
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def provider_id
|
def provider_id
|
||||||
|
@ -73,6 +70,7 @@ module DiscourseAi
|
||||||
|
|
||||||
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
|
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
|
||||||
payload[:stream] = true if @streaming_mode
|
payload[:stream] = true if @streaming_mode
|
||||||
|
payload[:tools] = prompt.tools if prompt.tools.present?
|
||||||
|
|
||||||
payload
|
payload
|
||||||
end
|
end
|
||||||
|
@ -87,30 +85,30 @@ module DiscourseAi
|
||||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||||
end
|
end
|
||||||
|
|
||||||
def final_log_update(log)
|
def processor
|
||||||
log.request_tokens = @input_tokens if @input_tokens
|
@processor ||=
|
||||||
log.response_tokens = @output_tokens if @output_tokens
|
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
|
||||||
|
end
|
||||||
|
|
||||||
|
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||||
|
processor.to_xml_tool_calls(function_buffer) if !partial
|
||||||
end
|
end
|
||||||
|
|
||||||
def extract_completion_from(response_raw)
|
def extract_completion_from(response_raw)
|
||||||
result = ""
|
processor.process_message(response_raw)
|
||||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
|
||||||
|
|
||||||
if @streaming_mode
|
|
||||||
if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
|
||||||
result = parsed.dig(:delta, :text).to_s
|
|
||||||
elsif parsed[:type] == "message_start"
|
|
||||||
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
|
|
||||||
elsif parsed[:type] == "message_delta"
|
|
||||||
@output_tokens = parsed.dig(:delta, :usage, :output_tokens)
|
|
||||||
end
|
|
||||||
else
|
|
||||||
result = parsed.dig(:content, 0, :text).to_s
|
|
||||||
@input_tokens = parsed.dig(:usage, :input_tokens)
|
|
||||||
@output_tokens = parsed.dig(:usage, :output_tokens)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
result
|
def has_tool?(_response_data)
|
||||||
|
processor.tool_calls.present?
|
||||||
|
end
|
||||||
|
|
||||||
|
def final_log_update(log)
|
||||||
|
log.request_tokens = processor.input_tokens if processor.input_tokens
|
||||||
|
log.response_tokens = processor.output_tokens if processor.output_tokens
|
||||||
|
end
|
||||||
|
|
||||||
|
def native_tool_support?
|
||||||
|
true
|
||||||
end
|
end
|
||||||
|
|
||||||
def partials_from(decoded_chunk)
|
def partials_from(decoded_chunk)
|
||||||
|
|
|
@ -36,7 +36,6 @@ module DiscourseAi
|
||||||
|
|
||||||
def default_options(dialect)
|
def default_options(dialect)
|
||||||
options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" }
|
options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" }
|
||||||
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
|
|
||||||
options
|
options
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -82,6 +81,8 @@ module DiscourseAi
|
||||||
def prepare_payload(prompt, model_params, dialect)
|
def prepare_payload(prompt, model_params, dialect)
|
||||||
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
|
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
|
||||||
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
|
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
|
||||||
|
payload[:tools] = prompt.tools if prompt.tools.present?
|
||||||
|
|
||||||
payload
|
payload
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -142,35 +143,35 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def final_log_update(log)
|
def final_log_update(log)
|
||||||
log.request_tokens = @input_tokens if @input_tokens
|
log.request_tokens = processor.input_tokens if processor.input_tokens
|
||||||
log.response_tokens = @output_tokens if @output_tokens
|
log.response_tokens = processor.output_tokens if processor.output_tokens
|
||||||
|
end
|
||||||
|
|
||||||
|
def processor
|
||||||
|
@processor ||=
|
||||||
|
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
|
||||||
|
end
|
||||||
|
|
||||||
|
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||||
|
processor.to_xml_tool_calls(function_buffer) if !partial
|
||||||
end
|
end
|
||||||
|
|
||||||
def extract_completion_from(response_raw)
|
def extract_completion_from(response_raw)
|
||||||
result = ""
|
processor.process_message(response_raw)
|
||||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
|
||||||
|
|
||||||
if @streaming_mode
|
|
||||||
if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
|
||||||
result = parsed.dig(:delta, :text).to_s
|
|
||||||
elsif parsed[:type] == "message_start"
|
|
||||||
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
|
|
||||||
elsif parsed[:type] == "message_delta"
|
|
||||||
@output_tokens = parsed.dig(:delta, :usage, :output_tokens)
|
|
||||||
end
|
|
||||||
else
|
|
||||||
result = parsed.dig(:content, 0, :text).to_s
|
|
||||||
@input_tokens = parsed.dig(:usage, :input_tokens)
|
|
||||||
@output_tokens = parsed.dig(:usage, :output_tokens)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
result
|
def has_tool?(_response_data)
|
||||||
|
processor.tool_calls.present?
|
||||||
end
|
end
|
||||||
|
|
||||||
def partials_from(decoded_chunks)
|
def partials_from(decoded_chunks)
|
||||||
decoded_chunks
|
decoded_chunks
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def native_tool_support?
|
||||||
|
true
|
||||||
|
end
|
||||||
|
|
||||||
def chunk_to_string(chunk)
|
def chunk_to_string(chunk)
|
||||||
joined = +chunk.join("\n")
|
joined = +chunk.join("\n")
|
||||||
joined << "\n" if joined.length > 0
|
joined << "\n" if joined.length > 0
|
||||||
|
|
|
@ -242,12 +242,14 @@ module DiscourseAi
|
||||||
else
|
else
|
||||||
leftover = ""
|
leftover = ""
|
||||||
end
|
end
|
||||||
|
|
||||||
prev_processed_partials = 0 if leftover.blank?
|
prev_processed_partials = 0 if leftover.blank?
|
||||||
end
|
end
|
||||||
rescue IOError, StandardError
|
rescue IOError, StandardError
|
||||||
raise if !cancelled
|
raise if !cancelled
|
||||||
end
|
end
|
||||||
|
|
||||||
|
has_tool ||= has_tool?(partials_raw)
|
||||||
# Once we have the full response, try to return the tool as a XML doc.
|
# Once we have the full response, try to return the tool as a XML doc.
|
||||||
if has_tool && native_tool_support?
|
if has_tool && native_tool_support?
|
||||||
function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw)
|
function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw)
|
||||||
|
@ -345,7 +347,7 @@ module DiscourseAi
|
||||||
TEXT
|
TEXT
|
||||||
end
|
end
|
||||||
|
|
||||||
def noop_function_call_text
|
def self.noop_function_call_text
|
||||||
(<<~TEXT).strip
|
(<<~TEXT).strip
|
||||||
<invoke>
|
<invoke>
|
||||||
<tool_name></tool_name>
|
<tool_name></tool_name>
|
||||||
|
@ -356,6 +358,10 @@ module DiscourseAi
|
||||||
TEXT
|
TEXT
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def noop_function_call_text
|
||||||
|
self.class.noop_function_call_text
|
||||||
|
end
|
||||||
|
|
||||||
def has_tool?(response)
|
def has_tool?(response)
|
||||||
response.include?("<function_calls>")
|
response.include?("<function_calls>")
|
||||||
end
|
end
|
||||||
|
|
|
@ -46,7 +46,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{ type: :user, id: "user1", content: "echo something" },
|
{ type: :user, id: "user1", content: "echo something" },
|
||||||
{ type: :tool_call, content: tool_call_prompt.to_json },
|
{ type: :tool_call, name: "echo", id: "tool_id", content: tool_call_prompt.to_json },
|
||||||
{ type: :tool, id: "tool_id", content: "something".to_json },
|
{ type: :tool, id: "tool_id", content: "something".to_json },
|
||||||
{ type: :model, content: "I did it" },
|
{ type: :model, content: "I did it" },
|
||||||
{ type: :user, id: "user1", content: "echo something else" },
|
{ type: :user, id: "user1", content: "echo something else" },
|
||||||
|
@ -63,24 +63,22 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
||||||
translated = dialect.translate
|
translated = dialect.translate
|
||||||
|
|
||||||
expect(translated.system_prompt).to start_with("You are a helpful bot")
|
expect(translated.system_prompt).to start_with("You are a helpful bot")
|
||||||
expect(translated.system_prompt).to include("echo a string")
|
|
||||||
|
|
||||||
expected = [
|
expected = [
|
||||||
{ role: "user", content: "user1: echo something" },
|
{ role: "user", content: "user1: echo something" },
|
||||||
{
|
{
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content:
|
content: [
|
||||||
"<function_calls>\n<invoke>\n<tool_name>echo</tool_name>\n<parameters>\n<text>something</text>\n</parameters>\n</invoke>\n</function_calls>",
|
{ type: "tool_use", id: "tool_id", name: "echo", input: { text: "something" } },
|
||||||
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
role: "user",
|
role: "user",
|
||||||
content:
|
content: [{ type: "tool_result", tool_use_id: "tool_id", content: "\"something\"" }],
|
||||||
"<function_results>\n<result>\n<tool_name>tool_id</tool_name>\n<json>\n\"something\"\n</json>\n</result>\n</function_results>",
|
|
||||||
},
|
},
|
||||||
{ role: "assistant", content: "I did it" },
|
{ role: "assistant", content: "I did it" },
|
||||||
{ role: "user", content: "user1: echo something else" },
|
{ role: "user", content: "user1: echo something else" },
|
||||||
]
|
]
|
||||||
|
|
||||||
expect(translated.messages).to eq(expected)
|
expect(translated.messages).to eq(expected)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -50,148 +50,55 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
it "does not eat spaces with tool calls" do
|
it "does not eat spaces with tool calls" do
|
||||||
body = <<~STRING
|
body = <<~STRING
|
||||||
event: message_start
|
event: message_start
|
||||||
data: {"type":"message_start","message":{"id":"msg_019kmW9Q3GqfWmuFJbePJTBR","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":347,"output_tokens":1}}}
|
data: {"type":"message_start","message":{"id":"msg_01Ju4j2MiGQb9KV9EEQ522Y3","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":1293,"output_tokens":1}} }
|
||||||
|
|
||||||
event: content_block_start
|
event: content_block_start
|
||||||
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
|
data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01DjrShFRRHp9SnHYRFRc53F","name":"search","input":{}} }
|
||||||
|
|
||||||
event: ping
|
event: ping
|
||||||
data: {"type": "ping"}
|
data: {"type": "ping"}
|
||||||
|
|
||||||
event: content_block_delta
|
event: content_block_delta
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<function"}}
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""} }
|
||||||
|
|
||||||
event: content_block_delta
|
event: content_block_delta
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{\\"searc"} }
|
||||||
|
|
||||||
event: content_block_delta
|
event: content_block_delta
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"calls"}}
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"h_qu"} }
|
||||||
|
|
||||||
event: content_block_delta
|
event: content_block_delta
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"er"} }
|
||||||
|
|
||||||
event: content_block_delta
|
event: content_block_delta
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"y\\": \\"s"} }
|
||||||
|
|
||||||
event: content_block_delta
|
event: content_block_delta
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<invoke"}}
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"am"} }
|
||||||
|
|
||||||
event: content_block_delta
|
event: content_block_delta
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":" "} }
|
||||||
|
|
||||||
event: content_block_delta
|
event: content_block_delta
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"sam\\""} }
|
||||||
|
|
||||||
event: content_block_delta
|
event: content_block_delta
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<tool"}}
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":", \\"cate"} }
|
||||||
|
|
||||||
event: content_block_delta
|
event: content_block_delta
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"gory"} }
|
||||||
|
|
||||||
event: content_block_delta
|
event: content_block_delta
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"name"}}
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"\\": \\"gene"} }
|
||||||
|
|
||||||
event: content_block_delta
|
event: content_block_delta
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"ral\\"}"} }
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"google"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</tool"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"name"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<parameters"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<query"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"top"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"10"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"things"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" do"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" in"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" japan"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" for"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" tourists"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</query"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</parameters"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</invoke"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
|
||||||
|
|
||||||
event: content_block_delta
|
|
||||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
|
||||||
|
|
||||||
event: content_block_stop
|
event: content_block_stop
|
||||||
data: {"type":"content_block_stop","index":0 }
|
data: {"type":"content_block_stop","index":0 }
|
||||||
|
|
||||||
event: message_delta
|
event: message_delta
|
||||||
data: {"type":"message_delta","delta":{"stop_reason":"stop_sequence","stop_sequence":"</function_calls>"},"usage":{"output_tokens":57}}
|
data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":70} }
|
||||||
|
|
||||||
event: message_stop
|
event: message_stop
|
||||||
data: {"type":"message_stop"}
|
data: {"type":"message_stop"}
|
||||||
|
@ -213,11 +120,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
expected = (<<~TEXT).strip
|
expected = (<<~TEXT).strip
|
||||||
<function_calls>
|
<function_calls>
|
||||||
<invoke>
|
<invoke>
|
||||||
<tool_name>google</tool_name>
|
<tool_name>search</tool_name>
|
||||||
<parameters>
|
<parameters><search_query>sam sam</search_query>
|
||||||
<query>top 10 things to do in japan for tourists</query>
|
<category>general</category></parameters>
|
||||||
</parameters>
|
<tool_id>toolu_01DjrShFRRHp9SnHYRFRc53F</tool_id>
|
||||||
<tool_id>tool_0</tool_id>
|
|
||||||
</invoke>
|
</invoke>
|
||||||
</function_calls>
|
</function_calls>
|
||||||
TEXT
|
TEXT
|
||||||
|
@ -285,71 +191,71 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
|
|
||||||
log = AiApiAuditLog.order(:id).last
|
log = AiApiAuditLog.order(:id).last
|
||||||
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
|
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
|
||||||
expect(log.request_tokens).to eq(25)
|
|
||||||
expect(log.response_tokens).to eq(15)
|
|
||||||
expect(log.feature_name).to eq("testing")
|
expect(log.feature_name).to eq("testing")
|
||||||
|
expect(log.response_tokens).to eq(15)
|
||||||
|
expect(log.request_tokens).to eq(25)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "can return multiple function calls" do
|
it "supports non streaming tool calls" do
|
||||||
functions = <<~FUNCTIONS
|
tool = {
|
||||||
<function_calls>
|
name: "calculate",
|
||||||
<invoke>
|
description: "calculate something",
|
||||||
<tool_name>echo</tool_name>
|
parameters: [
|
||||||
<parameters>
|
|
||||||
<text>something</text>
|
|
||||||
</parameters>
|
|
||||||
</invoke>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>echo</tool_name>
|
|
||||||
<parameters>
|
|
||||||
<text>something else</text>
|
|
||||||
</parameters>
|
|
||||||
</invoke>
|
|
||||||
FUNCTIONS
|
|
||||||
|
|
||||||
body = <<~STRING
|
|
||||||
{
|
{
|
||||||
"content": [
|
name: "expression",
|
||||||
{
|
type: "string",
|
||||||
"text": "Hello!\n\n#{functions}\njunk",
|
description: "expression to calculate",
|
||||||
"type": "text"
|
required: true,
|
||||||
}
|
},
|
||||||
],
|
],
|
||||||
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
|
|
||||||
"model": "claude-3-opus-20240229",
|
|
||||||
"role": "assistant",
|
|
||||||
"stop_reason": "end_turn",
|
|
||||||
"stop_sequence": null,
|
|
||||||
"type": "message",
|
|
||||||
"usage": {
|
|
||||||
"input_tokens": 10,
|
|
||||||
"output_tokens": 25
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
STRING
|
|
||||||
|
|
||||||
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
|
prompt =
|
||||||
|
DiscourseAi::Completions::Prompt.new(
|
||||||
|
"You a calculator",
|
||||||
|
messages: [{ type: :user, id: "user1", content: "calculate 2758975 + 21.11" }],
|
||||||
|
tools: [tool],
|
||||||
|
)
|
||||||
|
|
||||||
result = llm.generate(prompt_with_echo_tool, user: Discourse.system_user)
|
proxy = DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-haiku")
|
||||||
|
|
||||||
expected = (<<~EXPECTED).strip
|
body = {
|
||||||
|
id: "msg_01RdJkxCbsEj9VFyFYAkfy2S",
|
||||||
|
type: "message",
|
||||||
|
role: "assistant",
|
||||||
|
model: "claude-3-haiku-20240307",
|
||||||
|
content: [
|
||||||
|
{ type: "text", text: "Here is the calculation:" },
|
||||||
|
{
|
||||||
|
type: "tool_use",
|
||||||
|
id: "toolu_012kBdhG4eHaV68W56p4N94h",
|
||||||
|
name: "calculate",
|
||||||
|
input: {
|
||||||
|
expression: "2758975 + 21.11",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stop_reason: "tool_use",
|
||||||
|
stop_sequence: nil,
|
||||||
|
usage: {
|
||||||
|
input_tokens: 345,
|
||||||
|
output_tokens: 65,
|
||||||
|
},
|
||||||
|
}.to_json
|
||||||
|
|
||||||
|
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(body: body)
|
||||||
|
|
||||||
|
result = proxy.generate(prompt, user: Discourse.system_user)
|
||||||
|
|
||||||
|
expected = <<~TEXT.strip
|
||||||
<function_calls>
|
<function_calls>
|
||||||
<invoke>
|
<invoke>
|
||||||
<tool_name>echo</tool_name>
|
<tool_name>calculate</tool_name>
|
||||||
<parameters>
|
<parameters><expression>2758975 + 21.11</expression></parameters>
|
||||||
<text>something</text>
|
<tool_id>toolu_012kBdhG4eHaV68W56p4N94h</tool_id>
|
||||||
</parameters>
|
|
||||||
<tool_id>tool_0</tool_id>
|
|
||||||
</invoke>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>echo</tool_name>
|
|
||||||
<parameters>
|
|
||||||
<text>something else</text>
|
|
||||||
</parameters>
|
|
||||||
<tool_id>tool_1</tool_id>
|
|
||||||
</invoke>
|
</invoke>
|
||||||
</function_calls>
|
</function_calls>
|
||||||
EXPECTED
|
TEXT
|
||||||
|
|
||||||
expect(result.strip).to eq(expected)
|
expect(result.strip).to eq(expected)
|
||||||
end
|
end
|
||||||
|
|
|
@ -24,6 +24,175 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
SiteSetting.ai_bedrock_region = "us-east-1"
|
SiteSetting.ai_bedrock_region = "us-east-1"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "function calling" do
|
||||||
|
it "supports streaming function calls" do
|
||||||
|
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
||||||
|
|
||||||
|
request = nil
|
||||||
|
|
||||||
|
messages =
|
||||||
|
[
|
||||||
|
{
|
||||||
|
type: "message_start",
|
||||||
|
message: {
|
||||||
|
id: "msg_bdrk_01WYxeNMk6EKn9s98r6XXrAB",
|
||||||
|
type: "message",
|
||||||
|
role: "assistant",
|
||||||
|
model: "claude-3-haiku-20240307",
|
||||||
|
stop_sequence: nil,
|
||||||
|
usage: {
|
||||||
|
input_tokens: 840,
|
||||||
|
output_tokens: 1,
|
||||||
|
},
|
||||||
|
content: [],
|
||||||
|
stop_reason: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: "content_block_start",
|
||||||
|
index: 0,
|
||||||
|
content_block: {
|
||||||
|
type: "tool_use",
|
||||||
|
id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
|
||||||
|
name: "google",
|
||||||
|
input: {
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: "content_block_delta",
|
||||||
|
index: 0,
|
||||||
|
delta: {
|
||||||
|
type: "input_json_delta",
|
||||||
|
partial_json: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: "content_block_delta",
|
||||||
|
index: 0,
|
||||||
|
delta: {
|
||||||
|
type: "input_json_delta",
|
||||||
|
partial_json: "{\"query\": \"s",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: "content_block_delta",
|
||||||
|
index: 0,
|
||||||
|
delta: {
|
||||||
|
type: "input_json_delta",
|
||||||
|
partial_json: "ydney weat",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: "content_block_delta",
|
||||||
|
index: 0,
|
||||||
|
delta: {
|
||||||
|
type: "input_json_delta",
|
||||||
|
partial_json: "her today\"}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{ type: "content_block_stop", index: 0 },
|
||||||
|
{
|
||||||
|
type: "message_delta",
|
||||||
|
delta: {
|
||||||
|
stop_reason: "tool_use",
|
||||||
|
stop_sequence: nil,
|
||||||
|
},
|
||||||
|
usage: {
|
||||||
|
output_tokens: 53,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: "message_stop",
|
||||||
|
"amazon-bedrock-invocationMetrics": {
|
||||||
|
inputTokenCount: 846,
|
||||||
|
outputTokenCount: 39,
|
||||||
|
invocationLatency: 880,
|
||||||
|
firstByteLatency: 402,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
].map do |message|
|
||||||
|
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
|
||||||
|
io = StringIO.new(wrapped)
|
||||||
|
aws_message = Aws::EventStream::Message.new(payload: io)
|
||||||
|
Aws::EventStream::Encoder.new.encode(aws_message)
|
||||||
|
end
|
||||||
|
|
||||||
|
messages = messages.join("").split
|
||||||
|
|
||||||
|
bedrock_mock.with_chunk_array_support do
|
||||||
|
stub_request(
|
||||||
|
:post,
|
||||||
|
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
|
||||||
|
)
|
||||||
|
.with do |inner_request|
|
||||||
|
request = inner_request
|
||||||
|
true
|
||||||
|
end
|
||||||
|
.to_return(status: 200, body: messages)
|
||||||
|
|
||||||
|
prompt =
|
||||||
|
DiscourseAi::Completions::Prompt.new(
|
||||||
|
messages: [{ type: :user, content: "what is the weather in sydney" }],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool = {
|
||||||
|
name: "google",
|
||||||
|
description: "Will search using Google",
|
||||||
|
parameters: [
|
||||||
|
{ name: "query", description: "The search query", type: "string", required: true },
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt.tools = [tool]
|
||||||
|
response = +""
|
||||||
|
proxy.generate(prompt, user: user) { |partial| response << partial }
|
||||||
|
|
||||||
|
expect(request.headers["Authorization"]).to be_present
|
||||||
|
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
||||||
|
|
||||||
|
expected_response = (<<~RESPONSE).strip
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>google</tool_name>
|
||||||
|
<parameters><query>sydney weather today</query></parameters>
|
||||||
|
<tool_id>toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7</tool_id>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
RESPONSE
|
||||||
|
|
||||||
|
expect(response.strip).to eq(expected_response)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"max_tokens" => 3000,
|
||||||
|
"anthropic_version" => "bedrock-2023-05-31",
|
||||||
|
"messages" => [{ "role" => "user", "content" => "what is the weather in sydney" }],
|
||||||
|
"tools" => [
|
||||||
|
{
|
||||||
|
"name" => "google",
|
||||||
|
"description" => "Will search using Google",
|
||||||
|
"input_schema" => {
|
||||||
|
"type" => "object",
|
||||||
|
"properties" => {
|
||||||
|
"query" => {
|
||||||
|
"type" => "string",
|
||||||
|
"description" => "The search query",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required" => ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
expect(JSON.parse(request.body)).to eq(expected)
|
||||||
|
|
||||||
|
log = AiApiAuditLog.order(:id).last
|
||||||
|
expect(log.request_tokens).to eq(846)
|
||||||
|
expect(log.response_tokens).to eq(39)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
describe "Claude 3 Sonnet support" do
|
describe "Claude 3 Sonnet support" do
|
||||||
it "supports the sonnet model" do
|
it "supports the sonnet model" do
|
||||||
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
||||||
|
|
Loading…
Reference in New Issue