FIX: more robust function call support (#581)

For quite a few weeks now, some times, when running function calls
on Anthropic we would get a "stray" - "calls" line.

This has been enormously frustrating!

I have been unable to find the source of the bug so instead decoupled
the implementation and create a very clear "function call normalizer"

This new class is extensively tested and guards against the type of
edge cases we saw pre-normalizer.

This also simplifies the implementation of "endpoint" which no longer
needs to handle all this complex logic.
This commit is contained in:
Sam 2024-04-19 06:54:54 +10:00 committed by GitHub
parent 50be66ee63
commit a223d18f1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 352 additions and 72 deletions

View File

@ -63,7 +63,11 @@ module DiscourseAi
@tokenizer = tokenizer
end
def perform_completion!(dialect, user, model_params = {})
def native_tool_support?
false
end
def perform_completion!(dialect, user, model_params = {}, &blk)
allow_tools = dialect.prompt.has_tools?
model_params = normalize_model_params(model_params)
@ -111,14 +115,21 @@ module DiscourseAi
response_data = extract_completion_from(response_raw)
partials_raw = response_data.to_s
if allow_tools && has_tool?(response_data)
function_buffer = build_buffer # Nokogiri document
function_buffer = add_to_function_buffer(function_buffer, payload: response_data)
if native_tool_support?
if allow_tools && has_tool?(response_data)
function_buffer = build_buffer # Nokogiri document
function_buffer =
add_to_function_buffer(function_buffer, payload: response_data)
FunctionCallNormalizer.normalize_function_ids!(function_buffer)
normalize_function_ids!(function_buffer)
response_data = +function_buffer.at("function_calls").to_s
response_data << "\n"
response_data = +function_buffer.at("function_calls").to_s
response_data << "\n"
end
else
if allow_tools
response_data, function_calls = FunctionCallNormalizer.normalize(response_data)
response_data = function_calls if function_calls.present?
end
end
return response_data
@ -128,7 +139,14 @@ module DiscourseAi
begin
cancelled = false
cancel = lambda { cancelled = true }
cancel = -> { cancelled = true }
wrapped_blk = ->(partial, inner_cancel) do
response_data << partial
blk.call(partial, inner_cancel)
end
normalizer = FunctionCallNormalizer.new(wrapped_blk, cancel)
leftover = ""
function_buffer = build_buffer # Nokogiri document
@ -159,7 +177,6 @@ module DiscourseAi
end
json_error = false
buffered_partials = []
raw_partials.each do |raw_partial|
json_error = false
@ -175,31 +192,24 @@ module DiscourseAi
next if response_data.empty? && partial.empty?
partials_raw << partial.to_s
# Stop streaming the response as soon as you find a tool.
# We'll buffer and yield it later.
has_tool = true if allow_tools && has_tool?(partials_raw)
if native_tool_support?
# Stop streaming the response as soon as you find a tool.
# We'll buffer and yield it later.
has_tool = true if allow_tools && has_tool?(partials_raw)
if has_tool
if buffered_partials.present?
joined = buffered_partials.join
joined = joined.gsub(/<.+/, "")
yield joined, cancel if joined.present?
buffered_partials = []
end
function_buffer = add_to_function_buffer(function_buffer, partial: partial)
else
if maybe_has_tool?(partials_raw)
buffered_partials << partial
if has_tool
function_buffer =
add_to_function_buffer(function_buffer, partial: partial)
else
if buffered_partials.present?
buffered_partials.each do |buffered_partial|
response_data << buffered_partial
yield buffered_partial, cancel
end
buffered_partials = []
end
response_data << partial
yield partial, cancel if partial
blk.call(partial, cancel) if partial
end
else
if allow_tools
normalizer << partial
else
response_data << partial
blk.call(partial, cancel) if partial
end
end
rescue JSON::ParserError
@ -220,20 +230,25 @@ module DiscourseAi
end
# Once we have the full response, try to return the tool as a XML doc.
if has_tool
if has_tool && native_tool_support?
function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw)
if function_buffer.at("tool_name").text.present?
normalize_function_ids!(function_buffer)
FunctionCallNormalizer.normalize_function_ids!(function_buffer)
invocation = +function_buffer.at("function_calls").to_s
invocation << "\n"
response_data << invocation
yield invocation, cancel
blk.call(invocation, cancel)
end
end
if !native_tool_support? && function_calls = normalizer.function_calls
response_data << function_calls
blk.call(function_calls, cancel)
end
return response_data
ensure
if log
@ -250,21 +265,6 @@ module DiscourseAi
end
end
def normalize_function_ids!(function_buffer)
function_buffer
.css("invoke")
.each_with_index do |invoke, index|
if invoke.at("tool_id")
invoke.at("tool_id").content = "tool_#{index}" if invoke
.at("tool_id")
.content
.blank?
else
invoke.add_child("<tool_id>tool_#{index}</tool_id>\n") if !invoke.at("tool_id")
end
end
end
def final_log_update(log)
# for people that need to override
end
@ -341,19 +341,6 @@ module DiscourseAi
response.include?("<function_calls>")
end
def maybe_has_tool?(response)
# 16 is the length of function calls
substring = response[-16..-1] || response
split = substring.split("<")
if split.length > 1
match = "<" + split.last
"<function_calls>".start_with?(match)
else
false
end
end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
if payload&.include?("</invoke>")
matches = payload.match(%r{<function_calls>.*</invoke>}m)

