mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-02-20 18:35:55 +00:00
Refactor dialect selection and add Nova API support Change dialect selection to use llm_model object instead of just provider name Add support for Amazon Bedrock's Nova API with native tools Implement Nova-specific message processing and formatting Update specs for Nova and AWS Bedrock endpoints Enhance AWS Bedrock support to handle Nova models Fix Gemini beta API detection logic
178 lines
4.8 KiB
Ruby
178 lines
4.8 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module Completions
|
|
module Dialects
|
|
class Nova < Dialect
|
|
class << self
|
|
def can_translate?(llm_model)
|
|
llm_model.provider == "aws_bedrock" && llm_model.name.include?("amazon.nova")
|
|
end
|
|
end
|
|
|
|
class NovaPrompt
|
|
attr_reader :system, :messages, :inference_config, :tool_config
|
|
|
|
def initialize(system, messages, inference_config = nil, tool_config = nil)
|
|
@system = system
|
|
@messages = messages
|
|
@inference_config = inference_config
|
|
@tool_config = tool_config
|
|
end
|
|
|
|
def system_prompt
|
|
# small hack for size estimation
|
|
system.to_s
|
|
end
|
|
|
|
def has_tools?
|
|
tool_config.present?
|
|
end
|
|
|
|
def to_payload(options = nil)
|
|
stop_sequences = options[:stop_sequences]
|
|
max_tokens = options[:max_tokens]
|
|
|
|
inference_config = options&.slice(:temperature, :top_p, :top_k)
|
|
|
|
inference_config[:stopSequences] = stop_sequences if stop_sequences.present?
|
|
|
|
inference_config[:max_new_tokens] = max_tokens if max_tokens.present?
|
|
|
|
result = { system: system, messages: messages }
|
|
result[:inferenceConfig] = inference_config if inference_config.present?
|
|
result[:toolConfig] = tool_config if tool_config.present?
|
|
|
|
result
|
|
end
|
|
end
|
|
|
|
def translate
|
|
messages = super
|
|
|
|
system = messages.shift[:content] if messages.first&.dig(:role) == "system"
|
|
nova_messages = messages.map { |msg| { role: msg[:role], content: build_content(msg) } }
|
|
|
|
inference_config = build_inference_config
|
|
tool_config = tools_dialect.translated_tools if native_tool_support?
|
|
|
|
NovaPrompt.new(
|
|
system.presence && [{ text: system }],
|
|
nova_messages,
|
|
inference_config,
|
|
tool_config,
|
|
)
|
|
end
|
|
|
|
def max_prompt_tokens
|
|
llm_model.max_prompt_tokens
|
|
end
|
|
|
|
def native_tool_support?
|
|
!llm_model.lookup_custom_param("disable_native_tools")
|
|
end
|
|
|
|
def tools_dialect
|
|
if native_tool_support?
|
|
@tools_dialect ||= DiscourseAi::Completions::Dialects::NovaTools.new(prompt.tools)
|
|
else
|
|
super
|
|
end
|
|
end
|
|
|
|
private
|
|
|
|
def build_content(msg)
|
|
content = []
|
|
|
|
existing_content = msg[:content]
|
|
|
|
if existing_content.is_a?(Hash)
|
|
content << existing_content
|
|
elsif existing_content.is_a?(String)
|
|
content << { text: existing_content }
|
|
end
|
|
|
|
msg[:images]&.each { |image| content << image }
|
|
|
|
content
|
|
end
|
|
|
|
def build_inference_config
|
|
return unless opts[:inference_config]
|
|
|
|
config = {}
|
|
ic = opts[:inference_config]
|
|
|
|
config[:max_new_tokens] = ic[:max_new_tokens] if ic[:max_new_tokens]
|
|
config[:temperature] = ic[:temperature] if ic[:temperature]
|
|
config[:top_p] = ic[:top_p] if ic[:top_p]
|
|
config[:top_k] = ic[:top_k] if ic[:top_k]
|
|
config[:stopSequences] = ic[:stop_sequences] if ic[:stop_sequences]
|
|
|
|
config.present? ? config : nil
|
|
end
|
|
|
|
def detect_format(mime_type)
|
|
case mime_type
|
|
when "image/jpeg"
|
|
"jpeg"
|
|
when "image/png"
|
|
"png"
|
|
when "image/gif"
|
|
"gif"
|
|
when "image/webp"
|
|
"webp"
|
|
else
|
|
"jpeg" # default
|
|
end
|
|
end
|
|
|
|
def system_msg(msg)
|
|
msg = { role: "system", content: msg[:content] }
|
|
|
|
if tools_dialect.instructions.present?
|
|
msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
|
|
end
|
|
|
|
msg
|
|
end
|
|
|
|
def user_msg(msg)
|
|
images = nil
|
|
if vision_support?
|
|
encoded_uploads = prompt.encoded_uploads(msg)
|
|
encoded_uploads&.each do |upload|
|
|
images ||= []
|
|
images << {
|
|
image: {
|
|
format: upload[:format] || detect_format(upload[:mime_type]),
|
|
source: {
|
|
bytes: upload[:base64],
|
|
},
|
|
},
|
|
}
|
|
end
|
|
end
|
|
|
|
{ role: "user", content: msg[:content], images: images }
|
|
end
|
|
|
|
def model_msg(msg)
|
|
{ role: "assistant", content: msg[:content] }
|
|
end
|
|
|
|
def tool_msg(msg)
|
|
translated = tools_dialect.from_raw_tool(msg)
|
|
{ role: "user", content: translated }
|
|
end
|
|
|
|
def tool_call_msg(msg)
|
|
translated = tools_dialect.from_raw_tool_call(msg)
|
|
{ role: "assistant", content: translated }
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|