FEATURE: partial tool call support for OpenAI and Anthropic (#908)
Implement streaming tool call implementation for Anthropic and Open AI. When calling: llm.generate(..., partial_tool_calls: true) do ... Partials may contain ToolCall instances with partial: true, These tool calls are partially populated with json partially parsed. So for example when performing a search you may get: ToolCall(..., {search: "hello" }) ToolCall(..., {search: "hello world" }) The library used to parse json is: https://github.com/dgraham/json-stream We use a fork cause we need access to the internal buffer. This prepares internals to perform partial tool calls, but does not implement it yet.
This commit is contained in:
parent
f75b13c4fa
commit
823e8ef490
|
@ -185,20 +185,23 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def invoke_tool(tool, llm, cancel, context, &update_blk)
|
||||
update_blk.call("", cancel, build_placeholder(tool.summary, ""))
|
||||
show_placeholder = !context[:skip_tool_details]
|
||||
|
||||
update_blk.call("", cancel, build_placeholder(tool.summary, "")) if show_placeholder
|
||||
|
||||
result =
|
||||
tool.invoke do |progress|
|
||||
placeholder = build_placeholder(tool.summary, progress)
|
||||
update_blk.call("", cancel, placeholder)
|
||||
if show_placeholder
|
||||
placeholder = build_placeholder(tool.summary, progress)
|
||||
update_blk.call("", cancel, placeholder)
|
||||
end
|
||||
end
|
||||
|
||||
tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw)
|
||||
|
||||
if context[:skip_tool_details] && tool.custom_raw.present?
|
||||
update_blk.call(tool.custom_raw, cancel, nil, :custom_raw)
|
||||
elsif !context[:skip_tool_details]
|
||||
if show_placeholder
|
||||
tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw)
|
||||
update_blk.call(tool_details, cancel, nil, :tool_details)
|
||||
elsif tool.custom_raw.present?
|
||||
update_blk.call(tool.custom_raw, cancel, nil, :custom_raw)
|
||||
end
|
||||
|
||||
result
|
||||
|
|
|
@ -452,7 +452,7 @@ module DiscourseAi
|
|||
bot.reply(context) do |partial, cancel, placeholder, type|
|
||||
reply << partial
|
||||
raw = reply.dup
|
||||
raw << "\n\n" << placeholder if placeholder.present? && !context[:skip_tool_details]
|
||||
raw << "\n\n" << placeholder if placeholder.present?
|
||||
|
||||
blk.call(partial) if blk && type != :tool_details
|
||||
|
||||
|
|
|
@ -4,28 +4,50 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
|
|||
class AnthropicToolCall
|
||||
attr_reader :name, :raw_json, :id
|
||||
|
||||
def initialize(name, id)
|
||||
def initialize(name, id, partial_tool_calls: false)
|
||||
@name = name
|
||||
@id = id
|
||||
@raw_json = +""
|
||||
@tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {})
|
||||
@streaming_parser =
|
||||
DiscourseAi::Completions::ToolCallProgressTracker.new(self) if partial_tool_calls
|
||||
end
|
||||
|
||||
def append(json)
|
||||
@raw_json << json
|
||||
@streaming_parser << json if @streaming_parser
|
||||
end
|
||||
|
||||
def notify_progress(key, value)
|
||||
@tool_call.partial = true
|
||||
@tool_call.parameters[key.to_sym] = value
|
||||
@has_new_data = true
|
||||
end
|
||||
|
||||
def has_partial?
|
||||
@has_new_data
|
||||
end
|
||||
|
||||
def partial_tool_call
|
||||
@has_new_data = false
|
||||
@tool_call
|
||||
end
|
||||
|
||||
def to_tool_call
|
||||
parameters = JSON.parse(raw_json, symbolize_names: true)
|
||||
DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters)
|
||||
@tool_call.partial = false
|
||||
@tool_call.parameters = parameters
|
||||
@tool_call
|
||||
end
|
||||
end
|
||||
|
||||
attr_reader :tool_calls, :input_tokens, :output_tokens
|
||||
|
||||
def initialize(streaming_mode:)
|
||||
def initialize(streaming_mode:, partial_tool_calls: false)
|
||||
@streaming_mode = streaming_mode
|
||||
@tool_calls = []
|
||||
@current_tool_call = nil
|
||||
@partial_tool_calls = partial_tool_calls
|
||||
end
|
||||
|
||||
def to_tool_calls
|
||||
|
@ -38,11 +60,17 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
|
|||
tool_name = parsed.dig(:content_block, :name)
|
||||
tool_id = parsed.dig(:content_block, :id)
|
||||
result = @current_tool_call.to_tool_call if @current_tool_call
|
||||
@current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name
|
||||
@current_tool_call =
|
||||
AnthropicToolCall.new(
|
||||
tool_name,
|
||||
tool_id,
|
||||
partial_tool_calls: @partial_tool_calls,
|
||||
) if tool_name
|
||||
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
||||
if @current_tool_call
|
||||
tool_delta = parsed.dig(:delta, :partial_json).to_s
|
||||
@current_tool_call.append(tool_delta)
|
||||
result = @current_tool_call.partial_tool_call if @current_tool_call.has_partial?
|
||||
else
|
||||
result = parsed.dig(:delta, :text).to_s
|
||||
end
|
||||
|
|
|
@ -107,7 +107,10 @@ module DiscourseAi
|
|||
|
||||
def processor
|
||||
@processor ||=
|
||||
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
|
||||
DiscourseAi::Completions::AnthropicMessageProcessor.new(
|
||||
streaming_mode: @streaming_mode,
|
||||
partial_tool_calls: partial_tool_calls,
|
||||
)
|
||||
end
|
||||
|
||||
def has_tool?(_response_data)
|
||||
|
|
|
@ -4,6 +4,8 @@ module DiscourseAi
|
|||
module Completions
|
||||
module Endpoints
|
||||
class Base
|
||||
attr_reader :partial_tool_calls
|
||||
|
||||
CompletionFailed = Class.new(StandardError)
|
||||
TIMEOUT = 60
|
||||
|
||||
|
@ -58,8 +60,10 @@ module DiscourseAi
|
|||
model_params = {},
|
||||
feature_name: nil,
|
||||
feature_context: nil,
|
||||
partial_tool_calls: false,
|
||||
&blk
|
||||
)
|
||||
@partial_tool_calls = partial_tool_calls
|
||||
model_params = normalize_model_params(model_params)
|
||||
orig_blk = blk
|
||||
|
||||
|
|
|
@ -28,7 +28,8 @@ module DiscourseAi
|
|||
_user,
|
||||
_model_params,
|
||||
feature_name: nil,
|
||||
feature_context: nil
|
||||
feature_context: nil,
|
||||
partial_tool_calls: false
|
||||
)
|
||||
@dialect = dialect
|
||||
response = responses[completions]
|
||||
|
|
|
@ -120,7 +120,8 @@ module DiscourseAi
|
|||
user,
|
||||
model_params = {},
|
||||
feature_name: nil,
|
||||
feature_context: nil
|
||||
feature_context: nil,
|
||||
partial_tool_calls: false
|
||||
)
|
||||
last_call = { dialect: dialect, user: user, model_params: model_params }
|
||||
self.class.last_call = last_call
|
||||
|
|
|
@ -33,6 +33,7 @@ module DiscourseAi
|
|||
model_params = {},
|
||||
feature_name: nil,
|
||||
feature_context: nil,
|
||||
partial_tool_calls: false,
|
||||
&blk
|
||||
)
|
||||
if dialect.respond_to?(:is_gpt_o?) && dialect.is_gpt_o? && block_given?
|
||||
|
@ -103,10 +104,16 @@ module DiscourseAi
|
|||
|
||||
def decode_chunk(chunk)
|
||||
@decoder ||= JsonStreamDecoder.new
|
||||
(@decoder << chunk)
|
||||
.map { |parsed_json| processor.process_streamed_message(parsed_json) }
|
||||
.flatten
|
||||
.compact
|
||||
elements =
|
||||
(@decoder << chunk)
|
||||
.map { |parsed_json| processor.process_streamed_message(parsed_json) }
|
||||
.flatten
|
||||
.compact
|
||||
|
||||
# Remove duplicate partial tool calls
|
||||
# sometimes we stream weird chunks
|
||||
seen_tools = Set.new
|
||||
elements.select { |item| !item.is_a?(ToolCall) || seen_tools.add?(item) }
|
||||
end
|
||||
|
||||
def decode_chunk_finish
|
||||
|
@ -120,7 +127,7 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def processor
|
||||
@processor ||= OpenAiMessageProcessor.new
|
||||
@processor ||= OpenAiMessageProcessor.new(partial_tool_calls: partial_tool_calls)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,667 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
# This code is copied from the MIT licensed json-stream
|
||||
# see: https://github.com/dgraham/json-stream
|
||||
#
|
||||
# It was copied to avoid the dependency and allow us to make some small changes
|
||||
# particularly we need better access to internal state when parsing
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
# Raised on any invalid JSON text.
|
||||
ParserError = Class.new(RuntimeError)
|
||||
|
||||
# A streaming JSON parser that generates SAX-like events for state changes.
|
||||
# Use the json gem for small documents. Use this for huge documents that
|
||||
# won't fit in memory.
|
||||
#
|
||||
# Examples
|
||||
#
|
||||
# parser = JSON::Stream::Parser.new
|
||||
# parser.key { |key| puts key }
|
||||
# parser.value { |value| puts value }
|
||||
# parser << '{"answer":'
|
||||
# parser << ' 42}'
|
||||
class JsonStreamingParser
|
||||
# our changes:
|
||||
attr_reader :state, :buf, :pos
|
||||
|
||||
# A character buffer that expects a UTF-8 encoded stream of bytes.
|
||||
# This handles truncated multi-byte characters properly so we can just
|
||||
# feed it binary data and receive a properly formatted UTF-8 String as
|
||||
# output.
|
||||
#
|
||||
# More UTF-8 parsing details are available at:
|
||||
#
|
||||
# http://en.wikipedia.org/wiki/UTF-8
|
||||
# http://tools.ietf.org/html/rfc3629#section-3
|
||||
class Buffer
|
||||
def initialize
|
||||
@state = :start
|
||||
@buffer = []
|
||||
@need = 0
|
||||
end
|
||||
|
||||
# Fill the buffer with a String of binary UTF-8 encoded bytes. Returns
|
||||
# as much of the data in a UTF-8 String as we have. Truncated multi-byte
|
||||
# characters are saved in the buffer until the next call to this method
|
||||
# where we expect to receive the rest of the multi-byte character.
|
||||
#
|
||||
# data - The partial binary encoded String data.
|
||||
#
|
||||
# Raises JSON::Stream::ParserError if the UTF-8 byte sequence is malformed.
|
||||
#
|
||||
# Returns a UTF-8 encoded String.
|
||||
def <<(data)
|
||||
# Avoid state machine for complete UTF-8.
|
||||
if @buffer.empty?
|
||||
data.force_encoding(Encoding::UTF_8)
|
||||
return data if data.valid_encoding?
|
||||
end
|
||||
|
||||
bytes = []
|
||||
data.each_byte do |byte|
|
||||
case @state
|
||||
when :start
|
||||
if byte < 128
|
||||
bytes << byte
|
||||
elsif byte >= 192
|
||||
@state = :multi_byte
|
||||
@buffer << byte
|
||||
@need =
|
||||
case
|
||||
when byte >= 240
|
||||
4
|
||||
when byte >= 224
|
||||
3
|
||||
when byte >= 192
|
||||
2
|
||||
end
|
||||
else
|
||||
error("Expected start of multi-byte or single byte char")
|
||||
end
|
||||
when :multi_byte
|
||||
if byte > 127 && byte < 192
|
||||
@buffer << byte
|
||||
if @buffer.size == @need
|
||||
bytes += @buffer.slice!(0, @buffer.size)
|
||||
@state = :start
|
||||
end
|
||||
else
|
||||
error("Expected continuation byte")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# Build UTF-8 encoded string from completed codepoints.
|
||||
bytes
|
||||
.pack("C*")
|
||||
.force_encoding(Encoding::UTF_8)
|
||||
.tap { |text| error("Invalid UTF-8 byte sequence") unless text.valid_encoding? }
|
||||
end
|
||||
|
||||
# Determine if the buffer contains partial UTF-8 continuation bytes that
|
||||
# are waiting on subsequent completion bytes before a full codepoint is
|
||||
# formed.
|
||||
#
|
||||
# Examples
|
||||
#
|
||||
# bytes = "é".bytes
|
||||
#
|
||||
# buffer << bytes[0]
|
||||
# buffer.empty?
|
||||
# # => false
|
||||
#
|
||||
# buffer << bytes[1]
|
||||
# buffer.empty?
|
||||
# # => true
|
||||
#
|
||||
# Returns true if the buffer is empty.
|
||||
def empty?
|
||||
@buffer.empty?
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def error(message)
|
||||
raise ParserError, message
|
||||
end
|
||||
end
|
||||
|
||||
BUF_SIZE = 4096
|
||||
CONTROL = /[\x00-\x1F]/
|
||||
WS = /[ \n\t\r]/
|
||||
HEX = /[0-9a-fA-F]/
|
||||
DIGIT = /[0-9]/
|
||||
DIGIT_1_9 = /[1-9]/
|
||||
DIGIT_END = /\d$/
|
||||
TRUE_RE = /[rue]/
|
||||
FALSE_RE = /[alse]/
|
||||
NULL_RE = /[ul]/
|
||||
TRUE_KEYWORD = "true"
|
||||
FALSE_KEYWORD = "false"
|
||||
NULL_KEYWORD = "null"
|
||||
LEFT_BRACE = "{"
|
||||
RIGHT_BRACE = "}"
|
||||
LEFT_BRACKET = "["
|
||||
RIGHT_BRACKET = "]"
|
||||
BACKSLASH = '\\'
|
||||
SLASH = "/"
|
||||
QUOTE = '"'
|
||||
COMMA = ","
|
||||
COLON = ":"
|
||||
ZERO = "0"
|
||||
MINUS = "-"
|
||||
PLUS = "+"
|
||||
POINT = "."
|
||||
EXPONENT = /[eE]/
|
||||
B, F, N, R, T, U = %w[b f n r t u]
|
||||
|
||||
# Create a new parser with an optional initialization block where
|
||||
# we can register event callbacks.
|
||||
#
|
||||
# Examples
|
||||
#
|
||||
# parser = JSON::Stream::Parser.new do
|
||||
# start_document { puts "start document" }
|
||||
# end_document { puts "end document" }
|
||||
# start_object { puts "start object" }
|
||||
# end_object { puts "end object" }
|
||||
# start_array { puts "start array" }
|
||||
# end_array { puts "end array" }
|
||||
# key { |k| puts "key: #{k}" }
|
||||
# value { |v| puts "value: #{v}" }
|
||||
# end
|
||||
def initialize(&block)
|
||||
@state = :start_document
|
||||
@utf8 = Buffer.new
|
||||
@listeners = {
|
||||
start_document: [],
|
||||
end_document: [],
|
||||
start_object: [],
|
||||
end_object: [],
|
||||
start_array: [],
|
||||
end_array: [],
|
||||
key: [],
|
||||
value: [],
|
||||
}
|
||||
|
||||
# Track parse stack.
|
||||
@stack = []
|
||||
@unicode = +""
|
||||
@buf = +""
|
||||
@pos = -1
|
||||
|
||||
# Register any observers in the block.
|
||||
instance_eval(&block) if block_given?
|
||||
end
|
||||
|
||||
def start_document(&block)
|
||||
@listeners[:start_document] << block
|
||||
end
|
||||
|
||||
def end_document(&block)
|
||||
@listeners[:end_document] << block
|
||||
end
|
||||
|
||||
def start_object(&block)
|
||||
@listeners[:start_object] << block
|
||||
end
|
||||
|
||||
def end_object(&block)
|
||||
@listeners[:end_object] << block
|
||||
end
|
||||
|
||||
def start_array(&block)
|
||||
@listeners[:start_array] << block
|
||||
end
|
||||
|
||||
def end_array(&block)
|
||||
@listeners[:end_array] << block
|
||||
end
|
||||
|
||||
def key(&block)
|
||||
@listeners[:key] << block
|
||||
end
|
||||
|
||||
def value(&block)
|
||||
@listeners[:value] << block
|
||||
end
|
||||
|
||||
# Pass data into the parser to advance the state machine and
|
||||
# generate callback events. This is well suited for an EventMachine
|
||||
# receive_data loop.
|
||||
#
|
||||
# data - The String of partial JSON data to parse.
|
||||
#
|
||||
# Raises a JSON::Stream::ParserError if the JSON data is malformed.
|
||||
#
|
||||
# Returns nothing.
|
||||
def <<(data)
|
||||
(@utf8 << data).each_char do |ch|
|
||||
@pos += 1
|
||||
case @state
|
||||
when :start_document
|
||||
start_value(ch)
|
||||
when :start_object
|
||||
case ch
|
||||
when QUOTE
|
||||
@state = :start_string
|
||||
@stack.push(:key)
|
||||
when RIGHT_BRACE
|
||||
end_container(:object)
|
||||
when WS
|
||||
# ignore
|
||||
else
|
||||
error("Expected object key start")
|
||||
end
|
||||
when :start_string
|
||||
case ch
|
||||
when QUOTE
|
||||
if @stack.pop == :string
|
||||
end_value(@buf)
|
||||
else # :key
|
||||
@state = :end_key
|
||||
notify(:key, @buf)
|
||||
end
|
||||
@buf = +""
|
||||
when BACKSLASH
|
||||
@state = :start_escape
|
||||
when CONTROL
|
||||
error("Control characters must be escaped")
|
||||
else
|
||||
@buf << ch
|
||||
end
|
||||
when :start_escape
|
||||
case ch
|
||||
when QUOTE, BACKSLASH, SLASH
|
||||
@buf << ch
|
||||
@state = :start_string
|
||||
when B
|
||||
@buf << "\b"
|
||||
@state = :start_string
|
||||
when F
|
||||
@buf << "\f"
|
||||
@state = :start_string
|
||||
when N
|
||||
@buf << "\n"
|
||||
@state = :start_string
|
||||
when R
|
||||
@buf << "\r"
|
||||
@state = :start_string
|
||||
when T
|
||||
@buf << "\t"
|
||||
@state = :start_string
|
||||
when U
|
||||
@state = :unicode_escape
|
||||
else
|
||||
error("Expected escaped character")
|
||||
end
|
||||
when :unicode_escape
|
||||
case ch
|
||||
when HEX
|
||||
@unicode << ch
|
||||
if @unicode.size == 4
|
||||
codepoint = @unicode.slice!(0, 4).hex
|
||||
if codepoint >= 0xD800 && codepoint <= 0xDBFF
|
||||
error("Expected low surrogate pair half") if @stack[-1].is_a?(Integer)
|
||||
@state = :start_surrogate_pair
|
||||
@stack.push(codepoint)
|
||||
elsif codepoint >= 0xDC00 && codepoint <= 0xDFFF
|
||||
high = @stack.pop
|
||||
error("Expected high surrogate pair half") unless high.is_a?(Integer)
|
||||
pair = ((high - 0xD800) * 0x400) + (codepoint - 0xDC00) + 0x10000
|
||||
@buf << pair
|
||||
@state = :start_string
|
||||
else
|
||||
@buf << codepoint
|
||||
@state = :start_string
|
||||
end
|
||||
end
|
||||
else
|
||||
error("Expected unicode escape hex digit")
|
||||
end
|
||||
when :start_surrogate_pair
|
||||
case ch
|
||||
when BACKSLASH
|
||||
@state = :start_surrogate_pair_u
|
||||
else
|
||||
error("Expected low surrogate pair half")
|
||||
end
|
||||
when :start_surrogate_pair_u
|
||||
case ch
|
||||
when U
|
||||
@state = :unicode_escape
|
||||
else
|
||||
error("Expected low surrogate pair half")
|
||||
end
|
||||
when :start_negative_number
|
||||
case ch
|
||||
when ZERO
|
||||
@state = :start_zero
|
||||
@buf << ch
|
||||
when DIGIT_1_9
|
||||
@state = :start_int
|
||||
@buf << ch
|
||||
else
|
||||
error("Expected 0-9 digit")
|
||||
end
|
||||
when :start_zero
|
||||
case ch
|
||||
when POINT
|
||||
@state = :start_float
|
||||
@buf << ch
|
||||
when EXPONENT
|
||||
@state = :start_exponent
|
||||
@buf << ch
|
||||
else
|
||||
end_value(@buf.to_i)
|
||||
@buf = +""
|
||||
@pos -= 1
|
||||
redo
|
||||
end
|
||||
when :start_float
|
||||
case ch
|
||||
when DIGIT
|
||||
@state = :in_float
|
||||
@buf << ch
|
||||
else
|
||||
error("Expected 0-9 digit")
|
||||
end
|
||||
when :in_float
|
||||
case ch
|
||||
when DIGIT
|
||||
@buf << ch
|
||||
when EXPONENT
|
||||
@state = :start_exponent
|
||||
@buf << ch
|
||||
else
|
||||
end_value(@buf.to_f)
|
||||
@buf = +""
|
||||
@pos -= 1
|
||||
redo
|
||||
end
|
||||
when :start_exponent
|
||||
case ch
|
||||
when MINUS, PLUS, DIGIT
|
||||
@state = :in_exponent
|
||||
@buf << ch
|
||||
else
|
||||
error("Expected +, -, or 0-9 digit")
|
||||
end
|
||||
when :in_exponent
|
||||
case ch
|
||||
when DIGIT
|
||||
@buf << ch
|
||||
else
|
||||
error("Expected 0-9 digit") unless @buf =~ DIGIT_END
|
||||
end_value(@buf.to_f)
|
||||
@buf = +""
|
||||
@pos -= 1
|
||||
redo
|
||||
end
|
||||
when :start_int
|
||||
case ch
|
||||
when DIGIT
|
||||
@buf << ch
|
||||
when POINT
|
||||
@state = :start_float
|
||||
@buf << ch
|
||||
when EXPONENT
|
||||
@state = :start_exponent
|
||||
@buf << ch
|
||||
else
|
||||
end_value(@buf.to_i)
|
||||
@buf = +""
|
||||
@pos -= 1
|
||||
redo
|
||||
end
|
||||
when :start_true
|
||||
keyword(TRUE_KEYWORD, true, TRUE_RE, ch)
|
||||
when :start_false
|
||||
keyword(FALSE_KEYWORD, false, FALSE_RE, ch)
|
||||
when :start_null
|
||||
keyword(NULL_KEYWORD, nil, NULL_RE, ch)
|
||||
when :end_key
|
||||
case ch
|
||||
when COLON
|
||||
@state = :key_sep
|
||||
when WS
|
||||
# ignore
|
||||
else
|
||||
error("Expected colon key separator")
|
||||
end
|
||||
when :key_sep
|
||||
start_value(ch)
|
||||
when :start_array
|
||||
case ch
|
||||
when RIGHT_BRACKET
|
||||
end_container(:array)
|
||||
when WS
|
||||
# ignore
|
||||
else
|
||||
start_value(ch)
|
||||
end
|
||||
when :end_value
|
||||
case ch
|
||||
when COMMA
|
||||
@state = :value_sep
|
||||
when RIGHT_BRACE
|
||||
end_container(:object)
|
||||
when RIGHT_BRACKET
|
||||
end_container(:array)
|
||||
when WS
|
||||
# ignore
|
||||
else
|
||||
error("Expected comma or object or array close")
|
||||
end
|
||||
when :value_sep
|
||||
if @stack[-1] == :object
|
||||
case ch
|
||||
when QUOTE
|
||||
@state = :start_string
|
||||
@stack.push(:key)
|
||||
when WS
|
||||
# ignore
|
||||
else
|
||||
error("Expected object key start")
|
||||
end
|
||||
else
|
||||
start_value(ch)
|
||||
end
|
||||
when :end_document
|
||||
error("Unexpected data") unless ch =~ WS
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# Drain any remaining buffered characters into the parser to complete
|
||||
# the parsing of the document.
|
||||
#
|
||||
# This is only required when parsing a document containing a single
|
||||
# numeric value, integer or float. The parser has no other way to
|
||||
# detect when it should no longer expect additional characters with
|
||||
# which to complete the parse, so it must be signaled by a call to
|
||||
# this method.
|
||||
#
|
||||
# If you're parsing more typical object or array documents, there's no
|
||||
# need to call `finish` because the parse will complete when the final
|
||||
# closing `]` or `}` character is scanned.
|
||||
#
|
||||
# Raises a JSON::Stream::ParserError if the JSON data is malformed.
|
||||
#
|
||||
# Returns nothing.
|
||||
def finish
|
||||
# Partial multi-byte character waiting for completion bytes.
|
||||
error("Unexpected end-of-file") unless @utf8.empty?
|
||||
|
||||
# Partial array, object, or string.
|
||||
error("Unexpected end-of-file") unless @stack.empty?
|
||||
|
||||
case @state
|
||||
when :end_document
|
||||
# done, do nothing
|
||||
when :in_float
|
||||
end_value(@buf.to_f)
|
||||
when :in_exponent
|
||||
error("Unexpected end-of-file") unless @buf =~ DIGIT_END
|
||||
end_value(@buf.to_f)
|
||||
when :start_zero
|
||||
end_value(@buf.to_i)
|
||||
when :start_int
|
||||
end_value(@buf.to_i)
|
||||
else
|
||||
error("Unexpected end-of-file")
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
# Invoke all registered observer procs for the event type.
|
||||
#
|
||||
# type - The Symbol listener name.
|
||||
# args - The argument list to pass into the observer procs.
|
||||
#
|
||||
# Examples
|
||||
#
|
||||
# # broadcast events for {"answer": 42}
|
||||
# notify(:start_object)
|
||||
# notify(:key, "answer")
|
||||
# notify(:value, 42)
|
||||
# notify(:end_object)
|
||||
#
|
||||
# Returns nothing.
|
||||
def notify(type, *args)
|
||||
@listeners[type].each { |block| block.call(*args) }
|
||||
end
|
||||
|
||||
# Complete an object or array container value type.
|
||||
#
|
||||
# type - The Symbol, :object or :array, of the expected type.
|
||||
#
|
||||
# Raises a JSON::Stream::ParserError if the expected container type
|
||||
# was not completed.
|
||||
#
|
||||
# Returns nothing.
|
||||
def end_container(type)
|
||||
@state = :end_value
|
||||
if @stack.pop == type
|
||||
case type
|
||||
when :object
|
||||
notify(:end_object)
|
||||
when :array
|
||||
notify(:end_array)
|
||||
end
|
||||
else
|
||||
error("Expected end of #{type}")
|
||||
end
|
||||
notify_end_document if @stack.empty?
|
||||
end
|
||||
|
||||
# Broadcast an `end_document` event to observers after a complete JSON
|
||||
# value document (object, array, number, string, true, false, null) has
|
||||
# been parsed from the text. This is the final event sent to observers
|
||||
# and signals the parse has finished.
|
||||
#
|
||||
# Returns nothing.
|
||||
def notify_end_document
|
||||
@state = :end_document
|
||||
notify(:end_document)
|
||||
end
|
||||
|
||||
# Parse one of the three allowed keywords: true, false, null.
|
||||
#
|
||||
# word - The String keyword ('true', 'false', 'null').
|
||||
# value - The Ruby value (true, false, nil).
|
||||
# re - The Regexp of allowed keyword characters.
|
||||
# ch - The current String character being parsed.
|
||||
#
|
||||
# Raises a JSON::Stream::ParserError if the character does not belong
|
||||
# in the expected keyword.
|
||||
#
|
||||
# Returns nothing.
|
||||
def keyword(word, value, re, ch)
|
||||
if ch =~ re
|
||||
@buf << ch
|
||||
else
|
||||
error("Expected #{word} keyword")
|
||||
end
|
||||
|
||||
if @buf.size == word.size
|
||||
if @buf == word
|
||||
@buf = +""
|
||||
end_value(value)
|
||||
else
|
||||
error("Expected #{word} keyword")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# Process the first character of one of the seven possible JSON
|
||||
# values: object, array, string, true, false, null, number.
|
||||
#
|
||||
# ch - The current character String.
|
||||
#
|
||||
# Raises a JSON::Stream::ParserError if the character does not signal
|
||||
# the start of a value.
|
||||
#
|
||||
# Returns nothing.
|
||||
def start_value(ch)
|
||||
case ch
|
||||
when LEFT_BRACE
|
||||
notify(:start_document) if @stack.empty?
|
||||
@state = :start_object
|
||||
@stack.push(:object)
|
||||
notify(:start_object)
|
||||
when LEFT_BRACKET
|
||||
notify(:start_document) if @stack.empty?
|
||||
@state = :start_array
|
||||
@stack.push(:array)
|
||||
notify(:start_array)
|
||||
when QUOTE
|
||||
@state = :start_string
|
||||
@stack.push(:string)
|
||||
when T
|
||||
@state = :start_true
|
||||
@buf << ch
|
||||
when F
|
||||
@state = :start_false
|
||||
@buf << ch
|
||||
when N
|
||||
@state = :start_null
|
||||
@buf << ch
|
||||
when MINUS
|
||||
@state = :start_negative_number
|
||||
@buf << ch
|
||||
when ZERO
|
||||
@state = :start_zero
|
||||
@buf << ch
|
||||
when DIGIT_1_9
|
||||
@state = :start_int
|
||||
@buf << ch
|
||||
when WS
|
||||
# ignore
|
||||
else
|
||||
error("Expected value")
|
||||
end
|
||||
end
|
||||
|
||||
# Advance the state machine and notify `value` observers that a
|
||||
# string, number or keyword (true, false, null) value was parsed.
|
||||
#
|
||||
# value - The object to broadcast to observers.
|
||||
#
|
||||
# Returns nothing.
|
||||
def end_value(value)
|
||||
@state = :end_value
|
||||
notify(:start_document) if @stack.empty?
|
||||
notify(:value, value)
|
||||
notify_end_document if @stack.empty?
|
||||
end
|
||||
|
||||
def error(message)
|
||||
raise ParserError, "#{message}: char #{@pos}"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -164,24 +164,18 @@ module DiscourseAi
|
|||
|
||||
# @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object
|
||||
# @param user { User } - User requesting the summary.
|
||||
# @param temperature { Float - Optional } - The temperature to use for the completion.
|
||||
# @param top_p { Float - Optional } - The top_p to use for the completion.
|
||||
# @param max_tokens { Integer - Optional } - The maximum number of tokens to generate.
|
||||
# @param stop_sequences { Array<String> - Optional } - The stop sequences to use for the completion.
|
||||
# @param feature_name { String - Optional } - The feature name to use for the completion.
|
||||
# @param feature_context { Hash - Optional } - The feature context to use for the completion.
|
||||
# @param partial_tool_calls { Boolean - Optional } - If true, the completion will return partial tool calls.
|
||||
#
|
||||
# @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
|
||||
#
|
||||
# @returns { String } - Completion result.
|
||||
#
|
||||
# When the model invokes a tool, we'll wait until the endpoint finishes replying and feed you a fully-formed tool,
|
||||
# even if you passed a partial_read_blk block. Invocations are strings that look like this:
|
||||
#
|
||||
# <function_calls>
|
||||
# <invoke>
|
||||
# <tool_name>get_weather</tool_name>
|
||||
# <tool_id>get_weather</tool_id>
|
||||
# <parameters>
|
||||
# <location>Sydney</location>
|
||||
# <unit>c</unit>
|
||||
# </parameters>
|
||||
# </invoke>
|
||||
# </function_calls>
|
||||
# @returns String | ToolCall - Completion result.
|
||||
# if multiple tools or a tool and a message come back, the result will be an array of ToolCall / String objects.
|
||||
#
|
||||
def generate(
|
||||
prompt,
|
||||
|
@ -192,6 +186,7 @@ module DiscourseAi
|
|||
user:,
|
||||
feature_name: nil,
|
||||
feature_context: nil,
|
||||
partial_tool_calls: false,
|
||||
&partial_read_blk
|
||||
)
|
||||
self.class.record_prompt(prompt)
|
||||
|
@ -226,6 +221,7 @@ module DiscourseAi
|
|||
model_params,
|
||||
feature_name: feature_name,
|
||||
feature_context: feature_context,
|
||||
partial_tool_calls: partial_tool_calls,
|
||||
&partial_read_blk
|
||||
)
|
||||
end
|
||||
|
|
|
@ -3,11 +3,12 @@ module DiscourseAi::Completions
|
|||
class OpenAiMessageProcessor
|
||||
attr_reader :prompt_tokens, :completion_tokens
|
||||
|
||||
def initialize
|
||||
def initialize(partial_tool_calls: false)
|
||||
@tool = nil
|
||||
@tool_arguments = +""
|
||||
@prompt_tokens = nil
|
||||
@completion_tokens = nil
|
||||
@partial_tool_calls = partial_tool_calls
|
||||
end
|
||||
|
||||
def process_message(json)
|
||||
|
@ -57,12 +58,16 @@ module DiscourseAi::Completions
|
|||
if id.present? && name.present?
|
||||
@tool_arguments = +""
|
||||
@tool = ToolCall.new(id: id, name: name)
|
||||
@streaming_parser = ToolCallProgressTracker.new(self) if @partial_tool_calls
|
||||
end
|
||||
|
||||
@tool_arguments << arguments.to_s
|
||||
@streaming_parser << arguments.to_s if @streaming_parser && !arguments.to_s.empty?
|
||||
rval = current_tool_progress if !rval
|
||||
elsif finished_tools && @tool
|
||||
parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
|
||||
@tool.parameters = parsed_args
|
||||
@tool.partial = false
|
||||
rval = @tool
|
||||
@tool = nil
|
||||
elsif !content.to_s.empty?
|
||||
|
@ -75,6 +80,23 @@ module DiscourseAi::Completions
|
|||
rval
|
||||
end
|
||||
|
||||
def notify_progress(key, value)
|
||||
if @tool
|
||||
@tool.partial = true
|
||||
@tool.parameters[key.to_sym] = value
|
||||
@has_new_data = true
|
||||
end
|
||||
end
|
||||
|
||||
def current_tool_progress
|
||||
if @has_new_data
|
||||
@has_new_data = false
|
||||
@tool
|
||||
else
|
||||
nil
|
||||
end
|
||||
end
|
||||
|
||||
def finish
|
||||
rval = []
|
||||
if @tool
|
||||
|
|
|
@ -4,12 +4,14 @@ module DiscourseAi
|
|||
module Completions
|
||||
class ToolCall
|
||||
attr_reader :id, :name, :parameters
|
||||
attr_accessor :partial
|
||||
|
||||
def initialize(id:, name:, parameters: nil)
|
||||
@id = id
|
||||
@name = name
|
||||
self.parameters = parameters if parameters
|
||||
@parameters ||= {}
|
||||
@partial = false
|
||||
end
|
||||
|
||||
def parameters=(parameters)
|
||||
|
@ -24,6 +26,12 @@ module DiscourseAi
|
|||
def to_s
|
||||
"#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)"
|
||||
end
|
||||
|
||||
def dup
|
||||
call = ToolCall.new(id: id, name: name, parameters: parameters.deep_dup)
|
||||
call.partial = partial
|
||||
call
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
class ToolCallProgressTracker
|
||||
attr_reader :current_key, :current_value, :tool_call
|
||||
|
||||
def initialize(tool_call)
|
||||
@tool_call = tool_call
|
||||
@current_key = nil
|
||||
@current_value = nil
|
||||
@parser = DiscourseAi::Completions::JsonStreamingParser.new
|
||||
|
||||
@parser.key do |k|
|
||||
@current_key = k
|
||||
@current_value = nil
|
||||
end
|
||||
|
||||
@parser.value { |v| tool_call.notify_progress(@current_key, v) if @current_key }
|
||||
end
|
||||
|
||||
def <<(json)
|
||||
# llm could send broken json
|
||||
# in that case just deal with it later
|
||||
# don't stream
|
||||
return if @broken
|
||||
|
||||
begin
|
||||
@parser << json
|
||||
rescue DiscourseAi::Completions::ParserError
|
||||
@broken = true
|
||||
return
|
||||
end
|
||||
|
||||
if @parser.state == :start_string && @current_key
|
||||
# this is is worth notifying
|
||||
tool_call.notify_progress(@current_key, @parser.buf)
|
||||
end
|
||||
|
||||
@current_key = nil if @parser.state == :end_value
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -109,9 +109,11 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
EndpointMock.with_chunk_array_support do
|
||||
stub_request(:post, url).to_return(status: 200, body: body)
|
||||
|
||||
llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial|
|
||||
result << partial
|
||||
end
|
||||
llm.generate(
|
||||
prompt_with_google_tool,
|
||||
user: Discourse.system_user,
|
||||
partial_tool_calls: true,
|
||||
) { |partial| result << partial.dup }
|
||||
end
|
||||
|
||||
tool_call =
|
||||
|
@ -124,7 +126,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
},
|
||||
)
|
||||
|
||||
expect(result).to eq([tool_call])
|
||||
expect(result.last).to eq(tool_call)
|
||||
|
||||
search_queries = result.filter(&:partial).map { |r| r.parameters[:search_query] }
|
||||
categories = result.filter(&:partial).map { |r| r.parameters[:category] }
|
||||
|
||||
expect(categories).to eq([nil, nil, nil, nil, "gene", "general"])
|
||||
expect(search_queries).to eq(["s", "s<a>m", "s<a>m ", "s<a>m sam", "s<a>m sam", "s<a>m sam"])
|
||||
end
|
||||
|
||||
it "can stream a response" do
|
||||
|
|
|
@ -571,7 +571,7 @@ TEXT
|
|||
end
|
||||
end
|
||||
|
||||
it "properly handles spaces in tools payload" do
|
||||
it "properly handles spaces in tools payload and partial tool calls" do
|
||||
raw_data = <<~TEXT.strip
|
||||
data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"go|ogle","arg|uments":""}}]}}]}
|
||||
|
||||
|
@ -609,7 +609,9 @@ TEXT
|
|||
partials = []
|
||||
|
||||
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
|
||||
endpoint.perform_completion!(dialect, user) { |partial| partials << partial }
|
||||
endpoint.perform_completion!(dialect, user, partial_tool_calls: true) do |partial|
|
||||
partials << partial.dup
|
||||
end
|
||||
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
|
@ -620,7 +622,10 @@ TEXT
|
|||
},
|
||||
)
|
||||
|
||||
expect(partials).to eq([tool_call])
|
||||
expect(partials.last).to eq(tool_call)
|
||||
|
||||
progress = partials.map { |p| p.parameters[:query] }
|
||||
expect(progress).to eq(["Ad", "Adabas", "Adabas 9.", "Adabas 9.1"])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -462,7 +462,6 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||
expect(crafted_system_prompt).to include("fragment-n14")
|
||||
expect(crafted_system_prompt).to include("fragment-n13")
|
||||
expect(crafted_system_prompt).to include("fragment-n12")
|
||||
|
||||
expect(crafted_system_prompt).not_to include("fragment-n4") # Fragment #11 not included
|
||||
end
|
||||
end
|
||||
|
|
|
@ -253,6 +253,9 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
|||
after { SiteSetting.provider = @original_provider }
|
||||
|
||||
it "returns a 403 error if the user cannot access the secure upload" do
|
||||
# hosted-site plugin edge case, it enables embeddings
|
||||
SiteSetting.ai_embeddings_enabled = false
|
||||
|
||||
create_post(
|
||||
title: "Secure upload post",
|
||||
raw: "This is a new post <img src=\"#{upload.url}\" />",
|
||||
|
|
Loading…
Reference in New Issue