FEATURE: Amazon Nova support via bedrock (#997)

Refactor dialect selection and add Nova API support

    Change dialect selection to use llm_model object instead of just provider name
    Add support for Amazon Bedrock's Nova API with native tools
    Implement Nova-specific message processing and formatting
    Update specs for Nova and AWS Bedrock endpoints
    Enhance AWS Bedrock support to handle Nova models
    Fix Gemini beta API detection logic
This commit is contained in:
Sam 2024-12-06 07:45:58 +11:00 committed by GitHub
parent 5e87a50202
commit a55216773a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 907 additions and 37 deletions

View File

@ -5,8 +5,8 @@ module DiscourseAi
module Dialects module Dialects
class ChatGpt < Dialect class ChatGpt < Dialect
class << self class << self
def can_translate?(model_provider) def can_translate?(llm_model)
model_provider == "open_ai" || model_provider == "azure" llm_model.provider == "open_ai" || llm_model.provider == "azure"
end end
end end

View File

@ -5,8 +5,10 @@ module DiscourseAi
module Dialects module Dialects
class Claude < Dialect class Claude < Dialect
class << self class << self
def can_translate?(provider_name) def can_translate?(llm_model)
provider_name == "anthropic" || provider_name == "aws_bedrock" llm_model.provider == "anthropic" ||
(llm_model.provider == "aws_bedrock") &&
(llm_model.name.include?("anthropic") || llm_model.name.include?("claude"))
end end
end end

View File

@ -6,8 +6,8 @@ module DiscourseAi
module Completions module Completions
module Dialects module Dialects
class Command < Dialect class Command < Dialect
def self.can_translate?(model_provider) def self.can_translate?(llm_model)
model_provider == "cohere" llm_model.provider == "cohere"
end end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/

View File

@ -5,7 +5,7 @@ module DiscourseAi
module Dialects module Dialects
class Dialect class Dialect
class << self class << self
def can_translate?(model_provider) def can_translate?(llm_model)
raise NotImplemented raise NotImplemented
end end
@ -17,11 +17,12 @@ module DiscourseAi
DiscourseAi::Completions::Dialects::Command, DiscourseAi::Completions::Dialects::Command,
DiscourseAi::Completions::Dialects::Ollama, DiscourseAi::Completions::Dialects::Ollama,
DiscourseAi::Completions::Dialects::Mistral, DiscourseAi::Completions::Dialects::Mistral,
DiscourseAi::Completions::Dialects::Nova,
DiscourseAi::Completions::Dialects::OpenAiCompatible, DiscourseAi::Completions::Dialects::OpenAiCompatible,
] ]
end end
def dialect_for(model_provider) def dialect_for(llm_model)
dialects = [] dialects = []
if Rails.env.test? || Rails.env.development? if Rails.env.test? || Rails.env.development?
@ -30,7 +31,7 @@ module DiscourseAi
dialects = dialects.concat(all_dialects) dialects = dialects.concat(all_dialects)
dialect = dialects.find { |d| d.can_translate?(model_provider) } dialect = dialects.find { |d| d.can_translate?(llm_model) }
raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL if !dialect
dialect dialect

View File

@ -5,8 +5,8 @@ module DiscourseAi
module Dialects module Dialects
class Fake < Dialect class Fake < Dialect
class << self class << self
def can_translate?(model_name) def can_translate?(llm_model)
model_name == "fake" llm_model.provider == "fake"
end end
end end

View File

@ -5,8 +5,8 @@ module DiscourseAi
module Dialects module Dialects
class Gemini < Dialect class Gemini < Dialect
class << self class << self
def can_translate?(model_provider) def can_translate?(llm_model)
model_provider == "google" llm_model.provider == "google"
end end
end end
@ -80,7 +80,7 @@ module DiscourseAi
end end
def beta_api? def beta_api?
@beta_api ||= llm_model.name.start_with?("gemini-1.5") @beta_api ||= !llm_model.name.start_with?("gemini-1.0")
end end
def system_msg(msg) def system_msg(msg)

View File

