discourse-ai/spec/lib/completions/endpoints/aws_bedrock_spec.rb
Roman Rizzi e0bf6adb5b
DEV: Tool support for the LLM service. (#366)
This PR adds tool support to available LLMs. We'll buffer tool invocations and return them instead of making users of this service parse the response.

It also adds support for conversation context in the generic prompt. It includes bot messages, user messages, and tool invocations, which we'll trim to make sure it doesn't exceed the prompt limit, then translate them to the correct dialect.

Finally, It adds some buffering when reading chunks to handle cases when streaming is extremely slow.:M
2023-12-18 18:06:01 -03:00

109 lines
3.0 KiB
Ruby

# frozen_string_literal: true
require_relative "endpoint_examples"
require "aws-eventstream"
require "aws-sigv4"
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) }
let(:model_name) { "claude-2" }
let(:bedrock_name) { "claude-v2" }
let(:generic_prompt) { { insts: "write 3 words" } }
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { request_body }
before do
SiteSetting.ai_bedrock_access_key_id = "123456"
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
SiteSetting.ai_bedrock_region = "us-east-1"
end
def response(content)
{
completion: content,
stop: "\n\nHuman:",
stop_reason: "stop_sequence",
truncated: false,
log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
model: model_name,
exception: nil,
}
end
def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(
:post,
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.#{bedrock_name}/invoke",
)
.with(body: request_body)
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
def stream_line(delta, finish_reason: nil)
encoder = Aws::EventStream::Encoder.new
message =
Aws::EventStream::Message.new(
payload:
StringIO.new(
{
bytes:
Base64.encode64(
{
completion: delta,
stop: finish_reason ? "\n\nHuman:" : nil,
stop_reason: finish_reason,
truncated: false,
log_id: "12b029451c6d18094d868bc04ce83f63",
model: "claude-2",
exception: nil,
}.to_json,
),
}.to_json,
),
)
encoder.encode(message)
end
def stub_streamed_response(prompt, deltas, tool_call: false)
chunks =
deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1)
stream_line(deltas[index], finish_reason: "stop_sequence")
else
stream_line(deltas[index])
end
end
WebMock
.stub_request(
:post,
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.#{bedrock_name}/invoke-with-response-stream",
)
.with(body: stream_request_body)
.to_return(status: 200, body: chunks)
end
let(:tool_deltas) { ["<function", <<~REPLY] }
_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
let(:tool_call) { invocation }
it_behaves_like "an endpoint that can communicate with a completion service"
end