diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 268ea510..0766a9a4 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -63,7 +63,11 @@ module DiscourseAi @tokenizer = tokenizer end - def perform_completion!(dialect, user, model_params = {}) + def native_tool_support? + false + end + + def perform_completion!(dialect, user, model_params = {}, &blk) allow_tools = dialect.prompt.has_tools? model_params = normalize_model_params(model_params) @@ -111,14 +115,21 @@ module DiscourseAi response_data = extract_completion_from(response_raw) partials_raw = response_data.to_s - if allow_tools && has_tool?(response_data) - function_buffer = build_buffer # Nokogiri document - function_buffer = add_to_function_buffer(function_buffer, payload: response_data) + if native_tool_support? + if allow_tools && has_tool?(response_data) + function_buffer = build_buffer # Nokogiri document + function_buffer = + add_to_function_buffer(function_buffer, payload: response_data) + FunctionCallNormalizer.normalize_function_ids!(function_buffer) - normalize_function_ids!(function_buffer) - - response_data = +function_buffer.at("function_calls").to_s - response_data << "\n" + response_data = +function_buffer.at("function_calls").to_s + response_data << "\n" + end + else + if allow_tools + response_data, function_calls = FunctionCallNormalizer.normalize(response_data) + response_data = function_calls if function_calls.present? + end end return response_data @@ -128,7 +139,14 @@ module DiscourseAi begin cancelled = false - cancel = lambda { cancelled = true } + cancel = -> { cancelled = true } + + wrapped_blk = ->(partial, inner_cancel) do + response_data << partial + blk.call(partial, inner_cancel) + end + + normalizer = FunctionCallNormalizer.new(wrapped_blk, cancel) leftover = "" function_buffer = build_buffer # Nokogiri document @@ -159,7 +177,6 @@ module DiscourseAi end json_error = false - buffered_partials = [] raw_partials.each do |raw_partial| json_error = false @@ -175,31 +192,24 @@ module DiscourseAi next if response_data.empty? && partial.empty? partials_raw << partial.to_s - # Stop streaming the response as soon as you find a tool. - # We'll buffer and yield it later. - has_tool = true if allow_tools && has_tool?(partials_raw) + if native_tool_support? + # Stop streaming the response as soon as you find a tool. + # We'll buffer and yield it later. + has_tool = true if allow_tools && has_tool?(partials_raw) - if has_tool - if buffered_partials.present? - joined = buffered_partials.join - joined = joined.gsub(/<.+/, "") - yield joined, cancel if joined.present? - buffered_partials = [] - end - function_buffer = add_to_function_buffer(function_buffer, partial: partial) - else - if maybe_has_tool?(partials_raw) - buffered_partials << partial + if has_tool + function_buffer = + add_to_function_buffer(function_buffer, partial: partial) else - if buffered_partials.present? - buffered_partials.each do |buffered_partial| - response_data << buffered_partial - yield buffered_partial, cancel - end - buffered_partials = [] - end response_data << partial - yield partial, cancel if partial + blk.call(partial, cancel) if partial + end + else + if allow_tools + normalizer << partial + else + response_data << partial + blk.call(partial, cancel) if partial end end rescue JSON::ParserError @@ -220,20 +230,25 @@ module DiscourseAi end # Once we have the full response, try to return the tool as a XML doc. - if has_tool + if has_tool && native_tool_support? function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw) if function_buffer.at("tool_name").text.present? - normalize_function_ids!(function_buffer) + FunctionCallNormalizer.normalize_function_ids!(function_buffer) invocation = +function_buffer.at("function_calls").to_s invocation << "\n" response_data << invocation - yield invocation, cancel + blk.call(invocation, cancel) end end + if !native_tool_support? && function_calls = normalizer.function_calls + response_data << function_calls + blk.call(function_calls, cancel) + end + return response_data ensure if log @@ -250,21 +265,6 @@ module DiscourseAi end end - def normalize_function_ids!(function_buffer) - function_buffer - .css("invoke") - .each_with_index do |invoke, index| - if invoke.at("tool_id") - invoke.at("tool_id").content = "tool_#{index}" if invoke - .at("tool_id") - .content - .blank? - else - invoke.add_child("tool_#{index}\n") if !invoke.at("tool_id") - end - end - end - def final_log_update(log) # for people that need to override end @@ -341,19 +341,6 @@ module DiscourseAi response.include?("") end - def maybe_has_tool?(response) - # 16 is the length of function calls - substring = response[-16..-1] || response - split = substring.split("<") - - if split.length > 1 - match = "<" + split.last - "".start_with?(match) - else - false - end - end - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) if payload&.include?("") matches = payload.match(%r{.*}m) diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index dae4b483..55df3e18 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -105,9 +105,8 @@ module DiscourseAi @has_function_call end - def maybe_has_tool?(_partial_raw) - # we always get a full partial - false + def native_tool_support? + true end def add_to_function_buffer(function_buffer, payload: nil, partial: nil) diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index 50044810..f01f254f 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -162,9 +162,8 @@ module DiscourseAi @has_function_call end - def maybe_has_tool?(_partial_raw) - # we always get a full partial - false + def native_tool_support? + true end def add_to_function_buffer(function_buffer, partial: nil, payload: nil) diff --git a/lib/completions/function_call_normalizer.rb b/lib/completions/function_call_normalizer.rb new file mode 100644 index 00000000..ef40809c --- /dev/null +++ b/lib/completions/function_call_normalizer.rb @@ -0,0 +1,113 @@ +# frozen_string_literal: true + +class DiscourseAi::Completions::FunctionCallNormalizer + attr_reader :done + + # blk is the block to call with filtered data + def initialize(blk, cancel) + @blk = blk + @cancel = cancel + @done = false + + @in_tool = false + + @buffer = +"" + @function_buffer = +"" + end + + def self.normalize(data) + text = +"" + cancel = -> {} + blk = ->(partial, _) { text << partial } + + normalizer = self.new(blk, cancel) + normalizer << data + + [text, normalizer.function_calls] + end + + def function_calls + return nil if @function_buffer.blank? + + xml = Nokogiri::HTML5.fragment(@function_buffer) + self.class.normalize_function_ids!(xml) + last_invoke = xml.at("invoke:last") + if last_invoke + last_invoke.next_sibling.remove while last_invoke.next_sibling + xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling + end + xml.at("function_calls").to_s.dup.force_encoding("UTF-8") + end + + def <<(text) + @buffer << text + + if !@in_tool + # double check if we are clearly in a tool + search_length = text.length + 20 + search_string = @buffer[-search_length..-1] || @buffer + + index = search_string.rindex("") + @in_tool = !!index + if @in_tool + @function_buffer = @buffer[index..-1] + text_index = text.rindex("") + @blk.call(text[0..text_index - 1].strip, @cancel) if text_index && text_index > 0 + end + else + @function_buffer << text + end + + if !@in_tool + if maybe_has_tool?(@buffer) + split_index = text.rindex("<").to_i - 1 + if split_index >= 0 + @function_buffer = text[split_index + 1..-1] || "" + text = text[0..split_index] || "" + else + @function_buffer << text + text = "" + end + else + if @function_buffer.length > 0 + @blk.call(@function_buffer, @cancel) + @function_buffer = +"" + end + end + + @blk.call(text, @cancel) if text.length > 0 + else + if text.include?("") + @done = true + @cancel.call + end + end + end + + def self.normalize_function_ids!(function_buffer) + function_buffer + .css("invoke") + .each_with_index do |invoke, index| + if invoke.at("tool_id") + invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank? + else + invoke.add_child("tool_#{index}\n") if !invoke.at("tool_id") + end + end + end + + private + + def maybe_has_tool?(text) + # 16 is the length of function calls + substring = text[-16..-1] || text + split = substring.split("<") + + if split.length > 1 + match = "<" + split.last + "".start_with?(match) + else + substring.ends_with?("<") + end + end +end diff --git a/spec/lib/completions/endpoints/endpoint_compliance.rb b/spec/lib/completions/endpoints/endpoint_compliance.rb index 82c63ecf..d88de9b4 100644 --- a/spec/lib/completions/endpoints/endpoint_compliance.rb +++ b/spec/lib/completions/endpoints/endpoint_compliance.rb @@ -40,7 +40,7 @@ class EndpointMock end def tool_deltas - ["Let me use a tool for that get_weather @@ -185,7 +185,7 @@ class EndpointsCompliance mock.stub_tool_call(a_dialect.translate) completion_response = endpoint.perform_completion!(a_dialect, user) - expect(completion_response).to eq(mock.invocation_response) + expect(completion_response.strip).to eq(mock.invocation_response.strip) end def streaming_mode_simple_prompt(mock) @@ -223,7 +223,7 @@ class EndpointsCompliance cancel.call if buffered_partial.include?("") end - expect(buffered_partial).to eq(mock.invocation_response) + expect(buffered_partial.strip).to eq(mock.invocation_response.strip) end end diff --git a/spec/lib/completions/function_call_normalizer_spec.rb b/spec/lib/completions/function_call_normalizer_spec.rb new file mode 100644 index 00000000..dd78ed7f --- /dev/null +++ b/spec/lib/completions/function_call_normalizer_spec.rb @@ -0,0 +1,182 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Completions::FunctionCallNormalizer do + let(:buffer) { +"" } + + let(:normalizer) do + blk = ->(data, cancel) { buffer << data } + cancel = -> { @done = true } + DiscourseAi::Completions::FunctionCallNormalizer.new(blk, cancel) + end + + def pass_through!(data) + normalizer << data + expect(buffer[-data.length..-1]).to eq(data) + end + + it "is usable in non streaming mode" do + xml = (<<~XML).strip + hello + + + hello + + XML + + text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml) + + expect(text).to eq("hello") + + expected_function_calls = (<<~XML).strip + + + hello + tool_0 + + + XML + + expect(function_calls).to eq(expected_function_calls) + end + + it "strips junk from end of function calls" do + xml = (<<~XML).strip + hello + + + hello + + junk + XML + + _text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml) + + expected_function_calls = (<<~XML).strip + + + hello + tool_0 + + + XML + + expect(function_calls).to eq(expected_function_calls) + end + + it "returns nil for function calls if there are none" do + input = "hello world\n" + text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(input) + + expect(text).to eq(input) + expect(function_calls).to eq(nil) + end + + it "passes through data if there are no function calls detected" do + pass_through!("hello") + pass_through!("hello") + pass_through!("world") + pass_through!("") + end + + it "properly handles non English tools" do + normalizer << "hello\n" + + normalizer << (<<~XML).strip + + hello + + 世界 + + + XML + + expected = (<<~XML).strip + + + hello + + 世界 + + tool_0 + + + XML + + function_calls = normalizer.function_calls + expect(function_calls).to eq(expected) + end + + it "works correctly even if you only give it 1 letter at a time" do + xml = (<<~XML).strip + abc + + + hello + + world + + abc + + + hello2 + + world + + aba + + + XML + + xml.each_char { |char| normalizer << char } + + expect(buffer + normalizer.function_calls).to eq(xml) + end + + it "supports multiple invokes" do + xml = (<<~XML).strip + + + hello + + world + + abc + + + hello2 + + world + + aba + + + XML + + normalizer << xml + + expect(normalizer.function_calls).to eq(xml) + end + + it "can will cancel if it encounteres " do + normalizer << "" + expect(normalizer.done).to eq(false) + normalizer << "" + expect(normalizer.done).to eq(true) + expect(@done).to eq(true) + + expect(normalizer.function_calls).to eq("") + end + + it "pauses on function call and starts buffering" do + normalizer << "hello" + expect(buffer).to eq("hello") + expect(normalizer.done).to eq(false) + end +end