FEATURE: support tool progress callbacks
This is anthropic only for now, but we can get a callback as tool is completing, this gives us the ability to show progress to user as the function is populating.
This commit is contained in:
parent
0191b41877
commit
fd7ccfd0ab
|
@ -105,8 +105,15 @@ module DiscourseAi
|
|||
tool_found = false
|
||||
force_tool_if_needed(prompt, context)
|
||||
|
||||
tool_progress = proc { |progress| p progress }
|
||||
|
||||
result =
|
||||
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
|
||||
llm.generate(
|
||||
prompt,
|
||||
feature_name: "bot",
|
||||
tool_progress: tool_progress,
|
||||
**llm_kwargs,
|
||||
) do |partial, cancel|
|
||||
tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context)
|
||||
|
||||
if (tools.present?)
|
||||
|
|
|
@ -1,25 +1,83 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class DiscourseAi::Completions::AnthropicMessageProcessor
|
||||
class AnthropicToolCall
|
||||
attr_reader :name, :raw_json, :id
|
||||
class ToolCallProgressTracker
|
||||
attr_reader :current_key, :current_value, :tool_call
|
||||
|
||||
def initialize(name, id)
|
||||
def initialize(tool_call)
|
||||
@tool_call = tool_call
|
||||
@current_key = nil
|
||||
@current_value = nil
|
||||
@parser = DiscourseAi::Utils::JsonStreamingParser.new
|
||||
|
||||
@parser.key do |k|
|
||||
@current_key = k
|
||||
@current_value = nil
|
||||
end
|
||||
@parser.value do |v|
|
||||
@current_value = v
|
||||
|
||||
if @current_key
|
||||
tool_call.tool_progress.call(
|
||||
{ name: tool_call.name, id: tool_call.id, key: @current_key, value: @current_value },
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def <<(json)
|
||||
# llm could send broken json
|
||||
# in that case just deal with it later
|
||||
# don't stream
|
||||
return if @broken
|
||||
|
||||
begin
|
||||
@parser << json
|
||||
rescue DiscourseAi::Utils::ParserError
|
||||
@broken = true
|
||||
return
|
||||
end
|
||||
|
||||
if @parser.state == :start_string && @current_key
|
||||
# this is is worth notifying
|
||||
tool_call.tool_progress.call(
|
||||
{
|
||||
name: tool_call.name,
|
||||
id: tool_call.id,
|
||||
key: @current_key,
|
||||
value: @parser.buf,
|
||||
done: false,
|
||||
},
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
class AnthropicToolCall
|
||||
attr_reader :name, :raw_json, :id, :tool_progress
|
||||
|
||||
def initialize(name, id, tool_progress)
|
||||
@name = name
|
||||
@id = id
|
||||
@raw_json = +""
|
||||
if tool_progress
|
||||
@tool_progress = tool_progress
|
||||
@tool_call_progress_tracker = ToolCallProgressTracker.new(self)
|
||||
end
|
||||
end
|
||||
|
||||
def append(json)
|
||||
@raw_json << json
|
||||
@tool_call_progress_tracker << json if @tool_progress
|
||||
end
|
||||
end
|
||||
|
||||
attr_reader :tool_calls, :input_tokens, :output_tokens
|
||||
|
||||
def initialize(streaming_mode:)
|
||||
def initialize(streaming_mode:, tool_progress:)
|
||||
@streaming_mode = streaming_mode
|
||||
@tool_calls = []
|
||||
@tool_progress = tool_progress
|
||||
end
|
||||
|
||||
def to_xml_tool_calls(function_buffer)
|
||||
|
@ -58,7 +116,7 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
|
|||
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
|
||||
tool_name = parsed.dig(:content_block, :name)
|
||||
tool_id = parsed.dig(:content_block, :id)
|
||||
@tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name
|
||||
@tool_calls << AnthropicToolCall.new(tool_name, tool_id, @tool_progress) if tool_name
|
||||
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
||||
if @tool_calls.present?
|
||||
result = parsed.dig(:delta, :partial_json).to_s
|
||||
|
@ -83,7 +141,7 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
|
|||
if content.is_a?(Array)
|
||||
tool_call = content.find { |c| c[:type] == "tool_use" }
|
||||
if tool_call
|
||||
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
|
||||
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id], @tool_progress)
|
||||
@tool_calls.last.append(tool_call[:input].to_json)
|
||||
else
|
||||
result = parsed.dig(:content, 0, :text).to_s
|
||||
|
|
|
@ -92,7 +92,10 @@ module DiscourseAi
|
|||
|
||||
def processor
|
||||
@processor ||=
|
||||
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
|
||||
DiscourseAi::Completions::AnthropicMessageProcessor.new(
|
||||
streaming_mode: @streaming_mode,
|
||||
tool_progress: @tool_progress,
|
||||
)
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||
|
|
|
@ -157,7 +157,10 @@ module DiscourseAi
|
|||
|
||||
def processor
|
||||
@processor ||=
|
||||
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
|
||||
DiscourseAi::Completions::AnthropicMessageProcessor.new(
|
||||
streaming_mode: @streaming_mode,
|
||||
tool_progress: @tool_progress,
|
||||
)
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||
|
|
|
@ -62,12 +62,14 @@ module DiscourseAi
|
|||
model_params = {},
|
||||
feature_name: nil,
|
||||
feature_context: nil,
|
||||
tool_progress: nil,
|
||||
&blk
|
||||
)
|
||||
allow_tools = dialect.prompt.has_tools?
|
||||
model_params = normalize_model_params(model_params)
|
||||
orig_blk = blk
|
||||
|
||||
@tool_progress = tool_progress
|
||||
@streaming_mode = block_given?
|
||||
to_strip = xml_tags_to_strip(dialect)
|
||||
@xml_stripper =
|
||||
|
|
|
@ -28,7 +28,8 @@ module DiscourseAi
|
|||
_user,
|
||||
_model_params,
|
||||
feature_name: nil,
|
||||
feature_context: nil
|
||||
feature_context: nil,
|
||||
tool_progress: nil
|
||||
)
|
||||
@dialect = dialect
|
||||
response = responses[completions]
|
||||
|
|
|
@ -116,7 +116,8 @@ module DiscourseAi
|
|||
user,
|
||||
model_params = {},
|
||||
feature_name: nil,
|
||||
feature_context: nil
|
||||
feature_context: nil,
|
||||
tool_progress: nil
|
||||
)
|
||||
self.class.last_call = { dialect: dialect, user: user, model_params: model_params }
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ module DiscourseAi
|
|||
model_params = {},
|
||||
feature_name: nil,
|
||||
feature_context: nil,
|
||||
tool_progress: nil,
|
||||
&blk
|
||||
)
|
||||
if dialect.respond_to?(:is_gpt_o?) && dialect.is_gpt_o? && block_given?
|
||||
|
|
|
@ -192,6 +192,7 @@ module DiscourseAi
|
|||
user:,
|
||||
feature_name: nil,
|
||||
feature_context: nil,
|
||||
tool_progress: nil,
|
||||
&partial_read_blk
|
||||
)
|
||||
self.class.record_prompt(prompt)
|
||||
|
@ -226,6 +227,7 @@ module DiscourseAi
|
|||
model_params,
|
||||
feature_name: feature_name,
|
||||
feature_context: feature_context,
|
||||
tool_progress: tool_progress,
|
||||
&partial_read_blk
|
||||
)
|
||||
end
|
||||
|
|
|
@ -0,0 +1,666 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
# This code is copied from the MIT licensed json-stream
|
||||
#
|
||||
# It was copied to avoid the dependency and allow us to make some small changes
|
||||
# particularly we need better access to internal state when parsing
|
||||
|
||||
module DiscourseAi
|
||||
module Utils
|
||||
# Raised on any invalid JSON text.
|
||||
ParserError = Class.new(RuntimeError)
|
||||
|
||||
# A streaming JSON parser that generates SAX-like events for state changes.
|
||||
# Use the json gem for small documents. Use this for huge documents that
|
||||
# won't fit in memory.
|
||||
#
|
||||
# Examples
|
||||
#
|
||||
# parser = JSON::Stream::Parser.new
|
||||
# parser.key { |key| puts key }
|
||||
# parser.value { |value| puts value }
|
||||
# parser << '{"answer":'
|
||||
# parser << ' 42}'
|
||||
class JsonStreamingParser
|
||||
# our changes:
|
||||
attr_reader :state, :buf, :pos
|
||||
|
||||
# A character buffer that expects a UTF-8 encoded stream of bytes.
|
||||
# This handles truncated multi-byte characters properly so we can just
|
||||
# feed it binary data and receive a properly formatted UTF-8 String as
|
||||
# output.
|
||||
#
|
||||
# More UTF-8 parsing details are available at:
|
||||
#
|
||||
# http://en.wikipedia.org/wiki/UTF-8
|
||||
# http://tools.ietf.org/html/rfc3629#section-3
|
||||
class Buffer
|
||||
def initialize
|
||||
@state = :start
|
||||
@buffer = []
|
||||
@need = 0
|
||||
end
|
||||
|
||||
# Fill the buffer with a String of binary UTF-8 encoded bytes. Returns
|
||||
# as much of the data in a UTF-8 String as we have. Truncated multi-byte
|
||||
# characters are saved in the buffer until the next call to this method
|
||||
# where we expect to receive the rest of the multi-byte character.
|
||||
#
|
||||
# data - The partial binary encoded String data.
|
||||
#
|
||||
# Raises JSON::Stream::ParserError if the UTF-8 byte sequence is malformed.
|
||||
#
|
||||
# Returns a UTF-8 encoded String.
|
||||
def <<(data)
|
||||
# Avoid state machine for complete UTF-8.
|
||||
if @buffer.empty?
|
||||
data.force_encoding(Encoding::UTF_8)
|
||||
return data if data.valid_encoding?
|
||||
end
|
||||
|
||||
bytes = []
|
||||
data.each_byte do |byte|
|
||||
case @state
|
||||
when :start
|
||||
if byte < 128
|
||||
bytes << byte
|
||||
elsif byte >= 192
|
||||
@state = :multi_byte
|
||||
@buffer << byte
|
||||
@need =
|
||||
case
|
||||
when byte >= 240
|
||||
4
|
||||
when byte >= 224
|
||||
3
|
||||
when byte >= 192
|
||||
2
|
||||
end
|
||||
else
|
||||
error("Expected start of multi-byte or single byte char")
|
||||
end
|
||||
when :multi_byte
|
||||
if byte > 127 && byte < 192
|
||||
@buffer << byte
|
||||
if @buffer.size == @need
|
||||
bytes += @buffer.slice!(0, @buffer.size)
|
||||
@state = :start
|
||||
end
|
||||
else
|
||||
error("Expected continuation byte")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# Build UTF-8 encoded string from completed codepoints.
|
||||
bytes
|
||||
.pack("C*")
|
||||
.force_encoding(Encoding::UTF_8)
|
||||
.tap { |text| error("Invalid UTF-8 byte sequence") unless text.valid_encoding? }
|
||||
end
|
||||
|
||||
# Determine if the buffer contains partial UTF-8 continuation bytes that
|
||||
# are waiting on subsequent completion bytes before a full codepoint is
|
||||
# formed.
|
||||
#
|
||||
# Examples
|
||||
#
|
||||
# bytes = "é".bytes
|
||||
#
|
||||
# buffer << bytes[0]
|
||||
# buffer.empty?
|
||||
# # => false
|
||||
#
|
||||
# buffer << bytes[1]
|
||||
# buffer.empty?
|
||||
# # => true
|
||||
#
|
||||
# Returns true if the buffer is empty.
|
||||
def empty?
|
||||
@buffer.empty?
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def error(message)
|
||||
raise ParserError, message
|
||||
end
|
||||
end
|
||||
|
||||
BUF_SIZE = 4096
|
||||
CONTROL = /[\x00-\x1F]/
|
||||
WS = /[ \n\t\r]/
|
||||
HEX = /[0-9a-fA-F]/
|
||||
DIGIT = /[0-9]/
|
||||
DIGIT_1_9 = /[1-9]/
|
||||
DIGIT_END = /\d$/
|
||||
TRUE_RE = /[rue]/
|
||||
FALSE_RE = /[alse]/
|
||||
NULL_RE = /[ul]/
|
||||
TRUE_KEYWORD = "true"
|
||||
FALSE_KEYWORD = "false"
|
||||
NULL_KEYWORD = "null"
|
||||
LEFT_BRACE = "{"
|
||||
RIGHT_BRACE = "}"
|
||||
LEFT_BRACKET = "["
|
||||
RIGHT_BRACKET = "]"
|
||||
BACKSLASH = '\\'
|
||||
SLASH = "/"
|
||||
QUOTE = '"'
|
||||
COMMA = ","
|
||||
COLON = ":"
|
||||
ZERO = "0"
|
||||
MINUS = "-"
|
||||
PLUS = "+"
|
||||
POINT = "."
|
||||
EXPONENT = /[eE]/
|
||||
B, F, N, R, T, U = %w[b f n r t u]
|
||||
|
||||
# Create a new parser with an optional initialization block where
|
||||
# we can register event callbacks.
|
||||
#
|
||||
# Examples
|
||||
#
|
||||
# parser = JSON::Stream::Parser.new do
|
||||
# start_document { puts "start document" }
|
||||
# end_document { puts "end document" }
|
||||
# start_object { puts "start object" }
|
||||
# end_object { puts "end object" }
|
||||
# start_array { puts "start array" }
|
||||
# end_array { puts "end array" }
|
||||
# key { |k| puts "key: #{k}" }
|
||||
# value { |v| puts "value: #{v}" }
|
||||
# end
|
||||
def initialize(&block)
|
||||
@state = :start_document
|
||||
@utf8 = Buffer.new
|
||||
@listeners = {
|
||||
start_document: [],
|
||||
end_document: [],
|
||||
start_object: [],
|
||||
end_object: [],
|
||||
start_array: [],
|
||||
end_array: [],
|
||||
key: [],
|
||||
value: [],
|
||||
}
|
||||
|
||||
# Track parse stack.
|
||||
@stack = []
|
||||
@unicode = +""
|
||||
@buf = +""
|
||||
@pos = -1
|
||||
|
||||
# Register any observers in the block.
|
||||
instance_eval(&block) if block_given?
|
||||
end
|
||||
|
||||
def start_document(&block)
|
||||
@listeners[:start_document] << block
|
||||
end
|
||||
|
||||
def end_document(&block)
|
||||
@listeners[:end_document] << block
|
||||
end
|
||||
|
||||
def start_object(&block)
|
||||
@listeners[:start_object] << block
|
||||
end
|
||||
|
||||
def end_object(&block)
|
||||
@listeners[:end_object] << block
|
||||
end
|
||||
|
||||
def start_array(&block)
|
||||
@listeners[:start_array] << block
|
||||
end
|
||||
|
||||
def end_array(&block)
|
||||
@listeners[:end_array] << block
|
||||
end
|
||||
|
||||
def key(&block)
|
||||
@listeners[:key] << block
|
||||
end
|
||||
|
||||
def value(&block)
|
||||
@listeners[:value] << block
|
||||
end
|
||||
|
||||
# Pass data into the parser to advance the state machine and
|
||||
# generate callback events. This is well suited for an EventMachine
|
||||
# receive_data loop.
|
||||
#
|
||||
# data - The String of partial JSON data to parse.
|
||||
#
|
||||
# Raises a JSON::Stream::ParserError if the JSON data is malformed.
|
||||
#
|
||||
# Returns nothing.
|
||||
def <<(data)
|
||||
(@utf8 << data).each_char do |ch|
|
||||
@pos += 1
|
||||
case @state
|
||||
when :start_document
|
||||
start_value(ch)
|
||||
when :start_object
|
||||
case ch
|
||||
when QUOTE
|
||||
@state = :start_string
|
||||
@stack.push(:key)
|
||||
when RIGHT_BRACE
|
||||
end_container(:object)
|
||||
when WS
|
||||
# ignore
|
||||
else
|
||||
error("Expected object key start")
|
||||
end
|
||||
when :start_string
|
||||
case ch
|
||||
when QUOTE
|
||||
if @stack.pop == :string
|
||||
end_value(@buf)
|
||||
else # :key
|
||||
@state = :end_key
|
||||
notify(:key, @buf)
|
||||
end
|
||||
@buf = +""
|
||||
when BACKSLASH
|
||||
@state = :start_escape
|
||||
when CONTROL
|
||||
error("Control characters must be escaped")
|
||||
else
|
||||
@buf << ch
|
||||
end
|
||||
when :start_escape
|
||||
case ch
|
||||
when QUOTE, BACKSLASH, SLASH
|
||||
@buf << ch
|
||||
@state = :start_string
|
||||
when B
|
||||
@buf << "\b"
|
||||
@state = :start_string
|
||||
when F
|
||||
@buf << "\f"
|
||||
@state = :start_string
|
||||
when N
|
||||
@buf << "\n"
|
||||
@state = :start_string
|
||||
when R
|
||||
@buf << "\r"
|
||||
@state = :start_string
|
||||
when T
|
||||
@buf << "\t"
|
||||
@state = :start_string
|
||||
when U
|
||||
@state = :unicode_escape
|
||||
else
|
||||
error("Expected escaped character")
|
||||
end
|
||||
when :unicode_escape
|
||||
case ch
|
||||
when HEX
|
||||
@unicode << ch
|
||||
if @unicode.size == 4
|
||||
codepoint = @unicode.slice!(0, 4).hex
|
||||
if codepoint >= 0xD800 && codepoint <= 0xDBFF
|
||||
error("Expected low surrogate pair half") if @stack[-1].is_a?(Integer)
|
||||
@state = :start_surrogate_pair
|
||||
@stack.push(codepoint)
|
||||
elsif codepoint >= 0xDC00 && codepoint <= 0xDFFF
|
||||
high = @stack.pop
|
||||
error("Expected high surrogate pair half") unless high.is_a?(Integer)
|
||||
pair = ((high - 0xD800) * 0x400) + (codepoint - 0xDC00) + 0x10000
|
||||
@buf << pair
|
||||
@state = :start_string
|
||||
else
|
||||
@buf << codepoint
|
||||
@state = :start_string
|
||||
end
|
||||
end
|
||||
else
|
||||
error("Expected unicode escape hex digit")
|
||||
end
|
||||
when :start_surrogate_pair
|
||||
case ch
|
||||
when BACKSLASH
|
||||
@state = :start_surrogate_pair_u
|
||||
else
|
||||
error("Expected low surrogate pair half")
|
||||
end
|
||||
when :start_surrogate_pair_u
|
||||
case ch
|
||||
when U
|
||||
@state = :unicode_escape
|
||||
else
|
||||
error("Expected low surrogate pair half")
|
||||
end
|
||||
when :start_negative_number
|
||||
case ch
|
||||
when ZERO
|
||||
@state = :start_zero
|
||||
@buf << ch
|
||||
when DIGIT_1_9
|
||||
@state = :start_int
|
||||
@buf << ch
|
||||
else
|
||||
error("Expected 0-9 digit")
|
||||
end
|
||||
when :start_zero
|
||||
case ch
|
||||
when POINT
|
||||
@state = :start_float
|
||||
@buf << ch
|
||||
when EXPONENT
|
||||
@state = :start_exponent
|
||||
@buf << ch
|
||||
else
|
||||
end_value(@buf.to_i)
|
||||
@buf = +""
|
||||
@pos -= 1
|
||||
redo
|
||||
end
|
||||
when :start_float
|
||||
case ch
|
||||
when DIGIT
|
||||
@state = :in_float
|
||||
@buf << ch
|
||||
else
|
||||
error("Expected 0-9 digit")
|
||||
end
|
||||
when :in_float
|
||||
case ch
|
||||
when DIGIT
|
||||
@buf << ch
|
||||
when EXPONENT
|
||||
@state = :start_exponent
|
||||
@buf << ch
|
||||
else
|
||||
end_value(@buf.to_f)
|
||||
@buf = +""
|
||||
@pos -= 1
|
||||
redo
|
||||
end
|
||||
when :start_exponent
|
||||
case ch
|
||||
when MINUS, PLUS, DIGIT
|
||||
@state = :in_exponent
|
||||
@buf << ch
|
||||
else
|
||||
error("Expected +, -, or 0-9 digit")
|
||||
end
|
||||
when :in_exponent
|
||||
case ch
|
||||
when DIGIT
|
||||
@buf << ch
|
||||
else
|
||||
error("Expected 0-9 digit") unless @buf =~ DIGIT_END
|
||||
end_value(@buf.to_f)
|
||||
@buf = +""
|
||||
@pos -= 1
|
||||
redo
|
||||
end
|
||||
when :start_int
|
||||
case ch
|
||||
when DIGIT
|
||||
@buf << ch
|
||||
when POINT
|
||||
@state = :start_float
|
||||
@buf << ch
|
||||
when EXPONENT
|
||||
@state = :start_exponent
|
||||
@buf << ch
|
||||
else
|
||||
end_value(@buf.to_i)
|
||||
@buf = +""
|
||||
@pos -= 1
|
||||
redo
|
||||
end
|
||||
when :start_true
|
||||
keyword(TRUE_KEYWORD, true, TRUE_RE, ch)
|
||||
when :start_false
|
||||
keyword(FALSE_KEYWORD, false, FALSE_RE, ch)
|
||||
when :start_null
|
||||
keyword(NULL_KEYWORD, nil, NULL_RE, ch)
|
||||
when :end_key
|
||||
case ch
|
||||
when COLON
|
||||
@state = :key_sep
|
||||
when WS
|
||||
# ignore
|
||||
else
|
||||
error("Expected colon key separator")
|
||||
end
|
||||
when :key_sep
|
||||
start_value(ch)
|
||||
when :start_array
|
||||
case ch
|
||||
when RIGHT_BRACKET
|
||||
end_container(:array)
|
||||
when WS
|
||||
# ignore
|
||||
else
|
||||
start_value(ch)
|
||||
end
|
||||
when :end_value
|
||||
case ch
|
||||
when COMMA
|
||||
@state = :value_sep
|
||||
when RIGHT_BRACE
|
||||
end_container(:object)
|
||||
when RIGHT_BRACKET
|
||||
end_container(:array)
|
||||
when WS
|
||||
# ignore
|
||||
else
|
||||
error("Expected comma or object or array close")
|
||||
end
|
||||
when :value_sep
|
||||
if @stack[-1] == :object
|
||||
case ch
|
||||
when QUOTE
|
||||
@state = :start_string
|
||||
@stack.push(:key)
|
||||
when WS
|
||||
# ignore
|
||||
else
|
||||
error("Expected object key start")
|
||||
end
|
||||
else
|
||||
start_value(ch)
|
||||
end
|
||||
when :end_document
|
||||
error("Unexpected data") unless ch =~ WS
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# Drain any remaining buffered characters into the parser to complete
|
||||
# the parsing of the document.
|
||||
#
|
||||
# This is only required when parsing a document containing a single
|
||||
# numeric value, integer or float. The parser has no other way to
|
||||
# detect when it should no longer expect additional characters with
|
||||
# which to complete the parse, so it must be signaled by a call to
|
||||
# this method.
|
||||
#
|
||||
# If you're parsing more typical object or array documents, there's no
|
||||
# need to call `finish` because the parse will complete when the final
|
||||
# closing `]` or `}` character is scanned.
|
||||
#
|
||||
# Raises a JSON::Stream::ParserError if the JSON data is malformed.
|
||||
#
|
||||
# Returns nothing.
|
||||
def finish
|
||||
# Partial multi-byte character waiting for completion bytes.
|
||||
error("Unexpected end-of-file") unless @utf8.empty?
|
||||
|
||||
# Partial array, object, or string.
|
||||
error("Unexpected end-of-file") unless @stack.empty?
|
||||
|
||||
case @state
|
||||
when :end_document
|
||||
# done, do nothing
|
||||
when :in_float
|
||||
end_value(@buf.to_f)
|
||||
when :in_exponent
|
||||
error("Unexpected end-of-file") unless @buf =~ DIGIT_END
|
||||
end_value(@buf.to_f)
|
||||
when :start_zero
|
||||
end_value(@buf.to_i)
|
||||
when :start_int
|
||||
end_value(@buf.to_i)
|
||||
else
|
||||
error("Unexpected end-of-file")
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
# Invoke all registered observer procs for the event type.
|
||||
#
|
||||
# type - The Symbol listener name.
|
||||
# args - The argument list to pass into the observer procs.
|
||||
#
|
||||
# Examples
|
||||
#
|
||||
# # broadcast events for {"answer": 42}
|
||||
# notify(:start_object)
|
||||
# notify(:key, "answer")
|
||||
# notify(:value, 42)
|
||||
# notify(:end_object)
|
||||
#
|
||||
# Returns nothing.
|
||||
def notify(type, *args)
|
||||
@listeners[type].each { |block| block.call(*args) }
|
||||
end
|
||||
|
||||
# Complete an object or array container value type.
|
||||
#
|
||||
# type - The Symbol, :object or :array, of the expected type.
|
||||
#
|
||||
# Raises a JSON::Stream::ParserError if the expected container type
|
||||
# was not completed.
|
||||
#
|
||||
# Returns nothing.
|
||||
def end_container(type)
|
||||
@state = :end_value
|
||||
if @stack.pop == type
|
||||
case type
|
||||
when :object
|
||||
notify(:end_object)
|
||||
when :array
|
||||
notify(:end_array)
|
||||
end
|
||||
else
|
||||
error("Expected end of #{type}")
|
||||
end
|
||||
notify_end_document if @stack.empty?
|
||||
end
|
||||
|
||||
# Broadcast an `end_document` event to observers after a complete JSON
|
||||
# value document (object, array, number, string, true, false, null) has
|
||||
# been parsed from the text. This is the final event sent to observers
|
||||
# and signals the parse has finished.
|
||||
#
|
||||
# Returns nothing.
|
||||
def notify_end_document
|
||||
@state = :end_document
|
||||
notify(:end_document)
|
||||
end
|
||||
|
||||
# Parse one of the three allowed keywords: true, false, null.
|
||||
#
|
||||
# word - The String keyword ('true', 'false', 'null').
|
||||
# value - The Ruby value (true, false, nil).
|
||||
# re - The Regexp of allowed keyword characters.
|
||||
# ch - The current String character being parsed.
|
||||
#
|
||||
# Raises a JSON::Stream::ParserError if the character does not belong
|
||||
# in the expected keyword.
|
||||
#
|
||||
# Returns nothing.
|
||||
def keyword(word, value, re, ch)
|
||||
if ch =~ re
|
||||
@buf << ch
|
||||
else
|
||||
error("Expected #{word} keyword")
|
||||
end
|
||||
|
||||
if @buf.size == word.size
|
||||
if @buf == word
|
||||
@buf = +""
|
||||
end_value(value)
|
||||
else
|
||||
error("Expected #{word} keyword")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# Process the first character of one of the seven possible JSON
|
||||
# values: object, array, string, true, false, null, number.
|
||||
#
|
||||
# ch - The current character String.
|
||||
#
|
||||
# Raises a JSON::Stream::ParserError if the character does not signal
|
||||
# the start of a value.
|
||||
#
|
||||
# Returns nothing.
|
||||
def start_value(ch)
|
||||
case ch
|
||||
when LEFT_BRACE
|
||||
notify(:start_document) if @stack.empty?
|
||||
@state = :start_object
|
||||
@stack.push(:object)
|
||||
notify(:start_object)
|
||||
when LEFT_BRACKET
|
||||
notify(:start_document) if @stack.empty?
|
||||
@state = :start_array
|
||||
@stack.push(:array)
|
||||
notify(:start_array)
|
||||
when QUOTE
|
||||
@state = :start_string
|
||||
@stack.push(:string)
|
||||
when T
|
||||
@state = :start_true
|
||||
@buf << ch
|
||||
when F
|
||||
@state = :start_false
|
||||
@buf << ch
|
||||
when N
|
||||
@state = :start_null
|
||||
@buf << ch
|
||||
when MINUS
|
||||
@state = :start_negative_number
|
||||
@buf << ch
|
||||
when ZERO
|
||||
@state = :start_zero
|
||||
@buf << ch
|
||||
when DIGIT_1_9
|
||||
@state = :start_int
|
||||
@buf << ch
|
||||
when WS
|
||||
# ignore
|
||||
else
|
||||
error("Expected value")
|
||||
end
|
||||
end
|
||||
|
||||
# Advance the state machine and notify `value` observers that a
|
||||
# string, number or keyword (true, false, null) value was parsed.
|
||||
#
|
||||
# value - The object to broadcast to observers.
|
||||
#
|
||||
# Returns nothing.
|
||||
def end_value(value)
|
||||
@state = :end_value
|
||||
notify(:start_document) if @stack.empty?
|
||||
notify(:value, value)
|
||||
notify_end_document if @stack.empty?
|
||||
end
|
||||
|
||||
def error(message)
|
||||
raise ParserError, "#{message}: char #{@pos}"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -106,14 +106,72 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
|
||||
result = +""
|
||||
body = body.scan(/.*\n/)
|
||||
progress = []
|
||||
EndpointMock.with_chunk_array_support do
|
||||
stub_request(:post, url).to_return(status: 200, body: body)
|
||||
|
||||
llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial|
|
||||
result << partial
|
||||
end
|
||||
tool_progress = proc { |partial| progress << partial.deep_dup }
|
||||
|
||||
llm.generate(
|
||||
prompt_with_google_tool,
|
||||
user: Discourse.system_user,
|
||||
tool_progress: tool_progress,
|
||||
) { |partial| result << partial }
|
||||
end
|
||||
|
||||
expected_progress = [
|
||||
{
|
||||
name: "search",
|
||||
id: "toolu_01DjrShFRRHp9SnHYRFRc53F",
|
||||
key: "search_query",
|
||||
value: "s",
|
||||
done: false,
|
||||
},
|
||||
{
|
||||
name: "search",
|
||||
id: "toolu_01DjrShFRRHp9SnHYRFRc53F",
|
||||
key: "search_query",
|
||||
value: "s<a>m",
|
||||
done: false,
|
||||
},
|
||||
{
|
||||
name: "search",
|
||||
id: "toolu_01DjrShFRRHp9SnHYRFRc53F",
|
||||
key: "search_query",
|
||||
value: "s<a>m ",
|
||||
done: false,
|
||||
},
|
||||
{
|
||||
name: "search",
|
||||
id: "toolu_01DjrShFRRHp9SnHYRFRc53F",
|
||||
key: "search_query",
|
||||
value: "s<a>m sam",
|
||||
},
|
||||
{
|
||||
name: "search",
|
||||
id: "toolu_01DjrShFRRHp9SnHYRFRc53F",
|
||||
key: "search_query",
|
||||
value: "cate",
|
||||
done: false,
|
||||
},
|
||||
{
|
||||
name: "search",
|
||||
id: "toolu_01DjrShFRRHp9SnHYRFRc53F",
|
||||
key: "search_query",
|
||||
value: "category",
|
||||
done: false,
|
||||
},
|
||||
{
|
||||
name: "search",
|
||||
id: "toolu_01DjrShFRRHp9SnHYRFRc53F",
|
||||
key: "category",
|
||||
value: "gene",
|
||||
done: false,
|
||||
},
|
||||
{ name: "search", id: "toolu_01DjrShFRRHp9SnHYRFRc53F", key: "category", value: "general" },
|
||||
]
|
||||
expect(progress).to eq(expected_progress)
|
||||
|
||||
expected = (<<~TEXT).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
|
|
Loading…
Reference in New Issue