@ -7,8 +7,8 @@ module DiscourseAi
module Dialects module Dialects
class Mistral < ChatGpt class Mistral < ChatGpt
class << self class << self
def can_translate?(model_provider) def can_translate?(llm_model)
model_provider == "mistral" llm_model.provider == "mistral"
end end
end end

View File

@ -0,0 +1,177 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class Nova < Dialect
class << self
def can_translate?(llm_model)
llm_model.provider == "aws_bedrock" && llm_model.name.include?("amazon.nova")
end
end
class NovaPrompt
attr_reader :system, :messages, :inference_config, :tool_config
def initialize(system, messages, inference_config = nil, tool_config = nil)
@system = system
@messages = messages
@inference_config = inference_config
@tool_config = tool_config
end
def system_prompt
# small hack for size estimation
system.to_s
end
def has_tools?
tool_config.present?
end
def to_payload(options = nil)
stop_sequences = options[:stop_sequences]
max_tokens = options[:max_tokens]
inference_config = options&.slice(:temperature, :top_p, :top_k)
inference_config[:stopSequences] = stop_sequences if stop_sequences.present?
inference_config[:max_new_tokens] = max_tokens if max_tokens.present?
result = { system: system, messages: messages }
result[:inferenceConfig] = inference_config if inference_config.present?
result[:toolConfig] = tool_config if tool_config.present?
result
end
end
def translate
messages = super
system = messages.shift[:content] if messages.first&.dig(:role) == "system"
nova_messages = messages.map { |msg| { role: msg[:role], content: build_content(msg) } }
inference_config = build_inference_config
tool_config = tools_dialect.translated_tools if native_tool_support?
NovaPrompt.new(
system.presence && [{ text: system }],
nova_messages,
inference_config,
tool_config,
)
end
def max_prompt_tokens
llm_model.max_prompt_tokens
end
def native_tool_support?
!llm_model.lookup_custom_param("disable_native_tools")
end
def tools_dialect
if native_tool_support?
@tools_dialect ||= DiscourseAi::Completions::Dialects::NovaTools.new(prompt.tools)
else
super
end
end
private
def build_content(msg)
content = []
existing_content = msg[:content]
if existing_content.is_a?(Hash)
content << existing_content
elsif existing_content.is_a?(String)
content << { text: existing_content }
end
msg[:images]&.each { |image| content << image }
content
end
def build_inference_config
return unless opts[:inference_config]
config = {}
ic = opts[:inference_config]
config[:max_new_tokens] = ic[:max_new_tokens] if ic[:max_new_tokens]
config[:temperature] = ic[:temperature] if ic[:temperature]
config[:top_p] = ic[:top_p] if ic[:top_p]
config[:top_k] = ic[:top_k] if ic[:top_k]
config[:stopSequences] = ic[:stop_sequences] if ic[:stop_sequences]
config.present? ? config : nil
end
def detect_format(mime_type)
case mime_type
when "image/jpeg"
"jpeg"
when "image/png"
"png"
when "image/gif"
"gif"
when "image/webp"
"webp"
else
"jpeg" # default
end
end
def system_msg(msg)
msg = { role: "system", content: msg[:content] }
if tools_dialect.instructions.present?
msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
end
msg
end
def user_msg(msg)
images = nil
if vision_support?
encoded_uploads = prompt.encoded_uploads(msg)
encoded_uploads&.each do |upload|
images ||= []
images << {
image: {
format: upload[:format] || detect_format(upload[:mime_type]),
source: {
bytes: upload[:base64],
},
},
}
end
end
{ role: "user", content: msg[:content], images: images }
end
def model_msg(msg)
{ role: "assistant", content: msg[:content] }
end
def tool_msg(msg)
translated = tools_dialect.from_raw_tool(msg)
{ role: "user", content: translated }
end
def tool_call_msg(msg)
translated = tools_dialect.from_raw_tool_call(msg)
{ role: "assistant", content: translated }
end
end
end
end
end

View File

