FEATURE: partial tool call support for OpenAI and Anthropic (#908)

Implement streaming tool call implementation for Anthropic and Open AI.

When calling:

llm.generate(..., partial_tool_calls: true) do ...
Partials may contain ToolCall instances with partial: true, These tool calls are partially populated with json partially parsed.

So for example when performing a search you may get:

ToolCall(..., {search: "hello" })
ToolCall(..., {search: "hello world" })

The library used to parse json is:

https://github.com/dgraham/json-stream

We use a fork cause we need access to the internal buffer.

This prepares internals to perform partial tool calls, but does not implement it yet.
This commit is contained in:
Sam 2024-11-14 06:58:24 +11:00 committed by GitHub
parent f75b13c4fa
commit 823e8ef490
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 844 additions and 45 deletions

View File

@ -185,20 +185,23 @@ module DiscourseAi
end end
def invoke_tool(tool, llm, cancel, context, &update_blk) def invoke_tool(tool, llm, cancel, context, &update_blk)
update_blk.call("", cancel, build_placeholder(tool.summary, "")) show_placeholder = !context[:skip_tool_details]
update_blk.call("", cancel, build_placeholder(tool.summary, "")) if show_placeholder
result = result =
tool.invoke do |progress| tool.invoke do |progress|
if show_placeholder
placeholder = build_placeholder(tool.summary, progress) placeholder = build_placeholder(tool.summary, progress)
update_blk.call("", cancel, placeholder) update_blk.call("", cancel, placeholder)
end end
end
if show_placeholder
tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw) tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw)
if context[:skip_tool_details] && tool.custom_raw.present?
update_blk.call(tool.custom_raw, cancel, nil, :custom_raw)
elsif !context[:skip_tool_details]
update_blk.call(tool_details, cancel, nil, :tool_details) update_blk.call(tool_details, cancel, nil, :tool_details)
elsif tool.custom_raw.present?
update_blk.call(tool.custom_raw, cancel, nil, :custom_raw)
end end
result result

View File

@ -452,7 +452,7 @@ module DiscourseAi
bot.reply(context) do |partial, cancel, placeholder, type| bot.reply(context) do |partial, cancel, placeholder, type|
reply << partial reply << partial
raw = reply.dup raw = reply.dup
raw << "\n\n" << placeholder if placeholder.present? && !context[:skip_tool_details] raw << "\n\n" << placeholder if placeholder.present?
blk.call(partial) if blk && type != :tool_details blk.call(partial) if blk && type != :tool_details

View File

@ -4,28 +4,50 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
class AnthropicToolCall class AnthropicToolCall
attr_reader :name, :raw_json, :id attr_reader :name, :raw_json, :id
def initialize(name, id) def initialize(name, id, partial_tool_calls: false)
@name = name @name = name
@id = id @id = id
@raw_json = +"" @raw_json = +""
@tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {})
@streaming_parser =
DiscourseAi::Completions::ToolCallProgressTracker.new(self) if partial_tool_calls
end end
def append(json) def append(json)
@raw_json << json @raw_json << json
@streaming_parser << json if @streaming_parser
end
def notify_progress(key, value)
@tool_call.partial = true
@tool_call.parameters[key.to_sym] = value
@has_new_data = true
end
def has_partial?
@has_new_data
end
def partial_tool_call
@has_new_data = false
@tool_call
end end
def to_tool_call def to_tool_call
parameters = JSON.parse(raw_json, symbolize_names: true) parameters = JSON.parse(raw_json, symbolize_names: true)
DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters) @tool_call.partial = false
@tool_call.parameters = parameters
@tool_call
end end
end end
attr_reader :tool_calls, :input_tokens, :output_tokens attr_reader :tool_calls, :input_tokens, :output_tokens
def initialize(streaming_mode:) def initialize(streaming_mode:, partial_tool_calls: false)
@streaming_mode = streaming_mode @streaming_mode = streaming_mode
@tool_calls = [] @tool_calls = []
@current_tool_call = nil @current_tool_call = nil
@partial_tool_calls = partial_tool_calls
end end
def to_tool_calls def to_tool_calls
@ -38,11 +60,17 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
tool_name = parsed.dig(:content_block, :name) tool_name = parsed.dig(:content_block, :name)
tool_id = parsed.dig(:content_block, :id) tool_id = parsed.dig(:content_block, :id)
result = @current_tool_call.to_tool_call if @current_tool_call result = @current_tool_call.to_tool_call if @current_tool_call
@current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name @current_tool_call =
AnthropicToolCall.new(
tool_name,
tool_id,
partial_tool_calls: @partial_tool_calls,
) if tool_name
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
if @current_tool_call if @current_tool_call
tool_delta = parsed.dig(:delta, :partial_json).to_s tool_delta = parsed.dig(:delta, :partial_json).to_s
@current_tool_call.append(tool_delta) @current_tool_call.append(tool_delta)
result = @current_tool_call.partial_tool_call if @current_tool_call.has_partial?
else else
result = parsed.dig(:delta, :text).to_s result = parsed.dig(:delta, :text).to_s
end end

