FEATURE: Amazon Nova support via bedrock (#997)
Refactor dialect selection and add Nova API support Change dialect selection to use llm_model object instead of just provider name Add support for Amazon Bedrock's Nova API with native tools Implement Nova-specific message processing and formatting Update specs for Nova and AWS Bedrock endpoints Enhance AWS Bedrock support to handle Nova models Fix Gemini beta API detection logic
This commit is contained in:
parent
5e87a50202
commit
a55216773a
|
@ -5,8 +5,8 @@ module DiscourseAi
|
|||
module Dialects
|
||||
class ChatGpt < Dialect
|
||||
class << self
|
||||
def can_translate?(model_provider)
|
||||
model_provider == "open_ai" || model_provider == "azure"
|
||||
def can_translate?(llm_model)
|
||||
llm_model.provider == "open_ai" || llm_model.provider == "azure"
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -5,8 +5,10 @@ module DiscourseAi
|
|||
module Dialects
|
||||
class Claude < Dialect
|
||||
class << self
|
||||
def can_translate?(provider_name)
|
||||
provider_name == "anthropic" || provider_name == "aws_bedrock"
|
||||
def can_translate?(llm_model)
|
||||
llm_model.provider == "anthropic" ||
|
||||
(llm_model.provider == "aws_bedrock") &&
|
||||
(llm_model.name.include?("anthropic") || llm_model.name.include?("claude"))
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -6,8 +6,8 @@ module DiscourseAi
|
|||
module Completions
|
||||
module Dialects
|
||||
class Command < Dialect
|
||||
def self.can_translate?(model_provider)
|
||||
model_provider == "cohere"
|
||||
def self.can_translate?(llm_model)
|
||||
llm_model.provider == "cohere"
|
||||
end
|
||||
|
||||
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
||||
|
|
|
@ -5,7 +5,7 @@ module DiscourseAi
|
|||
module Dialects
|
||||
class Dialect
|
||||
class << self
|
||||
def can_translate?(model_provider)
|
||||
def can_translate?(llm_model)
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
|
@ -17,11 +17,12 @@ module DiscourseAi
|
|||
DiscourseAi::Completions::Dialects::Command,
|
||||
DiscourseAi::Completions::Dialects::Ollama,
|
||||
DiscourseAi::Completions::Dialects::Mistral,
|
||||
DiscourseAi::Completions::Dialects::Nova,
|
||||
DiscourseAi::Completions::Dialects::OpenAiCompatible,
|
||||
]
|
||||
end
|
||||
|
||||
def dialect_for(model_provider)
|
||||
def dialect_for(llm_model)
|
||||
dialects = []
|
||||
|
||||
if Rails.env.test? || Rails.env.development?
|
||||
|
@ -30,7 +31,7 @@ module DiscourseAi
|
|||
|
||||
dialects = dialects.concat(all_dialects)
|
||||
|
||||
dialect = dialects.find { |d| d.can_translate?(model_provider) }
|
||||
dialect = dialects.find { |d| d.can_translate?(llm_model) }
|
||||
raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect
|
||||
|
||||
dialect
|
||||
|
|
|
@ -5,8 +5,8 @@ module DiscourseAi
|
|||
module Dialects
|
||||
class Fake < Dialect
|
||||
class << self
|
||||
def can_translate?(model_name)
|
||||
model_name == "fake"
|
||||
def can_translate?(llm_model)
|
||||
llm_model.provider == "fake"
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@ module DiscourseAi
|
|||
module Dialects
|
||||
class Gemini < Dialect
|
||||
class << self
|
||||
def can_translate?(model_provider)
|
||||
model_provider == "google"
|
||||
def can_translate?(llm_model)
|
||||
llm_model.provider == "google"
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -80,7 +80,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def beta_api?
|
||||
@beta_api ||= llm_model.name.start_with?("gemini-1.5")
|
||||
@beta_api ||= !llm_model.name.start_with?("gemini-1.0")
|
||||
end
|
||||
|
||||
def system_msg(msg)
|
||||
|
|
|
@ -7,8 +7,8 @@ module DiscourseAi
|
|||
module Dialects
|
||||
class Mistral < ChatGpt
|
||||
class << self
|
||||
def can_translate?(model_provider)
|
||||
model_provider == "mistral"
|
||||
def can_translate?(llm_model)
|
||||
llm_model.provider == "mistral"
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
class Nova < Dialect
|
||||
class << self
|
||||
def can_translate?(llm_model)
|
||||
llm_model.provider == "aws_bedrock" && llm_model.name.include?("amazon.nova")
|
||||
end
|
||||
end
|
||||
|
||||
class NovaPrompt
|
||||
attr_reader :system, :messages, :inference_config, :tool_config
|
||||
|
||||
def initialize(system, messages, inference_config = nil, tool_config = nil)
|
||||
@system = system
|
||||
@messages = messages
|
||||
@inference_config = inference_config
|
||||
@tool_config = tool_config
|
||||
end
|
||||
|
||||
def system_prompt
|
||||
# small hack for size estimation
|
||||
system.to_s
|
||||
end
|
||||
|
||||
def has_tools?
|
||||
tool_config.present?
|
||||
end
|
||||
|
||||
def to_payload(options = nil)
|
||||
stop_sequences = options[:stop_sequences]
|
||||
max_tokens = options[:max_tokens]
|
||||
|
||||
inference_config = options&.slice(:temperature, :top_p, :top_k)
|
||||
|
||||
inference_config[:stopSequences] = stop_sequences if stop_sequences.present?
|
||||
|
||||
inference_config[:max_new_tokens] = max_tokens if max_tokens.present?
|
||||
|
||||
result = { system: system, messages: messages }
|
||||
result[:inferenceConfig] = inference_config if inference_config.present?
|
||||
result[:toolConfig] = tool_config if tool_config.present?
|
||||
|
||||
result
|
||||
end
|
||||
end
|
||||
|
||||
def translate
|
||||
messages = super
|
||||
|
||||
system = messages.shift[:content] if messages.first&.dig(:role) == "system"
|
||||
nova_messages = messages.map { |msg| { role: msg[:role], content: build_content(msg) } }
|
||||
|
||||
inference_config = build_inference_config
|
||||
tool_config = tools_dialect.translated_tools if native_tool_support?
|
||||
|
||||
NovaPrompt.new(
|
||||
system.presence && [{ text: system }],
|
||||
nova_messages,
|
||||
inference_config,
|
||||
tool_config,
|
||||
)
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
llm_model.max_prompt_tokens
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
!llm_model.lookup_custom_param("disable_native_tools")
|
||||
end
|
||||
|
||||
def tools_dialect
|
||||
if native_tool_support?
|
||||
@tools_dialect ||= DiscourseAi::Completions::Dialects::NovaTools.new(prompt.tools)
|
||||
else
|
||||
super
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def build_content(msg)
|
||||
content = []
|
||||
|
||||
existing_content = msg[:content]
|
||||
|
||||
if existing_content.is_a?(Hash)
|
||||
content << existing_content
|
||||
elsif existing_content.is_a?(String)
|
||||
content << { text: existing_content }
|
||||
end
|
||||
|
||||
msg[:images]&.each { |image| content << image }
|
||||
|
||||
content
|
||||
end
|
||||
|
||||
def build_inference_config
|
||||
return unless opts[:inference_config]
|
||||
|
||||
config = {}
|
||||
ic = opts[:inference_config]
|
||||
|
||||
config[:max_new_tokens] = ic[:max_new_tokens] if ic[:max_new_tokens]
|
||||
config[:temperature] = ic[:temperature] if ic[:temperature]
|
||||
config[:top_p] = ic[:top_p] if ic[:top_p]
|
||||
config[:top_k] = ic[:top_k] if ic[:top_k]
|
||||
config[:stopSequences] = ic[:stop_sequences] if ic[:stop_sequences]
|
||||
|
||||
config.present? ? config : nil
|
||||
end
|
||||
|
||||
def detect_format(mime_type)
|
||||
case mime_type
|
||||
when "image/jpeg"
|
||||
"jpeg"
|
||||
when "image/png"
|
||||
"png"
|
||||
when "image/gif"
|
||||
"gif"
|
||||
when "image/webp"
|
||||
"webp"
|
||||
else
|
||||
"jpeg" # default
|
||||
end
|
||||
end
|
||||
|
||||
def system_msg(msg)
|
||||
msg = { role: "system", content: msg[:content] }
|
||||
|
||||
if tools_dialect.instructions.present?
|
||||
msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
|
||||
end
|
||||
|
||||
msg
|
||||
end
|
||||
|
||||
def user_msg(msg)
|
||||
images = nil
|
||||
if vision_support?
|
||||
encoded_uploads = prompt.encoded_uploads(msg)
|
||||
encoded_uploads&.each do |upload|
|
||||
images ||= []
|
||||
images << {
|
||||
image: {
|
||||
format: upload[:format] || detect_format(upload[:mime_type]),
|
||||
source: {
|
||||
bytes: upload[:base64],
|
||||
},
|
||||
},
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
{ role: "user", content: msg[:content], images: images }
|
||||
end
|
||||
|
||||
def model_msg(msg)
|
||||
{ role: "assistant", content: msg[:content] }
|
||||
end
|
||||
|
||||
def tool_msg(msg)
|
||||
translated = tools_dialect.from_raw_tool(msg)
|
||||
{ role: "user", content: translated }
|
||||
end
|
||||
|
||||
def tool_call_msg(msg)
|
||||
translated = tools_dialect.from_raw_tool_call(msg)
|
||||
{ role: "assistant", content: translated }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,83 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
class NovaTools
|
||||
def initialize(tools)
|
||||
@raw_tools = tools
|
||||
end
|
||||
|
||||
def translated_tools
|
||||
return if !@raw_tools.present?
|
||||
|
||||
# note: forced tools are not supported yet toolChoice is always auto
|
||||
{
|
||||
tools:
|
||||
@raw_tools.map do |tool|
|
||||
{
|
||||
toolSpec: {
|
||||
name: tool[:name],
|
||||
description: tool[:description],
|
||||
inputSchema: {
|
||||
json: convert_tool_to_input_schema(tool),
|
||||
},
|
||||
},
|
||||
}
|
||||
end,
|
||||
}
|
||||
end
|
||||
|
||||
# nativ tools require no system instructions
|
||||
def instructions
|
||||
""
|
||||
end
|
||||
|
||||
def from_raw_tool_call(raw_message)
|
||||
{
|
||||
toolUse: {
|
||||
toolUseId: raw_message[:id],
|
||||
name: raw_message[:name],
|
||||
input: JSON.parse(raw_message[:content])["arguments"],
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
def from_raw_tool(raw_message)
|
||||
{
|
||||
toolResult: {
|
||||
toolUseId: raw_message[:id],
|
||||
content: [{ json: JSON.parse(raw_message[:content]) }],
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def convert_tool_to_input_schema(tool)
|
||||
tool = tool.transform_keys(&:to_sym)
|
||||
properties = {}
|
||||
tool[:parameters].each do |param|
|
||||
schema = {}
|
||||
type = param[:type]
|
||||
type = "string" if !%w[string number boolean integer array].include?(type)
|
||||
|
||||
schema[:type] = type
|
||||
|
||||
if enum = param[:enum]
|
||||
schema[:enum] = enum
|
||||
end
|
||||
|
||||
schema[:items] = { type: param[:item_type] } if type == "array"
|
||||
|
||||
schema[:required] = true if param[:required]
|
||||
|
||||
properties[param[:name]] = schema
|
||||
end
|
||||
|
||||
{ type: "object", properties: properties }
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -5,8 +5,8 @@ module DiscourseAi
|
|||
module Dialects
|
||||
class Ollama < Dialect
|
||||
class << self
|
||||
def can_translate?(model_provider)
|
||||
model_provider == "ollama"
|
||||
def can_translate?(llm_model)
|
||||
llm_model.provider == "ollama"
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -5,7 +5,8 @@ module DiscourseAi
|
|||
module Dialects
|
||||
class OpenAiCompatible < Dialect
|
||||
class << self
|
||||
def can_translate?(_model_name)
|
||||
def can_translate?(_llm_model)
|
||||
# fallback dialect
|
||||
true
|
||||
end
|
||||
end
|
||||
|
|
|
@ -6,6 +6,8 @@ module DiscourseAi
|
|||
module Completions
|
||||
module Endpoints
|
||||
class AwsBedrock < Base
|
||||
attr_reader :dialect
|
||||
|
||||
def self.can_contact?(model_provider)
|
||||
model_provider == "aws_bedrock"
|
||||
end
|
||||
|
@ -19,10 +21,15 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def default_options(dialect)
|
||||
options =
|
||||
if dialect.is_a?(DiscourseAi::Completions::Dialects::Claude)
|
||||
max_tokens = 4096
|
||||
max_tokens = 8192 if bedrock_model_id.match?(/3.5/)
|
||||
|
||||
options = { max_tokens: max_tokens, anthropic_version: "bedrock-2023-05-31" }
|
||||
{ max_tokens: max_tokens, anthropic_version: "bedrock-2023-05-31" }
|
||||
else
|
||||
{}
|
||||
end
|
||||
|
||||
options[:stop_sequences] = ["</function_calls>"] if !dialect.native_tool_support? &&
|
||||
dialect.prompt.has_tools?
|
||||
|
@ -86,7 +93,11 @@ module DiscourseAi
|
|||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
@native_tool_support = dialect.native_tool_support?
|
||||
@dialect = dialect
|
||||
|
||||
payload = nil
|
||||
|
||||
if dialect.is_a?(DiscourseAi::Completions::Dialects::Claude)
|
||||
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
|
||||
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
|
||||
|
||||
|
@ -96,7 +107,11 @@ module DiscourseAi
|
|||
payload[:tool_choice] = { type: "tool", name: dialect.tool_choice }
|
||||
end
|
||||
end
|
||||
|
||||
elsif dialect.is_a?(DiscourseAi::Completions::Dialects::Nova)
|
||||
payload = prompt.to_payload(default_options(dialect).merge(model_params))
|
||||
else
|
||||
raise "Unsupported dialect"
|
||||
end
|
||||
payload
|
||||
end
|
||||
|
||||
|
@ -151,6 +166,11 @@ module DiscourseAi
|
|||
i = 0
|
||||
while decoded
|
||||
parsed = JSON.parse(decoded.payload.string)
|
||||
if exception = decoded.headers[":exception-type"]
|
||||
Rails.logger.error("#{self.class.name}: #{exception}: #{parsed}")
|
||||
# TODO based on how often this happens, we may want to raise so we
|
||||
# can retry, this may catch rate limits for example
|
||||
end
|
||||
# perhaps some control message we can just ignore
|
||||
messages << Base64.decode64(parsed["bytes"]) if parsed && parsed["bytes"]
|
||||
|
||||
|
@ -180,8 +200,15 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def processor
|
||||
if dialect.is_a?(DiscourseAi::Completions::Dialects::Claude)
|
||||
@processor ||=
|
||||
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
|
||||
DiscourseAi::Completions::AnthropicMessageProcessor.new(
|
||||
streaming_mode: @streaming_mode,
|
||||
)
|
||||
else
|
||||
@processor ||=
|
||||
DiscourseAi::Completions::NovaMessageProcessor.new(streaming_mode: @streaming_mode)
|
||||
end
|
||||
end
|
||||
|
||||
def xml_tools_enabled?
|
||||
|
|
|
@ -176,8 +176,7 @@ module DiscourseAi
|
|||
|
||||
raise UNKNOWN_MODEL if llm_model.nil?
|
||||
|
||||
model_provider = llm_model.provider
|
||||
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_provider)
|
||||
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(llm_model)
|
||||
|
||||
if @canned_response
|
||||
if @canned_llm && @canned_llm != model
|
||||
|
@ -187,6 +186,7 @@ module DiscourseAi
|
|||
return new(dialect_klass, nil, llm_model, gateway: @canned_response)
|
||||
end
|
||||
|
||||
model_provider = llm_model.provider
|
||||
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_provider)
|
||||
|
||||
new(dialect_klass, gateway_klass, llm_model)
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class DiscourseAi::Completions::NovaMessageProcessor
|
||||
class NovaToolCall
|
||||
attr_reader :name, :raw_json, :id
|
||||
|
||||
def initialize(name, id, partial_tool_calls: false)
|
||||
@name = name
|
||||
@id = id
|
||||
@raw_json = +""
|
||||
@tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {})
|
||||
@streaming_parser =
|
||||
DiscourseAi::Completions::ToolCallProgressTracker.new(self) if partial_tool_calls
|
||||
end
|
||||
|
||||
def append(json)
|
||||
@raw_json << json
|
||||
@streaming_parser << json if @streaming_parser
|
||||
end
|
||||
|
||||
def notify_progress(key, value)
|
||||
@tool_call.partial = true
|
||||
@tool_call.parameters[key.to_sym] = value
|
||||
@has_new_data = true
|
||||
end
|
||||
|
||||
def has_partial?
|
||||
@has_new_data
|
||||
end
|
||||
|
||||
def partial_tool_call
|
||||
@has_new_data = false
|
||||
@tool_call
|
||||
end
|
||||
|
||||
def to_tool_call
|
||||
parameters = JSON.parse(raw_json, symbolize_names: true)
|
||||
# we dupe to avoid poisoning the original tool call
|
||||
@tool_call = @tool_call.dup
|
||||
@tool_call.partial = false
|
||||
@tool_call.parameters = parameters
|
||||
@tool_call
|
||||
end
|
||||
end
|
||||
|
||||
attr_reader :tool_calls, :input_tokens, :output_tokens
|
||||
|
||||
def initialize(streaming_mode:, partial_tool_calls: false)
|
||||
@streaming_mode = streaming_mode
|
||||
@tool_calls = []
|
||||
@current_tool_call = nil
|
||||
@partial_tool_calls = partial_tool_calls
|
||||
end
|
||||
|
||||
def to_tool_calls
|
||||
@tool_calls.map { |tool_call| tool_call.to_tool_call }
|
||||
end
|
||||
|
||||
def process_streamed_message(parsed)
|
||||
return if !parsed
|
||||
|
||||
result = nil
|
||||
|
||||
if tool_start = parsed.dig(:contentBlockStart, :start, :toolUse)
|
||||
@current_tool_call = NovaToolCall.new(tool_start[:name], tool_start[:toolUseId])
|
||||
end
|
||||
|
||||
if tool_progress = parsed.dig(:contentBlockDelta, :delta, :toolUse, :input)
|
||||
@current_tool_call.append(tool_progress)
|
||||
end
|
||||
|
||||
result = @current_tool_call.to_tool_call if parsed[:contentBlockStop] && @current_tool_call
|
||||
|
||||
if metadata = parsed[:metadata]
|
||||
@input_tokens = metadata.dig(:usage, :inputTokens)
|
||||
@output_tokens = metadata.dig(:usage, :outputTokens)
|
||||
end
|
||||
|
||||
result || parsed.dig(:contentBlockDelta, :delta, :text)
|
||||
end
|
||||
|
||||
def process_message(payload)
|
||||
result = []
|
||||
parsed = payload
|
||||
parsed = JSON.parse(payload, symbolize_names: true) if payload.is_a?(String)
|
||||
|
||||
result << parsed.dig(:output, :message, :content, 0, :text)
|
||||
|
||||
@input_tokens = parsed.dig(:usage, :inputTokens)
|
||||
@output_tokens = parsed.dig(:usage, :outputTokens)
|
||||
|
||||
result
|
||||
end
|
||||
end
|
|
@ -64,6 +64,17 @@ Fabricator(:bedrock_model, from: :anthropic_model) do
|
|||
provider_params { { region: "us-east-1", access_key_id: "123456" } }
|
||||
end
|
||||
|
||||
Fabricator(:nova_model, from: :llm_model) do
|
||||
display_name "Amazon Nova pro"
|
||||
name "amazon.nova-pro-v1:0"
|
||||
provider "aws_bedrock"
|
||||
tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer"
|
||||
max_prompt_tokens 300_000
|
||||
api_key "fake"
|
||||
url ""
|
||||
provider_params { { region: "us-east-1", access_key_id: "123456" } }
|
||||
end
|
||||
|
||||
Fabricator(:cohere_model, from: :llm_model) do
|
||||
display_name "Cohere Command R+"
|
||||
name "command-r-plus"
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
||||
let :opus_dialect_klass do
|
||||
DiscourseAi::Completions::Dialects::Dialect.dialect_for("anthropic")
|
||||
end
|
||||
|
||||
fab!(:llm_model) { Fabricate(:anthropic_model, name: "claude-3-opus") }
|
||||
|
||||
let :opus_dialect_klass do
|
||||
DiscourseAi::Completions::Dialects::Dialect.dialect_for(llm_model)
|
||||
end
|
||||
|
||||
describe "#translate" do
|
||||
it "can insert OKs to make stuff interleve properly" do
|
||||
messages = [
|
||||
|
|
|
@ -0,0 +1,190 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::Nova do
|
||||
fab!(:llm_model) { Fabricate(:nova_model, vision_enabled: true) }
|
||||
|
||||
let(:nova_dialect_klass) { DiscourseAi::Completions::Dialects::Dialect.dialect_for(llm_model) }
|
||||
|
||||
it "finds the right dialect" do
|
||||
expect(nova_dialect_klass).to eq(DiscourseAi::Completions::Dialects::Nova)
|
||||
end
|
||||
|
||||
describe "#translate" do
|
||||
it "properly formats a basic conversation" do
|
||||
messages = [
|
||||
{ type: :user, id: "user1", content: "Hello" },
|
||||
{ type: :model, content: "Hi there!" },
|
||||
]
|
||||
|
||||
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
|
||||
dialect = nova_dialect_klass.new(prompt, llm_model)
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated.system).to eq([{ text: "You are a helpful bot" }])
|
||||
expect(translated.messages).to eq(
|
||||
[
|
||||
{ role: "user", content: [{ text: "Hello" }] },
|
||||
{ role: "assistant", content: [{ text: "Hi there!" }] },
|
||||
],
|
||||
)
|
||||
end
|
||||
|
||||
context "with image content" do
|
||||
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
|
||||
let(:upload) do
|
||||
UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
|
||||
end
|
||||
|
||||
it "properly formats messages with images" do
|
||||
messages = [
|
||||
{ type: :user, id: "user1", content: "What's in this image?", upload_ids: [upload.id] },
|
||||
]
|
||||
|
||||
prompt = DiscourseAi::Completions::Prompt.new(messages: messages)
|
||||
|
||||
dialect = nova_dialect_klass.new(prompt, llm_model)
|
||||
translated = dialect.translate
|
||||
|
||||
encoded = prompt.encoded_uploads(messages.first).first[:base64]
|
||||
|
||||
expect(translated.messages.first[:content]).to eq(
|
||||
[
|
||||
{ text: "What's in this image?" },
|
||||
{ image: { format: "jpeg", source: { bytes: encoded } } },
|
||||
],
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "properly formats tool configuration" do
|
||||
tools = [
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: [
|
||||
{ name: "location", type: "string", description: "the city name", required: true },
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
messages = [{ type: :user, content: "What's the weather?" }]
|
||||
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are a helpful bot",
|
||||
messages: messages,
|
||||
tools: tools,
|
||||
)
|
||||
|
||||
dialect = nova_dialect_klass.new(prompt, llm_model)
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated.tool_config).to eq(
|
||||
{
|
||||
tools: [
|
||||
{
|
||||
toolSpec: {
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
inputSchema: {
|
||||
json: {
|
||||
type: "object",
|
||||
properties: {
|
||||
"location" => {
|
||||
type: "string",
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
context "with inference configuration" do
|
||||
it "includes inference configuration when provided" do
|
||||
messages = [{ type: :user, content: "Hello" }]
|
||||
|
||||
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
|
||||
|
||||
dialect = nova_dialect_klass.new(prompt, llm_model)
|
||||
|
||||
options = { temperature: 0.7, top_p: 0.9, max_tokens: 100, stop_sequences: ["STOP"] }
|
||||
|
||||
translated = dialect.translate
|
||||
|
||||
expected = {
|
||||
system: [{ text: "You are a helpful bot" }],
|
||||
messages: [{ role: "user", content: [{ text: "Hello" }] }],
|
||||
inferenceConfig: {
|
||||
temperature: 0.7,
|
||||
top_p: 0.9,
|
||||
stopSequences: ["STOP"],
|
||||
max_new_tokens: 100,
|
||||
},
|
||||
}
|
||||
|
||||
expect(translated.to_payload(options)).to eq(expected)
|
||||
end
|
||||
|
||||
it "omits inference configuration when not provided" do
|
||||
messages = [{ type: :user, content: "Hello" }]
|
||||
|
||||
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
|
||||
|
||||
dialect = nova_dialect_klass.new(prompt, llm_model)
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated.inference_config).to be_nil
|
||||
end
|
||||
end
|
||||
|
||||
it "handles tool calls and responses" do
|
||||
tool_call_prompt = { name: "get_weather", arguments: { location: "London" } }
|
||||
|
||||
messages = [
|
||||
{ type: :user, id: "user1", content: "What's the weather in London?" },
|
||||
{ type: :tool_call, name: "get_weather", id: "tool_id", content: tool_call_prompt.to_json },
|
||||
{ type: :tool, id: "tool_id", content: "Sunny, 22°C".to_json },
|
||||
{ type: :model, content: "The weather in London is sunny with 22°C" },
|
||||
]
|
||||
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are a helpful bot",
|
||||
messages: messages,
|
||||
tools: [
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: [
|
||||
{ name: "location", type: "string", description: "the city name", required: true },
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
dialect = nova_dialect_klass.new(prompt, llm_model)
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated.messages.map { |m| m[:role] }).to eq(%w[user assistant user assistant])
|
||||
expect(translated.messages.last[:content]).to eq(
|
||||
[{ text: "The weather in London is sunny with 22°C" }],
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#max_prompt_tokens" do
|
||||
it "returns the model's max prompt tokens" do
|
||||
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot")
|
||||
dialect = nova_dialect_klass.new(prompt, llm_model)
|
||||
|
||||
expect(dialect.max_prompt_tokens).to eq(llm_model.max_prompt_tokens)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,284 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "endpoint_compliance"
|
||||
require "aws-eventstream"
|
||||
require "aws-sigv4"
|
||||
|
||||
class BedrockMock < EndpointMock
|
||||
end
|
||||
|
||||
# nova is all implemented in bedrock endpoint, split out here
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||
fab!(:user)
|
||||
fab!(:nova_model)
|
||||
|
||||
subject(:endpoint) { described_class.new(nova_model) }
|
||||
|
||||
let(:bedrock_mock) { BedrockMock.new(endpoint) }
|
||||
|
||||
let(:stream_url) do
|
||||
"https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.nova-pro-v1:0/invoke-with-response-stream"
|
||||
end
|
||||
|
||||
def encode_message(message)
|
||||
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
|
||||
io = StringIO.new(wrapped)
|
||||
aws_message = Aws::EventStream::Message.new(payload: io)
|
||||
Aws::EventStream::Encoder.new.encode(aws_message)
|
||||
end
|
||||
|
||||
it "should be able to make a simple request" do
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{nova_model.id}")
|
||||
|
||||
content = {
|
||||
"output" => {
|
||||
"message" => {
|
||||
"content" => [{ "text" => "it is 2." }],
|
||||
"role" => "assistant",
|
||||
},
|
||||
},
|
||||
"stopReason" => "end_turn",
|
||||
"usage" => {
|
||||
"inputTokens" => 14,
|
||||
"outputTokens" => 119,
|
||||
"totalTokens" => 133,
|
||||
"cacheReadInputTokenCount" => nil,
|
||||
"cacheWriteInputTokenCount" => nil,
|
||||
},
|
||||
}.to_json
|
||||
|
||||
stub_request(
|
||||
:post,
|
||||
"https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.nova-pro-v1:0/invoke",
|
||||
).to_return(status: 200, body: content)
|
||||
|
||||
response = proxy.generate("hello world", user: user)
|
||||
expect(response).to eq("it is 2.")
|
||||
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
expect(log.request_tokens).to eq(14)
|
||||
expect(log.response_tokens).to eq(119)
|
||||
end
|
||||
|
||||
it "should be able to make a streaming request" do
|
||||
messages =
|
||||
[
|
||||
{ messageStart: { role: "assistant" } },
|
||||
{ contentBlockDelta: { delta: { text: "Hello" }, contentBlockIndex: 0 } },
|
||||
{ contentBlockStop: { contentBlockIndex: 0 } },
|
||||
{ contentBlockDelta: { delta: { text: "!" }, contentBlockIndex: 1 } },
|
||||
{ contentBlockStop: { contentBlockIndex: 1 } },
|
||||
{
|
||||
metadata: {
|
||||
usage: {
|
||||
inputTokens: 14,
|
||||
outputTokens: 18,
|
||||
},
|
||||
metrics: {
|
||||
},
|
||||
trace: {
|
||||
},
|
||||
},
|
||||
"amazon-bedrock-invocationMetrics": {
|
||||
inputTokenCount: 14,
|
||||
outputTokenCount: 18,
|
||||
invocationLatency: 402,
|
||||
firstByteLatency: 72,
|
||||
},
|
||||
},
|
||||
].map { |message| encode_message(message) }
|
||||
|
||||
stub_request(:post, stream_url).to_return(status: 200, body: messages.join)
|
||||
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{nova_model.id}")
|
||||
responses = []
|
||||
proxy.generate("Hello!", user: user) { |partial| responses << partial }
|
||||
|
||||
expect(responses).to eq(%w[Hello !])
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
expect(log.request_tokens).to eq(14)
|
||||
expect(log.response_tokens).to eq(18)
|
||||
end
|
||||
|
||||
it "should support native streaming tool calls" do
|
||||
#model.provider_params["disable_native_tools"] = true
|
||||
#model.save!
|
||||
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{nova_model.id}")
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are a helpful assistant.",
|
||||
messages: [{ type: :user, content: "what is the time in EST" }],
|
||||
)
|
||||
|
||||
tool = {
|
||||
name: "time",
|
||||
description: "Will look up the current time",
|
||||
parameters: [
|
||||
{ name: "timezone", description: "The timezone", type: "string", required: true },
|
||||
],
|
||||
}
|
||||
|
||||
prompt.tools = [tool]
|
||||
|
||||
messages =
|
||||
[
|
||||
{ messageStart: { role: "assistant" } },
|
||||
{
|
||||
contentBlockStart: {
|
||||
start: {
|
||||
toolUse: {
|
||||
toolUseId: "e1bd7033-7244-4408-b088-1d33cbcf0b67",
|
||||
name: "time",
|
||||
},
|
||||
},
|
||||
contentBlockIndex: 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
contentBlockDelta: {
|
||||
delta: {
|
||||
toolUse: {
|
||||
input: "{\"timezone\":\"EST\"}",
|
||||
},
|
||||
},
|
||||
contentBlockIndex: 0,
|
||||
},
|
||||
},
|
||||
{ contentBlockStop: { contentBlockIndex: 0 } },
|
||||
{ messageStop: { stopReason: "end_turn" } },
|
||||
{
|
||||
metadata: {
|
||||
usage: {
|
||||
inputTokens: 481,
|
||||
outputTokens: 28,
|
||||
},
|
||||
metrics: {
|
||||
},
|
||||
trace: {
|
||||
},
|
||||
},
|
||||
"amazon-bedrock-invocationMetrics": {
|
||||
inputTokenCount: 481,
|
||||
outputTokenCount: 28,
|
||||
invocationLatency: 383,
|
||||
firstByteLatency: 57,
|
||||
},
|
||||
},
|
||||
].map { |message| encode_message(message) }
|
||||
|
||||
request = nil
|
||||
stub_request(:post, stream_url)
|
||||
.with do |inner_request|
|
||||
request = inner_request
|
||||
true
|
||||
end
|
||||
.to_return(status: 200, body: messages)
|
||||
|
||||
response = []
|
||||
bedrock_mock.with_chunk_array_support do
|
||||
proxy.generate(prompt, user: user, max_tokens: 200) { |partial| response << partial }
|
||||
end
|
||||
|
||||
parsed_request = JSON.parse(request.body)
|
||||
expected = {
|
||||
"system" => [{ "text" => "You are a helpful assistant." }],
|
||||
"messages" => [{ "role" => "user", "content" => [{ "text" => "what is the time in EST" }] }],
|
||||
"inferenceConfig" => {
|
||||
"max_new_tokens" => 200,
|
||||
},
|
||||
"toolConfig" => {
|
||||
"tools" => [
|
||||
{
|
||||
"toolSpec" => {
|
||||
"name" => "time",
|
||||
"description" => "Will look up the current time",
|
||||
"inputSchema" => {
|
||||
"json" => {
|
||||
"type" => "object",
|
||||
"properties" => {
|
||||
"timezone" => {
|
||||
"type" => "string",
|
||||
"required" => true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
expect(parsed_request).to eq(expected)
|
||||
expect(response).to eq(
|
||||
[
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "time",
|
||||
id: "e1bd7033-7244-4408-b088-1d33cbcf0b67",
|
||||
parameters: {
|
||||
"timezone" => "EST",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# lets continue and ensure all messages are mapped correctly
|
||||
prompt.push(type: :tool_call, name: "time", content: { timezone: "EST" }.to_json, id: "111")
|
||||
prompt.push(type: :tool, name: "time", content: "1pm".to_json, id: "111")
|
||||
|
||||
# lets just return the tool call again, this is about ensuring we encode the prompt right
|
||||
stub_request(:post, stream_url)
|
||||
.with do |inner_request|
|
||||
request = inner_request
|
||||
true
|
||||
end
|
||||
.to_return(status: 200, body: messages)
|
||||
|
||||
response = []
|
||||
bedrock_mock.with_chunk_array_support do
|
||||
proxy.generate(prompt, user: user, max_tokens: 200) { |partial| response << partial }
|
||||
end
|
||||
|
||||
expected = {
|
||||
system: [{ text: "You are a helpful assistant." }],
|
||||
messages: [
|
||||
{ role: "user", content: [{ text: "what is the time in EST" }] },
|
||||
{
|
||||
role: "assistant",
|
||||
content: [{ toolUse: { toolUseId: "111", name: "time", input: nil } }],
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content: [{ toolResult: { toolUseId: "111", content: [{ json: "1pm" }] } }],
|
||||
},
|
||||
],
|
||||
inferenceConfig: {
|
||||
max_new_tokens: 200,
|
||||
},
|
||||
toolConfig: {
|
||||
tools: [
|
||||
{
|
||||
toolSpec: {
|
||||
name: "time",
|
||||
description: "Will look up the current time",
|
||||
inputSchema: {
|
||||
json: {
|
||||
type: "object",
|
||||
properties: {
|
||||
timezone: {
|
||||
type: "string",
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
expect(JSON.parse(request.body, symbolize_names: true)).to eq(expected)
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue