discourse-ai/lib/completions/dialects/gemini.rb

110 lines
3.2 KiB
Ruby

# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class Gemini < Dialect
class << self
def can_translate?(model_name)
%w[gemini-pro].include?(model_name)
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
end
end
def translate
# Gemini complains if we don't alternate model/user roles.
noop_model_response = { role: "model", parts: { text: "Ok." } }
messages = prompt.messages
# Gemini doesn't use an assistant msg to improve long-context responses.
messages.pop if messages.last[:type] == :model
trim_messages(messages).reduce([]) do |memo, msg|
if msg[:type] == :system
memo << { role: "user", parts: { text: msg[:content] } }
memo << noop_model_response.dup
elsif msg[:type] == :model
memo << { role: "model", parts: { text: msg[:content] } }
elsif msg[:type] == :tool_call
call_details = JSON.parse(msg[:content], symbolize_names: true)
memo << {
role: "model",
parts: {
functionCall: {
name: call_details[:name],
args: call_details[:arguments],
},
},
}
elsif msg[:type] == :tool
memo << {
role: "function",
parts: {
functionResponse: {
name: msg[:id],
response: {
content: msg[:content],
},
},
},
}
else
# Gemini quirk. Doesn't accept tool -> user or user -> user msgs.
previous_msg_role = memo.last&.dig(:role)
if previous_msg_role == "user" || previous_msg_role == "tool"
memo << noop_model_response.dup
end
memo << { role: "user", parts: { text: msg[:content] } }
end
memo
end
end
def tools
return if prompt.tools.blank?
translated_tools =
prompt.tools.map do |t|
tool = t.slice(:name, :description)
if t[:parameters]
tool[:parameters] = t[:parameters].reduce(
{ type: "object", required: [], properties: {} },
) do |memo, p|
name = p[:name]
memo[:required] << name if p[:required]
memo[:properties][name] = p.except(:name, :required, :item_type)
memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
memo
end
end
tool
end
[{ function_declarations: translated_tools }]
end
def max_prompt_tokens
16_384 # 50% of model tokens
end
protected
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
end
end
end
end