193 lines
5.8 KiB
Ruby
193 lines
5.8 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module Completions
|
|
module Endpoints
|
|
class OpenAi < Base
|
|
def self.can_contact?(model_provider)
|
|
%w[open_ai azure].include?(model_provider)
|
|
end
|
|
|
|
def normalize_model_params(model_params)
|
|
model_params = model_params.dup
|
|
|
|
# max_tokens, temperature are already supported
|
|
if model_params[:stop_sequences]
|
|
model_params[:stop] = model_params.delete(:stop_sequences)
|
|
end
|
|
|
|
model_params
|
|
end
|
|
|
|
def default_options
|
|
{ model: llm_model.name }
|
|
end
|
|
|
|
def provider_id
|
|
AiApiAuditLog::Provider::OpenAI
|
|
end
|
|
|
|
def perform_completion!(
|
|
dialect,
|
|
user,
|
|
model_params = {},
|
|
feature_name: nil,
|
|
feature_context: nil,
|
|
&blk
|
|
)
|
|
if dialect.respond_to?(:is_gpt_o?) && dialect.is_gpt_o? && block_given?
|
|
# we need to disable streaming and simulate it
|
|
blk.call "", lambda { |*| }
|
|
response = super(dialect, user, model_params, feature_name: feature_name, &nil)
|
|
blk.call response, lambda { |*| }
|
|
else
|
|
super
|
|
end
|
|
end
|
|
|
|
private
|
|
|
|
def model_uri
|
|
if llm_model.url.to_s.starts_with?("srv://")
|
|
service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", ""))
|
|
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
|
|
else
|
|
api_endpoint = llm_model.url
|
|
end
|
|
|
|
@uri ||= URI(api_endpoint)
|
|
end
|
|
|
|
def prepare_payload(prompt, model_params, dialect)
|
|
payload = default_options.merge(model_params).merge(messages: prompt)
|
|
|
|
if @streaming_mode
|
|
payload[:stream] = true
|
|
|
|
# Usage is not available in Azure yet.
|
|
# We'll fallback to guess this using the tokenizer.
|
|
payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai"
|
|
end
|
|
if dialect.tools.present?
|
|
payload[:tools] = dialect.tools
|
|
if dialect.tool_choice.present?
|
|
payload[:tool_choice] = { type: "function", function: { name: dialect.tool_choice } }
|
|
end
|
|
end
|
|
payload
|
|
end
|
|
|
|
def prepare_request(payload)
|
|
headers = { "Content-Type" => "application/json" }
|
|
api_key = llm_model.api_key
|
|
|
|
if llm_model.provider == "azure"
|
|
headers["api-key"] = api_key
|
|
else
|
|
headers["Authorization"] = "Bearer #{api_key}"
|
|
org_id = llm_model.lookup_custom_param("organization")
|
|
headers["OpenAI-Organization"] = org_id if org_id.present?
|
|
end
|
|
|
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
|
end
|
|
|
|
def final_log_update(log)
|
|
log.request_tokens = @prompt_tokens if @prompt_tokens
|
|
log.response_tokens = @completion_tokens if @completion_tokens
|
|
end
|
|
|
|
def extract_completion_from(response_raw)
|
|
json = JSON.parse(response_raw, symbolize_names: true)
|
|
|
|
if @streaming_mode
|
|
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
|
|
@completion_tokens ||= json.dig(:usage, :completion_tokens)
|
|
end
|
|
|
|
parsed = json.dig(:choices, 0)
|
|
return if !parsed
|
|
|
|
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
|
|
@has_function_call ||= response_h.dig(:tool_calls).present?
|
|
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
|
|
end
|
|
|
|
def partials_from(decoded_chunk)
|
|
decoded_chunk
|
|
.split("\n")
|
|
.map do |line|
|
|
data = line.split("data: ", 2)[1]
|
|
data == "[DONE]" ? nil : data
|
|
end
|
|
.compact
|
|
end
|
|
|
|
def has_tool?(_response_data)
|
|
@has_function_call
|
|
end
|
|
|
|
def native_tool_support?
|
|
true
|
|
end
|
|
|
|
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
|
if @streaming_mode
|
|
return function_buffer if !partial
|
|
else
|
|
partial = payload
|
|
end
|
|
|
|
@args_buffer ||= +""
|
|
|
|
f_name = partial.dig(:function, :name)
|
|
|
|
@current_function ||= function_buffer.at("invoke")
|
|
|
|
if f_name
|
|
current_name = function_buffer.at("tool_name").content
|
|
|
|
if current_name.blank?
|
|
# first call
|
|
else
|
|
# we have a previous function, so we need to add a noop
|
|
@args_buffer = +""
|
|
@current_function =
|
|
function_buffer.at("function_calls").add_child(
|
|
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
|
|
)
|
|
end
|
|
end
|
|
|
|
@current_function.at("tool_name").content = f_name if f_name
|
|
@current_function.at("tool_id").content = partial[:id] if partial[:id]
|
|
|
|
args = partial.dig(:function, :arguments)
|
|
|
|
# allow for SPACE within arguments
|
|
if args && args != ""
|
|
@args_buffer << args
|
|
|
|
begin
|
|
json_args = JSON.parse(@args_buffer, symbolize_names: true)
|
|
|
|
argument_fragments =
|
|
json_args.reduce(+"") do |memo, (arg_name, value)|
|
|
memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}</#{arg_name}>"
|
|
end
|
|
argument_fragments << "\n"
|
|
|
|
@current_function.at("parameters").children =
|
|
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
|
rescue JSON::ParserError
|
|
return function_buffer
|
|
end
|
|
end
|
|
|
|
function_buffer
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|