2023-11-23 10:58:54 -05:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
|
|
|
module DiscourseAi
|
|
|
|
module Completions
|
|
|
|
module Endpoints
|
2023-11-28 23:17:46 -05:00
|
|
|
class OpenAi < Base
|
2024-07-30 12:44:57 -04:00
|
|
|
def self.can_contact?(model_provider)
|
|
|
|
%w[open_ai azure].include?(model_provider)
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
|
2024-01-04 07:53:47 -05:00
|
|
|
def normalize_model_params(model_params)
|
|
|
|
model_params = model_params.dup
|
|
|
|
|
|
|
|
# max_tokens, temperature are already supported
|
|
|
|
if model_params[:stop_sequences]
|
|
|
|
model_params[:stop] = model_params.delete(:stop_sequences)
|
|
|
|
end
|
|
|
|
|
|
|
|
model_params
|
|
|
|
end
|
|
|
|
|
2023-11-23 10:58:54 -05:00
|
|
|
def default_options
|
2024-07-30 12:44:57 -04:00
|
|
|
{ model: llm_model.name }
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
|
|
|
|
def provider_id
|
|
|
|
AiApiAuditLog::Provider::OpenAI
|
|
|
|
end
|
|
|
|
|
2024-10-23 01:49:56 -04:00
|
|
|
def perform_completion!(
|
|
|
|
dialect,
|
|
|
|
user,
|
|
|
|
model_params = {},
|
|
|
|
feature_name: nil,
|
|
|
|
feature_context: nil,
|
2024-11-13 14:58:24 -05:00
|
|
|
partial_tool_calls: false,
|
2024-10-23 01:49:56 -04:00
|
|
|
&blk
|
|
|
|
)
|
2024-11-18 17:22:39 -05:00
|
|
|
@disable_native_tools = dialect.disable_native_tools?
|
|
|
|
super
|
2024-09-16 19:41:00 -04:00
|
|
|
end
|
|
|
|
|
2023-11-23 10:58:54 -05:00
|
|
|
private
|
|
|
|
|
|
|
|
def model_uri
|
2024-10-24 14:47:12 -04:00
|
|
|
if llm_model.url.to_s.starts_with?("srv://")
|
|
|
|
service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", ""))
|
|
|
|
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
|
|
|
|
else
|
|
|
|
api_endpoint = llm_model.url
|
|
|
|
end
|
|
|
|
|
|
|
|
@uri ||= URI(api_endpoint)
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
def prepare_payload(prompt, model_params, dialect)
|
2024-05-13 23:28:46 -04:00
|
|
|
payload = default_options.merge(model_params).merge(messages: prompt)
|
|
|
|
|
|
|
|
if @streaming_mode
|
|
|
|
payload[:stream] = true
|
2024-05-28 15:55:43 -04:00
|
|
|
|
|
|
|
# Usage is not available in Azure yet.
|
|
|
|
# We'll fallback to guess this using the tokenizer.
|
2024-07-30 12:44:57 -04:00
|
|
|
payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai"
|
2024-05-13 23:28:46 -04:00
|
|
|
end
|
2024-11-18 17:22:39 -05:00
|
|
|
if !xml_tools_enabled?
|
|
|
|
if dialect.tools.present?
|
|
|
|
payload[:tools] = dialect.tools
|
|
|
|
if dialect.tool_choice.present?
|
|
|
|
payload[:tool_choice] = {
|
|
|
|
type: "function",
|
|
|
|
function: {
|
|
|
|
name: dialect.tool_choice,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
end
|
2024-10-04 19:46:57 -04:00
|
|
|
end
|
|
|
|
end
|
2024-05-13 23:28:46 -04:00
|
|
|
payload
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
|
|
|
|
def prepare_request(payload)
|
2024-05-13 23:28:46 -04:00
|
|
|
headers = { "Content-Type" => "application/json" }
|
2024-07-30 12:44:57 -04:00
|
|
|
api_key = llm_model.api_key
|
2023-11-23 10:58:54 -05:00
|
|
|
|
2024-07-30 12:44:57 -04:00
|
|
|
if llm_model.provider == "azure"
|
2024-05-16 08:50:22 -04:00
|
|
|
headers["api-key"] = api_key
|
2024-05-13 23:28:46 -04:00
|
|
|
else
|
2024-05-16 08:50:22 -04:00
|
|
|
headers["Authorization"] = "Bearer #{api_key}"
|
2024-07-30 12:44:57 -04:00
|
|
|
org_id = llm_model.lookup_custom_param("organization")
|
|
|
|
headers["OpenAI-Organization"] = org_id if org_id.present?
|
2024-05-13 23:28:46 -04:00
|
|
|
end
|
|
|
|
|
2023-11-23 10:58:54 -05:00
|
|
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
|
|
|
end
|
|
|
|
|
2024-05-13 23:28:46 -04:00
|
|
|
def final_log_update(log)
|
2024-11-11 16:14:30 -05:00
|
|
|
log.request_tokens = processor.prompt_tokens if processor.prompt_tokens
|
|
|
|
log.response_tokens = processor.completion_tokens if processor.completion_tokens
|
2024-11-28 14:26:48 -05:00
|
|
|
log.cached_tokens = processor.cached_tokens if processor.cached_tokens
|
2024-05-13 23:28:46 -04:00
|
|
|
end
|
|
|
|
|
2024-11-11 16:14:30 -05:00
|
|
|
def decode(response_raw)
|
|
|
|
processor.process_message(JSON.parse(response_raw, symbolize_names: true))
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
|
2024-11-11 16:14:30 -05:00
|
|
|
def decode_chunk(chunk)
|
|
|
|
@decoder ||= JsonStreamDecoder.new
|
2024-11-13 14:58:24 -05:00
|
|
|
elements =
|
|
|
|
(@decoder << chunk)
|
|
|
|
.map { |parsed_json| processor.process_streamed_message(parsed_json) }
|
|
|
|
.flatten
|
|
|
|
.compact
|
|
|
|
|
|
|
|
# Remove duplicate partial tool calls
|
|
|
|
# sometimes we stream weird chunks
|
|
|
|
seen_tools = Set.new
|
|
|
|
elements.select { |item| !item.is_a?(ToolCall) || seen_tools.add?(item) }
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
|
2024-11-11 16:14:30 -05:00
|
|
|
def decode_chunk_finish
|
|
|
|
@processor.finish
|
2023-12-18 16:06:01 -05:00
|
|
|
end
|
|
|
|
|
2024-11-11 16:14:30 -05:00
|
|
|
def xml_tools_enabled?
|
2024-11-18 17:22:39 -05:00
|
|
|
!!@disable_native_tools
|
2024-03-07 14:37:23 -05:00
|
|
|
end
|
|
|
|
|
2024-11-11 16:14:30 -05:00
|
|
|
private
|
2024-01-02 09:21:13 -05:00
|
|
|
|
2024-11-11 16:14:30 -05:00
|
|
|
def processor
|
2024-11-13 14:58:24 -05:00
|
|
|
@processor ||= OpenAiMessageProcessor.new(partial_tool_calls: partial_tool_calls)
|
2023-12-18 16:06:01 -05:00
|
|
|
end
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|