diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb
index 268ea510..0766a9a4 100644
--- a/lib/completions/endpoints/base.rb
+++ b/lib/completions/endpoints/base.rb
@@ -63,7 +63,11 @@ module DiscourseAi
@tokenizer = tokenizer
end
- def perform_completion!(dialect, user, model_params = {})
+ def native_tool_support?
+ false
+ end
+
+ def perform_completion!(dialect, user, model_params = {}, &blk)
allow_tools = dialect.prompt.has_tools?
model_params = normalize_model_params(model_params)
@@ -111,14 +115,21 @@ module DiscourseAi
response_data = extract_completion_from(response_raw)
partials_raw = response_data.to_s
- if allow_tools && has_tool?(response_data)
- function_buffer = build_buffer # Nokogiri document
- function_buffer = add_to_function_buffer(function_buffer, payload: response_data)
+ if native_tool_support?
+ if allow_tools && has_tool?(response_data)
+ function_buffer = build_buffer # Nokogiri document
+ function_buffer =
+ add_to_function_buffer(function_buffer, payload: response_data)
+ FunctionCallNormalizer.normalize_function_ids!(function_buffer)
- normalize_function_ids!(function_buffer)
-
- response_data = +function_buffer.at("function_calls").to_s
- response_data << "\n"
+ response_data = +function_buffer.at("function_calls").to_s
+ response_data << "\n"
+ end
+ else
+ if allow_tools
+ response_data, function_calls = FunctionCallNormalizer.normalize(response_data)
+ response_data = function_calls if function_calls.present?
+ end
end
return response_data
@@ -128,7 +139,14 @@ module DiscourseAi
begin
cancelled = false
- cancel = lambda { cancelled = true }
+ cancel = -> { cancelled = true }
+
+ wrapped_blk = ->(partial, inner_cancel) do
+ response_data << partial
+ blk.call(partial, inner_cancel)
+ end
+
+ normalizer = FunctionCallNormalizer.new(wrapped_blk, cancel)
leftover = ""
function_buffer = build_buffer # Nokogiri document
@@ -159,7 +177,6 @@ module DiscourseAi
end
json_error = false
- buffered_partials = []
raw_partials.each do |raw_partial|
json_error = false
@@ -175,31 +192,24 @@ module DiscourseAi
next if response_data.empty? && partial.empty?
partials_raw << partial.to_s
- # Stop streaming the response as soon as you find a tool.
- # We'll buffer and yield it later.
- has_tool = true if allow_tools && has_tool?(partials_raw)
+ if native_tool_support?
+ # Stop streaming the response as soon as you find a tool.
+ # We'll buffer and yield it later.
+ has_tool = true if allow_tools && has_tool?(partials_raw)
- if has_tool
- if buffered_partials.present?
- joined = buffered_partials.join
- joined = joined.gsub(/<.+/, "")
- yield joined, cancel if joined.present?
- buffered_partials = []
- end
- function_buffer = add_to_function_buffer(function_buffer, partial: partial)
- else
- if maybe_has_tool?(partials_raw)
- buffered_partials << partial
+ if has_tool
+ function_buffer =
+ add_to_function_buffer(function_buffer, partial: partial)
else
- if buffered_partials.present?
- buffered_partials.each do |buffered_partial|
- response_data << buffered_partial
- yield buffered_partial, cancel
- end
- buffered_partials = []
- end
response_data << partial
- yield partial, cancel if partial
+ blk.call(partial, cancel) if partial
+ end
+ else
+ if allow_tools
+ normalizer << partial
+ else
+ response_data << partial
+ blk.call(partial, cancel) if partial
end
end
rescue JSON::ParserError
@@ -220,20 +230,25 @@ module DiscourseAi
end
# Once we have the full response, try to return the tool as a XML doc.
- if has_tool
+ if has_tool && native_tool_support?
function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw)
if function_buffer.at("tool_name").text.present?
- normalize_function_ids!(function_buffer)
+ FunctionCallNormalizer.normalize_function_ids!(function_buffer)
invocation = +function_buffer.at("function_calls").to_s
invocation << "\n"
response_data << invocation
- yield invocation, cancel
+ blk.call(invocation, cancel)
end
end
+ if !native_tool_support? && function_calls = normalizer.function_calls
+ response_data << function_calls
+ blk.call(function_calls, cancel)
+ end
+
return response_data
ensure
if log
@@ -250,21 +265,6 @@ module DiscourseAi
end
end
- def normalize_function_ids!(function_buffer)
- function_buffer
- .css("invoke")
- .each_with_index do |invoke, index|
- if invoke.at("tool_id")
- invoke.at("tool_id").content = "tool_#{index}" if invoke
- .at("tool_id")
- .content
- .blank?
- else
- invoke.add_child("tool_#{index}\n") if !invoke.at("tool_id")
- end
- end
- end
-
def final_log_update(log)
# for people that need to override
end
@@ -341,19 +341,6 @@ module DiscourseAi
response.include?("")
end
- def maybe_has_tool?(response)
- # 16 is the length of function calls
- substring = response[-16..-1] || response
- split = substring.split("<")
-
- if split.length > 1
- match = "<" + split.last
- "".start_with?(match)
- else
- false
- end
- end
-
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
if payload&.include?("")
matches = payload.match(%r{.*}m)
diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb
index dae4b483..55df3e18 100644
--- a/lib/completions/endpoints/gemini.rb
+++ b/lib/completions/endpoints/gemini.rb
@@ -105,9 +105,8 @@ module DiscourseAi
@has_function_call
end
- def maybe_has_tool?(_partial_raw)
- # we always get a full partial
- false
+ def native_tool_support?
+ true
end
def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb
index 50044810..f01f254f 100644
--- a/lib/completions/endpoints/open_ai.rb
+++ b/lib/completions/endpoints/open_ai.rb
@@ -162,9 +162,8 @@ module DiscourseAi
@has_function_call
end
- def maybe_has_tool?(_partial_raw)
- # we always get a full partial
- false
+ def native_tool_support?
+ true
end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
diff --git a/lib/completions/function_call_normalizer.rb b/lib/completions/function_call_normalizer.rb
new file mode 100644
index 00000000..ef40809c
--- /dev/null
+++ b/lib/completions/function_call_normalizer.rb
@@ -0,0 +1,113 @@
+# frozen_string_literal: true
+
+class DiscourseAi::Completions::FunctionCallNormalizer
+ attr_reader :done
+
+ # blk is the block to call with filtered data
+ def initialize(blk, cancel)
+ @blk = blk
+ @cancel = cancel
+ @done = false
+
+ @in_tool = false
+
+ @buffer = +""
+ @function_buffer = +""
+ end
+
+ def self.normalize(data)
+ text = +""
+ cancel = -> {}
+ blk = ->(partial, _) { text << partial }
+
+ normalizer = self.new(blk, cancel)
+ normalizer << data
+
+ [text, normalizer.function_calls]
+ end
+
+ def function_calls
+ return nil if @function_buffer.blank?
+
+ xml = Nokogiri::HTML5.fragment(@function_buffer)
+ self.class.normalize_function_ids!(xml)
+ last_invoke = xml.at("invoke:last")
+ if last_invoke
+ last_invoke.next_sibling.remove while last_invoke.next_sibling
+ xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling
+ end
+ xml.at("function_calls").to_s.dup.force_encoding("UTF-8")
+ end
+
+ def <<(text)
+ @buffer << text
+
+ if !@in_tool
+ # double check if we are clearly in a tool
+ search_length = text.length + 20
+ search_string = @buffer[-search_length..-1] || @buffer
+
+ index = search_string.rindex("")
+ @in_tool = !!index
+ if @in_tool
+ @function_buffer = @buffer[index..-1]
+ text_index = text.rindex("")
+ @blk.call(text[0..text_index - 1].strip, @cancel) if text_index && text_index > 0
+ end
+ else
+ @function_buffer << text
+ end
+
+ if !@in_tool
+ if maybe_has_tool?(@buffer)
+ split_index = text.rindex("<").to_i - 1
+ if split_index >= 0
+ @function_buffer = text[split_index + 1..-1] || ""
+ text = text[0..split_index] || ""
+ else
+ @function_buffer << text
+ text = ""
+ end
+ else
+ if @function_buffer.length > 0
+ @blk.call(@function_buffer, @cancel)
+ @function_buffer = +""
+ end
+ end
+
+ @blk.call(text, @cancel) if text.length > 0
+ else
+ if text.include?("")
+ @done = true
+ @cancel.call
+ end
+ end
+ end
+
+ def self.normalize_function_ids!(function_buffer)
+ function_buffer
+ .css("invoke")
+ .each_with_index do |invoke, index|
+ if invoke.at("tool_id")
+ invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank?
+ else
+ invoke.add_child("tool_#{index}\n") if !invoke.at("tool_id")
+ end
+ end
+ end
+
+ private
+
+ def maybe_has_tool?(text)
+ # 16 is the length of function calls
+ substring = text[-16..-1] || text
+ split = substring.split("<")
+
+ if split.length > 1
+ match = "<" + split.last
+ "".start_with?(match)
+ else
+ substring.ends_with?("<")
+ end
+ end
+end
diff --git a/spec/lib/completions/endpoints/endpoint_compliance.rb b/spec/lib/completions/endpoints/endpoint_compliance.rb
index 82c63ecf..d88de9b4 100644
--- a/spec/lib/completions/endpoints/endpoint_compliance.rb
+++ b/spec/lib/completions/endpoints/endpoint_compliance.rb
@@ -40,7 +40,7 @@ class EndpointMock
end
def tool_deltas
- ["Let me use a tool for that
get_weather
@@ -185,7 +185,7 @@ class EndpointsCompliance
mock.stub_tool_call(a_dialect.translate)
completion_response = endpoint.perform_completion!(a_dialect, user)
- expect(completion_response).to eq(mock.invocation_response)
+ expect(completion_response.strip).to eq(mock.invocation_response.strip)
end
def streaming_mode_simple_prompt(mock)
@@ -223,7 +223,7 @@ class EndpointsCompliance
cancel.call if buffered_partial.include?("")
end
- expect(buffered_partial).to eq(mock.invocation_response)
+ expect(buffered_partial.strip).to eq(mock.invocation_response.strip)
end
end
diff --git a/spec/lib/completions/function_call_normalizer_spec.rb b/spec/lib/completions/function_call_normalizer_spec.rb
new file mode 100644
index 00000000..dd78ed7f
--- /dev/null
+++ b/spec/lib/completions/function_call_normalizer_spec.rb
@@ -0,0 +1,182 @@
+# frozen_string_literal: true
+
+RSpec.describe DiscourseAi::Completions::FunctionCallNormalizer do
+ let(:buffer) { +"" }
+
+ let(:normalizer) do
+ blk = ->(data, cancel) { buffer << data }
+ cancel = -> { @done = true }
+ DiscourseAi::Completions::FunctionCallNormalizer.new(blk, cancel)
+ end
+
+ def pass_through!(data)
+ normalizer << data
+ expect(buffer[-data.length..-1]).to eq(data)
+ end
+
+ it "is usable in non streaming mode" do
+ xml = (<<~XML).strip
+ hello
+
+
+ hello
+
+ XML
+
+ text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml)
+
+ expect(text).to eq("hello")
+
+ expected_function_calls = (<<~XML).strip
+
+
+ hello
+ tool_0
+
+
+ XML
+
+ expect(function_calls).to eq(expected_function_calls)
+ end
+
+ it "strips junk from end of function calls" do
+ xml = (<<~XML).strip
+ hello
+
+
+ hello
+
+ junk
+ XML
+
+ _text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml)
+
+ expected_function_calls = (<<~XML).strip
+
+
+ hello
+ tool_0
+
+
+ XML
+
+ expect(function_calls).to eq(expected_function_calls)
+ end
+
+ it "returns nil for function calls if there are none" do
+ input = "hello world\n"
+ text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(input)
+
+ expect(text).to eq(input)
+ expect(function_calls).to eq(nil)
+ end
+
+ it "passes through data if there are no function calls detected" do
+ pass_through!("hello")
+ pass_through!("hello")
+ pass_through!("world")
+ pass_through!("")
+ end
+
+ it "properly handles non English tools" do
+ normalizer << "hello\n"
+
+ normalizer << (<<~XML).strip
+
+ hello
+
+ 世界
+
+
+ XML
+
+ expected = (<<~XML).strip
+
+
+ hello
+
+ 世界
+
+ tool_0
+
+
+ XML
+
+ function_calls = normalizer.function_calls
+ expect(function_calls).to eq(expected)
+ end
+
+ it "works correctly even if you only give it 1 letter at a time" do
+ xml = (<<~XML).strip
+ abc
+
+
+ hello
+
+ world
+
+ abc
+
+
+ hello2
+
+ world
+
+ aba
+
+
+ XML
+
+ xml.each_char { |char| normalizer << char }
+
+ expect(buffer + normalizer.function_calls).to eq(xml)
+ end
+
+ it "supports multiple invokes" do
+ xml = (<<~XML).strip
+
+
+ hello
+
+ world
+
+ abc
+
+
+ hello2
+
+ world
+
+ aba
+
+
+ XML
+
+ normalizer << xml
+
+ expect(normalizer.function_calls).to eq(xml)
+ end
+
+ it "can will cancel if it encounteres " do
+ normalizer << ""
+ expect(normalizer.done).to eq(false)
+ normalizer << ""
+ expect(normalizer.done).to eq(true)
+ expect(@done).to eq(true)
+
+ expect(normalizer.function_calls).to eq("")
+ end
+
+ it "pauses on function call and starts buffering" do
+ normalizer << "hello"
+ expect(buffer).to eq("hello")
+ expect(normalizer.done).to eq(false)
+ end
+end