discourse-ai/lib/completions/xml_tool_processor.rb
Sam e817b7dc11
FEATURE: improve tool support (#904)
This re-implements tool support in DiscourseAi::Completions::Llm #generate

Previously tool support was always returned via XML and it would be the responsibility of the caller to parse XML

New implementation has the endpoints return ToolCall objects.

Additionally this simplifies the Llm endpoint interface and gives it more clarity. Llms must implement

decode, decode_chunk (for streaming)

It is the implementers responsibility to figure out how to decode chunks, base no longer implements. To make this easy we ship a flexible json decoder which is easy to wire up.

Also (new)

    Better debugging for PMs, we now have a next / previous button to see all the Llm messages associated with a PM
    Token accounting is fixed for vllm (we were not correctly counting tokens)
2024-11-12 08:14:30 +11:00

125 lines
3.5 KiB
Ruby

# frozen_string_literal: true
# This class can be used to process a stream of text that may contain XML tool
# calls.
# It will return either text or ToolCall objects.
module DiscourseAi
module Completions
class XmlToolProcessor
def initialize
@buffer = +""
@function_buffer = +""
@should_cancel = false
@in_tool = false
end
def <<(text)
@buffer << text
result = []
if !@in_tool
# double check if we are clearly in a tool
search_length = text.length + 20
search_string = @buffer[-search_length..-1] || @buffer
index = search_string.rindex("<function_calls>")
@in_tool = !!index
if @in_tool
@function_buffer = @buffer[index..-1]
text_index = text.rindex("<function_calls>")
result << text[0..text_index - 1].strip if text_index && text_index > 0
end
else
@function_buffer << text
end
if !@in_tool
if maybe_has_tool?(@buffer)
split_index = text.rindex("<").to_i - 1
if split_index >= 0
@function_buffer = text[split_index + 1..-1] || ""
text = text[0..split_index] || ""
else
@function_buffer << text
text = ""
end
else
if @function_buffer.length > 0
result << @function_buffer
@function_buffer = +""
end
end
result << text if text.length > 0
else
@should_cancel = true if text.include?("</function_calls>")
end
result
end
def finish
return [] if @function_buffer.blank?
xml = Nokogiri::HTML5.fragment(@function_buffer)
normalize_function_ids!(xml)
last_invoke = xml.at("invoke:last")
if last_invoke
last_invoke.next_sibling.remove while last_invoke.next_sibling
xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling
end
xml
.css("invoke")
.map do |invoke|
tool_name = invoke.at("tool_name").content.force_encoding("UTF-8")
tool_id = invoke.at("tool_id").content.force_encoding("UTF-8")
parameters = {}
invoke
.at("parameters")
&.children
&.each do |node|
next if node.text?
name = node.name
value = node.content.to_s
parameters[name.to_sym] = value.to_s.force_encoding("UTF-8")
end
ToolCall.new(id: tool_id, name: tool_name, parameters: parameters)
end
end
def should_cancel?
@should_cancel
end
private
def normalize_function_ids!(function_buffer)
function_buffer
.css("invoke")
.each_with_index do |invoke, index|
if invoke.at("tool_id")
invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank?
else
invoke.add_child("<tool_id>tool_#{index}</tool_id>\n") if !invoke.at("tool_id")
end
end
end
def maybe_has_tool?(text)
# 16 is the length of function calls
substring = text[-16..-1] || text
split = substring.split("<")
if split.length > 1
match = "<" + split.last
"<function_calls>".start_with?(match)
else
substring.ends_with?("<")
end
end
end
end
end