2023-12-15 12:32:01 -05:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
|
|
|
module DiscourseAi
|
|
|
|
module Completions
|
|
|
|
module Dialects
|
2023-12-18 16:06:01 -05:00
|
|
|
class Gemini < Dialect
|
|
|
|
class << self
|
|
|
|
def can_translate?(model_name)
|
|
|
|
%w[gemini-pro].include?(model_name)
|
|
|
|
end
|
|
|
|
|
|
|
|
def tokenizer
|
|
|
|
DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
|
|
|
|
end
|
2023-12-15 12:32:01 -05:00
|
|
|
end
|
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
def translate
|
2024-01-04 16:15:34 -05:00
|
|
|
# Gemini complains if we don't alternate model/user roles.
|
|
|
|
noop_model_response = { role: "model", parts: { text: "Ok." } }
|
|
|
|
|
2023-12-15 12:32:01 -05:00
|
|
|
gemini_prompt = [
|
|
|
|
{
|
|
|
|
role: "user",
|
|
|
|
parts: {
|
2023-12-18 16:06:01 -05:00
|
|
|
text: [prompt[:insts], prompt[:post_insts].to_s].join("\n"),
|
2023-12-15 12:32:01 -05:00
|
|
|
},
|
|
|
|
},
|
2024-01-04 16:15:34 -05:00
|
|
|
noop_model_response,
|
2023-12-15 12:32:01 -05:00
|
|
|
]
|
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
if prompt[:examples]
|
|
|
|
prompt[:examples].each do |example_pair|
|
2023-12-15 12:32:01 -05:00
|
|
|
gemini_prompt << { role: "user", parts: { text: example_pair.first } }
|
|
|
|
gemini_prompt << { role: "model", parts: { text: example_pair.second } }
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2024-01-02 09:21:13 -05:00
|
|
|
gemini_prompt.concat(conversation_context) if prompt[:conversation_context]
|
2023-12-18 16:06:01 -05:00
|
|
|
|
2024-01-04 16:15:34 -05:00
|
|
|
if prompt[:input]
|
|
|
|
gemini_prompt << noop_model_response.dup if gemini_prompt.last[:role] == "user"
|
|
|
|
|
|
|
|
gemini_prompt << { role: "user", parts: { text: prompt[:input] } }
|
|
|
|
end
|
|
|
|
|
|
|
|
gemini_prompt
|
2023-12-15 12:32:01 -05:00
|
|
|
end
|
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
def tools
|
|
|
|
return if prompt[:tools].blank?
|
|
|
|
|
|
|
|
translated_tools =
|
|
|
|
prompt[:tools].map do |t|
|
2024-01-04 16:15:34 -05:00
|
|
|
tool = t.slice(:name, :description)
|
2023-12-18 16:06:01 -05:00
|
|
|
|
2024-01-04 16:15:34 -05:00
|
|
|
if t[:parameters]
|
|
|
|
tool[:parameters] = t[:parameters].reduce(
|
|
|
|
{ type: "object", required: [], properties: {} },
|
|
|
|
) do |memo, p|
|
|
|
|
name = p[:name]
|
|
|
|
memo[:required] << name if p[:required]
|
2023-12-18 16:06:01 -05:00
|
|
|
|
2024-01-04 16:15:34 -05:00
|
|
|
memo[:properties][name] = p.except(:name, :required, :item_type)
|
|
|
|
|
|
|
|
memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
|
|
|
|
memo
|
|
|
|
end
|
2023-12-18 16:06:01 -05:00
|
|
|
end
|
|
|
|
|
2024-01-04 16:15:34 -05:00
|
|
|
tool
|
2023-12-18 16:06:01 -05:00
|
|
|
end
|
|
|
|
|
|
|
|
[{ function_declarations: translated_tools }]
|
|
|
|
end
|
|
|
|
|
|
|
|
def conversation_context
|
|
|
|
return [] if prompt[:conversation_context].blank?
|
|
|
|
|
2024-01-04 16:15:34 -05:00
|
|
|
flattened_context = flatten_context(prompt[:conversation_context])
|
|
|
|
trimmed_context = trim_context(flattened_context)
|
2023-12-18 16:06:01 -05:00
|
|
|
|
|
|
|
trimmed_context.reverse.map do |context|
|
2024-01-04 16:15:34 -05:00
|
|
|
if context[:type] == "tool_call"
|
|
|
|
function = JSON.parse(context[:content], symbolize_names: true)
|
|
|
|
|
|
|
|
{
|
|
|
|
role: "model",
|
|
|
|
parts: {
|
|
|
|
functionCall: {
|
|
|
|
name: function[:name],
|
|
|
|
args: function[:arguments],
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
elsif context[:type] == "tool"
|
|
|
|
{
|
|
|
|
role: "function",
|
|
|
|
parts: {
|
|
|
|
functionResponse: {
|
|
|
|
name: context[:name],
|
|
|
|
response: {
|
|
|
|
content: context[:content],
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
2023-12-18 16:06:01 -05:00
|
|
|
else
|
2024-01-04 16:15:34 -05:00
|
|
|
{
|
|
|
|
role: context[:type] == "assistant" ? "model" : "user",
|
|
|
|
parts: {
|
|
|
|
text: context[:content],
|
|
|
|
},
|
|
|
|
}
|
2023-12-18 16:06:01 -05:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
def max_prompt_tokens
|
|
|
|
16_384 # 50% of model tokens
|
|
|
|
end
|
|
|
|
|
|
|
|
protected
|
|
|
|
|
|
|
|
def calculate_message_token(context)
|
|
|
|
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
2023-12-15 12:32:01 -05:00
|
|
|
end
|
2024-01-04 16:15:34 -05:00
|
|
|
|
|
|
|
private
|
|
|
|
|
|
|
|
def flatten_context(context)
|
2024-01-08 08:28:03 -05:00
|
|
|
flattened = []
|
|
|
|
context.each do |c|
|
|
|
|
if c[:type] == "multi_turn"
|
|
|
|
# gemini quirk
|
|
|
|
if c[:content].first[:type] == "tool"
|
|
|
|
flattend << { type: "assistant", content: "ok." }
|
|
|
|
end
|
|
|
|
|
|
|
|
flattened.concat(c[:content])
|
2024-01-04 16:15:34 -05:00
|
|
|
else
|
2024-01-08 08:28:03 -05:00
|
|
|
flattened << c
|
2024-01-04 16:15:34 -05:00
|
|
|
end
|
|
|
|
end
|
2024-01-08 08:28:03 -05:00
|
|
|
flattened
|
2024-01-04 16:15:34 -05:00
|
|
|
end
|
2023-12-15 12:32:01 -05:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|