View File

@ -107,7 +107,10 @@ module DiscourseAi
def processor def processor
@processor ||= @processor ||=
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) DiscourseAi::Completions::AnthropicMessageProcessor.new(
streaming_mode: @streaming_mode,
partial_tool_calls: partial_tool_calls,
)
end end
def has_tool?(_response_data) def has_tool?(_response_data)

View File

@ -4,6 +4,8 @@ module DiscourseAi
module Completions module Completions
module Endpoints module Endpoints
class Base class Base
attr_reader :partial_tool_calls
CompletionFailed = Class.new(StandardError) CompletionFailed = Class.new(StandardError)
TIMEOUT = 60 TIMEOUT = 60
@ -58,8 +60,10 @@ module DiscourseAi
model_params = {}, model_params = {},
feature_name: nil, feature_name: nil,
feature_context: nil, feature_context: nil,
partial_tool_calls: false,
&blk &blk
) )
@partial_tool_calls = partial_tool_calls
model_params = normalize_model_params(model_params) model_params = normalize_model_params(model_params)
orig_blk = blk orig_blk = blk

View File

@ -28,7 +28,8 @@ module DiscourseAi
_user, _user,
_model_params, _model_params,
feature_name: nil, feature_name: nil,
feature_context: nil feature_context: nil,
partial_tool_calls: false
) )
@dialect = dialect @dialect = dialect
response = responses[completions] response = responses[completions]

View File

@ -120,7 +120,8 @@ module DiscourseAi
user, user,
model_params = {}, model_params = {},
feature_name: nil, feature_name: nil,
feature_context: nil feature_context: nil,
partial_tool_calls: false
) )
last_call = { dialect: dialect, user: user, model_params: model_params } last_call = { dialect: dialect, user: user, model_params: model_params }
self.class.last_call = last_call self.class.last_call = last_call

View File

@ -33,6 +33,7 @@ module DiscourseAi
model_params = {}, model_params = {},
feature_name: nil, feature_name: nil,
feature_context: nil, feature_context: nil,
partial_tool_calls: false,
&blk &blk
) )
if dialect.respond_to?(:is_gpt_o?) && dialect.is_gpt_o? && block_given? if dialect.respond_to?(:is_gpt_o?) && dialect.is_gpt_o? && block_given?
@ -103,10 +104,16 @@ module DiscourseAi
def decode_chunk(chunk) def decode_chunk(chunk)
@decoder ||= JsonStreamDecoder.new @decoder ||= JsonStreamDecoder.new
elements =
(@decoder << chunk) (@decoder << chunk)
.map { |parsed_json| processor.process_streamed_message(parsed_json) } .map { |parsed_json| processor.process_streamed_message(parsed_json) }
.flatten .flatten
.compact .compact
# Remove duplicate partial tool calls
# sometimes we stream weird chunks
seen_tools = Set.new
elements.select { |item| !item.is_a?(ToolCall) || seen_tools.add?(item) }
end end
def decode_chunk_finish def decode_chunk_finish
@ -120,7 +127,7 @@ module DiscourseAi
private private
def processor def processor
@processor ||= OpenAiMessageProcessor.new @processor ||= OpenAiMessageProcessor.new(partial_tool_calls: partial_tool_calls)
end end
end end
end end

View File

@ -0,0 +1,667 @@
# frozen_string_literal: true
# This code is copied from the MIT licensed json-stream
# see: https://github.com/dgraham/json-stream
#
# It was copied to avoid the dependency and allow us to make some small changes
# particularly we need better access to internal state when parsing
module DiscourseAi
module Completions
# Raised on any invalid JSON text.
ParserError = Class.new(RuntimeError)
# A streaming JSON parser that generates SAX-like events for state changes.
# Use the json gem for small documents. Use this for huge documents that
# won't fit in memory.
#
# Examples
#
# parser = JSON::Stream::Parser.new
# parser.key { |key| puts key }
# parser.value { |value| puts value }
# parser << '{"answer":'
# parser << ' 42}'
class JsonStreamingParser
# our changes:
attr_reader :state, :buf, :pos
# A character buffer that expects a UTF-8 encoded stream of bytes.
# This handles truncated multi-byte characters properly so we can just
# feed it binary data and receive a properly formatted UTF-8 String as
# output.
#
# More UTF-8 parsing details are available at:
#
# http://en.wikipedia.org/wiki/UTF-8
# http://tools.ietf.org/html/rfc3629#section-3
class Buffer
def initialize
@state = :start
@buffer = []
@need = 0
end
# Fill the buffer with a String of binary UTF-8 encoded bytes. Returns
# as much of the data in a UTF-8 String as we have. Truncated multi-byte
# characters are saved in the buffer until the next call to this method
# where we expect to receive the rest of the multi-byte character.
#
# data - The partial binary encoded String data.
#
# Raises JSON::Stream::ParserError if the UTF-8 byte sequence is malformed.
#
# Returns a UTF-8 encoded String.
def <<(data)
# Avoid state machine for complete UTF-8.
if @buffer.empty?
data.force_encoding(Encoding::UTF_8)
return data if data.valid_encoding?
end
bytes = []
data.each_byte do |byte|
case @state
when :start
if byte < 128
bytes << byte
elsif byte >= 192
@state = :multi_byte
@buffer << byte
@need =
case
when byte >= 240
4
when byte >= 224
3
when byte >= 192
2
end
else
error("Expected start of multi-byte or single byte char")
end
when :multi_byte
if byte > 127 && byte < 192
@buffer << byte
if @buffer.size == @need
bytes += @buffer.slice!(0, @buffer.size)
@state = :start
end
else
error("Expected continuation byte")
end
end
end
# Build UTF-8 encoded string from completed codepoints.
bytes
.pack("C*")
.force_encoding(Encoding::UTF_8)
.tap { |text| error("Invalid UTF-8 byte sequence") unless text.valid_encoding? }
end
# Determine if the buffer contains partial UTF-8 continuation bytes that
# are waiting on subsequent completion bytes before a full codepoint is
# formed.
#
# Examples
#
# bytes = "é".bytes
#
# buffer << bytes[0]
# buffer.empty?
# # => false
#
# buffer << bytes[1]
# buffer.empty?
# # => true
#
# Returns true if the buffer is empty.
def empty?
@buffer.empty?
end
private
def error(message)
raise ParserError, message
end
end
BUF_SIZE = 4096
CONTROL = /[\x00-\x1F]/
WS = /[ \n\t\r]/
HEX = /[0-9a-fA-F]/
DIGIT = /[0-9]/
DIGIT_1_9 = /[1-9]/
DIGIT_END = /\d$/
TRUE_RE = /[rue]/
FALSE_RE = /[alse]/
NULL_RE = /[ul]/
TRUE_KEYWORD = "true"
FALSE_KEYWORD = "false"
NULL_KEYWORD = "null"
LEFT_BRACE = "{"
RIGHT_BRACE = "}"
LEFT_BRACKET = "["
RIGHT_BRACKET = "]"
BACKSLASH = '\\'
SLASH = "/"
QUOTE = '"'
COMMA = ","
COLON = ":"
ZERO = "0"
MINUS = "-"
PLUS = "+"
POINT = "."
EXPONENT = /[eE]/
B, F, N, R, T, U = %w[b f n r t u]
# Create a new parser with an optional initialization block where
# we can register event callbacks.
#
# Examples
#
# parser = JSON::Stream::Parser.new do
# start_document { puts "start document" }
# end_document { puts "end document" }
# start_object { puts "start object" }
# end_object { puts "end object" }
# start_array { puts "start array" }
# end_array { puts "end array" }
# key { |k| puts "key: #{k}" }
# value { |v| puts "value: #{v}" }
# end
def initialize(&block)
@state = :start_document
@utf8 = Buffer.new
@listeners = {
start_document: [],
end_document: [],
start_object: [],
end_object: [],
start_array: [],
end_array: [],
key: [],
value: [],
}
# Track parse stack.
@stack = []
@unicode = +""
@buf = +""
@pos = -1
# Register any observers in the block.
instance_eval(&block) if block_given?
end
def start_document(&block)
@listeners[:start_document] << block
end
def end_document(&block)
@listeners[:end_document] << block
end
def start_object(&block)
@listeners[:start_object] << block
end
def end_object(&block)
@listeners[:end_object] << block
end
def start_array(&block)
@listeners[:start_array] << block
end
def end_array(&block)
@listeners[:end_array] << block
end
def key(&block)
@listeners[:key] << block
end
def value(&block)
@listeners[:value] << block
end
# Pass data into the parser to advance the state machine and
# generate callback events. This is well suited for an EventMachine
# receive_data loop.
#
# data - The String of partial JSON data to parse.
#
# Raises a JSON::Stream::ParserError if the JSON data is malformed.
#
# Returns nothing.
def <<(data)
(@utf8 << data).each_char do |ch|
@pos += 1
case @state
when :start_document
start_value(ch)
when :start_object
case ch
when QUOTE
@state = :start_string
@stack.push(:key)
when RIGHT_BRACE
end_container(:object)
when WS
# ignore
else
error("Expected object key start")
end
when :start_string
case ch
when QUOTE
if @stack.pop == :string
end_value(@buf)
else # :key
@state = :end_key
notify(:key, @buf)
end
@buf = +""
when BACKSLASH
@state = :start_escape
when CONTROL
error("Control characters must be escaped")
else
@buf << ch
end
when :start_escape
case ch
when QUOTE, BACKSLASH, SLASH
@buf << ch
@state = :start_string
when B
@buf << "\b"
@state = :start_string
when F
@buf << "\f"
@state = :start_string
when N
@buf << "\n"
@state = :start_string
when R
@buf << "\r"
@state = :start_string
when T
@buf << "\t"
@state = :start_string
when U
@state = :unicode_escape
else
error("Expected escaped character")
end
when :unicode_escape
case ch
when HEX
@unicode << ch
if @unicode.size == 4
codepoint = @unicode.slice!(0, 4).hex
if codepoint >= 0xD800 && codepoint <= 0xDBFF
error("Expected low surrogate pair half") if @stack[-1].is_a?(Integer)
@state = :start_surrogate_pair
@stack.push(codepoint)
elsif codepoint >= 0xDC00 && codepoint <= 0xDFFF
high = @stack.pop
error("Expected high surrogate pair half") unless high.is_a?(Integer)
pair = ((high - 0xD800) * 0x400) + (codepoint - 0xDC00) + 0x10000
@buf << pair
@state = :start_string
else
@buf << codepoint
@state = :start_string
end
end
else
error("Expected unicode escape hex digit")
end
when :start_surrogate_pair
case ch
when BACKSLASH
@state = :start_surrogate_pair_u
else
error("Expected low surrogate pair half")
end
when :start_surrogate_pair_u
case ch
when U
@state = :unicode_escape
else
error("Expected low surrogate pair half")
end
when :start_negative_number
case ch
when ZERO
@state = :start_zero
@buf << ch
when DIGIT_1_9
@state = :start_int
@buf << ch
else
error("Expected 0-9 digit")
end
when :start_zero
case ch
when POINT
@state = :start_float
@buf << ch
when EXPONENT
@state = :start_exponent
@buf << ch
else
end_value(@buf.to_i)
@buf = +""
@pos -= 1
redo
end
when :start_float
case ch
when DIGIT
@state = :in_float
@buf << ch
else
error("Expected 0-9 digit")
end
when :in_float
case ch
when DIGIT
@buf << ch
when EXPONENT
@state = :start_exponent
@buf << ch
else
end_value(@buf.to_f)
@buf = +""
@pos -= 1
redo
end
when :start_exponent
case ch
when MINUS, PLUS, DIGIT
@state = :in_exponent
@buf << ch
else
error("Expected +, -, or 0-9 digit")
end
when :in_exponent
case ch
when DIGIT
@buf << ch
else
error("Expected 0-9 digit") unless @buf =~ DIGIT_END
end_value(@buf.to_f)
@buf = +""
@pos -= 1
redo
end
when :start_int
case ch
when DIGIT
@buf << ch
when POINT
@state = :start_float
@buf << ch
when EXPONENT
@state = :start_exponent
@buf << ch
else
end_value(@buf.to_i)
@buf = +""
@pos -= 1
redo
end
when :start_true
keyword(TRUE_KEYWORD, true, TRUE_RE, ch)
when :start_false
keyword(FALSE_KEYWORD, false, FALSE_RE, ch)
when :start_null
keyword(NULL_KEYWORD, nil, NULL_RE, ch)
when :end_key
case ch
when COLON
@state = :key_sep
when WS
# ignore
else
error("Expected colon key separator")
end
when :key_sep
start_value(ch)
when :start_array
case ch
when RIGHT_BRACKET
end_container(:array)
when WS
# ignore
else
start_value(ch)
end
when :end_value
case ch
when COMMA
@state = :value_sep
when RIGHT_BRACE
end_container(:object)
when RIGHT_BRACKET
end_container(:array)
when WS
# ignore
else
error("Expected comma or object or array close")
end
when :value_sep
if @stack[-1] == :object
case ch
when QUOTE
@state = :start_string
@stack.push(:key)
when WS
# ignore
else
error("Expected object key start")
end
else
start_value(ch)
end
when :end_document
error("Unexpected data") unless ch =~ WS
end
end
end
# Drain any remaining buffered characters into the parser to complete
# the parsing of the document.
#
# This is only required when parsing a document containing a single
# numeric value, integer or float. The parser has no other way to
# detect when it should no longer expect additional characters with
# which to complete the parse, so it must be signaled by a call to
# this method.
#
# If you're parsing more typical object or array documents, there's no
# need to call `finish` because the parse will complete when the final
# closing `]` or `}` character is scanned.
#
# Raises a JSON::Stream::ParserError if the JSON data is malformed.
#
# Returns nothing.
def finish
# Partial multi-byte character waiting for completion bytes.
error("Unexpected end-of-file") unless @utf8.empty?
# Partial array, object, or string.
error("Unexpected end-of-file") unless @stack.empty?
case @state
when :end_document
# done, do nothing
when :in_float
end_value(@buf.to_f)
when :in_exponent
error("Unexpected end-of-file") unless @buf =~ DIGIT_END
end_value(@buf.to_f)
when :start_zero
end_value(@buf.to_i)
when :start_int
end_value(@buf.to_i)
else
error("Unexpected end-of-file")
end
end
private
# Invoke all registered observer procs for the event type.
#
# type - The Symbol listener name.
# args - The argument list to pass into the observer procs.
#
# Examples
#
# # broadcast events for {"answer": 42}
# notify(:start_object)
# notify(:key, "answer")
# notify(:value, 42)
# notify(:end_object)
#
# Returns nothing.
def notify(type, *args)
@listeners[type].each { |block| block.call(*args) }
end
# Complete an object or array container value type.
#
# type - The Symbol, :object or :array, of the expected type.
#
# Raises a JSON::Stream::ParserError if the expected container type
# was not completed.
#
# Returns nothing.
def end_container(type)
@state = :end_value
if @stack.pop == type
case type
when :object
notify(:end_object)
when :array
notify(:end_array)
end
else
error("Expected end of #{type}")
end
notify_end_document if @stack.empty?
end
# Broadcast an `end_document` event to observers after a complete JSON
# value document (object, array, number, string, true, false, null) has
# been parsed from the text. This is the final event sent to observers
# and signals the parse has finished.
#
# Returns nothing.
def notify_end_document
@state = :end_document
notify(:end_document)
end
# Parse one of the three allowed keywords: true, false, null.
#
# word - The String keyword ('true', 'false', 'null').
# value - The Ruby value (true, false, nil).
# re - The Regexp of allowed keyword characters.
# ch - The current String character being parsed.
#
# Raises a JSON::Stream::ParserError if the character does not belong
# in the expected keyword.
#
# Returns nothing.
def keyword(word, value, re, ch)
if ch =~ re
@buf << ch
else
error("Expected #{word} keyword")
end
if @buf.size == word.size
if @buf == word
@buf = +""
end_value(value)
else
error("Expected #{word} keyword")
end
end
end
# Process the first character of one of the seven possible JSON
# values: object, array, string, true, false, null, number.
#
# ch - The current character String.
#
# Raises a JSON::Stream::ParserError if the character does not signal
# the start of a value.
#
# Returns nothing.
def start_value(ch)
case ch
when LEFT_BRACE
notify(:start_document) if @stack.empty?
@state = :start_object
@stack.push(:object)
notify(:start_object)
when LEFT_BRACKET
notify(:start_document) if @stack.empty?
@state = :start_array
@stack.push(:array)
notify(:start_array)
when QUOTE
@state = :start_string
@stack.push(:string)
when T
@state = :start_true
@buf << ch
when F
@state = :start_false
@buf << ch
when N
@state = :start_null
@buf << ch
when MINUS
@state = :start_negative_number
@buf << ch
when ZERO
@state = :start_zero
@buf << ch
when DIGIT_1_9
@state = :start_int
@buf << ch
when WS
# ignore
else
error("Expected value")
end
end
# Advance the state machine and notify `value` observers that a
# string, number or keyword (true, false, null) value was parsed.
#
# value - The object to broadcast to observers.
#
# Returns nothing.
def end_value(value)
@state = :end_value
notify(:start_document) if @stack.empty?
notify(:value, value)
notify_end_document if @stack.empty?
end
def error(message)
raise ParserError, "#{message}: char #{@pos}"
end
end
end
end

View File

@ -164,24 +164,18 @@ module DiscourseAi
# @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object # @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object
# @param user { User } - User requesting the summary. # @param user { User } - User requesting the summary.
# @param temperature { Float - Optional } - The temperature to use for the completion.
# @param top_p { Float - Optional } - The top_p to use for the completion.
# @param max_tokens { Integer - Optional } - The maximum number of tokens to generate.
# @param stop_sequences { Array<String> - Optional } - The stop sequences to use for the completion.
# @param feature_name { String - Optional } - The feature name to use for the completion.
# @param feature_context { Hash - Optional } - The feature context to use for the completion.
# @param partial_tool_calls { Boolean - Optional } - If true, the completion will return partial tool calls.
# #
# @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function. # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
# #
# @returns { String } - Completion result. # @returns String | ToolCall - Completion result.
# # if multiple tools or a tool and a message come back, the result will be an array of ToolCall / String objects.
# When the model invokes a tool, we'll wait until the endpoint finishes replying and feed you a fully-formed tool,
# even if you passed a partial_read_blk block. Invocations are strings that look like this:
#
# <function_calls>
# <invoke>
# <tool_name>get_weather</tool_name>
# <tool_id>get_weather</tool_id>
# <parameters>
# <location>Sydney</location>
# <unit>c</unit>
# </parameters>
# </invoke>
# </function_calls>
# #
def generate( def generate(
prompt, prompt,
@ -192,6 +186,7 @@ module DiscourseAi
user:, user:,
feature_name: nil, feature_name: nil,
feature_context: nil, feature_context: nil,
partial_tool_calls: false,
&partial_read_blk &partial_read_blk
) )
self.class.record_prompt(prompt) self.class.record_prompt(prompt)
@ -226,6 +221,7 @@ module DiscourseAi
model_params, model_params,
feature_name: feature_name, feature_name: feature_name,
feature_context: feature_context, feature_context: feature_context,
partial_tool_calls: partial_tool_calls,
&partial_read_blk &partial_read_blk
) )
end end

