mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-03-01 14:59:22 +00:00
* FEATURE: allows forced LLM tool use Sometimes we need to force LLMs to use tools, for example in RAG like use cases we may want to force an unconditional search. The new framework allows you backend to force tool usage. Front end commit to follow * UI for forcing tools now works, but it does not react right * fix bugs * fix tests, this is now ready for review
179 lines
5.4 KiB
Ruby
179 lines
5.4 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module Completions
|
|
module Endpoints
|
|
class OpenAi < Base
|
|
def self.can_contact?(model_provider)
|
|
%w[open_ai azure].include?(model_provider)
|
|
end
|
|
|
|
def normalize_model_params(model_params)
|
|
model_params = model_params.dup
|
|
|
|
# max_tokens, temperature are already supported
|
|
if model_params[:stop_sequences]
|
|
model_params[:stop] = model_params.delete(:stop_sequences)
|
|
end
|
|
|
|
model_params
|
|
end
|
|
|
|
def default_options
|
|
{ model: llm_model.name }
|
|
end
|
|
|
|
def provider_id
|
|
AiApiAuditLog::Provider::OpenAI
|
|
end
|
|
|
|
def perform_completion!(dialect, user, model_params = {}, feature_name: nil, &blk)
|
|
if dialect.respond_to?(:is_gpt_o?) && dialect.is_gpt_o? && block_given?
|
|
# we need to disable streaming and simulate it
|
|
blk.call "", lambda { |*| }
|
|
response = super(dialect, user, model_params, feature_name: feature_name, &nil)
|
|
blk.call response, lambda { |*| }
|
|
else
|
|
super
|
|
end
|
|
end
|
|
|
|
private
|
|
|
|
def model_uri
|
|
URI(llm_model.url)
|
|
end
|
|
|
|
def prepare_payload(prompt, model_params, dialect)
|
|
payload = default_options.merge(model_params).merge(messages: prompt)
|
|
|
|
if @streaming_mode
|
|
payload[:stream] = true
|
|
|
|
# Usage is not available in Azure yet.
|
|
# We'll fallback to guess this using the tokenizer.
|
|
payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai"
|
|
end
|
|
if dialect.tools.present?
|
|
payload[:tools] = dialect.tools
|
|
if dialect.tool_choice.present?
|
|
payload[:tool_choice] = { type: "function", function: { name: dialect.tool_choice } }
|
|
end
|
|
end
|
|
payload
|
|
end
|
|
|
|
def prepare_request(payload)
|
|
headers = { "Content-Type" => "application/json" }
|
|
api_key = llm_model.api_key
|
|
|
|
if llm_model.provider == "azure"
|
|
headers["api-key"] = api_key
|
|
else
|
|
headers["Authorization"] = "Bearer #{api_key}"
|
|
org_id = llm_model.lookup_custom_param("organization")
|
|
headers["OpenAI-Organization"] = org_id if org_id.present?
|
|
end
|
|
|
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
|
end
|
|
|
|
def final_log_update(log)
|
|
log.request_tokens = @prompt_tokens if @prompt_tokens
|
|
log.response_tokens = @completion_tokens if @completion_tokens
|
|
end
|
|
|
|
def extract_completion_from(response_raw)
|
|
json = JSON.parse(response_raw, symbolize_names: true)
|
|
|
|
if @streaming_mode
|
|
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
|
|
@completion_tokens ||= json.dig(:usage, :completion_tokens)
|
|
end
|
|
|
|
parsed = json.dig(:choices, 0)
|
|
return if !parsed
|
|
|
|
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
|
|
@has_function_call ||= response_h.dig(:tool_calls).present?
|
|
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
|
|
end
|
|
|
|
def partials_from(decoded_chunk)
|
|
decoded_chunk
|
|
.split("\n")
|
|
.map do |line|
|
|
data = line.split("data: ", 2)[1]
|
|
data == "[DONE]" ? nil : data
|
|
end
|
|
.compact
|
|
end
|
|
|
|
def has_tool?(_response_data)
|
|
@has_function_call
|
|
end
|
|
|
|
def native_tool_support?
|
|
true
|
|
end
|
|
|
|
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
|
if @streaming_mode
|
|
return function_buffer if !partial
|
|
else
|
|
partial = payload
|
|
end
|
|
|
|
@args_buffer ||= +""
|
|
|
|
f_name = partial.dig(:function, :name)
|
|
|
|
@current_function ||= function_buffer.at("invoke")
|
|
|
|
if f_name
|
|
current_name = function_buffer.at("tool_name").content
|
|
|
|
if current_name.blank?
|
|
# first call
|
|
else
|
|
# we have a previous function, so we need to add a noop
|
|
@args_buffer = +""
|
|
@current_function =
|
|
function_buffer.at("function_calls").add_child(
|
|
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
|
|
)
|
|
end
|
|
end
|
|
|
|
@current_function.at("tool_name").content = f_name if f_name
|
|
@current_function.at("tool_id").content = partial[:id] if partial[:id]
|
|
|
|
args = partial.dig(:function, :arguments)
|
|
|
|
# allow for SPACE within arguments
|
|
if args && args != ""
|
|
@args_buffer << args
|
|
|
|
begin
|
|
json_args = JSON.parse(@args_buffer, symbolize_names: true)
|
|
|
|
argument_fragments =
|
|
json_args.reduce(+"") do |memo, (arg_name, value)|
|
|
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
|
|
end
|
|
argument_fragments << "\n"
|
|
|
|
@current_function.at("parameters").children =
|
|
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
|
rescue JSON::ParserError
|
|
return function_buffer
|
|
end
|
|
end
|
|
|
|
function_buffer
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|