Roman Rizzi e0bf6adb5b
DEV: Tool support for the LLM service. (#366)
This PR adds tool support to available LLMs. We'll buffer tool invocations and return them instead of making users of this service parse the response.

It also adds support for conversation context in the generic prompt. It includes bot messages, user messages, and tool invocations, which we'll trim to make sure it doesn't exceed the prompt limit, then translate them to the correct dialect.

Finally, It adds some buffering when reading chunks to handle cases when streaming is extremely slow.:M
2023-12-18 18:06:01 -03:00

129 lines
3.9 KiB
Ruby

# frozen_string_literal: true
module DiscourseAi
module Completions
module Endpoints
class OpenAi < Base
def self.can_contact?(model_name)
%w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name)
end
def default_options
{ model: model }
end
def provider_id
AiApiAuditLog::Provider::OpenAI
end
private
def model_uri
url =
if model.include?("gpt-4")
if model.include?("32k")
SiteSetting.ai_openai_gpt4_32k_url
else
SiteSetting.ai_openai_gpt4_url
end
else
if model.include?("16k")
SiteSetting.ai_openai_gpt35_16k_url
else
SiteSetting.ai_openai_gpt35_url
end
end
URI(url)
end
def prepare_payload(prompt, model_params, dialect)
default_options
.merge(model_params)
.merge(messages: prompt)
.tap do |payload|
payload[:stream] = true if @streaming_mode
payload[:tools] = dialect.tools if dialect.tools.present?
end
end
def prepare_request(payload)
headers =
{ "Content-Type" => "application/json" }.tap do |h|
if model_uri.host.include?("azure")
h["api-key"] = SiteSetting.ai_openai_api_key
else
h["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
end
if SiteSetting.ai_openai_organization.present?
h["OpenAI-Organization"] = SiteSetting.ai_openai_organization
end
end
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
has_function_call = response_h.dig(:tool_calls).present?
has_function_call ? response_h.dig(:tool_calls, 0, :function) : response_h.dig(:content)
end
def partials_from(decoded_chunk)
decoded_chunk
.split("\n")
.map do |line|
data = line.split("data: ", 2)[1]
data == "[DONE]" ? nil : data
end
.compact
end
def extract_prompt_for_tokenizer(prompt)
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
end
def has_tool?(_response_data, partial)
partial.is_a?(Hash) && partial.has_key?(:name) # Has function name
end
def add_to_buffer(function_buffer, _response_data, partial)
function_buffer.at("tool_name").content = partial[:name] if partial[:name].present?
function_buffer.at("tool_id").content = partial[:id] if partial[:id].present?
if partial[:arguments]
argument_fragments =
partial[:arguments].reduce(+"") do |memo, (arg_name, value)|
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
end
argument_fragments << "\n"
function_buffer.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
end
function_buffer
end
def buffering_finished?(available_functions, buffer)
tool_name = buffer.at("tool_name")&.text
return false if tool_name.blank?
signature = available_functions.find { |f| f.dig(:tool, :name) == tool_name }[:tool]
signature[:parameters].reduce(true) do |memo, param|
param_present = buffer.at(param[:name]).present?
next(memo) if param_present && !param[:required]
memo && param_present
end
end
end
end
end
end