View File

@ -105,9 +105,8 @@ module DiscourseAi
@has_function_call
end
def maybe_has_tool?(_partial_raw)
# we always get a full partial
false
def native_tool_support?
true
end
def add_to_function_buffer(function_buffer, payload: nil, partial: nil)

View File

@ -162,9 +162,8 @@ module DiscourseAi
@has_function_call
end
def maybe_has_tool?(_partial_raw)
# we always get a full partial
false
def native_tool_support?
true
end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)

View File

@ -0,0 +1,113 @@
# frozen_string_literal: true
class DiscourseAi::Completions::FunctionCallNormalizer
attr_reader :done
# blk is the block to call with filtered data
def initialize(blk, cancel)
@blk = blk
@cancel = cancel
@done = false
@in_tool = false
@buffer = +""
@function_buffer = +""
end
def self.normalize(data)
text = +""
cancel = -> {}
blk = ->(partial, _) { text << partial }
normalizer = self.new(blk, cancel)
normalizer << data
[text, normalizer.function_calls]
end
def function_calls
return nil if @function_buffer.blank?
xml = Nokogiri::HTML5.fragment(@function_buffer)
self.class.normalize_function_ids!(xml)
last_invoke = xml.at("invoke:last")
if last_invoke
last_invoke.next_sibling.remove while last_invoke.next_sibling
xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling
end
xml.at("function_calls").to_s.dup.force_encoding("UTF-8")
end
def <<(text)
@buffer << text
if !@in_tool
# double check if we are clearly in a tool
search_length = text.length + 20
search_string = @buffer[-search_length..-1] || @buffer
index = search_string.rindex("<function_calls>")
@in_tool = !!index
if @in_tool
@function_buffer = @buffer[index..-1]
text_index = text.rindex("<function_calls>")
@blk.call(text[0..text_index - 1].strip, @cancel) if text_index && text_index > 0
end
else
@function_buffer << text
end
if !@in_tool
if maybe_has_tool?(@buffer)
split_index = text.rindex("<").to_i - 1
if split_index >= 0
@function_buffer = text[split_index + 1..-1] || ""
text = text[0..split_index] || ""
else
@function_buffer << text
text = ""
end
else
if @function_buffer.length > 0
@blk.call(@function_buffer, @cancel)
@function_buffer = +""
end
end
@blk.call(text, @cancel) if text.length > 0
else
if text.include?("</function_calls>")
@done = true
@cancel.call
end
end
end
def self.normalize_function_ids!(function_buffer)
function_buffer
.css("invoke")
.each_with_index do |invoke, index|
if invoke.at("tool_id")
invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank?
else
invoke.add_child("<tool_id>tool_#{index}</tool_id>\n") if !invoke.at("tool_id")
end
end
end
private
def maybe_has_tool?(text)
# 16 is the length of function calls
substring = text[-16..-1] || text
split = substring.split("<")
if split.length > 1
match = "<" + split.last
"<function_calls>".start_with?(match)
else
substring.ends_with?("<")
end
end
end

View File