View File

@ -3,11 +3,12 @@ module DiscourseAi::Completions
class OpenAiMessageProcessor class OpenAiMessageProcessor
attr_reader :prompt_tokens, :completion_tokens attr_reader :prompt_tokens, :completion_tokens
def initialize def initialize(partial_tool_calls: false)
@tool = nil @tool = nil
@tool_arguments = +"" @tool_arguments = +""
@prompt_tokens = nil @prompt_tokens = nil
@completion_tokens = nil @completion_tokens = nil
@partial_tool_calls = partial_tool_calls
end end
def process_message(json) def process_message(json)
@ -57,12 +58,16 @@ module DiscourseAi::Completions
if id.present? && name.present? if id.present? && name.present?
@tool_arguments = +"" @tool_arguments = +""
@tool = ToolCall.new(id: id, name: name) @tool = ToolCall.new(id: id, name: name)
@streaming_parser = ToolCallProgressTracker.new(self) if @partial_tool_calls
end end
@tool_arguments << arguments.to_s @tool_arguments << arguments.to_s
@streaming_parser << arguments.to_s if @streaming_parser && !arguments.to_s.empty?
rval = current_tool_progress if !rval
elsif finished_tools && @tool elsif finished_tools && @tool
parsed_args = JSON.parse(@tool_arguments, symbolize_names: true) parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
@tool.parameters = parsed_args @tool.parameters = parsed_args
@tool.partial = false
rval = @tool rval = @tool
@tool = nil @tool = nil
elsif !content.to_s.empty? elsif !content.to_s.empty?
@ -75,6 +80,23 @@ module DiscourseAi::Completions
rval rval
end end
def notify_progress(key, value)
if @tool
@tool.partial = true
@tool.parameters[key.to_sym] = value
@has_new_data = true
end
end
def current_tool_progress
if @has_new_data
@has_new_data = false
@tool
else
nil
end
end
def finish def finish
rval = [] rval = []
if @tool if @tool

View File

@ -4,12 +4,14 @@ module DiscourseAi
module Completions module Completions
class ToolCall class ToolCall
attr_reader :id, :name, :parameters attr_reader :id, :name, :parameters
attr_accessor :partial
def initialize(id:, name:, parameters: nil) def initialize(id:, name:, parameters: nil)
@id = id @id = id
@name = name @name = name
self.parameters = parameters if parameters self.parameters = parameters if parameters
@parameters ||= {} @parameters ||= {}
@partial = false
end end
def parameters=(parameters) def parameters=(parameters)
@ -24,6 +26,12 @@ module DiscourseAi
def to_s def to_s
"#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)" "#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)"
end end
def dup
call = ToolCall.new(id: id, name: name, parameters: parameters.deep_dup)
call.partial = partial
call
end
end end
end end
end end

View File

@ -0,0 +1,44 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
class ToolCallProgressTracker
attr_reader :current_key, :current_value, :tool_call
def initialize(tool_call)
@tool_call = tool_call
@current_key = nil
@current_value = nil
@parser = DiscourseAi::Completions::JsonStreamingParser.new
@parser.key do |k|
@current_key = k
@current_value = nil
end
@parser.value { |v| tool_call.notify_progress(@current_key, v) if @current_key }
end
def <<(json)
# llm could send broken json
# in that case just deal with it later
# don't stream
return if @broken
begin
@parser << json
rescue DiscourseAi::Completions::ParserError
@broken = true
return
end
if @parser.state == :start_string && @current_key
# this is is worth notifying
tool_call.notify_progress(@current_key, @parser.buf)
end
@current_key = nil if @parser.state == :end_value
end
end
end
end

View File

@ -109,9 +109,11 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
EndpointMock.with_chunk_array_support do EndpointMock.with_chunk_array_support do
stub_request(:post, url).to_return(status: 200, body: body) stub_request(:post, url).to_return(status: 200, body: body)
llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial| llm.generate(
result << partial prompt_with_google_tool,
end user: Discourse.system_user,
partial_tool_calls: true,
) { |partial| result << partial.dup }
end end
tool_call = tool_call =
@ -124,7 +126,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
}, },
) )
expect(result).to eq([tool_call]) expect(result.last).to eq(tool_call)
search_queries = result.filter(&:partial).map { |r| r.parameters[:search_query] }
categories = result.filter(&:partial).map { |r| r.parameters[:category] }
expect(categories).to eq([nil, nil, nil, nil, "gene", "general"])
expect(search_queries).to eq(["s", "s<a>m", "s<a>m ", "s<a>m sam", "s<a>m sam", "s<a>m sam"])
end end
it "can stream a response" do it "can stream a response" do

View File

@ -571,7 +571,7 @@ TEXT
end end
end end
it "properly handles spaces in tools payload" do it "properly handles spaces in tools payload and partial tool calls" do
raw_data = <<~TEXT.strip raw_data = <<~TEXT.strip
data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"go|ogle","arg|uments":""}}]}}]} data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"go|ogle","arg|uments":""}}]}}]}
@ -609,7 +609,9 @@ TEXT
partials = [] partials = []
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools)) dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
endpoint.perform_completion!(dialect, user) { |partial| partials << partial } endpoint.perform_completion!(dialect, user, partial_tool_calls: true) do |partial|
partials << partial.dup
end
tool_call = tool_call =
DiscourseAi::Completions::ToolCall.new( DiscourseAi::Completions::ToolCall.new(
@ -620,7 +622,10 @@ TEXT
}, },
) )
expect(partials).to eq([tool_call]) expect(partials.last).to eq(tool_call)
progress = partials.map { |p| p.parameters[:query] }
expect(progress).to eq(["Ad", "Adabas", "Adabas 9.", "Adabas 9.1"])
end end
end end
end end

View File

@ -462,7 +462,6 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
expect(crafted_system_prompt).to include("fragment-n14") expect(crafted_system_prompt).to include("fragment-n14")
expect(crafted_system_prompt).to include("fragment-n13") expect(crafted_system_prompt).to include("fragment-n13")
expect(crafted_system_prompt).to include("fragment-n12") expect(crafted_system_prompt).to include("fragment-n12")
expect(crafted_system_prompt).not_to include("fragment-n4") # Fragment #11 not included expect(crafted_system_prompt).not_to include("fragment-n4") # Fragment #11 not included
end end
end end

View File

@ -253,6 +253,9 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
after { SiteSetting.provider = @original_provider } after { SiteSetting.provider = @original_provider }
it "returns a 403 error if the user cannot access the secure upload" do it "returns a 403 error if the user cannot access the secure upload" do
# hosted-site plugin edge case, it enables embeddings
SiteSetting.ai_embeddings_enabled = false
create_post( create_post(
title: "Secure upload post", title: "Secure upload post",
raw: "This is a new post <img src=\"#{upload.url}\" />", raw: "This is a new post <img src=\"#{upload.url}\" />",