mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-10-27 20:48:39 +00:00
Per: https://platform.openai.com/docs/api-reference/authentication There is an organization option which is useful for large orgs > For users who belong to multiple organizations, you can pass a header to specify which organization is used for an API request. Usage from these API requests will count against the specified organization's subscription quota.
185 lines
5.7 KiB
Ruby
185 lines
5.7 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module ::DiscourseAi
|
|
module Inference
|
|
class OpenAiCompletions
|
|
TIMEOUT = 60
|
|
DEFAULT_RETRIES = 3
|
|
DEFAULT_RETRY_TIMEOUT_SECONDS = 3
|
|
RETRY_TIMEOUT_BACKOFF_MULTIPLIER = 3
|
|
|
|
CompletionFailed = Class.new(StandardError)
|
|
|
|
def self.perform!(
|
|
messages,
|
|
model,
|
|
temperature: nil,
|
|
top_p: nil,
|
|
max_tokens: nil,
|
|
functions: nil,
|
|
user_id: nil,
|
|
retries: DEFAULT_RETRIES,
|
|
retry_timeout: DEFAULT_RETRY_TIMEOUT_SECONDS,
|
|
&blk
|
|
)
|
|
log = nil
|
|
response_data = +""
|
|
response_raw = +""
|
|
|
|
url =
|
|
if model.include?("gpt-4")
|
|
if model.include?("32k")
|
|
URI(SiteSetting.ai_openai_gpt4_32k_url)
|
|
else
|
|
URI(SiteSetting.ai_openai_gpt4_url)
|
|
end
|
|
else
|
|
if model.include?("16k")
|
|
URI(SiteSetting.ai_openai_gpt35_16k_url)
|
|
else
|
|
URI(SiteSetting.ai_openai_gpt35_url)
|
|
end
|
|
end
|
|
headers = { "Content-Type" => "application/json" }
|
|
|
|
if url.host.include? ("azure")
|
|
headers["api-key"] = SiteSetting.ai_openai_api_key
|
|
else
|
|
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
|
|
end
|
|
|
|
if SiteSetting.ai_openai_organization.present?
|
|
headers["OpenAI-Organization"] = SiteSetting.ai_openai_organization
|
|
end
|
|
|
|
payload = { model: model, messages: messages }
|
|
|
|
payload[:temperature] = temperature if temperature
|
|
payload[:top_p] = top_p if top_p
|
|
payload[:max_tokens] = max_tokens if max_tokens
|
|
payload[:functions] = functions if functions
|
|
payload[:stream] = true if block_given?
|
|
|
|
Net::HTTP.start(
|
|
url.host,
|
|
url.port,
|
|
use_ssl: true,
|
|
read_timeout: TIMEOUT,
|
|
open_timeout: TIMEOUT,
|
|
write_timeout: TIMEOUT,
|
|
) do |http|
|
|
request = Net::HTTP::Post.new(url, headers)
|
|
request_body = payload.to_json
|
|
request.body = request_body
|
|
|
|
http.request(request) do |response|
|
|
if retries > 0 && response.code.to_i == 429
|
|
sleep(retry_timeout)
|
|
retries -= 1
|
|
retry_timeout *= RETRY_TIMEOUT_BACKOFF_MULTIPLIER
|
|
return(
|
|
perform!(
|
|
messages,
|
|
model,
|
|
temperature: temperature,
|
|
top_p: top_p,
|
|
max_tokens: max_tokens,
|
|
functions: functions,
|
|
user_id: user_id,
|
|
retries: retries,
|
|
retry_timeout: retry_timeout,
|
|
&blk
|
|
)
|
|
)
|
|
elsif response.code.to_i != 200
|
|
Rails.logger.error(
|
|
"OpenAiCompletions: status: #{response.code.to_i} - body: #{response.body}",
|
|
)
|
|
raise CompletionFailed, "status: #{response.code.to_i} - body: #{response.body}"
|
|
end
|
|
|
|
log =
|
|
AiApiAuditLog.create!(
|
|
provider_id: AiApiAuditLog::Provider::OpenAI,
|
|
raw_request_payload: request_body,
|
|
user_id: user_id,
|
|
)
|
|
|
|
if !blk
|
|
response_body = response.read_body
|
|
parsed_response = JSON.parse(response_body, symbolize_names: true)
|
|
|
|
log.update!(
|
|
raw_response_payload: response_body,
|
|
request_tokens: parsed_response.dig(:usage, :prompt_tokens),
|
|
response_tokens: parsed_response.dig(:usage, :completion_tokens),
|
|
)
|
|
return parsed_response
|
|
end
|
|
|
|
begin
|
|
cancelled = false
|
|
cancel = lambda { cancelled = true }
|
|
|
|
leftover = ""
|
|
|
|
response.read_body do |chunk|
|
|
if cancelled
|
|
http.finish
|
|
return
|
|
end
|
|
|
|
response_raw << chunk
|
|
|
|
(leftover + chunk)
|
|
.split("\n")
|
|
.each do |line|
|
|
data = line.split("data: ", 2)[1]
|
|
next if !data || data == "[DONE]"
|
|
next if cancelled
|
|
|
|
partial = nil
|
|
begin
|
|
partial = JSON.parse(data, symbolize_names: true)
|
|
leftover = ""
|
|
rescue JSON::ParserError
|
|
leftover = line
|
|
end
|
|
|
|
if partial
|
|
response_data << partial.dig(:choices, 0, :delta, :content).to_s
|
|
response_data << partial.dig(:choices, 0, :delta, :function_call).to_s
|
|
|
|
blk.call(partial, cancel)
|
|
end
|
|
end
|
|
rescue IOError
|
|
raise if !cancelled
|
|
end
|
|
end
|
|
|
|
return response_data
|
|
end
|
|
end
|
|
ensure
|
|
if log && block_given?
|
|
request_tokens = DiscourseAi::Tokenizer::OpenAiTokenizer.size(extract_prompt(messages))
|
|
response_tokens = DiscourseAi::Tokenizer::OpenAiTokenizer.size(response_data)
|
|
log.update!(
|
|
raw_response_payload: response_raw,
|
|
request_tokens: request_tokens,
|
|
response_tokens: response_tokens,
|
|
)
|
|
end
|
|
if log && Rails.env.development?
|
|
puts "OpenAiCompletions: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
|
|
end
|
|
end
|
|
|
|
def self.extract_prompt(messages)
|
|
messages.map { |message| message[:content] || message["content"] || "" }.join("\n")
|
|
end
|
|
end
|
|
end
|
|
end
|