@ -40,7 +40,7 @@ class EndpointMock
end
def tool_deltas
["Let me use a tool for that<function", <<~REPLY.strip, <<~REPLY.strip, <<~REPLY.strip]
["<function", <<~REPLY.strip, <<~REPLY.strip, <<~REPLY.strip]
_calls>
<invoke>
<tool_name>get_weather</tool_name>
@ -185,7 +185,7 @@ class EndpointsCompliance
mock.stub_tool_call(a_dialect.translate)
completion_response = endpoint.perform_completion!(a_dialect, user)
expect(completion_response).to eq(mock.invocation_response)
expect(completion_response.strip).to eq(mock.invocation_response.strip)
end
def streaming_mode_simple_prompt(mock)
@ -223,7 +223,7 @@ class EndpointsCompliance
cancel.call if buffered_partial.include?("<function_calls>")
end
expect(buffered_partial).to eq(mock.invocation_response)
expect(buffered_partial.strip).to eq(mock.invocation_response.strip)
end
end

View File

@ -0,0 +1,182 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::FunctionCallNormalizer do
let(:buffer) { +"" }
let(:normalizer) do
blk = ->(data, cancel) { buffer << data }
cancel = -> { @done = true }
DiscourseAi::Completions::FunctionCallNormalizer.new(blk, cancel)
end
def pass_through!(data)
normalizer << data
expect(buffer[-data.length..-1]).to eq(data)
end
it "is usable in non streaming mode" do
xml = (<<~XML).strip
hello
<function_calls>
<invoke>
<tool_name>hello</tool_name>
</invoke>
XML
text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml)
expect(text).to eq("hello")
expected_function_calls = (<<~XML).strip
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
XML
expect(function_calls).to eq(expected_function_calls)
end
it "strips junk from end of function calls" do
xml = (<<~XML).strip
hello
<function_calls>
<invoke>
<tool_name>hello</tool_name>
</invoke>
junk
XML
_text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml)
expected_function_calls = (<<~XML).strip
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
XML
expect(function_calls).to eq(expected_function_calls)
end
it "returns nil for function calls if there are none" do
input = "hello world\n"
text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(input)
expect(text).to eq(input)
expect(function_calls).to eq(nil)
end
it "passes through data if there are no function calls detected" do
pass_through!("hello")
pass_through!("<tool_name>hello</tool_name>")
pass_through!("<parameters><hello>world</hello></parameters>")
pass_through!("<function_call>")
end
it "properly handles non English tools" do
normalizer << "hello<function"
expect(buffer).to eq("hello")
normalizer << "_calls>\n"
normalizer << (<<~XML).strip
<invoke>
<tool_name>hello</tool_name>
<parameters>
<hello></hello>
</parameters>
</invoke>
XML
expected = (<<~XML).strip
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<parameters>
<hello></hello>
</parameters>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
XML
function_calls = normalizer.function_calls
expect(function_calls).to eq(expected)
end
it "works correctly even if you only give it 1 letter at a time" do
xml = (<<~XML).strip
abc
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<parameters>
<hello>world</hello>
</parameters>
<tool_id>abc</tool_id>
</invoke>
<invoke>
<tool_name>hello2</tool_name>
<parameters>
<hello>world</hello>
</parameters>
<tool_id>aba</tool_id>
</invoke>
</function_calls>
XML
xml.each_char { |char| normalizer << char }
expect(buffer + normalizer.function_calls).to eq(xml)
end
it "supports multiple invokes" do
xml = (<<~XML).strip
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<parameters>
<hello>world</hello>
</parameters>
<tool_id>abc</tool_id>
</invoke>
<invoke>
<tool_name>hello2</tool_name>
<parameters>
<hello>world</hello>
</parameters>
<tool_id>aba</tool_id>
</invoke>
</function_calls>
XML
normalizer << xml
expect(normalizer.function_calls).to eq(xml)
end
it "can will cancel if it encounteres </function_calls>" do
normalizer << "<function_calls>"
expect(normalizer.done).to eq(false)
normalizer << "</function_calls>"
expect(normalizer.done).to eq(true)
expect(@done).to eq(true)
expect(normalizer.function_calls).to eq("<function_calls></function_calls>")
end
it "pauses on function call and starts buffering" do
normalizer << "hello<function_call"
expect(buffer).to eq("hello")
expect(normalizer.done).to eq(false)
normalizer << ">"
expect(buffer).to eq("hello<function_call>")
expect(normalizer.done).to eq(false)
end
end