mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-08-03 11:53:25 +00:00
Open AI support function calling, this has a very specific shape that other LLMs have not quite adopted. This simulates a command framework using system prompts on LLMs that are not open AI. Features include: - Smart system prompt to steer the LLM - Parameter validation (we ensure all the params are specified correctly) This is being tested on Anthropic at the moment and intial results are promising.
149 lines
4.5 KiB
Ruby
149 lines
4.5 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module ::DiscourseAi
|
|
module Inference
|
|
class OpenAiCompletions
|
|
TIMEOUT = 60
|
|
CompletionFailed = Class.new(StandardError)
|
|
|
|
def self.perform!(
|
|
messages,
|
|
model,
|
|
temperature: nil,
|
|
top_p: nil,
|
|
max_tokens: nil,
|
|
functions: nil,
|
|
user_id: nil
|
|
)
|
|
url =
|
|
if model.include?("gpt-4")
|
|
if model.include?("32k")
|
|
URI(SiteSetting.ai_openai_gpt4_32k_url)
|
|
else
|
|
URI(SiteSetting.ai_openai_gpt4_url)
|
|
end
|
|
else
|
|
if model.include?("16k")
|
|
URI(SiteSetting.ai_openai_gpt35_16k_url)
|
|
else
|
|
URI(SiteSetting.ai_openai_gpt35_url)
|
|
end
|
|
end
|
|
headers = { "Content-Type" => "application/json" }
|
|
|
|
if url.host.include? ("azure")
|
|
headers["api-key"] = SiteSetting.ai_openai_api_key
|
|
else
|
|
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
|
|
end
|
|
|
|
payload = { model: model, messages: messages }
|
|
|
|
payload[:temperature] = temperature if temperature
|
|
payload[:top_p] = top_p if top_p
|
|
payload[:max_tokens] = max_tokens if max_tokens
|
|
payload[:functions] = functions if functions
|
|
payload[:stream] = true if block_given?
|
|
|
|
Net::HTTP.start(
|
|
url.host,
|
|
url.port,
|
|
use_ssl: true,
|
|
read_timeout: TIMEOUT,
|
|
open_timeout: TIMEOUT,
|
|
write_timeout: TIMEOUT,
|
|
) do |http|
|
|
request = Net::HTTP::Post.new(url, headers)
|
|
request_body = payload.to_json
|
|
request.body = request_body
|
|
|
|
http.request(request) do |response|
|
|
if response.code.to_i != 200
|
|
Rails.logger.error(
|
|
"OpenAiCompletions: status: #{response.code.to_i} - body: #{response.body}",
|
|
)
|
|
raise CompletionFailed
|
|
end
|
|
|
|
log =
|
|
AiApiAuditLog.create!(
|
|
provider_id: AiApiAuditLog::Provider::OpenAI,
|
|
raw_request_payload: request_body,
|
|
user_id: user_id,
|
|
)
|
|
|
|
if !block_given?
|
|
response_body = response.read_body
|
|
parsed_response = JSON.parse(response_body, symbolize_names: true)
|
|
|
|
log.update!(
|
|
raw_response_payload: response_body,
|
|
request_tokens: parsed_response.dig(:usage, :prompt_tokens),
|
|
response_tokens: parsed_response.dig(:usage, :completion_tokens),
|
|
)
|
|
return parsed_response
|
|
end
|
|
|
|
response_data = +""
|
|
|
|
begin
|
|
cancelled = false
|
|
cancel = lambda { cancelled = true }
|
|
response_raw = +""
|
|
|
|
leftover = ""
|
|
|
|
response.read_body do |chunk|
|
|
if cancelled
|
|
http.finish
|
|
return
|
|
end
|
|
|
|
response_raw << chunk
|
|
|
|
(leftover + chunk)
|
|
.split("\n")
|
|
.each do |line|
|
|
data = line.split("data: ", 2)[1]
|
|
next if !data || data == "[DONE]"
|
|
next if cancelled
|
|
|
|
partial = nil
|
|
begin
|
|
partial = JSON.parse(data, symbolize_names: true)
|
|
leftover = ""
|
|
rescue JSON::ParserError
|
|
leftover = line
|
|
end
|
|
|
|
if partial
|
|
response_data << partial.dig(:choices, 0, :delta, :content).to_s
|
|
response_data << partial.dig(:choices, 0, :delta, :function_call).to_s
|
|
|
|
yield partial, cancel
|
|
end
|
|
end
|
|
rescue IOError
|
|
raise if !cancelled
|
|
ensure
|
|
log.update!(
|
|
raw_response_payload: response_raw,
|
|
request_tokens:
|
|
DiscourseAi::Tokenizer::OpenAiTokenizer.size(extract_prompt(messages)),
|
|
response_tokens: DiscourseAi::Tokenizer::OpenAiTokenizer.size(response_data),
|
|
)
|
|
end
|
|
end
|
|
|
|
return response_data
|
|
end
|
|
end
|
|
end
|
|
|
|
def self.extract_prompt(messages)
|
|
messages.map { |message| message[:content] || message["content"] || "" }.join("\n")
|
|
end
|
|
end
|
|
end
|
|
end
|