FIX: more robust function call support (#581)
For quite a few weeks now, some times, when running function calls on Anthropic we would get a "stray" - "calls" line. This has been enormously frustrating! I have been unable to find the source of the bug so instead decoupled the implementation and create a very clear "function call normalizer" This new class is extensively tested and guards against the type of edge cases we saw pre-normalizer. This also simplifies the implementation of "endpoint" which no longer needs to handle all this complex logic.
This commit is contained in:
parent
50be66ee63
commit
a223d18f1a
|
@ -63,7 +63,11 @@ module DiscourseAi
|
|||
@tokenizer = tokenizer
|
||||
end
|
||||
|
||||
def perform_completion!(dialect, user, model_params = {})
|
||||
def native_tool_support?
|
||||
false
|
||||
end
|
||||
|
||||
def perform_completion!(dialect, user, model_params = {}, &blk)
|
||||
allow_tools = dialect.prompt.has_tools?
|
||||
model_params = normalize_model_params(model_params)
|
||||
|
||||
|
@ -111,14 +115,21 @@ module DiscourseAi
|
|||
response_data = extract_completion_from(response_raw)
|
||||
partials_raw = response_data.to_s
|
||||
|
||||
if allow_tools && has_tool?(response_data)
|
||||
function_buffer = build_buffer # Nokogiri document
|
||||
function_buffer = add_to_function_buffer(function_buffer, payload: response_data)
|
||||
if native_tool_support?
|
||||
if allow_tools && has_tool?(response_data)
|
||||
function_buffer = build_buffer # Nokogiri document
|
||||
function_buffer =
|
||||
add_to_function_buffer(function_buffer, payload: response_data)
|
||||
FunctionCallNormalizer.normalize_function_ids!(function_buffer)
|
||||
|
||||
normalize_function_ids!(function_buffer)
|
||||
|
||||
response_data = +function_buffer.at("function_calls").to_s
|
||||
response_data << "\n"
|
||||
response_data = +function_buffer.at("function_calls").to_s
|
||||
response_data << "\n"
|
||||
end
|
||||
else
|
||||
if allow_tools
|
||||
response_data, function_calls = FunctionCallNormalizer.normalize(response_data)
|
||||
response_data = function_calls if function_calls.present?
|
||||
end
|
||||
end
|
||||
|
||||
return response_data
|
||||
|
@ -128,7 +139,14 @@ module DiscourseAi
|
|||
|
||||
begin
|
||||
cancelled = false
|
||||
cancel = lambda { cancelled = true }
|
||||
cancel = -> { cancelled = true }
|
||||
|
||||
wrapped_blk = ->(partial, inner_cancel) do
|
||||
response_data << partial
|
||||
blk.call(partial, inner_cancel)
|
||||
end
|
||||
|
||||
normalizer = FunctionCallNormalizer.new(wrapped_blk, cancel)
|
||||
|
||||
leftover = ""
|
||||
function_buffer = build_buffer # Nokogiri document
|
||||
|
@ -159,7 +177,6 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
json_error = false
|
||||
buffered_partials = []
|
||||
|
||||
raw_partials.each do |raw_partial|
|
||||
json_error = false
|
||||
|
@ -175,31 +192,24 @@ module DiscourseAi
|
|||
next if response_data.empty? && partial.empty?
|
||||
partials_raw << partial.to_s
|
||||
|
||||
# Stop streaming the response as soon as you find a tool.
|
||||
# We'll buffer and yield it later.
|
||||
has_tool = true if allow_tools && has_tool?(partials_raw)
|
||||
if native_tool_support?
|
||||
# Stop streaming the response as soon as you find a tool.
|
||||
# We'll buffer and yield it later.
|
||||
has_tool = true if allow_tools && has_tool?(partials_raw)
|
||||
|
||||
if has_tool
|
||||
if buffered_partials.present?
|
||||
joined = buffered_partials.join
|
||||
joined = joined.gsub(/<.+/, "")
|
||||
yield joined, cancel if joined.present?
|
||||
buffered_partials = []
|
||||
end
|
||||
function_buffer = add_to_function_buffer(function_buffer, partial: partial)
|
||||
else
|
||||
if maybe_has_tool?(partials_raw)
|
||||
buffered_partials << partial
|
||||
if has_tool
|
||||
function_buffer =
|
||||
add_to_function_buffer(function_buffer, partial: partial)
|
||||
else
|
||||
if buffered_partials.present?
|
||||
buffered_partials.each do |buffered_partial|
|
||||
response_data << buffered_partial
|
||||
yield buffered_partial, cancel
|
||||
end
|
||||
buffered_partials = []
|
||||
end
|
||||
response_data << partial
|
||||
yield partial, cancel if partial
|
||||
blk.call(partial, cancel) if partial
|
||||
end
|
||||
else
|
||||
if allow_tools
|
||||
normalizer << partial
|
||||
else
|
||||
response_data << partial
|
||||
blk.call(partial, cancel) if partial
|
||||
end
|
||||
end
|
||||
rescue JSON::ParserError
|
||||
|
@ -220,20 +230,25 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
# Once we have the full response, try to return the tool as a XML doc.
|
||||
if has_tool
|
||||
if has_tool && native_tool_support?
|
||||
function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw)
|
||||
|
||||
if function_buffer.at("tool_name").text.present?
|
||||
normalize_function_ids!(function_buffer)
|
||||
FunctionCallNormalizer.normalize_function_ids!(function_buffer)
|
||||
|
||||
invocation = +function_buffer.at("function_calls").to_s
|
||||
invocation << "\n"
|
||||
|
||||
response_data << invocation
|
||||
yield invocation, cancel
|
||||
blk.call(invocation, cancel)
|
||||
end
|
||||
end
|
||||
|
||||
if !native_tool_support? && function_calls = normalizer.function_calls
|
||||
response_data << function_calls
|
||||
blk.call(function_calls, cancel)
|
||||
end
|
||||
|
||||
return response_data
|
||||
ensure
|
||||
if log
|
||||
|
@ -250,21 +265,6 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def normalize_function_ids!(function_buffer)
|
||||
function_buffer
|
||||
.css("invoke")
|
||||
.each_with_index do |invoke, index|
|
||||
if invoke.at("tool_id")
|
||||
invoke.at("tool_id").content = "tool_#{index}" if invoke
|
||||
.at("tool_id")
|
||||
.content
|
||||
.blank?
|
||||
else
|
||||
invoke.add_child("<tool_id>tool_#{index}</tool_id>\n") if !invoke.at("tool_id")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def final_log_update(log)
|
||||
# for people that need to override
|
||||
end
|
||||
|
@ -341,19 +341,6 @@ module DiscourseAi
|
|||
response.include?("<function_calls>")
|
||||
end
|
||||
|
||||
def maybe_has_tool?(response)
|
||||
# 16 is the length of function calls
|
||||
substring = response[-16..-1] || response
|
||||
split = substring.split("<")
|
||||
|
||||
if split.length > 1
|
||||
match = "<" + split.last
|
||||
"<function_calls>".start_with?(match)
|
||||
else
|
||||
false
|
||||
end
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||
if payload&.include?("</invoke>")
|
||||
matches = payload.match(%r{<function_calls>.*</invoke>}m)
|
||||
|
|
|
@ -105,9 +105,8 @@ module DiscourseAi
|
|||
@has_function_call
|
||||
end
|
||||
|
||||
def maybe_has_tool?(_partial_raw)
|
||||
# we always get a full partial
|
||||
false
|
||||
def native_tool_support?
|
||||
true
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
|
||||
|
|
|
@ -162,9 +162,8 @@ module DiscourseAi
|
|||
@has_function_call
|
||||
end
|
||||
|
||||
def maybe_has_tool?(_partial_raw)
|
||||
# we always get a full partial
|
||||
false
|
||||
def native_tool_support?
|
||||
true
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class DiscourseAi::Completions::FunctionCallNormalizer
|
||||
attr_reader :done
|
||||
|
||||
# blk is the block to call with filtered data
|
||||
def initialize(blk, cancel)
|
||||
@blk = blk
|
||||
@cancel = cancel
|
||||
@done = false
|
||||
|
||||
@in_tool = false
|
||||
|
||||
@buffer = +""
|
||||
@function_buffer = +""
|
||||
end
|
||||
|
||||
def self.normalize(data)
|
||||
text = +""
|
||||
cancel = -> {}
|
||||
blk = ->(partial, _) { text << partial }
|
||||
|
||||
normalizer = self.new(blk, cancel)
|
||||
normalizer << data
|
||||
|
||||
[text, normalizer.function_calls]
|
||||
end
|
||||
|
||||
def function_calls
|
||||
return nil if @function_buffer.blank?
|
||||
|
||||
xml = Nokogiri::HTML5.fragment(@function_buffer)
|
||||
self.class.normalize_function_ids!(xml)
|
||||
last_invoke = xml.at("invoke:last")
|
||||
if last_invoke
|
||||
last_invoke.next_sibling.remove while last_invoke.next_sibling
|
||||
xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling
|
||||
end
|
||||
xml.at("function_calls").to_s.dup.force_encoding("UTF-8")
|
||||
end
|
||||
|
||||
def <<(text)
|
||||
@buffer << text
|
||||
|
||||
if !@in_tool
|
||||
# double check if we are clearly in a tool
|
||||
search_length = text.length + 20
|
||||
search_string = @buffer[-search_length..-1] || @buffer
|
||||
|
||||
index = search_string.rindex("<function_calls>")
|
||||
@in_tool = !!index
|
||||
if @in_tool
|
||||
@function_buffer = @buffer[index..-1]
|
||||
text_index = text.rindex("<function_calls>")
|
||||
@blk.call(text[0..text_index - 1].strip, @cancel) if text_index && text_index > 0
|
||||
end
|
||||
else
|
||||
@function_buffer << text
|
||||
end
|
||||
|
||||
if !@in_tool
|
||||
if maybe_has_tool?(@buffer)
|
||||
split_index = text.rindex("<").to_i - 1
|
||||
if split_index >= 0
|
||||
@function_buffer = text[split_index + 1..-1] || ""
|
||||
text = text[0..split_index] || ""
|
||||
else
|
||||
@function_buffer << text
|
||||
text = ""
|
||||
end
|
||||
else
|
||||
if @function_buffer.length > 0
|
||||
@blk.call(@function_buffer, @cancel)
|
||||
@function_buffer = +""
|
||||
end
|
||||
end
|
||||
|
||||
@blk.call(text, @cancel) if text.length > 0
|
||||
else
|
||||
if text.include?("</function_calls>")
|
||||
@done = true
|
||||
@cancel.call
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def self.normalize_function_ids!(function_buffer)
|
||||
function_buffer
|
||||
.css("invoke")
|
||||
.each_with_index do |invoke, index|
|
||||
if invoke.at("tool_id")
|
||||
invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank?
|
||||
else
|
||||
invoke.add_child("<tool_id>tool_#{index}</tool_id>\n") if !invoke.at("tool_id")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def maybe_has_tool?(text)
|
||||
# 16 is the length of function calls
|
||||
substring = text[-16..-1] || text
|
||||
split = substring.split("<")
|
||||
|
||||
if split.length > 1
|
||||
match = "<" + split.last
|
||||
"<function_calls>".start_with?(match)
|
||||
else
|
||||
substring.ends_with?("<")
|
||||
end
|
||||
end
|
||||
end
|
|
@ -40,7 +40,7 @@ class EndpointMock
|
|||
end
|
||||
|
||||
def tool_deltas
|
||||
["Let me use a tool for that<function", <<~REPLY.strip, <<~REPLY.strip, <<~REPLY.strip]
|
||||
["<function", <<~REPLY.strip, <<~REPLY.strip, <<~REPLY.strip]
|
||||
_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
|
@ -185,7 +185,7 @@ class EndpointsCompliance
|
|||
mock.stub_tool_call(a_dialect.translate)
|
||||
|
||||
completion_response = endpoint.perform_completion!(a_dialect, user)
|
||||
expect(completion_response).to eq(mock.invocation_response)
|
||||
expect(completion_response.strip).to eq(mock.invocation_response.strip)
|
||||
end
|
||||
|
||||
def streaming_mode_simple_prompt(mock)
|
||||
|
@ -223,7 +223,7 @@ class EndpointsCompliance
|
|||
cancel.call if buffered_partial.include?("<function_calls>")
|
||||
end
|
||||
|
||||
expect(buffered_partial).to eq(mock.invocation_response)
|
||||
expect(buffered_partial.strip).to eq(mock.invocation_response.strip)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::FunctionCallNormalizer do
|
||||
let(:buffer) { +"" }
|
||||
|
||||
let(:normalizer) do
|
||||
blk = ->(data, cancel) { buffer << data }
|
||||
cancel = -> { @done = true }
|
||||
DiscourseAi::Completions::FunctionCallNormalizer.new(blk, cancel)
|
||||
end
|
||||
|
||||
def pass_through!(data)
|
||||
normalizer << data
|
||||
expect(buffer[-data.length..-1]).to eq(data)
|
||||
end
|
||||
|
||||
it "is usable in non streaming mode" do
|
||||
xml = (<<~XML).strip
|
||||
hello
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
</invoke>
|
||||
XML
|
||||
|
||||
text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml)
|
||||
|
||||
expect(text).to eq("hello")
|
||||
|
||||
expected_function_calls = (<<~XML).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
|
||||
expect(function_calls).to eq(expected_function_calls)
|
||||
end
|
||||
|
||||
it "strips junk from end of function calls" do
|
||||
xml = (<<~XML).strip
|
||||
hello
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
</invoke>
|
||||
junk
|
||||
XML
|
||||
|
||||
_text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml)
|
||||
|
||||
expected_function_calls = (<<~XML).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
|
||||
expect(function_calls).to eq(expected_function_calls)
|
||||
end
|
||||
|
||||
it "returns nil for function calls if there are none" do
|
||||
input = "hello world\n"
|
||||
text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(input)
|
||||
|
||||
expect(text).to eq(input)
|
||||
expect(function_calls).to eq(nil)
|
||||
end
|
||||
|
||||
it "passes through data if there are no function calls detected" do
|
||||
pass_through!("hello")
|
||||
pass_through!("<tool_name>hello</tool_name>")
|
||||
pass_through!("<parameters><hello>world</hello></parameters>")
|
||||
pass_through!("<function_call>")
|
||||
end
|
||||
|
||||
it "properly handles non English tools" do
|
||||
normalizer << "hello<function"
|
||||
expect(buffer).to eq("hello")
|
||||
|
||||
normalizer << "_calls>\n"
|
||||
|
||||
normalizer << (<<~XML).strip
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
<parameters>
|
||||
<hello>世界</hello>
|
||||
</parameters>
|
||||
</invoke>
|
||||
XML
|
||||
|
||||
expected = (<<~XML).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
<parameters>
|
||||
<hello>世界</hello>
|
||||
</parameters>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
|
||||
function_calls = normalizer.function_calls
|
||||
expect(function_calls).to eq(expected)
|
||||
end
|
||||
|
||||
it "works correctly even if you only give it 1 letter at a time" do
|
||||
xml = (<<~XML).strip
|
||||
abc
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
<parameters>
|
||||
<hello>world</hello>
|
||||
</parameters>
|
||||
<tool_id>abc</tool_id>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>hello2</tool_name>
|
||||
<parameters>
|
||||
<hello>world</hello>
|
||||
</parameters>
|
||||
<tool_id>aba</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
|
||||
xml.each_char { |char| normalizer << char }
|
||||
|
||||
expect(buffer + normalizer.function_calls).to eq(xml)
|
||||
end
|
||||
|
||||
it "supports multiple invokes" do
|
||||
xml = (<<~XML).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
<parameters>
|
||||
<hello>world</hello>
|
||||
</parameters>
|
||||
<tool_id>abc</tool_id>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>hello2</tool_name>
|
||||
<parameters>
|
||||
<hello>world</hello>
|
||||
</parameters>
|
||||
<tool_id>aba</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
|
||||
normalizer << xml
|
||||
|
||||
expect(normalizer.function_calls).to eq(xml)
|
||||
end
|
||||
|
||||
it "can will cancel if it encounteres </function_calls>" do
|
||||
normalizer << "<function_calls>"
|
||||
expect(normalizer.done).to eq(false)
|
||||
normalizer << "</function_calls>"
|
||||
expect(normalizer.done).to eq(true)
|
||||
expect(@done).to eq(true)
|
||||
|
||||
expect(normalizer.function_calls).to eq("<function_calls></function_calls>")
|
||||
end
|
||||
|
||||
it "pauses on function call and starts buffering" do
|
||||
normalizer << "hello<function_call"
|
||||
expect(buffer).to eq("hello")
|
||||
expect(normalizer.done).to eq(false)
|
||||
|
||||
normalizer << ">"
|
||||
expect(buffer).to eq("hello<function_call>")
|
||||
expect(normalizer.done).to eq(false)
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue