FEATURE: partial tool call support for OpenAI and Anthropic (#908)
Implement streaming tool call implementation for Anthropic and Open AI. When calling: llm.generate(..., partial_tool_calls: true) do ... Partials may contain ToolCall instances with partial: true, These tool calls are partially populated with json partially parsed. So for example when performing a search you may get: ToolCall(..., {search: "hello" }) ToolCall(..., {search: "hello world" }) The library used to parse json is: https://github.com/dgraham/json-stream We use a fork cause we need access to the internal buffer. This prepares internals to perform partial tool calls, but does not implement it yet.
This commit is contained in:
parent
f75b13c4fa
commit
823e8ef490
|
@ -185,20 +185,23 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke_tool(tool, llm, cancel, context, &update_blk)
|
def invoke_tool(tool, llm, cancel, context, &update_blk)
|
||||||
update_blk.call("", cancel, build_placeholder(tool.summary, ""))
|
show_placeholder = !context[:skip_tool_details]
|
||||||
|
|
||||||
|
update_blk.call("", cancel, build_placeholder(tool.summary, "")) if show_placeholder
|
||||||
|
|
||||||
result =
|
result =
|
||||||
tool.invoke do |progress|
|
tool.invoke do |progress|
|
||||||
|
if show_placeholder
|
||||||
placeholder = build_placeholder(tool.summary, progress)
|
placeholder = build_placeholder(tool.summary, progress)
|
||||||
update_blk.call("", cancel, placeholder)
|
update_blk.call("", cancel, placeholder)
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
if show_placeholder
|
||||||
tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw)
|
tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw)
|
||||||
|
|
||||||
if context[:skip_tool_details] && tool.custom_raw.present?
|
|
||||||
update_blk.call(tool.custom_raw, cancel, nil, :custom_raw)
|
|
||||||
elsif !context[:skip_tool_details]
|
|
||||||
update_blk.call(tool_details, cancel, nil, :tool_details)
|
update_blk.call(tool_details, cancel, nil, :tool_details)
|
||||||
|
elsif tool.custom_raw.present?
|
||||||
|
update_blk.call(tool.custom_raw, cancel, nil, :custom_raw)
|
||||||
end
|
end
|
||||||
|
|
||||||
result
|
result
|
||||||
|
|
|
@ -452,7 +452,7 @@ module DiscourseAi
|
||||||
bot.reply(context) do |partial, cancel, placeholder, type|
|
bot.reply(context) do |partial, cancel, placeholder, type|
|
||||||
reply << partial
|
reply << partial
|
||||||
raw = reply.dup
|
raw = reply.dup
|
||||||
raw << "\n\n" << placeholder if placeholder.present? && !context[:skip_tool_details]
|
raw << "\n\n" << placeholder if placeholder.present?
|
||||||
|
|
||||||
blk.call(partial) if blk && type != :tool_details
|
blk.call(partial) if blk && type != :tool_details
|
||||||
|
|
||||||
|
|
|
@ -4,28 +4,50 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
|
||||||
class AnthropicToolCall
|
class AnthropicToolCall
|
||||||
attr_reader :name, :raw_json, :id
|
attr_reader :name, :raw_json, :id
|
||||||
|
|
||||||
def initialize(name, id)
|
def initialize(name, id, partial_tool_calls: false)
|
||||||
@name = name
|
@name = name
|
||||||
@id = id
|
@id = id
|
||||||
@raw_json = +""
|
@raw_json = +""
|
||||||
|
@tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {})
|
||||||
|
@streaming_parser =
|
||||||
|
DiscourseAi::Completions::ToolCallProgressTracker.new(self) if partial_tool_calls
|
||||||
end
|
end
|
||||||
|
|
||||||
def append(json)
|
def append(json)
|
||||||
@raw_json << json
|
@raw_json << json
|
||||||
|
@streaming_parser << json if @streaming_parser
|
||||||
|
end
|
||||||
|
|
||||||
|
def notify_progress(key, value)
|
||||||
|
@tool_call.partial = true
|
||||||
|
@tool_call.parameters[key.to_sym] = value
|
||||||
|
@has_new_data = true
|
||||||
|
end
|
||||||
|
|
||||||
|
def has_partial?
|
||||||
|
@has_new_data
|
||||||
|
end
|
||||||
|
|
||||||
|
def partial_tool_call
|
||||||
|
@has_new_data = false
|
||||||
|
@tool_call
|
||||||
end
|
end
|
||||||
|
|
||||||
def to_tool_call
|
def to_tool_call
|
||||||
parameters = JSON.parse(raw_json, symbolize_names: true)
|
parameters = JSON.parse(raw_json, symbolize_names: true)
|
||||||
DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters)
|
@tool_call.partial = false
|
||||||
|
@tool_call.parameters = parameters
|
||||||
|
@tool_call
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
attr_reader :tool_calls, :input_tokens, :output_tokens
|
attr_reader :tool_calls, :input_tokens, :output_tokens
|
||||||
|
|
||||||
def initialize(streaming_mode:)
|
def initialize(streaming_mode:, partial_tool_calls: false)
|
||||||
@streaming_mode = streaming_mode
|
@streaming_mode = streaming_mode
|
||||||
@tool_calls = []
|
@tool_calls = []
|
||||||
@current_tool_call = nil
|
@current_tool_call = nil
|
||||||
|
@partial_tool_calls = partial_tool_calls
|
||||||
end
|
end
|
||||||
|
|
||||||
def to_tool_calls
|
def to_tool_calls
|
||||||
|
@ -38,11 +60,17 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
|
||||||
tool_name = parsed.dig(:content_block, :name)
|
tool_name = parsed.dig(:content_block, :name)
|
||||||
tool_id = parsed.dig(:content_block, :id)
|
tool_id = parsed.dig(:content_block, :id)
|
||||||
result = @current_tool_call.to_tool_call if @current_tool_call
|
result = @current_tool_call.to_tool_call if @current_tool_call
|
||||||
@current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name
|
@current_tool_call =
|
||||||
|
AnthropicToolCall.new(
|
||||||
|
tool_name,
|
||||||
|
tool_id,
|
||||||
|
partial_tool_calls: @partial_tool_calls,
|
||||||
|
) if tool_name
|
||||||
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
||||||
if @current_tool_call
|
if @current_tool_call
|
||||||
tool_delta = parsed.dig(:delta, :partial_json).to_s
|
tool_delta = parsed.dig(:delta, :partial_json).to_s
|
||||||
@current_tool_call.append(tool_delta)
|
@current_tool_call.append(tool_delta)
|
||||||
|
result = @current_tool_call.partial_tool_call if @current_tool_call.has_partial?
|
||||||
else
|
else
|
||||||
result = parsed.dig(:delta, :text).to_s
|
result = parsed.dig(:delta, :text).to_s
|
||||||
end
|
end
|
||||||
|
|
|
@ -107,7 +107,10 @@ module DiscourseAi
|
||||||
|
|
||||||
def processor
|
def processor
|
||||||
@processor ||=
|
@processor ||=
|
||||||
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
|
DiscourseAi::Completions::AnthropicMessageProcessor.new(
|
||||||
|
streaming_mode: @streaming_mode,
|
||||||
|
partial_tool_calls: partial_tool_calls,
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
def has_tool?(_response_data)
|
def has_tool?(_response_data)
|
||||||
|
|
|
@ -4,6 +4,8 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class Base
|
class Base
|
||||||
|
attr_reader :partial_tool_calls
|
||||||
|
|
||||||
CompletionFailed = Class.new(StandardError)
|
CompletionFailed = Class.new(StandardError)
|
||||||
TIMEOUT = 60
|
TIMEOUT = 60
|
||||||
|
|
||||||
|
@ -58,8 +60,10 @@ module DiscourseAi
|
||||||
model_params = {},
|
model_params = {},
|
||||||
feature_name: nil,
|
feature_name: nil,
|
||||||
feature_context: nil,
|
feature_context: nil,
|
||||||
|
partial_tool_calls: false,
|
||||||
&blk
|
&blk
|
||||||
)
|
)
|
||||||
|
@partial_tool_calls = partial_tool_calls
|
||||||
model_params = normalize_model_params(model_params)
|
model_params = normalize_model_params(model_params)
|
||||||
orig_blk = blk
|
orig_blk = blk
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,8 @@ module DiscourseAi
|
||||||
_user,
|
_user,
|
||||||
_model_params,
|
_model_params,
|
||||||
feature_name: nil,
|
feature_name: nil,
|
||||||
feature_context: nil
|
feature_context: nil,
|
||||||
|
partial_tool_calls: false
|
||||||
)
|
)
|
||||||
@dialect = dialect
|
@dialect = dialect
|
||||||
response = responses[completions]
|
response = responses[completions]
|
||||||
|
|
|
@ -120,7 +120,8 @@ module DiscourseAi
|
||||||
user,
|
user,
|
||||||
model_params = {},
|
model_params = {},
|
||||||
feature_name: nil,
|
feature_name: nil,
|
||||||
feature_context: nil
|
feature_context: nil,
|
||||||
|
partial_tool_calls: false
|
||||||
)
|
)
|
||||||
last_call = { dialect: dialect, user: user, model_params: model_params }
|
last_call = { dialect: dialect, user: user, model_params: model_params }
|
||||||
self.class.last_call = last_call
|
self.class.last_call = last_call
|
||||||
|
|
|
@ -33,6 +33,7 @@ module DiscourseAi
|
||||||
model_params = {},
|
model_params = {},
|
||||||
feature_name: nil,
|
feature_name: nil,
|
||||||
feature_context: nil,
|
feature_context: nil,
|
||||||
|
partial_tool_calls: false,
|
||||||
&blk
|
&blk
|
||||||
)
|
)
|
||||||
if dialect.respond_to?(:is_gpt_o?) && dialect.is_gpt_o? && block_given?
|
if dialect.respond_to?(:is_gpt_o?) && dialect.is_gpt_o? && block_given?
|
||||||
|
@ -103,10 +104,16 @@ module DiscourseAi
|
||||||
|
|
||||||
def decode_chunk(chunk)
|
def decode_chunk(chunk)
|
||||||
@decoder ||= JsonStreamDecoder.new
|
@decoder ||= JsonStreamDecoder.new
|
||||||
|
elements =
|
||||||
(@decoder << chunk)
|
(@decoder << chunk)
|
||||||
.map { |parsed_json| processor.process_streamed_message(parsed_json) }
|
.map { |parsed_json| processor.process_streamed_message(parsed_json) }
|
||||||
.flatten
|
.flatten
|
||||||
.compact
|
.compact
|
||||||
|
|
||||||
|
# Remove duplicate partial tool calls
|
||||||
|
# sometimes we stream weird chunks
|
||||||
|
seen_tools = Set.new
|
||||||
|
elements.select { |item| !item.is_a?(ToolCall) || seen_tools.add?(item) }
|
||||||
end
|
end
|
||||||
|
|
||||||
def decode_chunk_finish
|
def decode_chunk_finish
|
||||||
|
@ -120,7 +127,7 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def processor
|
def processor
|
||||||
@processor ||= OpenAiMessageProcessor.new
|
@processor ||= OpenAiMessageProcessor.new(partial_tool_calls: partial_tool_calls)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,667 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
# This code is copied from the MIT licensed json-stream
|
||||||
|
# see: https://github.com/dgraham/json-stream
|
||||||
|
#
|
||||||
|
# It was copied to avoid the dependency and allow us to make some small changes
|
||||||
|
# particularly we need better access to internal state when parsing
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
# Raised on any invalid JSON text.
|
||||||
|
ParserError = Class.new(RuntimeError)
|
||||||
|
|
||||||
|
# A streaming JSON parser that generates SAX-like events for state changes.
|
||||||
|
# Use the json gem for small documents. Use this for huge documents that
|
||||||
|
# won't fit in memory.
|
||||||
|
#
|
||||||
|
# Examples
|
||||||
|
#
|
||||||
|
# parser = JSON::Stream::Parser.new
|
||||||
|
# parser.key { |key| puts key }
|
||||||
|
# parser.value { |value| puts value }
|
||||||
|
# parser << '{"answer":'
|
||||||
|
# parser << ' 42}'
|
||||||
|
class JsonStreamingParser
|
||||||
|
# our changes:
|
||||||
|
attr_reader :state, :buf, :pos
|
||||||
|
|
||||||
|
# A character buffer that expects a UTF-8 encoded stream of bytes.
|
||||||
|
# This handles truncated multi-byte characters properly so we can just
|
||||||
|
# feed it binary data and receive a properly formatted UTF-8 String as
|
||||||
|
# output.
|
||||||
|
#
|
||||||
|
# More UTF-8 parsing details are available at:
|
||||||
|
#
|
||||||
|
# http://en.wikipedia.org/wiki/UTF-8
|
||||||
|
# http://tools.ietf.org/html/rfc3629#section-3
|
||||||
|
class Buffer
|
||||||
|
def initialize
|
||||||
|
@state = :start
|
||||||
|
@buffer = []
|
||||||
|
@need = 0
|
||||||
|
end
|
||||||
|
|
||||||
|
# Fill the buffer with a String of binary UTF-8 encoded bytes. Returns
|
||||||
|
# as much of the data in a UTF-8 String as we have. Truncated multi-byte
|
||||||
|
# characters are saved in the buffer until the next call to this method
|
||||||
|
# where we expect to receive the rest of the multi-byte character.
|
||||||
|
#
|
||||||
|
# data - The partial binary encoded String data.
|
||||||
|
#
|
||||||
|
# Raises JSON::Stream::ParserError if the UTF-8 byte sequence is malformed.
|
||||||
|
#
|
||||||
|
# Returns a UTF-8 encoded String.
|
||||||
|
def <<(data)
|
||||||
|
# Avoid state machine for complete UTF-8.
|
||||||
|
if @buffer.empty?
|
||||||
|
data.force_encoding(Encoding::UTF_8)
|
||||||
|
return data if data.valid_encoding?
|
||||||
|
end
|
||||||
|
|
||||||
|
bytes = []
|
||||||
|
data.each_byte do |byte|
|
||||||
|
case @state
|
||||||
|
when :start
|
||||||
|
if byte < 128
|
||||||
|
bytes << byte
|
||||||
|
elsif byte >= 192
|
||||||
|
@state = :multi_byte
|
||||||
|
@buffer << byte
|
||||||
|
@need =
|
||||||
|
case
|
||||||
|
when byte >= 240
|
||||||
|
4
|
||||||
|
when byte >= 224
|
||||||
|
3
|
||||||
|
when byte >= 192
|
||||||
|
2
|
||||||
|
end
|
||||||
|
else
|
||||||
|
error("Expected start of multi-byte or single byte char")
|
||||||
|
end
|
||||||
|
when :multi_byte
|
||||||
|
if byte > 127 && byte < 192
|
||||||
|
@buffer << byte
|
||||||
|
if @buffer.size == @need
|
||||||
|
bytes += @buffer.slice!(0, @buffer.size)
|
||||||
|
@state = :start
|
||||||
|
end
|
||||||
|
else
|
||||||
|
error("Expected continuation byte")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# Build UTF-8 encoded string from completed codepoints.
|
||||||
|
bytes
|
||||||
|
.pack("C*")
|
||||||
|
.force_encoding(Encoding::UTF_8)
|
||||||
|
.tap { |text| error("Invalid UTF-8 byte sequence") unless text.valid_encoding? }
|
||||||
|
end
|
||||||
|
|
||||||
|
# Determine if the buffer contains partial UTF-8 continuation bytes that
|
||||||
|
# are waiting on subsequent completion bytes before a full codepoint is
|
||||||
|
# formed.
|
||||||
|
#
|
||||||
|
# Examples
|
||||||
|
#
|
||||||
|
# bytes = "é".bytes
|
||||||
|
#
|
||||||
|
# buffer << bytes[0]
|
||||||
|
# buffer.empty?
|
||||||
|
# # => false
|
||||||
|
#
|
||||||
|
# buffer << bytes[1]
|
||||||
|
# buffer.empty?
|
||||||
|
# # => true
|
||||||
|
#
|
||||||
|
# Returns true if the buffer is empty.
|
||||||
|
def empty?
|
||||||
|
@buffer.empty?
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def error(message)
|
||||||
|
raise ParserError, message
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
BUF_SIZE = 4096
|
||||||
|
CONTROL = /[\x00-\x1F]/
|
||||||
|
WS = /[ \n\t\r]/
|
||||||
|
HEX = /[0-9a-fA-F]/
|
||||||
|
DIGIT = /[0-9]/
|
||||||
|
DIGIT_1_9 = /[1-9]/
|
||||||
|
DIGIT_END = /\d$/
|
||||||
|
TRUE_RE = /[rue]/
|
||||||
|
FALSE_RE = /[alse]/
|
||||||
|
NULL_RE = /[ul]/
|
||||||
|
TRUE_KEYWORD = "true"
|
||||||
|
FALSE_KEYWORD = "false"
|
||||||
|
NULL_KEYWORD = "null"
|
||||||
|
LEFT_BRACE = "{"
|
||||||
|
RIGHT_BRACE = "}"
|
||||||
|
LEFT_BRACKET = "["
|
||||||
|
RIGHT_BRACKET = "]"
|
||||||
|
BACKSLASH = '\\'
|
||||||
|
SLASH = "/"
|
||||||
|
QUOTE = '"'
|
||||||
|
COMMA = ","
|
||||||
|
COLON = ":"
|
||||||
|
ZERO = "0"
|
||||||
|
MINUS = "-"
|
||||||
|
PLUS = "+"
|
||||||
|
POINT = "."
|
||||||
|
EXPONENT = /[eE]/
|
||||||
|
B, F, N, R, T, U = %w[b f n r t u]
|
||||||
|
|
||||||
|
# Create a new parser with an optional initialization block where
|
||||||
|
# we can register event callbacks.
|
||||||
|
#
|
||||||
|
# Examples
|
||||||
|
#
|
||||||
|
# parser = JSON::Stream::Parser.new do
|
||||||
|
# start_document { puts "start document" }
|
||||||
|
# end_document { puts "end document" }
|
||||||
|
# start_object { puts "start object" }
|
||||||
|
# end_object { puts "end object" }
|
||||||
|
# start_array { puts "start array" }
|
||||||
|
# end_array { puts "end array" }
|
||||||
|
# key { |k| puts "key: #{k}" }
|
||||||
|
# value { |v| puts "value: #{v}" }
|
||||||
|
# end
|
||||||
|
def initialize(&block)
|
||||||
|
@state = :start_document
|
||||||
|
@utf8 = Buffer.new
|
||||||
|
@listeners = {
|
||||||
|
start_document: [],
|
||||||
|
end_document: [],
|
||||||
|
start_object: [],
|
||||||
|
end_object: [],
|
||||||
|
start_array: [],
|
||||||
|
end_array: [],
|
||||||
|
key: [],
|
||||||
|
value: [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Track parse stack.
|
||||||
|
@stack = []
|
||||||
|
@unicode = +""
|
||||||
|
@buf = +""
|
||||||
|
@pos = -1
|
||||||
|
|
||||||
|
# Register any observers in the block.
|
||||||
|
instance_eval(&block) if block_given?
|
||||||
|
end
|
||||||
|
|
||||||
|
def start_document(&block)
|
||||||
|
@listeners[:start_document] << block
|
||||||
|
end
|
||||||
|
|
||||||
|
def end_document(&block)
|
||||||
|
@listeners[:end_document] << block
|
||||||
|
end
|
||||||
|
|
||||||
|
def start_object(&block)
|
||||||
|
@listeners[:start_object] << block
|
||||||
|
end
|
||||||
|
|
||||||
|
def end_object(&block)
|
||||||
|
@listeners[:end_object] << block
|
||||||
|
end
|
||||||
|
|
||||||
|
def start_array(&block)
|
||||||
|
@listeners[:start_array] << block
|
||||||
|
end
|
||||||
|
|
||||||
|
def end_array(&block)
|
||||||
|
@listeners[:end_array] << block
|
||||||
|
end
|
||||||
|
|
||||||
|
def key(&block)
|
||||||
|
@listeners[:key] << block
|
||||||
|
end
|
||||||
|
|
||||||
|
def value(&block)
|
||||||
|
@listeners[:value] << block
|
||||||
|
end
|
||||||
|
|
||||||
|
# Pass data into the parser to advance the state machine and
|
||||||
|
# generate callback events. This is well suited for an EventMachine
|
||||||
|
# receive_data loop.
|
||||||
|
#
|
||||||
|
# data - The String of partial JSON data to parse.
|
||||||
|
#
|
||||||
|
# Raises a JSON::Stream::ParserError if the JSON data is malformed.
|
||||||
|
#
|
||||||
|
# Returns nothing.
|
||||||
|
def <<(data)
|
||||||
|
(@utf8 << data).each_char do |ch|
|
||||||
|
@pos += 1
|
||||||
|
case @state
|
||||||
|
when :start_document
|
||||||
|
start_value(ch)
|
||||||
|
when :start_object
|
||||||
|
case ch
|
||||||
|
when QUOTE
|
||||||
|
@state = :start_string
|
||||||
|
@stack.push(:key)
|
||||||
|
when RIGHT_BRACE
|
||||||
|
end_container(:object)
|
||||||
|
when WS
|
||||||
|
# ignore
|
||||||
|
else
|
||||||
|
error("Expected object key start")
|
||||||
|
end
|
||||||
|
when :start_string
|
||||||
|
case ch
|
||||||
|
when QUOTE
|
||||||
|
if @stack.pop == :string
|
||||||
|
end_value(@buf)
|
||||||
|
else # :key
|
||||||
|
@state = :end_key
|
||||||
|
notify(:key, @buf)
|
||||||
|
end
|
||||||
|
@buf = +""
|
||||||
|
when BACKSLASH
|
||||||
|
@state = :start_escape
|
||||||
|
when CONTROL
|
||||||
|
error("Control characters must be escaped")
|
||||||
|
else
|
||||||
|
@buf << ch
|
||||||
|
end
|
||||||
|
when :start_escape
|
||||||
|
case ch
|
||||||
|
when QUOTE, BACKSLASH, SLASH
|
||||||
|
@buf << ch
|
||||||
|
@state = :start_string
|
||||||
|
when B
|
||||||
|
@buf << "\b"
|
||||||
|
@state = :start_string
|
||||||
|
when F
|
||||||
|
@buf << "\f"
|
||||||
|
@state = :start_string
|
||||||
|
when N
|
||||||
|
@buf << "\n"
|
||||||
|
@state = :start_string
|
||||||
|
when R
|
||||||
|
@buf << "\r"
|
||||||
|
@state = :start_string
|
||||||
|
when T
|
||||||
|
@buf << "\t"
|
||||||
|
@state = :start_string
|
||||||
|
when U
|
||||||
|
@state = :unicode_escape
|
||||||
|
else
|
||||||
|
error("Expected escaped character")
|
||||||
|
end
|
||||||
|
when :unicode_escape
|
||||||
|
case ch
|
||||||
|
when HEX
|
||||||
|
@unicode << ch
|
||||||
|
if @unicode.size == 4
|
||||||
|
codepoint = @unicode.slice!(0, 4).hex
|
||||||
|
if codepoint >= 0xD800 && codepoint <= 0xDBFF
|
||||||
|
error("Expected low surrogate pair half") if @stack[-1].is_a?(Integer)
|
||||||
|
@state = :start_surrogate_pair
|
||||||
|
@stack.push(codepoint)
|
||||||
|
elsif codepoint >= 0xDC00 && codepoint <= 0xDFFF
|
||||||
|
high = @stack.pop
|
||||||
|
error("Expected high surrogate pair half") unless high.is_a?(Integer)
|
||||||
|
pair = ((high - 0xD800) * 0x400) + (codepoint - 0xDC00) + 0x10000
|
||||||
|
@buf << pair
|
||||||
|
@state = :start_string
|
||||||
|
else
|
||||||
|
@buf << codepoint
|
||||||
|
@state = :start_string
|
||||||
|
end
|
||||||
|
end
|
||||||
|
else
|
||||||
|
error("Expected unicode escape hex digit")
|
||||||
|
end
|
||||||
|
when :start_surrogate_pair
|
||||||
|
case ch
|
||||||
|
when BACKSLASH
|
||||||
|
@state = :start_surrogate_pair_u
|
||||||
|
else
|
||||||
|
error("Expected low surrogate pair half")
|
||||||
|
end
|
||||||
|
when :start_surrogate_pair_u
|
||||||
|
case ch
|
||||||
|
when U
|
||||||
|
@state = :unicode_escape
|
||||||
|
else
|
||||||
|
error("Expected low surrogate pair half")
|
||||||
|
end
|
||||||
|
when :start_negative_number
|
||||||
|
case ch
|
||||||
|
when ZERO
|
||||||
|
@state = :start_zero
|
||||||
|
@buf << ch
|
||||||
|
when DIGIT_1_9
|
||||||
|
@state = :start_int
|
||||||
|
@buf << ch
|
||||||
|
else
|
||||||
|
error("Expected 0-9 digit")
|
||||||
|
end
|
||||||
|
when :start_zero
|
||||||
|
case ch
|
||||||
|
when POINT
|
||||||
|
@state = :start_float
|
||||||
|
@buf << ch
|
||||||
|
when EXPONENT
|
||||||
|
@state = :start_exponent
|
||||||
|
@buf << ch
|
||||||
|
else
|
||||||
|
end_value(@buf.to_i)
|
||||||
|
@buf = +""
|
||||||
|
@pos -= 1
|
||||||
|
redo
|
||||||
|
end
|
||||||
|
when :start_float
|
||||||
|
case ch
|
||||||
|
when DIGIT
|
||||||
|
@state = :in_float
|
||||||
|
@buf << ch
|
||||||
|
else
|
||||||
|
error("Expected 0-9 digit")
|
||||||
|
end
|
||||||
|
when :in_float
|
||||||
|
case ch
|
||||||
|
when DIGIT
|
||||||
|
@buf << ch
|
||||||
|
when EXPONENT
|
||||||
|
@state = :start_exponent
|
||||||
|
@buf << ch
|
||||||
|
else
|
||||||
|
end_value(@buf.to_f)
|
||||||
|
@buf = +""
|
||||||
|
@pos -= 1
|
||||||
|
redo
|
||||||
|
end
|
||||||
|
when :start_exponent
|
||||||
|
case ch
|
||||||
|
when MINUS, PLUS, DIGIT
|
||||||
|
@state = :in_exponent
|
||||||
|
@buf << ch
|
||||||
|
else
|
||||||
|
error("Expected +, -, or 0-9 digit")
|
||||||
|
end
|
||||||
|
when :in_exponent
|
||||||
|
case ch
|
||||||
|
when DIGIT
|
||||||
|
@buf << ch
|
||||||
|
else
|
||||||
|
error("Expected 0-9 digit") unless @buf =~ DIGIT_END
|
||||||
|
end_value(@buf.to_f)
|
||||||
|
@buf = +""
|
||||||
|
@pos -= 1
|
||||||
|
redo
|
||||||
|
end
|
||||||
|
when :start_int
|
||||||
|
case ch
|
||||||
|
when DIGIT
|
||||||
|
@buf << ch
|
||||||
|
when POINT
|
||||||
|
@state = :start_float
|
||||||
|
@buf << ch
|
||||||
|
when EXPONENT
|
||||||
|
@state = :start_exponent
|
||||||
|
@buf << ch
|
||||||
|
else
|
||||||
|
end_value(@buf.to_i)
|
||||||
|
@buf = +""
|
||||||
|
@pos -= 1
|
||||||
|
redo
|
||||||
|
end
|
||||||
|
when :start_true
|
||||||
|
keyword(TRUE_KEYWORD, true, TRUE_RE, ch)
|
||||||
|
when :start_false
|
||||||
|
keyword(FALSE_KEYWORD, false, FALSE_RE, ch)
|
||||||
|
when :start_null
|
||||||
|
keyword(NULL_KEYWORD, nil, NULL_RE, ch)
|
||||||
|
when :end_key
|
||||||
|
case ch
|
||||||
|
when COLON
|
||||||
|
@state = :key_sep
|
||||||
|
when WS
|
||||||
|
# ignore
|
||||||
|
else
|
||||||
|
error("Expected colon key separator")
|
||||||
|
end
|
||||||
|
when :key_sep
|
||||||
|
start_value(ch)
|
||||||
|
when :start_array
|
||||||
|
case ch
|
||||||
|
when RIGHT_BRACKET
|
||||||
|
end_container(:array)
|
||||||
|
when WS
|
||||||
|
# ignore
|
||||||
|
else
|
||||||
|
start_value(ch)
|
||||||
|
end
|
||||||
|
when :end_value
|
||||||
|
case ch
|
||||||
|
when COMMA
|
||||||
|
@state = :value_sep
|
||||||
|
when RIGHT_BRACE
|
||||||
|
end_container(:object)
|
||||||
|
when RIGHT_BRACKET
|
||||||
|
end_container(:array)
|
||||||
|
when WS
|
||||||
|
# ignore
|
||||||
|
else
|
||||||
|
error("Expected comma or object or array close")
|
||||||
|
end
|
||||||
|
when :value_sep
|
||||||
|
if @stack[-1] == :object
|
||||||
|
case ch
|
||||||
|
when QUOTE
|
||||||
|
@state = :start_string
|
||||||
|
@stack.push(:key)
|
||||||
|
when WS
|
||||||
|
# ignore
|
||||||
|
else
|
||||||
|
error("Expected object key start")
|
||||||
|
end
|
||||||
|
else
|
||||||
|
start_value(ch)
|
||||||
|
end
|
||||||
|
when :end_document
|
||||||
|
error("Unexpected data") unless ch =~ WS
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# Drain any remaining buffered characters into the parser to complete
|
||||||
|
# the parsing of the document.
|
||||||
|
#
|
||||||
|
# This is only required when parsing a document containing a single
|
||||||
|
# numeric value, integer or float. The parser has no other way to
|
||||||
|
# detect when it should no longer expect additional characters with
|
||||||
|
# which to complete the parse, so it must be signaled by a call to
|
||||||
|
# this method.
|
||||||
|
#
|
||||||
|
# If you're parsing more typical object or array documents, there's no
|
||||||
|
# need to call `finish` because the parse will complete when the final
|
||||||
|
# closing `]` or `}` character is scanned.
|
||||||
|
#
|
||||||
|
# Raises a JSON::Stream::ParserError if the JSON data is malformed.
|
||||||
|
#
|
||||||
|
# Returns nothing.
|
||||||
|
def finish
|
||||||
|
# Partial multi-byte character waiting for completion bytes.
|
||||||
|
error("Unexpected end-of-file") unless @utf8.empty?
|
||||||
|
|
||||||
|
# Partial array, object, or string.
|
||||||
|
error("Unexpected end-of-file") unless @stack.empty?
|
||||||
|
|
||||||
|
case @state
|
||||||
|
when :end_document
|
||||||
|
# done, do nothing
|
||||||
|
when :in_float
|
||||||
|
end_value(@buf.to_f)
|
||||||
|
when :in_exponent
|
||||||
|
error("Unexpected end-of-file") unless @buf =~ DIGIT_END
|
||||||
|
end_value(@buf.to_f)
|
||||||
|
when :start_zero
|
||||||
|
end_value(@buf.to_i)
|
||||||
|
when :start_int
|
||||||
|
end_value(@buf.to_i)
|
||||||
|
else
|
||||||
|
error("Unexpected end-of-file")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
# Invoke all registered observer procs for the event type.
|
||||||
|
#
|
||||||
|
# type - The Symbol listener name.
|
||||||
|
# args - The argument list to pass into the observer procs.
|
||||||
|
#
|
||||||
|
# Examples
|
||||||
|
#
|
||||||
|
# # broadcast events for {"answer": 42}
|
||||||
|
# notify(:start_object)
|
||||||
|
# notify(:key, "answer")
|
||||||
|
# notify(:value, 42)
|
||||||
|
# notify(:end_object)
|
||||||
|
#
|
||||||
|
# Returns nothing.
|
||||||
|
def notify(type, *args)
|
||||||
|
@listeners[type].each { |block| block.call(*args) }
|
||||||
|
end
|
||||||
|
|
||||||
|
# Complete an object or array container value type.
|
||||||
|
#
|
||||||
|
# type - The Symbol, :object or :array, of the expected type.
|
||||||
|
#
|
||||||
|
# Raises a JSON::Stream::ParserError if the expected container type
|
||||||
|
# was not completed.
|
||||||
|
#
|
||||||
|
# Returns nothing.
|
||||||
|
def end_container(type)
|
||||||
|
@state = :end_value
|
||||||
|
if @stack.pop == type
|
||||||
|
case type
|
||||||
|
when :object
|
||||||
|
notify(:end_object)
|
||||||
|
when :array
|
||||||
|
notify(:end_array)
|
||||||
|
end
|
||||||
|
else
|
||||||
|
error("Expected end of #{type}")
|
||||||
|
end
|
||||||
|
notify_end_document if @stack.empty?
|
||||||
|
end
|
||||||
|
|
||||||
|
# Broadcast an `end_document` event to observers after a complete JSON
|
||||||
|
# value document (object, array, number, string, true, false, null) has
|
||||||
|
# been parsed from the text. This is the final event sent to observers
|
||||||
|
# and signals the parse has finished.
|
||||||
|
#
|
||||||
|
# Returns nothing.
|
||||||
|
def notify_end_document
|
||||||
|
@state = :end_document
|
||||||
|
notify(:end_document)
|
||||||
|
end
|
||||||
|
|
||||||
|
# Parse one of the three allowed keywords: true, false, null.
|
||||||
|
#
|
||||||
|
# word - The String keyword ('true', 'false', 'null').
|
||||||
|
# value - The Ruby value (true, false, nil).
|
||||||
|
# re - The Regexp of allowed keyword characters.
|
||||||
|
# ch - The current String character being parsed.
|
||||||
|
#
|
||||||
|
# Raises a JSON::Stream::ParserError if the character does not belong
|
||||||
|
# in the expected keyword.
|
||||||
|
#
|
||||||
|
# Returns nothing.
|
||||||
|
def keyword(word, value, re, ch)
|
||||||
|
if ch =~ re
|
||||||
|
@buf << ch
|
||||||
|
else
|
||||||
|
error("Expected #{word} keyword")
|
||||||
|
end
|
||||||
|
|
||||||
|
if @buf.size == word.size
|
||||||
|
if @buf == word
|
||||||
|
@buf = +""
|
||||||
|
end_value(value)
|
||||||
|
else
|
||||||
|
error("Expected #{word} keyword")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# Process the first character of one of the seven possible JSON
|
||||||
|
# values: object, array, string, true, false, null, number.
|
||||||
|
#
|
||||||
|
# ch - The current character String.
|
||||||
|
#
|
||||||
|
# Raises a JSON::Stream::ParserError if the character does not signal
|
||||||
|
# the start of a value.
|
||||||
|
#
|
||||||
|
# Returns nothing.
|
||||||
|
def start_value(ch)
|
||||||
|
case ch
|
||||||
|
when LEFT_BRACE
|
||||||
|
notify(:start_document) if @stack.empty?
|
||||||
|
@state = :start_object
|
||||||
|
@stack.push(:object)
|
||||||
|
notify(:start_object)
|
||||||
|
when LEFT_BRACKET
|
||||||
|
notify(:start_document) if @stack.empty?
|
||||||
|
@state = :start_array
|
||||||
|
@stack.push(:array)
|
||||||
|
notify(:start_array)
|
||||||
|
when QUOTE
|
||||||
|
@state = :start_string
|
||||||
|
@stack.push(:string)
|
||||||
|
when T
|
||||||
|
@state = :start_true
|
||||||
|
@buf << ch
|
||||||
|
when F
|
||||||
|
@state = :start_false
|
||||||
|
@buf << ch
|
||||||
|
when N
|
||||||
|
@state = :start_null
|
||||||
|
@buf << ch
|
||||||
|
when MINUS
|
||||||
|
@state = :start_negative_number
|
||||||
|
@buf << ch
|
||||||
|
when ZERO
|
||||||
|
@state = :start_zero
|
||||||
|
@buf << ch
|
||||||
|
when DIGIT_1_9
|
||||||
|
@state = :start_int
|
||||||
|
@buf << ch
|
||||||
|
when WS
|
||||||
|
# ignore
|
||||||
|
else
|
||||||
|
error("Expected value")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# Advance the state machine and notify `value` observers that a
|
||||||
|
# string, number or keyword (true, false, null) value was parsed.
|
||||||
|
#
|
||||||
|
# value - The object to broadcast to observers.
|
||||||
|
#
|
||||||
|
# Returns nothing.
|
||||||
|
def end_value(value)
|
||||||
|
@state = :end_value
|
||||||
|
notify(:start_document) if @stack.empty?
|
||||||
|
notify(:value, value)
|
||||||
|
notify_end_document if @stack.empty?
|
||||||
|
end
|
||||||
|
|
||||||
|
def error(message)
|
||||||
|
raise ParserError, "#{message}: char #{@pos}"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -164,24 +164,18 @@ module DiscourseAi
|
||||||
|
|
||||||
# @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object
|
# @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object
|
||||||
# @param user { User } - User requesting the summary.
|
# @param user { User } - User requesting the summary.
|
||||||
|
# @param temperature { Float - Optional } - The temperature to use for the completion.
|
||||||
|
# @param top_p { Float - Optional } - The top_p to use for the completion.
|
||||||
|
# @param max_tokens { Integer - Optional } - The maximum number of tokens to generate.
|
||||||
|
# @param stop_sequences { Array<String> - Optional } - The stop sequences to use for the completion.
|
||||||
|
# @param feature_name { String - Optional } - The feature name to use for the completion.
|
||||||
|
# @param feature_context { Hash - Optional } - The feature context to use for the completion.
|
||||||
|
# @param partial_tool_calls { Boolean - Optional } - If true, the completion will return partial tool calls.
|
||||||
#
|
#
|
||||||
# @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
|
# @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
|
||||||
#
|
#
|
||||||
# @returns { String } - Completion result.
|
# @returns String | ToolCall - Completion result.
|
||||||
#
|
# if multiple tools or a tool and a message come back, the result will be an array of ToolCall / String objects.
|
||||||
# When the model invokes a tool, we'll wait until the endpoint finishes replying and feed you a fully-formed tool,
|
|
||||||
# even if you passed a partial_read_blk block. Invocations are strings that look like this:
|
|
||||||
#
|
|
||||||
# <function_calls>
|
|
||||||
# <invoke>
|
|
||||||
# <tool_name>get_weather</tool_name>
|
|
||||||
# <tool_id>get_weather</tool_id>
|
|
||||||
# <parameters>
|
|
||||||
# <location>Sydney</location>
|
|
||||||
# <unit>c</unit>
|
|
||||||
# </parameters>
|
|
||||||
# </invoke>
|
|
||||||
# </function_calls>
|
|
||||||
#
|
#
|
||||||
def generate(
|
def generate(
|
||||||
prompt,
|
prompt,
|
||||||
|
@ -192,6 +186,7 @@ module DiscourseAi
|
||||||
user:,
|
user:,
|
||||||
feature_name: nil,
|
feature_name: nil,
|
||||||
feature_context: nil,
|
feature_context: nil,
|
||||||
|
partial_tool_calls: false,
|
||||||
&partial_read_blk
|
&partial_read_blk
|
||||||
)
|
)
|
||||||
self.class.record_prompt(prompt)
|
self.class.record_prompt(prompt)
|
||||||
|
@ -226,6 +221,7 @@ module DiscourseAi
|
||||||
model_params,
|
model_params,
|
||||||
feature_name: feature_name,
|
feature_name: feature_name,
|
||||||
feature_context: feature_context,
|
feature_context: feature_context,
|
||||||
|
partial_tool_calls: partial_tool_calls,
|
||||||
&partial_read_blk
|
&partial_read_blk
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,11 +3,12 @@ module DiscourseAi::Completions
|
||||||
class OpenAiMessageProcessor
|
class OpenAiMessageProcessor
|
||||||
attr_reader :prompt_tokens, :completion_tokens
|
attr_reader :prompt_tokens, :completion_tokens
|
||||||
|
|
||||||
def initialize
|
def initialize(partial_tool_calls: false)
|
||||||
@tool = nil
|
@tool = nil
|
||||||
@tool_arguments = +""
|
@tool_arguments = +""
|
||||||
@prompt_tokens = nil
|
@prompt_tokens = nil
|
||||||
@completion_tokens = nil
|
@completion_tokens = nil
|
||||||
|
@partial_tool_calls = partial_tool_calls
|
||||||
end
|
end
|
||||||
|
|
||||||
def process_message(json)
|
def process_message(json)
|
||||||
|
@ -57,12 +58,16 @@ module DiscourseAi::Completions
|
||||||
if id.present? && name.present?
|
if id.present? && name.present?
|
||||||
@tool_arguments = +""
|
@tool_arguments = +""
|
||||||
@tool = ToolCall.new(id: id, name: name)
|
@tool = ToolCall.new(id: id, name: name)
|
||||||
|
@streaming_parser = ToolCallProgressTracker.new(self) if @partial_tool_calls
|
||||||
end
|
end
|
||||||
|
|
||||||
@tool_arguments << arguments.to_s
|
@tool_arguments << arguments.to_s
|
||||||
|
@streaming_parser << arguments.to_s if @streaming_parser && !arguments.to_s.empty?
|
||||||
|
rval = current_tool_progress if !rval
|
||||||
elsif finished_tools && @tool
|
elsif finished_tools && @tool
|
||||||
parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
|
parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
|
||||||
@tool.parameters = parsed_args
|
@tool.parameters = parsed_args
|
||||||
|
@tool.partial = false
|
||||||
rval = @tool
|
rval = @tool
|
||||||
@tool = nil
|
@tool = nil
|
||||||
elsif !content.to_s.empty?
|
elsif !content.to_s.empty?
|
||||||
|
@ -75,6 +80,23 @@ module DiscourseAi::Completions
|
||||||
rval
|
rval
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def notify_progress(key, value)
|
||||||
|
if @tool
|
||||||
|
@tool.partial = true
|
||||||
|
@tool.parameters[key.to_sym] = value
|
||||||
|
@has_new_data = true
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def current_tool_progress
|
||||||
|
if @has_new_data
|
||||||
|
@has_new_data = false
|
||||||
|
@tool
|
||||||
|
else
|
||||||
|
nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
def finish
|
def finish
|
||||||
rval = []
|
rval = []
|
||||||
if @tool
|
if @tool
|
||||||
|
|
|
@ -4,12 +4,14 @@ module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
class ToolCall
|
class ToolCall
|
||||||
attr_reader :id, :name, :parameters
|
attr_reader :id, :name, :parameters
|
||||||
|
attr_accessor :partial
|
||||||
|
|
||||||
def initialize(id:, name:, parameters: nil)
|
def initialize(id:, name:, parameters: nil)
|
||||||
@id = id
|
@id = id
|
||||||
@name = name
|
@name = name
|
||||||
self.parameters = parameters if parameters
|
self.parameters = parameters if parameters
|
||||||
@parameters ||= {}
|
@parameters ||= {}
|
||||||
|
@partial = false
|
||||||
end
|
end
|
||||||
|
|
||||||
def parameters=(parameters)
|
def parameters=(parameters)
|
||||||
|
@ -24,6 +26,12 @@ module DiscourseAi
|
||||||
def to_s
|
def to_s
|
||||||
"#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)"
|
"#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def dup
|
||||||
|
call = ToolCall.new(id: id, name: name, parameters: parameters.deep_dup)
|
||||||
|
call.partial = partial
|
||||||
|
call
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
class ToolCallProgressTracker
|
||||||
|
attr_reader :current_key, :current_value, :tool_call
|
||||||
|
|
||||||
|
def initialize(tool_call)
|
||||||
|
@tool_call = tool_call
|
||||||
|
@current_key = nil
|
||||||
|
@current_value = nil
|
||||||
|
@parser = DiscourseAi::Completions::JsonStreamingParser.new
|
||||||
|
|
||||||
|
@parser.key do |k|
|
||||||
|
@current_key = k
|
||||||
|
@current_value = nil
|
||||||
|
end
|
||||||
|
|
||||||
|
@parser.value { |v| tool_call.notify_progress(@current_key, v) if @current_key }
|
||||||
|
end
|
||||||
|
|
||||||
|
def <<(json)
|
||||||
|
# llm could send broken json
|
||||||
|
# in that case just deal with it later
|
||||||
|
# don't stream
|
||||||
|
return if @broken
|
||||||
|
|
||||||
|
begin
|
||||||
|
@parser << json
|
||||||
|
rescue DiscourseAi::Completions::ParserError
|
||||||
|
@broken = true
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
if @parser.state == :start_string && @current_key
|
||||||
|
# this is is worth notifying
|
||||||
|
tool_call.notify_progress(@current_key, @parser.buf)
|
||||||
|
end
|
||||||
|
|
||||||
|
@current_key = nil if @parser.state == :end_value
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -109,9 +109,11 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
EndpointMock.with_chunk_array_support do
|
EndpointMock.with_chunk_array_support do
|
||||||
stub_request(:post, url).to_return(status: 200, body: body)
|
stub_request(:post, url).to_return(status: 200, body: body)
|
||||||
|
|
||||||
llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial|
|
llm.generate(
|
||||||
result << partial
|
prompt_with_google_tool,
|
||||||
end
|
user: Discourse.system_user,
|
||||||
|
partial_tool_calls: true,
|
||||||
|
) { |partial| result << partial.dup }
|
||||||
end
|
end
|
||||||
|
|
||||||
tool_call =
|
tool_call =
|
||||||
|
@ -124,7 +126,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
expect(result).to eq([tool_call])
|
expect(result.last).to eq(tool_call)
|
||||||
|
|
||||||
|
search_queries = result.filter(&:partial).map { |r| r.parameters[:search_query] }
|
||||||
|
categories = result.filter(&:partial).map { |r| r.parameters[:category] }
|
||||||
|
|
||||||
|
expect(categories).to eq([nil, nil, nil, nil, "gene", "general"])
|
||||||
|
expect(search_queries).to eq(["s", "s<a>m", "s<a>m ", "s<a>m sam", "s<a>m sam", "s<a>m sam"])
|
||||||
end
|
end
|
||||||
|
|
||||||
it "can stream a response" do
|
it "can stream a response" do
|
||||||
|
|
|
@ -571,7 +571,7 @@ TEXT
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
it "properly handles spaces in tools payload" do
|
it "properly handles spaces in tools payload and partial tool calls" do
|
||||||
raw_data = <<~TEXT.strip
|
raw_data = <<~TEXT.strip
|
||||||
data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"go|ogle","arg|uments":""}}]}}]}
|
data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"go|ogle","arg|uments":""}}]}}]}
|
||||||
|
|
||||||
|
@ -609,7 +609,9 @@ TEXT
|
||||||
partials = []
|
partials = []
|
||||||
|
|
||||||
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
|
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
|
||||||
endpoint.perform_completion!(dialect, user) { |partial| partials << partial }
|
endpoint.perform_completion!(dialect, user, partial_tool_calls: true) do |partial|
|
||||||
|
partials << partial.dup
|
||||||
|
end
|
||||||
|
|
||||||
tool_call =
|
tool_call =
|
||||||
DiscourseAi::Completions::ToolCall.new(
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
|
@ -620,7 +622,10 @@ TEXT
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
expect(partials).to eq([tool_call])
|
expect(partials.last).to eq(tool_call)
|
||||||
|
|
||||||
|
progress = partials.map { |p| p.parameters[:query] }
|
||||||
|
expect(progress).to eq(["Ad", "Adabas", "Adabas 9.", "Adabas 9.1"])
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -462,7 +462,6 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||||
expect(crafted_system_prompt).to include("fragment-n14")
|
expect(crafted_system_prompt).to include("fragment-n14")
|
||||||
expect(crafted_system_prompt).to include("fragment-n13")
|
expect(crafted_system_prompt).to include("fragment-n13")
|
||||||
expect(crafted_system_prompt).to include("fragment-n12")
|
expect(crafted_system_prompt).to include("fragment-n12")
|
||||||
|
|
||||||
expect(crafted_system_prompt).not_to include("fragment-n4") # Fragment #11 not included
|
expect(crafted_system_prompt).not_to include("fragment-n4") # Fragment #11 not included
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -253,6 +253,9 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
||||||
after { SiteSetting.provider = @original_provider }
|
after { SiteSetting.provider = @original_provider }
|
||||||
|
|
||||||
it "returns a 403 error if the user cannot access the secure upload" do
|
it "returns a 403 error if the user cannot access the secure upload" do
|
||||||
|
# hosted-site plugin edge case, it enables embeddings
|
||||||
|
SiteSetting.ai_embeddings_enabled = false
|
||||||
|
|
||||||
create_post(
|
create_post(
|
||||||
title: "Secure upload post",
|
title: "Secure upload post",
|
||||||
raw: "This is a new post <img src=\"#{upload.url}\" />",
|
raw: "This is a new post <img src=\"#{upload.url}\" />",
|
||||||
|
|
Loading…
Reference in New Issue