2023-11-23 10:58:54 -05:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
|
|
|
module DiscourseAi
|
|
|
|
module Completions
|
|
|
|
module Dialects
|
2023-12-18 16:06:01 -05:00
|
|
|
class ChatGpt < Dialect
|
|
|
|
class << self
|
|
|
|
def can_translate?(model_name)
|
2023-12-18 20:04:15 -05:00
|
|
|
%w[
|
|
|
|
gpt-3.5-turbo
|
|
|
|
gpt-4
|
|
|
|
gpt-3.5-turbo-16k
|
|
|
|
gpt-4-32k
|
|
|
|
gpt-4-1106-preview
|
|
|
|
gpt-4-turbo
|
|
|
|
].include?(model_name)
|
2023-12-18 16:06:01 -05:00
|
|
|
end
|
|
|
|
|
|
|
|
def tokenizer
|
|
|
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
|
|
|
end
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
def translate
|
2023-11-23 10:58:54 -05:00
|
|
|
open_ai_prompt = [
|
2023-12-18 16:06:01 -05:00
|
|
|
{ role: "system", content: [prompt[:insts], prompt[:post_insts].to_s].join("\n") },
|
2023-11-23 10:58:54 -05:00
|
|
|
]
|
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
if prompt[:examples]
|
|
|
|
prompt[:examples].each do |example_pair|
|
2023-11-23 10:58:54 -05:00
|
|
|
open_ai_prompt << { role: "user", content: example_pair.first }
|
|
|
|
open_ai_prompt << { role: "assistant", content: example_pair.second }
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
open_ai_prompt.concat!(conversation_context) if prompt[:conversation_context]
|
|
|
|
|
|
|
|
open_ai_prompt << { role: "user", content: prompt[:input] } if prompt[:input]
|
|
|
|
|
|
|
|
open_ai_prompt
|
|
|
|
end
|
|
|
|
|
|
|
|
def tools
|
|
|
|
return if prompt[:tools].blank?
|
|
|
|
|
|
|
|
prompt[:tools].map { |t| { type: "function", tool: t } }
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
def conversation_context
|
|
|
|
return [] if prompt[:conversation_context].blank?
|
|
|
|
|
|
|
|
trimmed_context = trim_context(prompt[:conversation_context])
|
|
|
|
|
|
|
|
trimmed_context.reverse.map do |context|
|
|
|
|
translated = context.slice(:content)
|
|
|
|
translated[:role] = context[:type]
|
|
|
|
|
|
|
|
if context[:name]
|
|
|
|
if translated[:role] == "tool"
|
|
|
|
translated[:tool_call_id] = context[:name]
|
|
|
|
else
|
|
|
|
translated[:name] = context[:name]
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
translated
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
def max_prompt_tokens
|
|
|
|
# provide a buffer of 120 tokens - our function counting is not
|
|
|
|
# 100% accurate and getting numbers to align exactly is very hard
|
|
|
|
buffer = (opts[:max_tokens_to_sample] || 2500) + 50
|
|
|
|
|
|
|
|
if tools.present?
|
|
|
|
# note this is about 100 tokens over, OpenAI have a more optimal representation
|
|
|
|
@function_size ||= self.class.tokenizer.size(tools.to_json.to_s)
|
|
|
|
buffer += @function_size
|
|
|
|
end
|
|
|
|
|
|
|
|
model_max_tokens - buffer
|
|
|
|
end
|
|
|
|
|
|
|
|
private
|
|
|
|
|
|
|
|
def per_message_overhead
|
|
|
|
# open ai defines about 4 tokens per message of overhead
|
|
|
|
4
|
|
|
|
end
|
|
|
|
|
|
|
|
def calculate_message_token(context)
|
|
|
|
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
|
|
|
end
|
|
|
|
|
|
|
|
def model_max_tokens
|
|
|
|
case model_name
|
|
|
|
when "gpt-3.5-turbo", "gpt-3.5-turbo-16k"
|
|
|
|
16_384
|
|
|
|
when "gpt-4"
|
|
|
|
8192
|
|
|
|
when "gpt-4-32k"
|
|
|
|
32_768
|
|
|
|
else
|
|
|
|
8192
|
|
|
|
end
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|