mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-03-09 11:48:47 +00:00
Refactor dialect selection and add Nova API support Change dialect selection to use llm_model object instead of just provider name Add support for Amazon Bedrock's Nova API with native tools Implement Nova-specific message processing and formatting Update specs for Nova and AWS Bedrock endpoints Enhance AWS Bedrock support to handle Nova models Fix Gemini beta API detection logic
285 lines
7.9 KiB
Ruby
285 lines
7.9 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
require_relative "endpoint_compliance"
|
|
require "aws-eventstream"
|
|
require "aws-sigv4"
|
|
|
|
class BedrockMock < EndpointMock
|
|
end
|
|
|
|
# nova is all implemented in bedrock endpoint, split out here
|
|
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|
fab!(:user)
|
|
fab!(:nova_model)
|
|
|
|
subject(:endpoint) { described_class.new(nova_model) }
|
|
|
|
let(:bedrock_mock) { BedrockMock.new(endpoint) }
|
|
|
|
let(:stream_url) do
|
|
"https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.nova-pro-v1:0/invoke-with-response-stream"
|
|
end
|
|
|
|
def encode_message(message)
|
|
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
|
|
io = StringIO.new(wrapped)
|
|
aws_message = Aws::EventStream::Message.new(payload: io)
|
|
Aws::EventStream::Encoder.new.encode(aws_message)
|
|
end
|
|
|
|
it "should be able to make a simple request" do
|
|
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{nova_model.id}")
|
|
|
|
content = {
|
|
"output" => {
|
|
"message" => {
|
|
"content" => [{ "text" => "it is 2." }],
|
|
"role" => "assistant",
|
|
},
|
|
},
|
|
"stopReason" => "end_turn",
|
|
"usage" => {
|
|
"inputTokens" => 14,
|
|
"outputTokens" => 119,
|
|
"totalTokens" => 133,
|
|
"cacheReadInputTokenCount" => nil,
|
|
"cacheWriteInputTokenCount" => nil,
|
|
},
|
|
}.to_json
|
|
|
|
stub_request(
|
|
:post,
|
|
"https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.nova-pro-v1:0/invoke",
|
|
).to_return(status: 200, body: content)
|
|
|
|
response = proxy.generate("hello world", user: user)
|
|
expect(response).to eq("it is 2.")
|
|
|
|
log = AiApiAuditLog.order(:id).last
|
|
expect(log.request_tokens).to eq(14)
|
|
expect(log.response_tokens).to eq(119)
|
|
end
|
|
|
|
it "should be able to make a streaming request" do
|
|
messages =
|
|
[
|
|
{ messageStart: { role: "assistant" } },
|
|
{ contentBlockDelta: { delta: { text: "Hello" }, contentBlockIndex: 0 } },
|
|
{ contentBlockStop: { contentBlockIndex: 0 } },
|
|
{ contentBlockDelta: { delta: { text: "!" }, contentBlockIndex: 1 } },
|
|
{ contentBlockStop: { contentBlockIndex: 1 } },
|
|
{
|
|
metadata: {
|
|
usage: {
|
|
inputTokens: 14,
|
|
outputTokens: 18,
|
|
},
|
|
metrics: {
|
|
},
|
|
trace: {
|
|
},
|
|
},
|
|
"amazon-bedrock-invocationMetrics": {
|
|
inputTokenCount: 14,
|
|
outputTokenCount: 18,
|
|
invocationLatency: 402,
|
|
firstByteLatency: 72,
|
|
},
|
|
},
|
|
].map { |message| encode_message(message) }
|
|
|
|
stub_request(:post, stream_url).to_return(status: 200, body: messages.join)
|
|
|
|
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{nova_model.id}")
|
|
responses = []
|
|
proxy.generate("Hello!", user: user) { |partial| responses << partial }
|
|
|
|
expect(responses).to eq(%w[Hello !])
|
|
log = AiApiAuditLog.order(:id).last
|
|
expect(log.request_tokens).to eq(14)
|
|
expect(log.response_tokens).to eq(18)
|
|
end
|
|
|
|
it "should support native streaming tool calls" do
|
|
#model.provider_params["disable_native_tools"] = true
|
|
#model.save!
|
|
|
|
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{nova_model.id}")
|
|
prompt =
|
|
DiscourseAi::Completions::Prompt.new(
|
|
"You are a helpful assistant.",
|
|
messages: [{ type: :user, content: "what is the time in EST" }],
|
|
)
|
|
|
|
tool = {
|
|
name: "time",
|
|
description: "Will look up the current time",
|
|
parameters: [
|
|
{ name: "timezone", description: "The timezone", type: "string", required: true },
|
|
],
|
|
}
|
|
|
|
prompt.tools = [tool]
|
|
|
|
messages =
|
|
[
|
|
{ messageStart: { role: "assistant" } },
|
|
{
|
|
contentBlockStart: {
|
|
start: {
|
|
toolUse: {
|
|
toolUseId: "e1bd7033-7244-4408-b088-1d33cbcf0b67",
|
|
name: "time",
|
|
},
|
|
},
|
|
contentBlockIndex: 0,
|
|
},
|
|
},
|
|
{
|
|
contentBlockDelta: {
|
|
delta: {
|
|
toolUse: {
|
|
input: "{\"timezone\":\"EST\"}",
|
|
},
|
|
},
|
|
contentBlockIndex: 0,
|
|
},
|
|
},
|
|
{ contentBlockStop: { contentBlockIndex: 0 } },
|
|
{ messageStop: { stopReason: "end_turn" } },
|
|
{
|
|
metadata: {
|
|
usage: {
|
|
inputTokens: 481,
|
|
outputTokens: 28,
|
|
},
|
|
metrics: {
|
|
},
|
|
trace: {
|
|
},
|
|
},
|
|
"amazon-bedrock-invocationMetrics": {
|
|
inputTokenCount: 481,
|
|
outputTokenCount: 28,
|
|
invocationLatency: 383,
|
|
firstByteLatency: 57,
|
|
},
|
|
},
|
|
].map { |message| encode_message(message) }
|
|
|
|
request = nil
|
|
stub_request(:post, stream_url)
|
|
.with do |inner_request|
|
|
request = inner_request
|
|
true
|
|
end
|
|
.to_return(status: 200, body: messages)
|
|
|
|
response = []
|
|
bedrock_mock.with_chunk_array_support do
|
|
proxy.generate(prompt, user: user, max_tokens: 200) { |partial| response << partial }
|
|
end
|
|
|
|
parsed_request = JSON.parse(request.body)
|
|
expected = {
|
|
"system" => [{ "text" => "You are a helpful assistant." }],
|
|
"messages" => [{ "role" => "user", "content" => [{ "text" => "what is the time in EST" }] }],
|
|
"inferenceConfig" => {
|
|
"max_new_tokens" => 200,
|
|
},
|
|
"toolConfig" => {
|
|
"tools" => [
|
|
{
|
|
"toolSpec" => {
|
|
"name" => "time",
|
|
"description" => "Will look up the current time",
|
|
"inputSchema" => {
|
|
"json" => {
|
|
"type" => "object",
|
|
"properties" => {
|
|
"timezone" => {
|
|
"type" => "string",
|
|
"required" => true,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
],
|
|
},
|
|
}
|
|
|
|
expect(parsed_request).to eq(expected)
|
|
expect(response).to eq(
|
|
[
|
|
DiscourseAi::Completions::ToolCall.new(
|
|
name: "time",
|
|
id: "e1bd7033-7244-4408-b088-1d33cbcf0b67",
|
|
parameters: {
|
|
"timezone" => "EST",
|
|
},
|
|
),
|
|
],
|
|
)
|
|
|
|
# lets continue and ensure all messages are mapped correctly
|
|
prompt.push(type: :tool_call, name: "time", content: { timezone: "EST" }.to_json, id: "111")
|
|
prompt.push(type: :tool, name: "time", content: "1pm".to_json, id: "111")
|
|
|
|
# lets just return the tool call again, this is about ensuring we encode the prompt right
|
|
stub_request(:post, stream_url)
|
|
.with do |inner_request|
|
|
request = inner_request
|
|
true
|
|
end
|
|
.to_return(status: 200, body: messages)
|
|
|
|
response = []
|
|
bedrock_mock.with_chunk_array_support do
|
|
proxy.generate(prompt, user: user, max_tokens: 200) { |partial| response << partial }
|
|
end
|
|
|
|
expected = {
|
|
system: [{ text: "You are a helpful assistant." }],
|
|
messages: [
|
|
{ role: "user", content: [{ text: "what is the time in EST" }] },
|
|
{
|
|
role: "assistant",
|
|
content: [{ toolUse: { toolUseId: "111", name: "time", input: nil } }],
|
|
},
|
|
{
|
|
role: "user",
|
|
content: [{ toolResult: { toolUseId: "111", content: [{ json: "1pm" }] } }],
|
|
},
|
|
],
|
|
inferenceConfig: {
|
|
max_new_tokens: 200,
|
|
},
|
|
toolConfig: {
|
|
tools: [
|
|
{
|
|
toolSpec: {
|
|
name: "time",
|
|
description: "Will look up the current time",
|
|
inputSchema: {
|
|
json: {
|
|
type: "object",
|
|
properties: {
|
|
timezone: {
|
|
type: "string",
|
|
required: true,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
],
|
|
},
|
|
}
|
|
|
|
expect(JSON.parse(request.body, symbolize_names: true)).to eq(expected)
|
|
end
|
|
end
|