@ -0,0 +1,83 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class NovaTools
def initialize(tools)
@raw_tools = tools
end
def translated_tools
return if !@raw_tools.present?
# note: forced tools are not supported yet toolChoice is always auto
{
tools:
@raw_tools.map do |tool|
{
toolSpec: {
name: tool[:name],
description: tool[:description],
inputSchema: {
json: convert_tool_to_input_schema(tool),
},
},
}
end,
}
end
# nativ tools require no system instructions
def instructions
""
end
def from_raw_tool_call(raw_message)
{
toolUse: {
toolUseId: raw_message[:id],
name: raw_message[:name],
input: JSON.parse(raw_message[:content])["arguments"],
},
}
end
def from_raw_tool(raw_message)
{
toolResult: {
toolUseId: raw_message[:id],
content: [{ json: JSON.parse(raw_message[:content]) }],
},
}
end
private
def convert_tool_to_input_schema(tool)
tool = tool.transform_keys(&:to_sym)
properties = {}
tool[:parameters].each do |param|
schema = {}
type = param[:type]
type = "string" if !%w[string number boolean integer array].include?(type)
schema[:type] = type
if enum = param[:enum]
schema[:enum] = enum
end
schema[:items] = { type: param[:item_type] } if type == "array"
schema[:required] = true if param[:required]
properties[param[:name]] = schema
end
{ type: "object", properties: properties }
end
end
end
end
end

View File

@ -5,8 +5,8 @@ module DiscourseAi
module Dialects module Dialects
class Ollama < Dialect class Ollama < Dialect
class << self class << self
def can_translate?(model_provider) def can_translate?(llm_model)
model_provider == "ollama" llm_model.provider == "ollama"
end end
end end

View File

@ -5,7 +5,8 @@ module DiscourseAi
module Dialects module Dialects
class OpenAiCompatible < Dialect class OpenAiCompatible < Dialect
class << self class << self
def can_translate?(_model_name) def can_translate?(_llm_model)
# fallback dialect
true true
end end
end end

View File

@ -6,6 +6,8 @@ module DiscourseAi
module Completions module Completions
module Endpoints module Endpoints
class AwsBedrock < Base class AwsBedrock < Base
attr_reader :dialect
def self.can_contact?(model_provider) def self.can_contact?(model_provider)
model_provider == "aws_bedrock" model_provider == "aws_bedrock"
end end
@ -19,10 +21,15 @@ module DiscourseAi
end end
def default_options(dialect) def default_options(dialect)
max_tokens = 4096 options =
max_tokens = 8192 if bedrock_model_id.match?(/3.5/) if dialect.is_a?(DiscourseAi::Completions::Dialects::Claude)
max_tokens = 4096
max_tokens = 8192 if bedrock_model_id.match?(/3.5/)
options = { max_tokens: max_tokens, anthropic_version: "bedrock-2023-05-31" } { max_tokens: max_tokens, anthropic_version: "bedrock-2023-05-31" }
else
{}
end
options[:stop_sequences] = ["</function_calls>"] if !dialect.native_tool_support? && options[:stop_sequences] = ["</function_calls>"] if !dialect.native_tool_support? &&
dialect.prompt.has_tools? dialect.prompt.has_tools?
@ -86,17 +93,25 @@ module DiscourseAi
def prepare_payload(prompt, model_params, dialect) def prepare_payload(prompt, model_params, dialect)
@native_tool_support = dialect.native_tool_support? @native_tool_support = dialect.native_tool_support?
@dialect = dialect
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages) payload = nil
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
if prompt.has_tools? if dialect.is_a?(DiscourseAi::Completions::Dialects::Claude)
payload[:tools] = prompt.tools payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
if dialect.tool_choice.present? payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
payload[:tool_choice] = { type: "tool", name: dialect.tool_choice }
if prompt.has_tools?
payload[:tools] = prompt.tools
if dialect.tool_choice.present?
payload[:tool_choice] = { type: "tool", name: dialect.tool_choice }
end
end end
elsif dialect.is_a?(DiscourseAi::Completions::Dialects::Nova)
payload = prompt.to_payload(default_options(dialect).merge(model_params))
else
raise "Unsupported dialect"
end end
payload payload
end end
@ -151,6 +166,11 @@ module DiscourseAi
i = 0 i = 0
while decoded while decoded
parsed = JSON.parse(decoded.payload.string) parsed = JSON.parse(decoded.payload.string)
if exception = decoded.headers[":exception-type"]
Rails.logger.error("#{self.class.name}: #{exception}: #{parsed}")
# TODO based on how often this happens, we may want to raise so we
# can retry, this may catch rate limits for example
end
# perhaps some control message we can just ignore # perhaps some control message we can just ignore
messages << Base64.decode64(parsed["bytes"]) if parsed && parsed["bytes"] messages << Base64.decode64(parsed["bytes"]) if parsed && parsed["bytes"]
@ -180,8 +200,15 @@ module DiscourseAi
end end
def processor def processor
@processor ||= if dialect.is_a?(DiscourseAi::Completions::Dialects::Claude)
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) @processor ||=
DiscourseAi::Completions::AnthropicMessageProcessor.new(
streaming_mode: @streaming_mode,
)
else
@processor ||=
DiscourseAi::Completions::NovaMessageProcessor.new(streaming_mode: @streaming_mode)
end
end end
def xml_tools_enabled? def xml_tools_enabled?

