# frozen_string_literal: true # This class can be used to process a stream of text that may contain XML tool # calls. # It will return either text or ToolCall objects. module DiscourseAi module Completions class XmlToolProcessor def initialize(partial_tool_calls: false) @buffer = +"" @function_buffer = +"" @should_cancel = false @in_tool = false @partial_tool_calls = partial_tool_calls @partial_tools = [] if @partial_tool_calls end def <<(text) @buffer << text result = [] if !@in_tool # double check if we are clearly in a tool search_length = text.length + 20 search_string = @buffer[-search_length..-1] || @buffer index = search_string.rindex("") @in_tool = !!index if @in_tool @function_buffer = @buffer[index..-1] text_index = text.rindex("") result << text[0..text_index - 1].rstrip if text_index && text_index > 0 end else add_to_function_buffer(text) end if !@in_tool if maybe_has_tool?(@buffer) split_index = text.rindex("<").to_i - 1 if split_index >= 0 @function_buffer = text[split_index + 1..-1] || "" text = text[0..split_index] || "" else add_to_function_buffer(text) text = "" end else if @function_buffer.length > 0 result << @function_buffer @function_buffer = +"" end end result << text if text.length > 0 else @should_cancel = true if text.include?("") end if @should_notify_partial_tool @should_notify_partial_tool = false result << @partial_tools.last end result end def finish return [] if @function_buffer.blank? idx = -1 parse_malformed_xml(@function_buffer).map do |tool| ToolCall.new( id: "tool_#{idx += 1}", name: tool[:tool_name], parameters: tool[:parameters], ) end end def should_cancel? @should_cancel end private def add_to_function_buffer(text) @function_buffer << text detect_partial_tool_calls(@function_buffer, text) if @partial_tool_calls end def detect_partial_tool_calls(buffer, delta) parse_partial_tool_call(buffer) end def parse_partial_tool_call(buffer) match = buffer .scan( %r{ \s* ([^<]+) \s* (.*?) (|\Z) }mx, ) .to_a .last if match params = partial_parse_params(match[1]) if params.present? current_tool = @partial_tools.last if !current_tool || current_tool.name != match[0].strip current_tool = ToolCall.new( id: "tool_#{@partial_tools.length}", name: match[0].strip, parameters: params, ) @partial_tools << current_tool current_tool.partial = true @should_notify_partial_tool = true end if current_tool.parameters != params current_tool.parameters = params @should_notify_partial_tool = true end end end end def partial_parse_params(params) params .scan(%r{ <([^>]+)> (.*?) (|\Z) }mx) .each_with_object({}) do |(name, value), hash| next if "$/, "") end end def parse_malformed_xml(input) input .scan( %r{ \s* ([^<]+) \s* (.*?) \s* }mx, ) .map do |tool_name, params| { tool_name: tool_name.strip, parameters: params .scan(%r{ <([^>]+)> (.*?) }mx) .each_with_object({}) do |(name, value), hash| hash[name.to_sym] = value.gsub(/^$/, "") end, } end end def normalize_function_ids!(function_buffer) function_buffer .css("invoke") .each_with_index do |invoke, index| if invoke.at("tool_id") invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank? else invoke.add_child("tool_#{index}\n") if !invoke.at("tool_id") end end end def maybe_has_tool?(text) # 16 is the length of function calls substring = text[-16..-1] || text split = substring.split("<") if split.length > 1 match = "<" + split.last "".start_with?(match) else substring.ends_with?("<") end end end end end