diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index b965b1f6..c00d5b65 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -185,20 +185,23 @@ module DiscourseAi end def invoke_tool(tool, llm, cancel, context, &update_blk) - update_blk.call("", cancel, build_placeholder(tool.summary, "")) + show_placeholder = !context[:skip_tool_details] + + update_blk.call("", cancel, build_placeholder(tool.summary, "")) if show_placeholder result = tool.invoke do |progress| - placeholder = build_placeholder(tool.summary, progress) - update_blk.call("", cancel, placeholder) + if show_placeholder + placeholder = build_placeholder(tool.summary, progress) + update_blk.call("", cancel, placeholder) + end end - tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw) - - if context[:skip_tool_details] && tool.custom_raw.present? - update_blk.call(tool.custom_raw, cancel, nil, :custom_raw) - elsif !context[:skip_tool_details] + if show_placeholder + tool_details = build_placeholder(tool.summary, tool.details, custom_raw: tool.custom_raw) update_blk.call(tool_details, cancel, nil, :tool_details) + elsif tool.custom_raw.present? + update_blk.call(tool.custom_raw, cancel, nil, :custom_raw) end result diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index af18f989..3873ebc6 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -452,7 +452,7 @@ module DiscourseAi bot.reply(context) do |partial, cancel, placeholder, type| reply << partial raw = reply.dup - raw << "\n\n" << placeholder if placeholder.present? && !context[:skip_tool_details] + raw << "\n\n" << placeholder if placeholder.present? blk.call(partial) if blk && type != :tool_details diff --git a/lib/completions/anthropic_message_processor.rb b/lib/completions/anthropic_message_processor.rb index 5d5602ef..aeca321d 100644 --- a/lib/completions/anthropic_message_processor.rb +++ b/lib/completions/anthropic_message_processor.rb @@ -4,28 +4,50 @@ class DiscourseAi::Completions::AnthropicMessageProcessor class AnthropicToolCall attr_reader :name, :raw_json, :id - def initialize(name, id) + def initialize(name, id, partial_tool_calls: false) @name = name @id = id @raw_json = +"" + @tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {}) + @streaming_parser = + DiscourseAi::Completions::ToolCallProgressTracker.new(self) if partial_tool_calls end def append(json) @raw_json << json + @streaming_parser << json if @streaming_parser + end + + def notify_progress(key, value) + @tool_call.partial = true + @tool_call.parameters[key.to_sym] = value + @has_new_data = true + end + + def has_partial? + @has_new_data + end + + def partial_tool_call + @has_new_data = false + @tool_call end def to_tool_call parameters = JSON.parse(raw_json, symbolize_names: true) - DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters) + @tool_call.partial = false + @tool_call.parameters = parameters + @tool_call end end attr_reader :tool_calls, :input_tokens, :output_tokens - def initialize(streaming_mode:) + def initialize(streaming_mode:, partial_tool_calls: false) @streaming_mode = streaming_mode @tool_calls = [] @current_tool_call = nil + @partial_tool_calls = partial_tool_calls end def to_tool_calls @@ -38,11 +60,17 @@ class DiscourseAi::Completions::AnthropicMessageProcessor tool_name = parsed.dig(:content_block, :name) tool_id = parsed.dig(:content_block, :id) result = @current_tool_call.to_tool_call if @current_tool_call - @current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name + @current_tool_call = + AnthropicToolCall.new( + tool_name, + tool_id, + partial_tool_calls: @partial_tool_calls, + ) if tool_name elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" if @current_tool_call tool_delta = parsed.dig(:delta, :partial_json).to_s @current_tool_call.append(tool_delta) + result = @current_tool_call.partial_tool_call if @current_tool_call.has_partial? else result = parsed.dig(:delta, :text).to_s end diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index 6576ef3b..c505e936 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -107,7 +107,10 @@ module DiscourseAi def processor @processor ||= - DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) + DiscourseAi::Completions::AnthropicMessageProcessor.new( + streaming_mode: @streaming_mode, + partial_tool_calls: partial_tool_calls, + ) end def has_tool?(_response_data) diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index c78fcdd9..7abfdf6a 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -4,6 +4,8 @@ module DiscourseAi module Completions module Endpoints class Base + attr_reader :partial_tool_calls + CompletionFailed = Class.new(StandardError) TIMEOUT = 60 @@ -58,8 +60,10 @@ module DiscourseAi model_params = {}, feature_name: nil, feature_context: nil, + partial_tool_calls: false, &blk ) + @partial_tool_calls = partial_tool_calls model_params = normalize_model_params(model_params) orig_blk = blk diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb index bd3ae4ea..c62e0bdd 100644 --- a/lib/completions/endpoints/canned_response.rb +++ b/lib/completions/endpoints/canned_response.rb @@ -28,7 +28,8 @@ module DiscourseAi _user, _model_params, feature_name: nil, - feature_context: nil + feature_context: nil, + partial_tool_calls: false ) @dialect = dialect response = responses[completions] diff --git a/lib/completions/endpoints/fake.rb b/lib/completions/endpoints/fake.rb index 15cc254d..e9b96c77 100644 --- a/lib/completions/endpoints/fake.rb +++ b/lib/completions/endpoints/fake.rb @@ -120,7 +120,8 @@ module DiscourseAi user, model_params = {}, feature_name: nil, - feature_context: nil + feature_context: nil, + partial_tool_calls: false ) last_call = { dialect: dialect, user: user, model_params: model_params } self.class.last_call = last_call diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index a185a840..1e96215b 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -33,6 +33,7 @@ module DiscourseAi model_params = {}, feature_name: nil, feature_context: nil, + partial_tool_calls: false, &blk ) if dialect.respond_to?(:is_gpt_o?) && dialect.is_gpt_o? && block_given? @@ -103,10 +104,16 @@ module DiscourseAi def decode_chunk(chunk) @decoder ||= JsonStreamDecoder.new - (@decoder << chunk) - .map { |parsed_json| processor.process_streamed_message(parsed_json) } - .flatten - .compact + elements = + (@decoder << chunk) + .map { |parsed_json| processor.process_streamed_message(parsed_json) } + .flatten + .compact + + # Remove duplicate partial tool calls + # sometimes we stream weird chunks + seen_tools = Set.new + elements.select { |item| !item.is_a?(ToolCall) || seen_tools.add?(item) } end def decode_chunk_finish @@ -120,7 +127,7 @@ module DiscourseAi private def processor - @processor ||= OpenAiMessageProcessor.new + @processor ||= OpenAiMessageProcessor.new(partial_tool_calls: partial_tool_calls) end end end diff --git a/lib/completions/json_streaming_parser.rb b/lib/completions/json_streaming_parser.rb new file mode 100644 index 00000000..c8466c82 --- /dev/null +++ b/lib/completions/json_streaming_parser.rb @@ -0,0 +1,667 @@ +# frozen_string_literal: true + +# This code is copied from the MIT licensed json-stream +# see: https://github.com/dgraham/json-stream +# +# It was copied to avoid the dependency and allow us to make some small changes +# particularly we need better access to internal state when parsing + +module DiscourseAi + module Completions + # Raised on any invalid JSON text. + ParserError = Class.new(RuntimeError) + + # A streaming JSON parser that generates SAX-like events for state changes. + # Use the json gem for small documents. Use this for huge documents that + # won't fit in memory. + # + # Examples + # + # parser = JSON::Stream::Parser.new + # parser.key { |key| puts key } + # parser.value { |value| puts value } + # parser << '{"answer":' + # parser << ' 42}' + class JsonStreamingParser + # our changes: + attr_reader :state, :buf, :pos + + # A character buffer that expects a UTF-8 encoded stream of bytes. + # This handles truncated multi-byte characters properly so we can just + # feed it binary data and receive a properly formatted UTF-8 String as + # output. + # + # More UTF-8 parsing details are available at: + # + # http://en.wikipedia.org/wiki/UTF-8 + # http://tools.ietf.org/html/rfc3629#section-3 + class Buffer + def initialize + @state = :start + @buffer = [] + @need = 0 + end + + # Fill the buffer with a String of binary UTF-8 encoded bytes. Returns + # as much of the data in a UTF-8 String as we have. Truncated multi-byte + # characters are saved in the buffer until the next call to this method + # where we expect to receive the rest of the multi-byte character. + # + # data - The partial binary encoded String data. + # + # Raises JSON::Stream::ParserError if the UTF-8 byte sequence is malformed. + # + # Returns a UTF-8 encoded String. + def <<(data) + # Avoid state machine for complete UTF-8. + if @buffer.empty? + data.force_encoding(Encoding::UTF_8) + return data if data.valid_encoding? + end + + bytes = [] + data.each_byte do |byte| + case @state + when :start + if byte < 128 + bytes << byte + elsif byte >= 192 + @state = :multi_byte + @buffer << byte + @need = + case + when byte >= 240 + 4 + when byte >= 224 + 3 + when byte >= 192 + 2 + end + else + error("Expected start of multi-byte or single byte char") + end + when :multi_byte + if byte > 127 && byte < 192 + @buffer << byte + if @buffer.size == @need + bytes += @buffer.slice!(0, @buffer.size) + @state = :start + end + else + error("Expected continuation byte") + end + end + end + + # Build UTF-8 encoded string from completed codepoints. + bytes + .pack("C*") + .force_encoding(Encoding::UTF_8) + .tap { |text| error("Invalid UTF-8 byte sequence") unless text.valid_encoding? } + end + + # Determine if the buffer contains partial UTF-8 continuation bytes that + # are waiting on subsequent completion bytes before a full codepoint is + # formed. + # + # Examples + # + # bytes = "é".bytes + # + # buffer << bytes[0] + # buffer.empty? + # # => false + # + # buffer << bytes[1] + # buffer.empty? + # # => true + # + # Returns true if the buffer is empty. + def empty? + @buffer.empty? + end + + private + + def error(message) + raise ParserError, message + end + end + + BUF_SIZE = 4096 + CONTROL = /[\x00-\x1F]/ + WS = /[ \n\t\r]/ + HEX = /[0-9a-fA-F]/ + DIGIT = /[0-9]/ + DIGIT_1_9 = /[1-9]/ + DIGIT_END = /\d$/ + TRUE_RE = /[rue]/ + FALSE_RE = /[alse]/ + NULL_RE = /[ul]/ + TRUE_KEYWORD = "true" + FALSE_KEYWORD = "false" + NULL_KEYWORD = "null" + LEFT_BRACE = "{" + RIGHT_BRACE = "}" + LEFT_BRACKET = "[" + RIGHT_BRACKET = "]" + BACKSLASH = '\\' + SLASH = "/" + QUOTE = '"' + COMMA = "," + COLON = ":" + ZERO = "0" + MINUS = "-" + PLUS = "+" + POINT = "." + EXPONENT = /[eE]/ + B, F, N, R, T, U = %w[b f n r t u] + + # Create a new parser with an optional initialization block where + # we can register event callbacks. + # + # Examples + # + # parser = JSON::Stream::Parser.new do + # start_document { puts "start document" } + # end_document { puts "end document" } + # start_object { puts "start object" } + # end_object { puts "end object" } + # start_array { puts "start array" } + # end_array { puts "end array" } + # key { |k| puts "key: #{k}" } + # value { |v| puts "value: #{v}" } + # end + def initialize(&block) + @state = :start_document + @utf8 = Buffer.new + @listeners = { + start_document: [], + end_document: [], + start_object: [], + end_object: [], + start_array: [], + end_array: [], + key: [], + value: [], + } + + # Track parse stack. + @stack = [] + @unicode = +"" + @buf = +"" + @pos = -1 + + # Register any observers in the block. + instance_eval(&block) if block_given? + end + + def start_document(&block) + @listeners[:start_document] << block + end + + def end_document(&block) + @listeners[:end_document] << block + end + + def start_object(&block) + @listeners[:start_object] << block + end + + def end_object(&block) + @listeners[:end_object] << block + end + + def start_array(&block) + @listeners[:start_array] << block + end + + def end_array(&block) + @listeners[:end_array] << block + end + + def key(&block) + @listeners[:key] << block + end + + def value(&block) + @listeners[:value] << block + end + + # Pass data into the parser to advance the state machine and + # generate callback events. This is well suited for an EventMachine + # receive_data loop. + # + # data - The String of partial JSON data to parse. + # + # Raises a JSON::Stream::ParserError if the JSON data is malformed. + # + # Returns nothing. + def <<(data) + (@utf8 << data).each_char do |ch| + @pos += 1 + case @state + when :start_document + start_value(ch) + when :start_object + case ch + when QUOTE + @state = :start_string + @stack.push(:key) + when RIGHT_BRACE + end_container(:object) + when WS + # ignore + else + error("Expected object key start") + end + when :start_string + case ch + when QUOTE + if @stack.pop == :string + end_value(@buf) + else # :key + @state = :end_key + notify(:key, @buf) + end + @buf = +"" + when BACKSLASH + @state = :start_escape + when CONTROL + error("Control characters must be escaped") + else + @buf << ch + end + when :start_escape + case ch + when QUOTE, BACKSLASH, SLASH + @buf << ch + @state = :start_string + when B + @buf << "\b" + @state = :start_string + when F + @buf << "\f" + @state = :start_string + when N + @buf << "\n" + @state = :start_string + when R + @buf << "\r" + @state = :start_string + when T + @buf << "\t" + @state = :start_string + when U + @state = :unicode_escape + else + error("Expected escaped character") + end + when :unicode_escape + case ch + when HEX + @unicode << ch + if @unicode.size == 4 + codepoint = @unicode.slice!(0, 4).hex + if codepoint >= 0xD800 && codepoint <= 0xDBFF + error("Expected low surrogate pair half") if @stack[-1].is_a?(Integer) + @state = :start_surrogate_pair + @stack.push(codepoint) + elsif codepoint >= 0xDC00 && codepoint <= 0xDFFF + high = @stack.pop + error("Expected high surrogate pair half") unless high.is_a?(Integer) + pair = ((high - 0xD800) * 0x400) + (codepoint - 0xDC00) + 0x10000 + @buf << pair + @state = :start_string + else + @buf << codepoint + @state = :start_string + end + end + else + error("Expected unicode escape hex digit") + end + when :start_surrogate_pair + case ch + when BACKSLASH + @state = :start_surrogate_pair_u + else + error("Expected low surrogate pair half") + end + when :start_surrogate_pair_u + case ch + when U + @state = :unicode_escape + else + error("Expected low surrogate pair half") + end + when :start_negative_number + case ch + when ZERO + @state = :start_zero + @buf << ch + when DIGIT_1_9 + @state = :start_int + @buf << ch + else + error("Expected 0-9 digit") + end + when :start_zero + case ch + when POINT + @state = :start_float + @buf << ch + when EXPONENT + @state = :start_exponent + @buf << ch + else + end_value(@buf.to_i) + @buf = +"" + @pos -= 1 + redo + end + when :start_float + case ch + when DIGIT + @state = :in_float + @buf << ch + else + error("Expected 0-9 digit") + end + when :in_float + case ch + when DIGIT + @buf << ch + when EXPONENT + @state = :start_exponent + @buf << ch + else + end_value(@buf.to_f) + @buf = +"" + @pos -= 1 + redo + end + when :start_exponent + case ch + when MINUS, PLUS, DIGIT + @state = :in_exponent + @buf << ch + else + error("Expected +, -, or 0-9 digit") + end + when :in_exponent + case ch + when DIGIT + @buf << ch + else + error("Expected 0-9 digit") unless @buf =~ DIGIT_END + end_value(@buf.to_f) + @buf = +"" + @pos -= 1 + redo + end + when :start_int + case ch + when DIGIT + @buf << ch + when POINT + @state = :start_float + @buf << ch + when EXPONENT + @state = :start_exponent + @buf << ch + else + end_value(@buf.to_i) + @buf = +"" + @pos -= 1 + redo + end + when :start_true + keyword(TRUE_KEYWORD, true, TRUE_RE, ch) + when :start_false + keyword(FALSE_KEYWORD, false, FALSE_RE, ch) + when :start_null + keyword(NULL_KEYWORD, nil, NULL_RE, ch) + when :end_key + case ch + when COLON + @state = :key_sep + when WS + # ignore + else + error("Expected colon key separator") + end + when :key_sep + start_value(ch) + when :start_array + case ch + when RIGHT_BRACKET + end_container(:array) + when WS + # ignore + else + start_value(ch) + end + when :end_value + case ch + when COMMA + @state = :value_sep + when RIGHT_BRACE + end_container(:object) + when RIGHT_BRACKET + end_container(:array) + when WS + # ignore + else + error("Expected comma or object or array close") + end + when :value_sep + if @stack[-1] == :object + case ch + when QUOTE + @state = :start_string + @stack.push(:key) + when WS + # ignore + else + error("Expected object key start") + end + else + start_value(ch) + end + when :end_document + error("Unexpected data") unless ch =~ WS + end + end + end + + # Drain any remaining buffered characters into the parser to complete + # the parsing of the document. + # + # This is only required when parsing a document containing a single + # numeric value, integer or float. The parser has no other way to + # detect when it should no longer expect additional characters with + # which to complete the parse, so it must be signaled by a call to + # this method. + # + # If you're parsing more typical object or array documents, there's no + # need to call `finish` because the parse will complete when the final + # closing `]` or `}` character is scanned. + # + # Raises a JSON::Stream::ParserError if the JSON data is malformed. + # + # Returns nothing. + def finish + # Partial multi-byte character waiting for completion bytes. + error("Unexpected end-of-file") unless @utf8.empty? + + # Partial array, object, or string. + error("Unexpected end-of-file") unless @stack.empty? + + case @state + when :end_document + # done, do nothing + when :in_float + end_value(@buf.to_f) + when :in_exponent + error("Unexpected end-of-file") unless @buf =~ DIGIT_END + end_value(@buf.to_f) + when :start_zero + end_value(@buf.to_i) + when :start_int + end_value(@buf.to_i) + else + error("Unexpected end-of-file") + end + end + + private + + # Invoke all registered observer procs for the event type. + # + # type - The Symbol listener name. + # args - The argument list to pass into the observer procs. + # + # Examples + # + # # broadcast events for {"answer": 42} + # notify(:start_object) + # notify(:key, "answer") + # notify(:value, 42) + # notify(:end_object) + # + # Returns nothing. + def notify(type, *args) + @listeners[type].each { |block| block.call(*args) } + end + + # Complete an object or array container value type. + # + # type - The Symbol, :object or :array, of the expected type. + # + # Raises a JSON::Stream::ParserError if the expected container type + # was not completed. + # + # Returns nothing. + def end_container(type) + @state = :end_value + if @stack.pop == type + case type + when :object + notify(:end_object) + when :array + notify(:end_array) + end + else + error("Expected end of #{type}") + end + notify_end_document if @stack.empty? + end + + # Broadcast an `end_document` event to observers after a complete JSON + # value document (object, array, number, string, true, false, null) has + # been parsed from the text. This is the final event sent to observers + # and signals the parse has finished. + # + # Returns nothing. + def notify_end_document + @state = :end_document + notify(:end_document) + end + + # Parse one of the three allowed keywords: true, false, null. + # + # word - The String keyword ('true', 'false', 'null'). + # value - The Ruby value (true, false, nil). + # re - The Regexp of allowed keyword characters. + # ch - The current String character being parsed. + # + # Raises a JSON::Stream::ParserError if the character does not belong + # in the expected keyword. + # + # Returns nothing. + def keyword(word, value, re, ch) + if ch =~ re + @buf << ch + else + error("Expected #{word} keyword") + end + + if @buf.size == word.size + if @buf == word + @buf = +"" + end_value(value) + else + error("Expected #{word} keyword") + end + end + end + + # Process the first character of one of the seven possible JSON + # values: object, array, string, true, false, null, number. + # + # ch - The current character String. + # + # Raises a JSON::Stream::ParserError if the character does not signal + # the start of a value. + # + # Returns nothing. + def start_value(ch) + case ch + when LEFT_BRACE + notify(:start_document) if @stack.empty? + @state = :start_object + @stack.push(:object) + notify(:start_object) + when LEFT_BRACKET + notify(:start_document) if @stack.empty? + @state = :start_array + @stack.push(:array) + notify(:start_array) + when QUOTE + @state = :start_string + @stack.push(:string) + when T + @state = :start_true + @buf << ch + when F + @state = :start_false + @buf << ch + when N + @state = :start_null + @buf << ch + when MINUS + @state = :start_negative_number + @buf << ch + when ZERO + @state = :start_zero + @buf << ch + when DIGIT_1_9 + @state = :start_int + @buf << ch + when WS + # ignore + else + error("Expected value") + end + end + + # Advance the state machine and notify `value` observers that a + # string, number or keyword (true, false, null) value was parsed. + # + # value - The object to broadcast to observers. + # + # Returns nothing. + def end_value(value) + @state = :end_value + notify(:start_document) if @stack.empty? + notify(:value, value) + notify_end_document if @stack.empty? + end + + def error(message) + raise ParserError, "#{message}: char #{@pos}" + end + end + end +end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 0d53b413..dc336bf2 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -164,24 +164,18 @@ module DiscourseAi # @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object # @param user { User } - User requesting the summary. + # @param temperature { Float - Optional } - The temperature to use for the completion. + # @param top_p { Float - Optional } - The top_p to use for the completion. + # @param max_tokens { Integer - Optional } - The maximum number of tokens to generate. + # @param stop_sequences { Array - Optional } - The stop sequences to use for the completion. + # @param feature_name { String - Optional } - The feature name to use for the completion. + # @param feature_context { Hash - Optional } - The feature context to use for the completion. + # @param partial_tool_calls { Boolean - Optional } - If true, the completion will return partial tool calls. # # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function. # - # @returns { String } - Completion result. - # - # When the model invokes a tool, we'll wait until the endpoint finishes replying and feed you a fully-formed tool, - # even if you passed a partial_read_blk block. Invocations are strings that look like this: - # - # - # - # get_weather - # get_weather - # - # Sydney - # c - # - # - # + # @returns String | ToolCall - Completion result. + # if multiple tools or a tool and a message come back, the result will be an array of ToolCall / String objects. # def generate( prompt, @@ -192,6 +186,7 @@ module DiscourseAi user:, feature_name: nil, feature_context: nil, + partial_tool_calls: false, &partial_read_blk ) self.class.record_prompt(prompt) @@ -226,6 +221,7 @@ module DiscourseAi model_params, feature_name: feature_name, feature_context: feature_context, + partial_tool_calls: partial_tool_calls, &partial_read_blk ) end diff --git a/lib/completions/open_ai_message_processor.rb b/lib/completions/open_ai_message_processor.rb index 2890083b..7b7378db 100644 --- a/lib/completions/open_ai_message_processor.rb +++ b/lib/completions/open_ai_message_processor.rb @@ -3,11 +3,12 @@ module DiscourseAi::Completions class OpenAiMessageProcessor attr_reader :prompt_tokens, :completion_tokens - def initialize + def initialize(partial_tool_calls: false) @tool = nil @tool_arguments = +"" @prompt_tokens = nil @completion_tokens = nil + @partial_tool_calls = partial_tool_calls end def process_message(json) @@ -57,12 +58,16 @@ module DiscourseAi::Completions if id.present? && name.present? @tool_arguments = +"" @tool = ToolCall.new(id: id, name: name) + @streaming_parser = ToolCallProgressTracker.new(self) if @partial_tool_calls end @tool_arguments << arguments.to_s + @streaming_parser << arguments.to_s if @streaming_parser && !arguments.to_s.empty? + rval = current_tool_progress if !rval elsif finished_tools && @tool parsed_args = JSON.parse(@tool_arguments, symbolize_names: true) @tool.parameters = parsed_args + @tool.partial = false rval = @tool @tool = nil elsif !content.to_s.empty? @@ -75,6 +80,23 @@ module DiscourseAi::Completions rval end + def notify_progress(key, value) + if @tool + @tool.partial = true + @tool.parameters[key.to_sym] = value + @has_new_data = true + end + end + + def current_tool_progress + if @has_new_data + @has_new_data = false + @tool + else + nil + end + end + def finish rval = [] if @tool diff --git a/lib/completions/tool_call.rb b/lib/completions/tool_call.rb index 15be7b3f..1dedc7cf 100644 --- a/lib/completions/tool_call.rb +++ b/lib/completions/tool_call.rb @@ -4,12 +4,14 @@ module DiscourseAi module Completions class ToolCall attr_reader :id, :name, :parameters + attr_accessor :partial def initialize(id:, name:, parameters: nil) @id = id @name = name self.parameters = parameters if parameters @parameters ||= {} + @partial = false end def parameters=(parameters) @@ -24,6 +26,12 @@ module DiscourseAi def to_s "#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)" end + + def dup + call = ToolCall.new(id: id, name: name, parameters: parameters.deep_dup) + call.partial = partial + call + end end end end diff --git a/lib/completions/tool_call_progress_tracker.rb b/lib/completions/tool_call_progress_tracker.rb new file mode 100644 index 00000000..f33bd3fc --- /dev/null +++ b/lib/completions/tool_call_progress_tracker.rb @@ -0,0 +1,44 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + class ToolCallProgressTracker + attr_reader :current_key, :current_value, :tool_call + + def initialize(tool_call) + @tool_call = tool_call + @current_key = nil + @current_value = nil + @parser = DiscourseAi::Completions::JsonStreamingParser.new + + @parser.key do |k| + @current_key = k + @current_value = nil + end + + @parser.value { |v| tool_call.notify_progress(@current_key, v) if @current_key } + end + + def <<(json) + # llm could send broken json + # in that case just deal with it later + # don't stream + return if @broken + + begin + @parser << json + rescue DiscourseAi::Completions::ParserError + @broken = true + return + end + + if @parser.state == :start_string && @current_key + # this is is worth notifying + tool_call.notify_progress(@current_key, @parser.buf) + end + + @current_key = nil if @parser.state == :end_value + end + end + end +end diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index 40eca30f..8bdc796e 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -109,9 +109,11 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do EndpointMock.with_chunk_array_support do stub_request(:post, url).to_return(status: 200, body: body) - llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial| - result << partial - end + llm.generate( + prompt_with_google_tool, + user: Discourse.system_user, + partial_tool_calls: true, + ) { |partial| result << partial.dup } end tool_call = @@ -124,7 +126,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do }, ) - expect(result).to eq([tool_call]) + expect(result.last).to eq(tool_call) + + search_queries = result.filter(&:partial).map { |r| r.parameters[:search_query] } + categories = result.filter(&:partial).map { |r| r.parameters[:category] } + + expect(categories).to eq([nil, nil, nil, nil, "gene", "general"]) + expect(search_queries).to eq(["s", "sm", "sm ", "sm sam", "sm sam", "sm sam"]) end it "can stream a response" do diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb index a1c6702a..e07914a5 100644 --- a/spec/lib/completions/endpoints/open_ai_spec.rb +++ b/spec/lib/completions/endpoints/open_ai_spec.rb @@ -571,7 +571,7 @@ TEXT end end - it "properly handles spaces in tools payload" do + it "properly handles spaces in tools payload and partial tool calls" do raw_data = <<~TEXT.strip data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"go|ogle","arg|uments":""}}]}}]} @@ -609,7 +609,9 @@ TEXT partials = [] dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools)) - endpoint.perform_completion!(dialect, user) { |partial| partials << partial } + endpoint.perform_completion!(dialect, user, partial_tool_calls: true) do |partial| + partials << partial.dup + end tool_call = DiscourseAi::Completions::ToolCall.new( @@ -620,7 +622,10 @@ TEXT }, ) - expect(partials).to eq([tool_call]) + expect(partials.last).to eq(tool_call) + + progress = partials.map { |p| p.parameters[:query] } + expect(progress).to eq(["Ad", "Adabas", "Adabas 9.", "Adabas 9.1"]) end end end diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb index 5271374d..a77fbbf0 100644 --- a/spec/lib/modules/ai_bot/personas/persona_spec.rb +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -462,7 +462,6 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do expect(crafted_system_prompt).to include("fragment-n14") expect(crafted_system_prompt).to include("fragment-n13") expect(crafted_system_prompt).to include("fragment-n12") - expect(crafted_system_prompt).not_to include("fragment-n4") # Fragment #11 not included end end diff --git a/spec/requests/ai_helper/assistant_controller_spec.rb b/spec/requests/ai_helper/assistant_controller_spec.rb index 4dd58ff1..71dab56a 100644 --- a/spec/requests/ai_helper/assistant_controller_spec.rb +++ b/spec/requests/ai_helper/assistant_controller_spec.rb @@ -253,6 +253,9 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do after { SiteSetting.provider = @original_provider } it "returns a 403 error if the user cannot access the secure upload" do + # hosted-site plugin edge case, it enables embeddings + SiteSetting.ai_embeddings_enabled = false + create_post( title: "Secure upload post", raw: "This is a new post ",