View File

@ -176,8 +176,7 @@ module DiscourseAi
raise UNKNOWN_MODEL if llm_model.nil? raise UNKNOWN_MODEL if llm_model.nil?
model_provider = llm_model.provider dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(llm_model)
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_provider)
if @canned_response if @canned_response
if @canned_llm && @canned_llm != model if @canned_llm && @canned_llm != model
@ -187,6 +186,7 @@ module DiscourseAi
return new(dialect_klass, nil, llm_model, gateway: @canned_response) return new(dialect_klass, nil, llm_model, gateway: @canned_response)
end end
model_provider = llm_model.provider
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_provider) gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_provider)
new(dialect_klass, gateway_klass, llm_model) new(dialect_klass, gateway_klass, llm_model)

View File

@ -0,0 +1,94 @@
# frozen_string_literal: true
class DiscourseAi::Completions::NovaMessageProcessor
class NovaToolCall
attr_reader :name, :raw_json, :id
def initialize(name, id, partial_tool_calls: false)
@name = name
@id = id
@raw_json = +""
@tool_call = DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: {})
@streaming_parser =
DiscourseAi::Completions::ToolCallProgressTracker.new(self) if partial_tool_calls
end
def append(json)
@raw_json << json
@streaming_parser << json if @streaming_parser
end
def notify_progress(key, value)
@tool_call.partial = true
@tool_call.parameters[key.to_sym] = value
@has_new_data = true
end
def has_partial?
@has_new_data
end
def partial_tool_call
@has_new_data = false
@tool_call
end
def to_tool_call
parameters = JSON.parse(raw_json, symbolize_names: true)
# we dupe to avoid poisoning the original tool call
@tool_call = @tool_call.dup
@tool_call.partial = false
@tool_call.parameters = parameters
@tool_call
end
end
attr_reader :tool_calls, :input_tokens, :output_tokens
def initialize(streaming_mode:, partial_tool_calls: false)
@streaming_mode = streaming_mode
@tool_calls = []
@current_tool_call = nil
@partial_tool_calls = partial_tool_calls
end
def to_tool_calls
@tool_calls.map { |tool_call| tool_call.to_tool_call }
end
def process_streamed_message(parsed)
return if !parsed
result = nil
if tool_start = parsed.dig(:contentBlockStart, :start, :toolUse)
@current_tool_call = NovaToolCall.new(tool_start[:name], tool_start[:toolUseId])
end
if tool_progress = parsed.dig(:contentBlockDelta, :delta, :toolUse, :input)
@current_tool_call.append(tool_progress)
end
result = @current_tool_call.to_tool_call if parsed[:contentBlockStop] && @current_tool_call
if metadata = parsed[:metadata]
@input_tokens = metadata.dig(:usage, :inputTokens)
@output_tokens = metadata.dig(:usage, :outputTokens)
end
result || parsed.dig(:contentBlockDelta, :delta, :text)
end
def process_message(payload)
result = []
parsed = payload
parsed = JSON.parse(payload, symbolize_names: true) if payload.is_a?(String)
result << parsed.dig(:output, :message, :content, 0, :text)
@input_tokens = parsed.dig(:usage, :inputTokens)
@output_tokens = parsed.dig(:usage, :outputTokens)
result
end
end

View File

@ -64,6 +64,17 @@ Fabricator(:bedrock_model, from: :anthropic_model) do
provider_params { { region: "us-east-1", access_key_id: "123456" } } provider_params { { region: "us-east-1", access_key_id: "123456" } }
end end
Fabricator(:nova_model, from: :llm_model) do
display_name "Amazon Nova pro"
name "amazon.nova-pro-v1:0"
provider "aws_bedrock"
tokenizer "DiscourseAi::Tokenizer::OpenAiTokenizer"
max_prompt_tokens 300_000
api_key "fake"
url ""
provider_params { { region: "us-east-1", access_key_id: "123456" } }
end
Fabricator(:cohere_model, from: :llm_model) do Fabricator(:cohere_model, from: :llm_model) do
display_name "Cohere Command R+" display_name "Cohere Command R+"
name "command-r-plus" name "command-r-plus"

View File

@ -1,12 +1,12 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::Claude do RSpec.describe DiscourseAi::Completions::Dialects::Claude do
let :opus_dialect_klass do
DiscourseAi::Completions::Dialects::Dialect.dialect_for("anthropic")
end
fab!(:llm_model) { Fabricate(:anthropic_model, name: "claude-3-opus") } fab!(:llm_model) { Fabricate(:anthropic_model, name: "claude-3-opus") }
let :opus_dialect_klass do
DiscourseAi::Completions::Dialects::Dialect.dialect_for(llm_model)
end
describe "#translate" do describe "#translate" do
it "can insert OKs to make stuff interleve properly" do it "can insert OKs to make stuff interleve properly" do
messages = [ messages = [

View File

@ -0,0 +1,190 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::Nova do
fab!(:llm_model) { Fabricate(:nova_model, vision_enabled: true) }
let(:nova_dialect_klass) { DiscourseAi::Completions::Dialects::Dialect.dialect_for(llm_model) }
it "finds the right dialect" do
expect(nova_dialect_klass).to eq(DiscourseAi::Completions::Dialects::Nova)
end
describe "#translate" do
it "properly formats a basic conversation" do
messages = [
{ type: :user, id: "user1", content: "Hello" },
{ type: :model, content: "Hi there!" },
]
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
dialect = nova_dialect_klass.new(prompt, llm_model)
translated = dialect.translate
expect(translated.system).to eq([{ text: "You are a helpful bot" }])
expect(translated.messages).to eq(
[
{ role: "user", content: [{ text: "Hello" }] },
{ role: "assistant", content: [{ text: "Hi there!" }] },
],
)
end
context "with image content" do
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
let(:upload) do
UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
end
it "properly formats messages with images" do
messages = [
{ type: :user, id: "user1", content: "What's in this image?", upload_ids: [upload.id] },
]
prompt = DiscourseAi::Completions::Prompt.new(messages: messages)
dialect = nova_dialect_klass.new(prompt, llm_model)
translated = dialect.translate
encoded = prompt.encoded_uploads(messages.first).first[:base64]
expect(translated.messages.first[:content]).to eq(
[
{ text: "What's in this image?" },
{ image: { format: "jpeg", source: { bytes: encoded } } },
],
)
end
end
context "with tools" do
it "properly formats tool configuration" do
tools = [
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
],
},
]
messages = [{ type: :user, content: "What's the weather?" }]
prompt =
DiscourseAi::Completions::Prompt.new(
"You are a helpful bot",
messages: messages,
tools: tools,
)
dialect = nova_dialect_klass.new(prompt, llm_model)
translated = dialect.translate
expect(translated.tool_config).to eq(
{
tools: [
{
toolSpec: {
name: "get_weather",
description: "Get the weather in a city",
inputSchema: {
json: {
type: "object",
properties: {
"location" => {
type: "string",
required: true,
},
},
},
},
},
},
],
},
)
end
end
context "with inference configuration" do
it "includes inference configuration when provided" do
messages = [{ type: :user, content: "Hello" }]
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
dialect = nova_dialect_klass.new(prompt, llm_model)
options = { temperature: 0.7, top_p: 0.9, max_tokens: 100, stop_sequences: ["STOP"] }
translated = dialect.translate
expected = {
system: [{ text: "You are a helpful bot" }],
messages: [{ role: "user", content: [{ text: "Hello" }] }],
inferenceConfig: {
temperature: 0.7,
top_p: 0.9,
stopSequences: ["STOP"],
max_new_tokens: 100,
},
}
expect(translated.to_payload(options)).to eq(expected)
end
it "omits inference configuration when not provided" do
messages = [{ type: :user, content: "Hello" }]
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
dialect = nova_dialect_klass.new(prompt, llm_model)
translated = dialect.translate
expect(translated.inference_config).to be_nil
end
end
it "handles tool calls and responses" do
tool_call_prompt = { name: "get_weather", arguments: { location: "London" } }
messages = [
{ type: :user, id: "user1", content: "What's the weather in London?" },
{ type: :tool_call, name: "get_weather", id: "tool_id", content: tool_call_prompt.to_json },
{ type: :tool, id: "tool_id", content: "Sunny, 22°C".to_json },
{ type: :model, content: "The weather in London is sunny with 22°C" },
]
prompt =
DiscourseAi::Completions::Prompt.new(
"You are a helpful bot",
messages: messages,
tools: [
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
],
},
],
)
dialect = nova_dialect_klass.new(prompt, llm_model)
translated = dialect.translate
expect(translated.messages.map { |m| m[:role] }).to eq(%w[user assistant user assistant])
expect(translated.messages.last[:content]).to eq(
[{ text: "The weather in London is sunny with 22°C" }],
)
end
end
describe "#max_prompt_tokens" do
it "returns the model's max prompt tokens" do
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot")
dialect = nova_dialect_klass.new(prompt, llm_model)
expect(dialect.max_prompt_tokens).to eq(llm_model.max_prompt_tokens)
end
end
end

