2023-03-07 14:14:39 -05:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
2023-03-14 15:03:50 -04:00
|
|
|
module ::DiscourseAi
|
2023-03-07 14:14:39 -05:00
|
|
|
module Inference
|
2023-03-15 16:02:20 -04:00
|
|
|
class OpenAiCompletions
|
2023-04-21 02:54:25 -04:00
|
|
|
TIMEOUT = 60
|
2023-10-04 18:00:45 -04:00
|
|
|
DEFAULT_RETRIES = 3
|
|
|
|
DEFAULT_RETRY_TIMEOUT_SECONDS = 3
|
|
|
|
RETRY_TIMEOUT_BACKOFF_MULTIPLIER = 3
|
|
|
|
|
2023-03-22 15:00:28 -04:00
|
|
|
CompletionFailed = Class.new(StandardError)
|
|
|
|
|
2023-04-21 02:54:25 -04:00
|
|
|
def self.perform!(
|
|
|
|
messages,
|
2023-05-11 09:03:03 -04:00
|
|
|
model,
|
2023-04-21 02:54:25 -04:00
|
|
|
temperature: nil,
|
|
|
|
top_p: nil,
|
|
|
|
max_tokens: nil,
|
2023-06-19 18:45:31 -04:00
|
|
|
functions: nil,
|
2023-10-04 18:00:45 -04:00
|
|
|
user_id: nil,
|
|
|
|
retries: DEFAULT_RETRIES,
|
|
|
|
retry_timeout: DEFAULT_RETRY_TIMEOUT_SECONDS,
|
|
|
|
&blk
|
2023-04-21 02:54:25 -04:00
|
|
|
)
|
2023-08-31 21:48:51 -04:00
|
|
|
log = nil
|
|
|
|
response_data = +""
|
|
|
|
response_raw = +""
|
|
|
|
|
2023-06-20 20:39:51 -04:00
|
|
|
url =
|
|
|
|
if model.include?("gpt-4")
|
2023-08-16 21:00:11 -04:00
|
|
|
if model.include?("32k")
|
|
|
|
URI(SiteSetting.ai_openai_gpt4_32k_url)
|
|
|
|
else
|
|
|
|
URI(SiteSetting.ai_openai_gpt4_url)
|
|
|
|
end
|
2023-06-20 20:39:51 -04:00
|
|
|
else
|
2023-08-16 21:00:11 -04:00
|
|
|
if model.include?("16k")
|
|
|
|
URI(SiteSetting.ai_openai_gpt35_16k_url)
|
|
|
|
else
|
|
|
|
URI(SiteSetting.ai_openai_gpt35_url)
|
|
|
|
end
|
2023-06-20 20:39:51 -04:00
|
|
|
end
|
|
|
|
headers = { "Content-Type" => "application/json" }
|
|
|
|
|
|
|
|
if url.host.include? ("azure")
|
|
|
|
headers["api-key"] = SiteSetting.ai_openai_api_key
|
|
|
|
else
|
|
|
|
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
|
|
|
|
end
|
2023-05-05 14:28:31 -04:00
|
|
|
|
2023-10-05 19:23:18 -04:00
|
|
|
if SiteSetting.ai_openai_organization.present?
|
|
|
|
headers["OpenAI-Organization"] = SiteSetting.ai_openai_organization
|
|
|
|
end
|
|
|
|
|
2023-04-21 02:54:25 -04:00
|
|
|
payload = { model: model, messages: messages }
|
|
|
|
|
|
|
|
payload[:temperature] = temperature if temperature
|
|
|
|
payload[:top_p] = top_p if top_p
|
|
|
|
payload[:max_tokens] = max_tokens if max_tokens
|
2023-06-19 18:45:31 -04:00
|
|
|
payload[:functions] = functions if functions
|
2023-05-05 14:28:31 -04:00
|
|
|
payload[:stream] = true if block_given?
|
2023-03-07 14:14:39 -05:00
|
|
|
|
2023-04-21 02:54:25 -04:00
|
|
|
Net::HTTP.start(
|
|
|
|
url.host,
|
|
|
|
url.port,
|
|
|
|
use_ssl: true,
|
|
|
|
read_timeout: TIMEOUT,
|
|
|
|
open_timeout: TIMEOUT,
|
|
|
|
write_timeout: TIMEOUT,
|
|
|
|
) do |http|
|
|
|
|
request = Net::HTTP::Post.new(url, headers)
|
2023-04-25 21:44:29 -04:00
|
|
|
request_body = payload.to_json
|
|
|
|
request.body = request_body
|
2023-03-20 15:43:51 -04:00
|
|
|
|
2023-05-05 14:28:31 -04:00
|
|
|
http.request(request) do |response|
|
2023-10-04 18:00:45 -04:00
|
|
|
if retries > 0 && response.code.to_i == 429
|
|
|
|
sleep(retry_timeout)
|
|
|
|
retries -= 1
|
|
|
|
retry_timeout *= RETRY_TIMEOUT_BACKOFF_MULTIPLIER
|
|
|
|
return(
|
|
|
|
perform!(
|
|
|
|
messages,
|
|
|
|
model,
|
|
|
|
temperature: temperature,
|
|
|
|
top_p: top_p,
|
|
|
|
max_tokens: max_tokens,
|
|
|
|
functions: functions,
|
|
|
|
user_id: user_id,
|
|
|
|
retries: retries,
|
|
|
|
retry_timeout: retry_timeout,
|
|
|
|
&blk
|
|
|
|
)
|
|
|
|
)
|
|
|
|
elsif response.code.to_i != 200
|
2023-05-05 14:28:31 -04:00
|
|
|
Rails.logger.error(
|
|
|
|
"OpenAiCompletions: status: #{response.code.to_i} - body: #{response.body}",
|
|
|
|
)
|
2023-10-04 18:00:45 -04:00
|
|
|
raise CompletionFailed, "status: #{response.code.to_i} - body: #{response.body}"
|
2023-05-05 14:28:31 -04:00
|
|
|
end
|
2023-04-25 21:44:29 -04:00
|
|
|
|
2023-05-05 14:28:31 -04:00
|
|
|
log =
|
|
|
|
AiApiAuditLog.create!(
|
|
|
|
provider_id: AiApiAuditLog::Provider::OpenAI,
|
|
|
|
raw_request_payload: request_body,
|
|
|
|
user_id: user_id,
|
|
|
|
)
|
|
|
|
|
2023-10-04 18:00:45 -04:00
|
|
|
if !blk
|
2023-05-05 14:28:31 -04:00
|
|
|
response_body = response.read_body
|
|
|
|
parsed_response = JSON.parse(response_body, symbolize_names: true)
|
|
|
|
|
|
|
|
log.update!(
|
|
|
|
raw_response_payload: response_body,
|
|
|
|
request_tokens: parsed_response.dig(:usage, :prompt_tokens),
|
|
|
|
response_tokens: parsed_response.dig(:usage, :completion_tokens),
|
|
|
|
)
|
|
|
|
return parsed_response
|
|
|
|
end
|
2023-04-21 02:54:25 -04:00
|
|
|
|
2023-05-05 14:28:31 -04:00
|
|
|
begin
|
|
|
|
cancelled = false
|
|
|
|
cancel = lambda { cancelled = true }
|
2023-04-21 02:54:25 -04:00
|
|
|
|
2023-06-19 18:45:31 -04:00
|
|
|
leftover = ""
|
|
|
|
|
2023-05-05 14:28:31 -04:00
|
|
|
response.read_body do |chunk|
|
|
|
|
if cancelled
|
|
|
|
http.finish
|
|
|
|
return
|
2023-04-25 21:44:29 -04:00
|
|
|
end
|
2023-04-21 02:54:25 -04:00
|
|
|
|
2023-05-05 14:28:31 -04:00
|
|
|
response_raw << chunk
|
|
|
|
|
2023-06-19 18:45:31 -04:00
|
|
|
(leftover + chunk)
|
2023-05-05 14:28:31 -04:00
|
|
|
.split("\n")
|
|
|
|
.each do |line|
|
|
|
|
data = line.split("data: ", 2)[1]
|
|
|
|
next if !data || data == "[DONE]"
|
2023-06-19 18:45:31 -04:00
|
|
|
next if cancelled
|
|
|
|
|
|
|
|
partial = nil
|
|
|
|
begin
|
|
|
|
partial = JSON.parse(data, symbolize_names: true)
|
|
|
|
leftover = ""
|
|
|
|
rescue JSON::ParserError
|
|
|
|
leftover = line
|
|
|
|
end
|
2023-05-05 14:28:31 -04:00
|
|
|
|
2023-06-19 18:45:31 -04:00
|
|
|
if partial
|
2023-05-05 14:28:31 -04:00
|
|
|
response_data << partial.dig(:choices, 0, :delta, :content).to_s
|
2023-06-19 18:45:31 -04:00
|
|
|
response_data << partial.dig(:choices, 0, :delta, :function_call).to_s
|
2023-05-05 14:28:31 -04:00
|
|
|
|
2023-10-04 18:00:45 -04:00
|
|
|
blk.call(partial, cancel)
|
2023-05-05 14:28:31 -04:00
|
|
|
end
|
|
|
|
end
|
|
|
|
rescue IOError
|
|
|
|
raise if !cancelled
|
2023-04-21 02:54:25 -04:00
|
|
|
end
|
|
|
|
end
|
2023-08-11 14:08:54 -04:00
|
|
|
|
|
|
|
return response_data
|
2023-05-05 14:28:31 -04:00
|
|
|
end
|
2023-04-21 02:54:25 -04:00
|
|
|
end
|
2023-08-31 21:48:51 -04:00
|
|
|
ensure
|
|
|
|
if log && block_given?
|
|
|
|
request_tokens = DiscourseAi::Tokenizer::OpenAiTokenizer.size(extract_prompt(messages))
|
|
|
|
response_tokens = DiscourseAi::Tokenizer::OpenAiTokenizer.size(response_data)
|
|
|
|
log.update!(
|
|
|
|
raw_response_payload: response_raw,
|
|
|
|
request_tokens: request_tokens,
|
|
|
|
response_tokens: response_tokens,
|
|
|
|
)
|
|
|
|
end
|
|
|
|
if log && Rails.env.development?
|
|
|
|
puts "OpenAiCompletions: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
|
|
|
|
end
|
2023-04-25 21:44:29 -04:00
|
|
|
end
|
|
|
|
|
|
|
|
def self.extract_prompt(messages)
|
|
|
|
messages.map { |message| message[:content] || message["content"] || "" }.join("\n")
|
2023-03-07 14:14:39 -05:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|