mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-08-05 20:53:28 +00:00
* FIX: Llm selector / forced tools / search tool This fixes a few issues: 1. When search was not finding any semantic results we would break the tool 2. Gemin / Anthropic models did not implement forced tools previously despite it being an API option 3. Mechanics around displaying llm selector were not right. If you disabled LLM selector server side persona PM did not work correctly. 4. Disabling native tools for anthropic model moved out of a site setting. This deliberately does not migrate cause this feature is really rare to need now, people who had it set probably did not need it. 5. Updates anthropic model names to latest release * linting * fix a couple of tests I missed * clean up conditional
376 lines
11 KiB
Ruby
376 lines
11 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
require_relative "endpoint_compliance"
|
|
require "aws-eventstream"
|
|
require "aws-sigv4"
|
|
|
|
class BedrockMock < EndpointMock
|
|
end
|
|
|
|
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|
subject(:endpoint) { described_class.new(model) }
|
|
|
|
fab!(:user)
|
|
fab!(:model) { Fabricate(:bedrock_model) }
|
|
|
|
let(:bedrock_mock) { BedrockMock.new(endpoint) }
|
|
|
|
let(:compliance) do
|
|
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user)
|
|
end
|
|
|
|
def encode_message(message)
|
|
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
|
|
io = StringIO.new(wrapped)
|
|
aws_message = Aws::EventStream::Message.new(payload: io)
|
|
Aws::EventStream::Encoder.new.encode(aws_message)
|
|
end
|
|
|
|
describe "function calling" do
|
|
it "supports old school xml function calls" do
|
|
model.provider_params["disable_native_tools"] = true
|
|
model.save!
|
|
|
|
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
|
|
|
incomplete_tool_call = <<~XML.strip
|
|
<thinking>I should be ignored</thinking>
|
|
<search_quality_reflection>also ignored</search_quality_reflection>
|
|
<search_quality_score>0</search_quality_score>
|
|
<function_calls>
|
|
<invoke>
|
|
<tool_name>google</tool_name>
|
|
<parameters><query>sydney weather today</query></parameters>
|
|
</invoke>
|
|
</function_calls>
|
|
XML
|
|
|
|
messages =
|
|
[
|
|
{ type: "message_start", message: { usage: { input_tokens: 9 } } },
|
|
{ type: "content_block_delta", delta: { text: "hello\n" } },
|
|
{ type: "content_block_delta", delta: { text: incomplete_tool_call } },
|
|
{ type: "message_delta", delta: { usage: { output_tokens: 25 } } },
|
|
].map { |message| encode_message(message) }
|
|
|
|
request = nil
|
|
bedrock_mock.with_chunk_array_support do
|
|
stub_request(
|
|
:post,
|
|
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
|
|
)
|
|
.with do |inner_request|
|
|
request = inner_request
|
|
true
|
|
end
|
|
.to_return(status: 200, body: messages)
|
|
|
|
prompt =
|
|
DiscourseAi::Completions::Prompt.new(
|
|
messages: [{ type: :user, content: "what is the weather in sydney" }],
|
|
)
|
|
|
|
tool = {
|
|
name: "google",
|
|
description: "Will search using Google",
|
|
parameters: [
|
|
{ name: "query", description: "The search query", type: "string", required: true },
|
|
],
|
|
}
|
|
|
|
prompt.tools = [tool]
|
|
response = +""
|
|
proxy.generate(prompt, user: user) { |partial| response << partial }
|
|
|
|
expect(request.headers["Authorization"]).to be_present
|
|
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
|
|
|
parsed_body = JSON.parse(request.body)
|
|
expect(parsed_body["system"]).to include("<function_calls>")
|
|
expect(parsed_body["tools"]).to eq(nil)
|
|
expect(parsed_body["stop_sequences"]).to eq(["</function_calls>"])
|
|
|
|
# note we now have a tool_id cause we were normalized
|
|
function_call = <<~XML.strip
|
|
hello
|
|
|
|
|
|
<function_calls>
|
|
<invoke>
|
|
<tool_name>google</tool_name>
|
|
<parameters><query>sydney weather today</query></parameters>
|
|
<tool_id>tool_0</tool_id>
|
|
</invoke>
|
|
</function_calls>
|
|
XML
|
|
|
|
expect(response.strip).to eq(function_call)
|
|
end
|
|
end
|
|
|
|
it "supports streaming function calls" do
|
|
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
|
|
|
request = nil
|
|
|
|
messages =
|
|
[
|
|
{
|
|
type: "message_start",
|
|
message: {
|
|
id: "msg_bdrk_01WYxeNMk6EKn9s98r6XXrAB",
|
|
type: "message",
|
|
role: "assistant",
|
|
model: "claude-3-sonnet-20240307",
|
|
stop_sequence: nil,
|
|
usage: {
|
|
input_tokens: 840,
|
|
output_tokens: 1,
|
|
},
|
|
content: [],
|
|
stop_reason: nil,
|
|
},
|
|
},
|
|
{
|
|
type: "content_block_start",
|
|
index: 0,
|
|
delta: {
|
|
text: "<thinking>I should be ignored</thinking>",
|
|
},
|
|
},
|
|
{
|
|
type: "content_block_start",
|
|
index: 0,
|
|
content_block: {
|
|
type: "tool_use",
|
|
id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
|
|
name: "google",
|
|
input: {
|
|
},
|
|
},
|
|
},
|
|
{
|
|
type: "content_block_delta",
|
|
index: 0,
|
|
delta: {
|
|
type: "input_json_delta",
|
|
partial_json: "",
|
|
},
|
|
},
|
|
{
|
|
type: "content_block_delta",
|
|
index: 0,
|
|
delta: {
|
|
type: "input_json_delta",
|
|
partial_json: "{\"query\": \"s",
|
|
},
|
|
},
|
|
{
|
|
type: "content_block_delta",
|
|
index: 0,
|
|
delta: {
|
|
type: "input_json_delta",
|
|
partial_json: "ydney weat",
|
|
},
|
|
},
|
|
{
|
|
type: "content_block_delta",
|
|
index: 0,
|
|
delta: {
|
|
type: "input_json_delta",
|
|
partial_json: "her today\"}",
|
|
},
|
|
},
|
|
{ type: "content_block_stop", index: 0 },
|
|
{
|
|
type: "message_delta",
|
|
delta: {
|
|
stop_reason: "tool_use",
|
|
stop_sequence: nil,
|
|
},
|
|
usage: {
|
|
output_tokens: 53,
|
|
},
|
|
},
|
|
{
|
|
type: "message_stop",
|
|
"amazon-bedrock-invocationMetrics": {
|
|
inputTokenCount: 846,
|
|
outputTokenCount: 39,
|
|
invocationLatency: 880,
|
|
firstByteLatency: 402,
|
|
},
|
|
},
|
|
].map { |message| encode_message(message) }
|
|
|
|
messages = messages.join("").split
|
|
|
|
bedrock_mock.with_chunk_array_support do
|
|
stub_request(
|
|
:post,
|
|
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
|
|
)
|
|
.with do |inner_request|
|
|
request = inner_request
|
|
true
|
|
end
|
|
.to_return(status: 200, body: messages)
|
|
|
|
prompt =
|
|
DiscourseAi::Completions::Prompt.new(
|
|
messages: [{ type: :user, content: "what is the weather in sydney" }],
|
|
)
|
|
|
|
tool = {
|
|
name: "google",
|
|
description: "Will search using Google",
|
|
parameters: [
|
|
{ name: "query", description: "The search query", type: "string", required: true },
|
|
],
|
|
}
|
|
|
|
prompt.tools = [tool]
|
|
response = +""
|
|
proxy.generate(prompt, user: user) { |partial| response << partial }
|
|
|
|
expect(request.headers["Authorization"]).to be_present
|
|
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
|
|
|
expected_response = (<<~RESPONSE).strip
|
|
<function_calls>
|
|
<invoke>
|
|
<tool_name>google</tool_name>
|
|
<parameters><query>sydney weather today</query></parameters>
|
|
<tool_id>toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7</tool_id>
|
|
</invoke>
|
|
</function_calls>
|
|
RESPONSE
|
|
|
|
expect(response.strip).to eq(expected_response)
|
|
|
|
expected = {
|
|
"max_tokens" => 3000,
|
|
"anthropic_version" => "bedrock-2023-05-31",
|
|
"messages" => [{ "role" => "user", "content" => "what is the weather in sydney" }],
|
|
"tools" => [
|
|
{
|
|
"name" => "google",
|
|
"description" => "Will search using Google",
|
|
"input_schema" => {
|
|
"type" => "object",
|
|
"properties" => {
|
|
"query" => {
|
|
"type" => "string",
|
|
"description" => "The search query",
|
|
},
|
|
},
|
|
"required" => ["query"],
|
|
},
|
|
},
|
|
],
|
|
}
|
|
expect(JSON.parse(request.body)).to eq(expected)
|
|
|
|
log = AiApiAuditLog.order(:id).last
|
|
expect(log.request_tokens).to eq(846)
|
|
expect(log.response_tokens).to eq(39)
|
|
end
|
|
end
|
|
end
|
|
|
|
describe "Claude 3 support" do
|
|
it "supports regular completions" do
|
|
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
|
|
|
request = nil
|
|
|
|
content = {
|
|
content: [text: "hello sam"],
|
|
usage: {
|
|
input_tokens: 10,
|
|
output_tokens: 20,
|
|
},
|
|
}.to_json
|
|
|
|
stub_request(
|
|
:post,
|
|
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke",
|
|
)
|
|
.with do |inner_request|
|
|
request = inner_request
|
|
true
|
|
end
|
|
.to_return(status: 200, body: content)
|
|
|
|
response = proxy.generate("hello world", user: user)
|
|
|
|
expect(request.headers["Authorization"]).to be_present
|
|
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
|
|
|
expected = {
|
|
"max_tokens" => 3000,
|
|
"anthropic_version" => "bedrock-2023-05-31",
|
|
"messages" => [{ "role" => "user", "content" => "hello world" }],
|
|
"system" => "You are a helpful bot",
|
|
}
|
|
expect(JSON.parse(request.body)).to eq(expected)
|
|
|
|
expect(response).to eq("hello sam")
|
|
|
|
log = AiApiAuditLog.order(:id).last
|
|
expect(log.request_tokens).to eq(10)
|
|
expect(log.response_tokens).to eq(20)
|
|
end
|
|
|
|
it "supports claude 3 streaming" do
|
|
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
|
|
|
request = nil
|
|
|
|
messages =
|
|
[
|
|
{ type: "message_start", message: { usage: { input_tokens: 9 } } },
|
|
{ type: "content_block_delta", delta: { text: "hello " } },
|
|
{ type: "content_block_delta", delta: { text: "sam" } },
|
|
{ type: "message_delta", delta: { usage: { output_tokens: 25 } } },
|
|
].map { |message| encode_message(message) }
|
|
|
|
# stream 1 letter at a time
|
|
# cause we need to handle this case
|
|
messages = messages.join("").split
|
|
|
|
bedrock_mock.with_chunk_array_support do
|
|
stub_request(
|
|
:post,
|
|
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
|
|
)
|
|
.with do |inner_request|
|
|
request = inner_request
|
|
true
|
|
end
|
|
.to_return(status: 200, body: messages)
|
|
|
|
response = +""
|
|
proxy.generate("hello world", user: user) { |partial| response << partial }
|
|
|
|
expect(request.headers["Authorization"]).to be_present
|
|
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
|
|
|
expected = {
|
|
"max_tokens" => 3000,
|
|
"anthropic_version" => "bedrock-2023-05-31",
|
|
"messages" => [{ "role" => "user", "content" => "hello world" }],
|
|
"system" => "You are a helpful bot",
|
|
}
|
|
expect(JSON.parse(request.body)).to eq(expected)
|
|
|
|
expect(response).to eq("hello sam")
|
|
|
|
log = AiApiAuditLog.order(:id).last
|
|
expect(log.request_tokens).to eq(9)
|
|
expect(log.response_tokens).to eq(25)
|
|
end
|
|
end
|
|
end
|
|
end
|