mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-02-16 00:14:48 +00:00
Implement streaming tool call implementation for Anthropic and Open AI. When calling: llm.generate(..., partial_tool_calls: true) do ... Partials may contain ToolCall instances with partial: true, These tool calls are partially populated with json partially parsed. So for example when performing a search you may get: ToolCall(..., {search: "hello" }) ToolCall(..., {search: "hello world" }) The library used to parse json is: https://github.com/dgraham/json-stream We use a fork cause we need access to the internal buffer. This prepares internals to perform partial tool calls, but does not implement it yet.
127 lines
3.3 KiB
Ruby
127 lines
3.3 KiB
Ruby
# frozen_string_literal: true
|
|
module DiscourseAi::Completions
|
|
class OpenAiMessageProcessor
|
|
attr_reader :prompt_tokens, :completion_tokens
|
|
|
|
def initialize(partial_tool_calls: false)
|
|
@tool = nil
|
|
@tool_arguments = +""
|
|
@prompt_tokens = nil
|
|
@completion_tokens = nil
|
|
@partial_tool_calls = partial_tool_calls
|
|
end
|
|
|
|
def process_message(json)
|
|
result = []
|
|
tool_calls = json.dig(:choices, 0, :message, :tool_calls)
|
|
|
|
message = json.dig(:choices, 0, :message, :content)
|
|
result << message if message.present?
|
|
|
|
if tool_calls.present?
|
|
tool_calls.each do |tool_call|
|
|
id = tool_call.dig(:id)
|
|
name = tool_call.dig(:function, :name)
|
|
arguments = tool_call.dig(:function, :arguments)
|
|
parameters = arguments.present? ? JSON.parse(arguments, symbolize_names: true) : {}
|
|
result << ToolCall.new(id: id, name: name, parameters: parameters)
|
|
end
|
|
end
|
|
|
|
update_usage(json)
|
|
|
|
result
|
|
end
|
|
|
|
def process_streamed_message(json)
|
|
rval = nil
|
|
|
|
tool_calls = json.dig(:choices, 0, :delta, :tool_calls)
|
|
content = json.dig(:choices, 0, :delta, :content)
|
|
|
|
finished_tools = json.dig(:choices, 0, :finish_reason) || tool_calls == []
|
|
|
|
if tool_calls.present?
|
|
id = tool_calls.dig(0, :id)
|
|
name = tool_calls.dig(0, :function, :name)
|
|
arguments = tool_calls.dig(0, :function, :arguments)
|
|
|
|
# TODO: multiple tool support may require index
|
|
#index = tool_calls[0].dig(:index)
|
|
|
|
if id.present? && @tool && @tool.id != id
|
|
process_arguments
|
|
rval = @tool
|
|
@tool = nil
|
|
end
|
|
|
|
if id.present? && name.present?
|
|
@tool_arguments = +""
|
|
@tool = ToolCall.new(id: id, name: name)
|
|
@streaming_parser = ToolCallProgressTracker.new(self) if @partial_tool_calls
|
|
end
|
|
|
|
@tool_arguments << arguments.to_s
|
|
@streaming_parser << arguments.to_s if @streaming_parser && !arguments.to_s.empty?
|
|
rval = current_tool_progress if !rval
|
|
elsif finished_tools && @tool
|
|
parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
|
|
@tool.parameters = parsed_args
|
|
@tool.partial = false
|
|
rval = @tool
|
|
@tool = nil
|
|
elsif !content.to_s.empty?
|
|
# we don't want to strip empty content like "\n", do not use present?
|
|
rval = content
|
|
end
|
|
|
|
update_usage(json)
|
|
|
|
rval
|
|
end
|
|
|
|
def notify_progress(key, value)
|
|
if @tool
|
|
@tool.partial = true
|
|
@tool.parameters[key.to_sym] = value
|
|
@has_new_data = true
|
|
end
|
|
end
|
|
|
|
def current_tool_progress
|
|
if @has_new_data
|
|
@has_new_data = false
|
|
@tool
|
|
else
|
|
nil
|
|
end
|
|
end
|
|
|
|
def finish
|
|
rval = []
|
|
if @tool
|
|
process_arguments
|
|
rval << @tool
|
|
@tool = nil
|
|
end
|
|
|
|
rval
|
|
end
|
|
|
|
private
|
|
|
|
def process_arguments
|
|
if @tool_arguments.present?
|
|
parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
|
|
@tool.parameters = parsed_args
|
|
@tool_arguments = nil
|
|
end
|
|
end
|
|
|
|
def update_usage(json)
|
|
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
|
|
@completion_tokens ||= json.dig(:usage, :completion_tokens)
|
|
end
|
|
end
|
|
end
|