View File

@ -0,0 +1,284 @@
# frozen_string_literal: true
require_relative "endpoint_compliance"
require "aws-eventstream"
require "aws-sigv4"
class BedrockMock < EndpointMock
end
# nova is all implemented in bedrock endpoint, split out here
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
fab!(:user)
fab!(:nova_model)
subject(:endpoint) { described_class.new(nova_model) }
let(:bedrock_mock) { BedrockMock.new(endpoint) }
let(:stream_url) do
"https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.nova-pro-v1:0/invoke-with-response-stream"
end
def encode_message(message)
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
io = StringIO.new(wrapped)
aws_message = Aws::EventStream::Message.new(payload: io)
Aws::EventStream::Encoder.new.encode(aws_message)
end
it "should be able to make a simple request" do
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{nova_model.id}")
content = {
"output" => {
"message" => {
"content" => [{ "text" => "it is 2." }],
"role" => "assistant",
},
},
"stopReason" => "end_turn",
"usage" => {
"inputTokens" => 14,
"outputTokens" => 119,
"totalTokens" => 133,
"cacheReadInputTokenCount" => nil,
"cacheWriteInputTokenCount" => nil,
},
}.to_json
stub_request(
:post,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.nova-pro-v1:0/invoke",
).to_return(status: 200, body: content)
response = proxy.generate("hello world", user: user)
expect(response).to eq("it is 2.")
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(14)
expect(log.response_tokens).to eq(119)
end
it "should be able to make a streaming request" do
messages =
[
{ messageStart: { role: "assistant" } },
{ contentBlockDelta: { delta: { text: "Hello" }, contentBlockIndex: 0 } },
{ contentBlockStop: { contentBlockIndex: 0 } },
{ contentBlockDelta: { delta: { text: "!" }, contentBlockIndex: 1 } },
{ contentBlockStop: { contentBlockIndex: 1 } },
{
metadata: {
usage: {
inputTokens: 14,
outputTokens: 18,
},
metrics: {
},
trace: {
},
},
"amazon-bedrock-invocationMetrics": {
inputTokenCount: 14,
outputTokenCount: 18,
invocationLatency: 402,
firstByteLatency: 72,
},
},
].map { |message| encode_message(message) }
stub_request(:post, stream_url).to_return(status: 200, body: messages.join)
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{nova_model.id}")
responses = []
proxy.generate("Hello!", user: user) { |partial| responses << partial }
expect(responses).to eq(%w[Hello !])
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(14)
expect(log.response_tokens).to eq(18)
end
it "should support native streaming tool calls" do
#model.provider_params["disable_native_tools"] = true
#model.save!
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{nova_model.id}")
prompt =
DiscourseAi::Completions::Prompt.new(
"You are a helpful assistant.",
messages: [{ type: :user, content: "what is the time in EST" }],
)
tool = {
name: "time",
description: "Will look up the current time",
parameters: [
{ name: "timezone", description: "The timezone", type: "string", required: true },
],
}
prompt.tools = [tool]
messages =
[
{ messageStart: { role: "assistant" } },
{
contentBlockStart: {
start: {
toolUse: {
toolUseId: "e1bd7033-7244-4408-b088-1d33cbcf0b67",
name: "time",
},
},
contentBlockIndex: 0,
},
},
{
contentBlockDelta: {
delta: {
toolUse: {
input: "{\"timezone\":\"EST\"}",
},
},
contentBlockIndex: 0,
},
},
{ contentBlockStop: { contentBlockIndex: 0 } },
{ messageStop: { stopReason: "end_turn" } },
{
metadata: {
usage: {
inputTokens: 481,
outputTokens: 28,
},
metrics: {
},
trace: {
},
},
"amazon-bedrock-invocationMetrics": {
inputTokenCount: 481,
outputTokenCount: 28,
invocationLatency: 383,
firstByteLatency: 57,
},
},
].map { |message| encode_message(message) }
request = nil
stub_request(:post, stream_url)
.with do |inner_request|
request = inner_request
true
end
.to_return(status: 200, body: messages)
response = []
bedrock_mock.with_chunk_array_support do
proxy.generate(prompt, user: user, max_tokens: 200) { |partial| response << partial }
end
parsed_request = JSON.parse(request.body)
expected = {
"system" => [{ "text" => "You are a helpful assistant." }],
"messages" => [{ "role" => "user", "content" => [{ "text" => "what is the time in EST" }] }],
"inferenceConfig" => {
"max_new_tokens" => 200,
},
"toolConfig" => {
"tools" => [
{
"toolSpec" => {
"name" => "time",
"description" => "Will look up the current time",
"inputSchema" => {
"json" => {
"type" => "object",
"properties" => {
"timezone" => {
"type" => "string",
"required" => true,
},
},
},
},
},
},
],
},
}
expect(parsed_request).to eq(expected)
expect(response).to eq(
[
DiscourseAi::Completions::ToolCall.new(
name: "time",
id: "e1bd7033-7244-4408-b088-1d33cbcf0b67",
parameters: {
"timezone" => "EST",
},
),
],
)
# lets continue and ensure all messages are mapped correctly
prompt.push(type: :tool_call, name: "time", content: { timezone: "EST" }.to_json, id: "111")
prompt.push(type: :tool, name: "time", content: "1pm".to_json, id: "111")
# lets just return the tool call again, this is about ensuring we encode the prompt right
stub_request(:post, stream_url)
.with do |inner_request|
request = inner_request
true
end
.to_return(status: 200, body: messages)
response = []
bedrock_mock.with_chunk_array_support do
proxy.generate(prompt, user: user, max_tokens: 200) { |partial| response << partial }
end
expected = {
system: [{ text: "You are a helpful assistant." }],
messages: [
{ role: "user", content: [{ text: "what is the time in EST" }] },
{
role: "assistant",
content: [{ toolUse: { toolUseId: "111", name: "time", input: nil } }],
},
{
role: "user",
content: [{ toolResult: { toolUseId: "111", content: [{ json: "1pm" }] } }],
},
],
inferenceConfig: {
max_new_tokens: 200,
},
toolConfig: {
tools: [
{
toolSpec: {
name: "time",
description: "Will look up the current time",
inputSchema: {
json: {
type: "object",
properties: {
timezone: {
type: "string",
required: true,
},
},
},
},
},
},
],
},
}
expect(JSON.parse(request.body, symbolize_names: true)).to eq(expected)
end
end