discourse-ai/spec/lib/completions/endpoints/nova_spec.rb

285 lines
7.9 KiB
Ruby

# frozen_string_literal: true
require_relative "endpoint_compliance"
require "aws-eventstream"
require "aws-sigv4"
class BedrockMock < EndpointMock
end
# nova is all implemented in bedrock endpoint, split out here
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
fab!(:user)
fab!(:nova_model)
subject(:endpoint) { described_class.new(nova_model) }
let(:bedrock_mock) { BedrockMock.new(endpoint) }
let(:stream_url) do
"https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.nova-pro-v1:0/invoke-with-response-stream"
end
def encode_message(message)
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
io = StringIO.new(wrapped)
aws_message = Aws::EventStream::Message.new(payload: io)
Aws::EventStream::Encoder.new.encode(aws_message)
end
it "should be able to make a simple request" do
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{nova_model.id}")
content = {
"output" => {
"message" => {
"content" => [{ "text" => "it is 2." }],
"role" => "assistant",
},
},
"stopReason" => "end_turn",
"usage" => {
"inputTokens" => 14,
"outputTokens" => 119,
"totalTokens" => 133,
"cacheReadInputTokenCount" => nil,
"cacheWriteInputTokenCount" => nil,
},
}.to_json
stub_request(
:post,
"https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.nova-pro-v1:0/invoke",
).to_return(status: 200, body: content)
response = proxy.generate("hello world", user: user)
expect(response).to eq("it is 2.")
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(14)
expect(log.response_tokens).to eq(119)
end
it "should be able to make a streaming request" do
messages =
[
{ messageStart: { role: "assistant" } },
{ contentBlockDelta: { delta: { text: "Hello" }, contentBlockIndex: 0 } },
{ contentBlockStop: { contentBlockIndex: 0 } },
{ contentBlockDelta: { delta: { text: "!" }, contentBlockIndex: 1 } },
{ contentBlockStop: { contentBlockIndex: 1 } },
{
metadata: {
usage: {
inputTokens: 14,
outputTokens: 18,
},
metrics: {
},
trace: {
},
},
"amazon-bedrock-invocationMetrics": {
inputTokenCount: 14,
outputTokenCount: 18,
invocationLatency: 402,
firstByteLatency: 72,
},
},
].map { |message| encode_message(message) }
stub_request(:post, stream_url).to_return(status: 200, body: messages.join)
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{nova_model.id}")
responses = []
proxy.generate("Hello!", user: user) { |partial| responses << partial }
expect(responses).to eq(%w[Hello !])
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(14)
expect(log.response_tokens).to eq(18)
end
it "should support native streaming tool calls" do
#model.provider_params["disable_native_tools"] = true
#model.save!
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{nova_model.id}")
prompt =
DiscourseAi::Completions::Prompt.new(
"You are a helpful assistant.",
messages: [{ type: :user, content: "what is the time in EST" }],
)
tool = {
name: "time",
description: "Will look up the current time",
parameters: [
{ name: "timezone", description: "The timezone", type: "string", required: true },
],
}
prompt.tools = [tool]
messages =
[
{ messageStart: { role: "assistant" } },
{
contentBlockStart: {
start: {
toolUse: {
toolUseId: "e1bd7033-7244-4408-b088-1d33cbcf0b67",
name: "time",
},
},
contentBlockIndex: 0,
},
},
{
contentBlockDelta: {
delta: {
toolUse: {
input: "{\"timezone\":\"EST\"}",
},
},
contentBlockIndex: 0,
},
},
{ contentBlockStop: { contentBlockIndex: 0 } },
{ messageStop: { stopReason: "end_turn" } },
{
metadata: {
usage: {
inputTokens: 481,
outputTokens: 28,
},
metrics: {
},
trace: {
},
},
"amazon-bedrock-invocationMetrics": {
inputTokenCount: 481,
outputTokenCount: 28,
invocationLatency: 383,
firstByteLatency: 57,
},
},
].map { |message| encode_message(message) }
request = nil
stub_request(:post, stream_url)
.with do |inner_request|
request = inner_request
true
end
.to_return(status: 200, body: messages)
response = []
bedrock_mock.with_chunk_array_support do
proxy.generate(prompt, user: user, max_tokens: 200) { |partial| response << partial }
end
parsed_request = JSON.parse(request.body)
expected = {
"system" => [{ "text" => "You are a helpful assistant." }],
"messages" => [{ "role" => "user", "content" => [{ "text" => "what is the time in EST" }] }],
"inferenceConfig" => {
"max_new_tokens" => 200,
},
"toolConfig" => {
"tools" => [
{
"toolSpec" => {
"name" => "time",
"description" => "Will look up the current time",
"inputSchema" => {
"json" => {
"type" => "object",
"properties" => {
"timezone" => {
"type" => "string",
"required" => true,
},
},
},
},
},
},
],
},
}
expect(parsed_request).to eq(expected)
expect(response).to eq(
[
DiscourseAi::Completions::ToolCall.new(
name: "time",
id: "e1bd7033-7244-4408-b088-1d33cbcf0b67",
parameters: {
"timezone" => "EST",
},
),
],
)
# lets continue and ensure all messages are mapped correctly
prompt.push(type: :tool_call, name: "time", content: { timezone: "EST" }.to_json, id: "111")
prompt.push(type: :tool, name: "time", content: "1pm".to_json, id: "111")
# lets just return the tool call again, this is about ensuring we encode the prompt right
stub_request(:post, stream_url)
.with do |inner_request|
request = inner_request
true
end
.to_return(status: 200, body: messages)
response = []
bedrock_mock.with_chunk_array_support do
proxy.generate(prompt, user: user, max_tokens: 200) { |partial| response << partial }
end
expected = {
system: [{ text: "You are a helpful assistant." }],
messages: [
{ role: "user", content: [{ text: "what is the time in EST" }] },
{
role: "assistant",
content: [{ toolUse: { toolUseId: "111", name: "time", input: nil } }],
},
{
role: "user",
content: [{ toolResult: { toolUseId: "111", content: [{ json: "1pm" }] } }],
},
],
inferenceConfig: {
max_new_tokens: 200,
},
toolConfig: {
tools: [
{
toolSpec: {
name: "time",
description: "Will look up the current time",
inputSchema: {
json: {
type: "object",
properties: {
timezone: {
type: "string",
required: true,
},
},
},
},
},
},
],
},
}
expect(JSON.parse(request.body, symbolize_names: true)).to eq(expected)
end
end