FEATURE: add Claude 3 sonnet/haiku support for Amazon Bedrock (#534)
This PR consolidates the implements new Anthropic Messages interface for Bedrock Claude endpoints and adds support for the new Claude 3 models (haiku, opus, sonnet). Key changes: - Renamed `AnthropicMessages` and `Anthropic` endpoint classes into a single `Anthropic` class (ditto for ClaudeMessages -> Claude) - Updated `AwsBedrock` endpoints to use the new `/messages` API format for all Claude models - Added `claude-3-haiku`, `claude-3-opus` and `claude-3-sonnet` model support in both Anthropic and AWS Bedrock endpoints - Updated specs for the new consolidated endpoints and Claude 3 model support This refactor removes support for old non messages API which has been deprecated by anthropic
This commit is contained in:
parent
d7ed8180af
commit
f62703760f
|
@ -178,9 +178,16 @@ module DiscourseAi
|
|||
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
|
||||
"fake:fake"
|
||||
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID
|
||||
# no bedrock support yet 18-03
|
||||
"anthropic:claude-3-opus"
|
||||
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID
|
||||
"anthropic:claude-3-sonnet"
|
||||
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(
|
||||
"claude-3-sonnet",
|
||||
)
|
||||
"aws_bedrock:claude-3-sonnet"
|
||||
else
|
||||
"anthropic:claude-3-sonnet"
|
||||
end
|
||||
else
|
||||
nil
|
||||
end
|
||||
|
|
|
@ -216,13 +216,16 @@ Follow the provided writing composition instructions carefully and precisely ste
|
|||
def translate_model(model)
|
||||
return "google:gemini-pro" if model == "gemini-pro"
|
||||
return "open_ai:#{model}" if model.start_with? "gpt"
|
||||
return "anthropic:#{model}" if model.start_with? "claude-3"
|
||||
|
||||
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
|
||||
"aws_bedrock:#{model}"
|
||||
else
|
||||
"anthropic:#{model}"
|
||||
if model.start_with? "claude"
|
||||
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(model)
|
||||
return "aws_bedrock:#{model}"
|
||||
else
|
||||
return "anthropic:#{model}"
|
||||
end
|
||||
end
|
||||
|
||||
raise "Unknown model #{model}"
|
||||
end
|
||||
|
||||
private
|
||||
|
|
|
@ -6,7 +6,9 @@ module DiscourseAi
|
|||
class Claude < Dialect
|
||||
class << self
|
||||
def can_translate?(model_name)
|
||||
%w[claude-instant-1 claude-2].include?(model_name)
|
||||
%w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus].include?(
|
||||
model_name,
|
||||
)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
|
@ -14,53 +16,69 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
class ClaudePrompt
|
||||
attr_reader :system_prompt
|
||||
attr_reader :messages
|
||||
|
||||
def initialize(system_prompt, messages)
|
||||
@system_prompt = system_prompt
|
||||
@messages = messages
|
||||
end
|
||||
end
|
||||
|
||||
def translate
|
||||
messages = prompt.messages
|
||||
system_prompt = +""
|
||||
|
||||
trimmed_messages = trim_messages(messages)
|
||||
messages =
|
||||
trim_messages(messages)
|
||||
.map do |msg|
|
||||
case msg[:type]
|
||||
when :system
|
||||
system_prompt << msg[:content]
|
||||
nil
|
||||
when :tool_call
|
||||
{ role: "assistant", content: tool_call_to_xml(msg) }
|
||||
when :tool
|
||||
{ role: "user", content: tool_result_to_xml(msg) }
|
||||
when :model
|
||||
{ role: "assistant", content: msg[:content] }
|
||||
when :user
|
||||
content = +""
|
||||
content << "#{msg[:id]}: " if msg[:id]
|
||||
content << msg[:content]
|
||||
|
||||
# Need to include this differently
|
||||
last_message = trimmed_messages.last[:type] == :assistant ? trimmed_messages.pop : nil
|
||||
|
||||
claude_prompt =
|
||||
trimmed_messages.reduce(+"") do |memo, msg|
|
||||
if msg[:type] == :tool_call
|
||||
memo << "\n\nAssistant: #{tool_call_to_xml(msg)}"
|
||||
elsif msg[:type] == :system
|
||||
memo << "Human: " unless uses_system_message?
|
||||
memo << msg[:content]
|
||||
if prompt.tools.present?
|
||||
memo << "\n"
|
||||
memo << build_tools_prompt
|
||||
{ role: "user", content: content }
|
||||
end
|
||||
elsif msg[:type] == :model
|
||||
memo << "\n\nAssistant: #{msg[:content]}"
|
||||
elsif msg[:type] == :tool
|
||||
memo << "\n\nHuman:\n"
|
||||
memo << tool_result_to_xml(msg)
|
||||
else
|
||||
memo << "\n\nHuman: "
|
||||
memo << "#{msg[:id]}: " if msg[:id]
|
||||
memo << msg[:content]
|
||||
end
|
||||
.compact
|
||||
|
||||
memo
|
||||
if prompt.tools.present?
|
||||
system_prompt << "\n\n"
|
||||
system_prompt << build_tools_prompt
|
||||
end
|
||||
|
||||
interleving_messages = []
|
||||
|
||||
previous_message = nil
|
||||
messages.each do |message|
|
||||
if previous_message
|
||||
if previous_message[:role] == "user" && message[:role] == "user"
|
||||
interleving_messages << { role: "assistant", content: "OK" }
|
||||
elsif previous_message[:role] == "assistant" && message[:role] == "assistant"
|
||||
interleving_messages << { role: "user", content: "OK" }
|
||||
end
|
||||
end
|
||||
interleving_messages << message
|
||||
previous_message = message
|
||||
end
|
||||
|
||||
claude_prompt << "\n\nAssistant:"
|
||||
claude_prompt << " #{last_message[:content]}:" if last_message
|
||||
|
||||
claude_prompt
|
||||
ClaudePrompt.new(system_prompt.presence, interleving_messages)
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
100_000 # Claude-2.1 has a 200k context window.
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def uses_system_message?
|
||||
model_name == "claude-2"
|
||||
# Longer term it will have over 1 million
|
||||
200_000 # Claude-3 has a 200k context window for now
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,85 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
class ClaudeMessages < Dialect
|
||||
class << self
|
||||
def can_translate?(model_name)
|
||||
# TODO: add haiku not released yet as of 2024-03-05
|
||||
%w[claude-3-sonnet claude-3-opus].include?(model_name)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::AnthropicTokenizer
|
||||
end
|
||||
end
|
||||
|
||||
class ClaudePrompt
|
||||
attr_reader :system_prompt
|
||||
attr_reader :messages
|
||||
|
||||
def initialize(system_prompt, messages)
|
||||
@system_prompt = system_prompt
|
||||
@messages = messages
|
||||
end
|
||||
end
|
||||
|
||||
def translate
|
||||
messages = prompt.messages
|
||||
system_prompt = +""
|
||||
|
||||
messages =
|
||||
trim_messages(messages)
|
||||
.map do |msg|
|
||||
case msg[:type]
|
||||
when :system
|
||||
system_prompt << msg[:content]
|
||||
nil
|
||||
when :tool_call
|
||||
{ role: "assistant", content: tool_call_to_xml(msg) }
|
||||
when :tool
|
||||
{ role: "user", content: tool_result_to_xml(msg) }
|
||||
when :model
|
||||
{ role: "assistant", content: msg[:content] }
|
||||
when :user
|
||||
content = +""
|
||||
content << "#{msg[:id]}: " if msg[:id]
|
||||
content << msg[:content]
|
||||
|
||||
{ role: "user", content: content }
|
||||
end
|
||||
end
|
||||
.compact
|
||||
|
||||
if prompt.tools.present?
|
||||
system_prompt << "\n\n"
|
||||
system_prompt << build_tools_prompt
|
||||
end
|
||||
|
||||
interleving_messages = []
|
||||
|
||||
previous_message = nil
|
||||
messages.each do |message|
|
||||
if previous_message
|
||||
if previous_message[:role] == "user" && message[:role] == "user"
|
||||
interleving_messages << { role: "assistant", content: "OK" }
|
||||
elsif previous_message[:role] == "assistant" && message[:role] == "assistant"
|
||||
interleving_messages << { role: "user", content: "OK" }
|
||||
end
|
||||
end
|
||||
interleving_messages << message
|
||||
previous_message = message
|
||||
end
|
||||
|
||||
ClaudePrompt.new(system_prompt.presence, interleving_messages)
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
# Longer term it will have over 1 million
|
||||
200_000 # Claude-3 has a 200k context window for now
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -11,13 +11,12 @@ module DiscourseAi
|
|||
|
||||
def dialect_for(model_name)
|
||||
dialects = [
|
||||
DiscourseAi::Completions::Dialects::Claude,
|
||||
DiscourseAi::Completions::Dialects::Llama2Classic,
|
||||
DiscourseAi::Completions::Dialects::ChatGpt,
|
||||
DiscourseAi::Completions::Dialects::OrcaStyle,
|
||||
DiscourseAi::Completions::Dialects::Gemini,
|
||||
DiscourseAi::Completions::Dialects::Mixtral,
|
||||
DiscourseAi::Completions::Dialects::ClaudeMessages,
|
||||
DiscourseAi::Completions::Dialects::Claude,
|
||||
]
|
||||
|
||||
if Rails.env.test? || Rails.env.development?
|
||||
|
|
|
@ -6,7 +6,10 @@ module DiscourseAi
|
|||
class Anthropic < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name, model_name)
|
||||
endpoint_name == "anthropic" && %w[claude-instant-1 claude-2].include?(model_name)
|
||||
endpoint_name == "anthropic" &&
|
||||
%w[claude-instant-1 claude-2 claude-3-haiku claude-3-opus claude-3-sonnet].include?(
|
||||
model_name,
|
||||
)
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
|
@ -23,23 +26,32 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def normalize_model_params(model_params)
|
||||
model_params = model_params.dup
|
||||
|
||||
# temperature, stop_sequences are already supported
|
||||
#
|
||||
if model_params[:max_tokens]
|
||||
model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens)
|
||||
end
|
||||
|
||||
# max_tokens, temperature, stop_sequences are already supported
|
||||
model_params
|
||||
end
|
||||
|
||||
def default_options
|
||||
{
|
||||
model: model == "claude-2" ? "claude-2.1" : model,
|
||||
max_tokens_to_sample: 3_000,
|
||||
stop_sequences: ["\n\nHuman:", "</function_calls>"],
|
||||
}
|
||||
def default_options(dialect)
|
||||
# skipping 2.0 support for now, since other models are better
|
||||
mapped_model =
|
||||
case model
|
||||
when "claude-2"
|
||||
"claude-2.1"
|
||||
when "claude-instant-1"
|
||||
"claude-instant-1.2"
|
||||
when "claude-3-haiku"
|
||||
"claude-3-haiku-20240307"
|
||||
when "claude-3-sonnet"
|
||||
"claude-3-sonnet-20240229"
|
||||
when "claude-3-opus"
|
||||
"claude-3-opus-20240229"
|
||||
else
|
||||
raise "Unsupported model: #{model}"
|
||||
end
|
||||
|
||||
options = { model: mapped_model, max_tokens: 3_000 }
|
||||
|
||||
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
|
||||
options
|
||||
end
|
||||
|
||||
def provider_id
|
||||
|
@ -48,15 +60,22 @@ module DiscourseAi
|
|||
|
||||
private
|
||||
|
||||
def model_uri
|
||||
@uri ||= URI("https://api.anthropic.com/v1/complete")
|
||||
# this is an approximation, we will update it later if request goes through
|
||||
def prompt_size(prompt)
|
||||
super(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, _dialect)
|
||||
default_options
|
||||
.merge(model_params)
|
||||
.merge(prompt: prompt)
|
||||
.tap { |payload| payload[:stream] = true if @streaming_mode }
|
||||
def model_uri
|
||||
@uri ||= URI("https://api.anthropic.com/v1/messages")
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
|
||||
|
||||
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
|
||||
payload[:stream] = true if @streaming_mode
|
||||
|
||||
payload
|
||||
end
|
||||
|
||||
def prepare_request(payload)
|
||||
|
@ -69,8 +88,30 @@ module DiscourseAi
|
|||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
end
|
||||
|
||||
def final_log_update(log)
|
||||
log.request_tokens = @input_tokens if @input_tokens
|
||||
log.response_tokens = @output_tokens if @output_tokens
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
JSON.parse(response_raw, symbolize_names: true)[:completion].to_s
|
||||
result = ""
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||
|
||||
if @streaming_mode
|
||||
if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
||||
result = parsed.dig(:delta, :text).to_s
|
||||
elsif parsed[:type] == "message_start"
|
||||
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
|
||||
elsif parsed[:type] == "message_delta"
|
||||
@output_tokens = parsed.dig(:delta, :usage, :output_tokens)
|
||||
end
|
||||
else
|
||||
result = parsed.dig(:content, 0, :text).to_s
|
||||
@input_tokens = parsed.dig(:usage, :input_tokens)
|
||||
@output_tokens = parsed.dig(:usage, :output_tokens)
|
||||
end
|
||||
|
||||
result
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
|
|
|
@ -1,103 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
module Endpoints
|
||||
class AnthropicMessages < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name, model_name)
|
||||
endpoint_name == "anthropic" && %w[claude-3-opus claude-3-sonnet].include?(model_name)
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_anthropic_api_key]
|
||||
end
|
||||
|
||||
def correctly_configured?(_model_name)
|
||||
SiteSetting.ai_anthropic_api_key.present?
|
||||
end
|
||||
|
||||
def endpoint_name(model_name)
|
||||
"Anthropic - #{model_name}"
|
||||
end
|
||||
end
|
||||
|
||||
def normalize_model_params(model_params)
|
||||
# max_tokens, temperature, stop_sequences are already supported
|
||||
model_params
|
||||
end
|
||||
|
||||
def default_options(dialect)
|
||||
options = { model: model + "-20240229", max_tokens: 3_000 }
|
||||
|
||||
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
|
||||
options
|
||||
end
|
||||
|
||||
def provider_id
|
||||
AiApiAuditLog::Provider::Anthropic
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
# this is an approximation, we will update it later if request goes through
|
||||
def prompt_size(prompt)
|
||||
super(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
|
||||
end
|
||||
|
||||
def model_uri
|
||||
@uri ||= URI("https://api.anthropic.com/v1/messages")
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
|
||||
|
||||
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
|
||||
payload[:stream] = true if @streaming_mode
|
||||
|
||||
payload
|
||||
end
|
||||
|
||||
def prepare_request(payload)
|
||||
headers = {
|
||||
"anthropic-version" => "2023-06-01",
|
||||
"x-api-key" => SiteSetting.ai_anthropic_api_key,
|
||||
"content-type" => "application/json",
|
||||
}
|
||||
|
||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
end
|
||||
|
||||
def final_log_update(log)
|
||||
log.request_tokens = @input_tokens if @input_tokens
|
||||
log.response_tokens = @output_tokens if @output_tokens
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
result = ""
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||
|
||||
if @streaming_mode
|
||||
if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
||||
result = parsed.dig(:delta, :text).to_s
|
||||
elsif parsed[:type] == "message_start"
|
||||
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
|
||||
elsif parsed[:type] == "message_delta"
|
||||
@output_tokens = parsed.dig(:delta, :usage, :output_tokens)
|
||||
end
|
||||
else
|
||||
result = parsed.dig(:content, 0, :text).to_s
|
||||
@input_tokens = parsed.dig(:usage, :input_tokens)
|
||||
@output_tokens = parsed.dig(:usage, :output_tokens)
|
||||
end
|
||||
|
||||
result
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
decoded_chunk.split("\n").map { |line| line.split("data: ", 2)[1] }.compact
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -8,17 +8,18 @@ module DiscourseAi
|
|||
class AwsBedrock < Base
|
||||
class << self
|
||||
def can_contact?(endpoint_name, model_name)
|
||||
endpoint_name == "aws_bedrock" && %w[claude-instant-1 claude-2].include?(model_name)
|
||||
endpoint_name == "aws_bedrock" &&
|
||||
%w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet].include?(model_name)
|
||||
end
|
||||
|
||||
def dependant_setting_names
|
||||
%w[ai_bedrock_access_key_id ai_bedrock_secret_access_key ai_bedrock_region]
|
||||
end
|
||||
|
||||
def correctly_configured?(_model_name)
|
||||
def correctly_configured?(model)
|
||||
SiteSetting.ai_bedrock_access_key_id.present? &&
|
||||
SiteSetting.ai_bedrock_secret_access_key.present? &&
|
||||
SiteSetting.ai_bedrock_region.present?
|
||||
SiteSetting.ai_bedrock_region.present? && can_contact?("aws_bedrock", model)
|
||||
end
|
||||
|
||||
def endpoint_name(model_name)
|
||||
|
@ -29,17 +30,15 @@ module DiscourseAi
|
|||
def normalize_model_params(model_params)
|
||||
model_params = model_params.dup
|
||||
|
||||
# temperature, stop_sequences are already supported
|
||||
#
|
||||
if model_params[:max_tokens]
|
||||
model_params[:max_tokens_to_sample] = model_params.delete(:max_tokens)
|
||||
end
|
||||
# max_tokens, temperature, stop_sequences, top_p are already supported
|
||||
|
||||
model_params
|
||||
end
|
||||
|
||||
def default_options
|
||||
{ max_tokens_to_sample: 3_000, stop_sequences: ["\n\nHuman:", "</function_calls>"] }
|
||||
def default_options(dialect)
|
||||
options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" }
|
||||
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
|
||||
options
|
||||
end
|
||||
|
||||
def provider_id
|
||||
|
@ -48,25 +47,40 @@ module DiscourseAi
|
|||
|
||||
private
|
||||
|
||||
def model_uri
|
||||
# Bedrock uses slightly different names
|
||||
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||
bedrock_model_id = model.split("-")
|
||||
bedrock_model_id[-1] = "v#{bedrock_model_id.last}"
|
||||
bedrock_model_id = +(bedrock_model_id.join("-"))
|
||||
def prompt_size(prompt)
|
||||
# approximation
|
||||
super(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
|
||||
end
|
||||
|
||||
bedrock_model_id << ":1" if model == "claude-2" # For claude-2.1
|
||||
def model_uri
|
||||
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||
#
|
||||
# FYI there is a 2.0 version of Claude, very little need to support it given
|
||||
# haiku/sonnet are better fits anyway, we map to claude-2.1
|
||||
bedrock_model_id =
|
||||
case model
|
||||
when "claude-2"
|
||||
"anthropic.claude-v2:1"
|
||||
when "claude-3-haiku"
|
||||
"anthropic.claude-3-haiku-20240307-v1:0"
|
||||
when "claude-3-sonnet"
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
when "claude-instant-1"
|
||||
"anthropic.claude-instant-v1"
|
||||
end
|
||||
|
||||
api_url =
|
||||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.#{bedrock_model_id}/invoke"
|
||||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
|
||||
|
||||
api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url
|
||||
|
||||
URI(api_url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, _dialect)
|
||||
default_options.merge(prompt: prompt).merge(model_params)
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
|
||||
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
|
||||
payload
|
||||
end
|
||||
|
||||
def prepare_request(payload)
|
||||
|
@ -117,8 +131,30 @@ module DiscourseAi
|
|||
nil
|
||||
end
|
||||
|
||||
def final_log_update(log)
|
||||
log.request_tokens = @input_tokens if @input_tokens
|
||||
log.response_tokens = @output_tokens if @output_tokens
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
JSON.parse(response_raw, symbolize_names: true)[:completion].to_s
|
||||
result = ""
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||
|
||||
if @streaming_mode
|
||||
if parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
||||
result = parsed.dig(:delta, :text).to_s
|
||||
elsif parsed[:type] == "message_start"
|
||||
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
|
||||
elsif parsed[:type] == "message_delta"
|
||||
@output_tokens = parsed.dig(:delta, :usage, :output_tokens)
|
||||
end
|
||||
else
|
||||
result = parsed.dig(:content, 0, :text).to_s
|
||||
@input_tokens = parsed.dig(:usage, :input_tokens)
|
||||
@output_tokens = parsed.dig(:usage, :output_tokens)
|
||||
end
|
||||
|
||||
result
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
|
|
|
@ -11,12 +11,11 @@ module DiscourseAi
|
|||
def endpoint_for(provider_name, model_name)
|
||||
endpoints = [
|
||||
DiscourseAi::Completions::Endpoints::AwsBedrock,
|
||||
DiscourseAi::Completions::Endpoints::Anthropic,
|
||||
DiscourseAi::Completions::Endpoints::OpenAi,
|
||||
DiscourseAi::Completions::Endpoints::HuggingFace,
|
||||
DiscourseAi::Completions::Endpoints::Gemini,
|
||||
DiscourseAi::Completions::Endpoints::Vllm,
|
||||
DiscourseAi::Completions::Endpoints::AnthropicMessages,
|
||||
DiscourseAi::Completions::Endpoints::Anthropic,
|
||||
]
|
||||
|
||||
if Rails.env.test? || Rails.env.development?
|
||||
|
|
|
@ -23,8 +23,8 @@ module DiscourseAi
|
|||
# However, since they use the same URL/key settings, there's no reason to duplicate them.
|
||||
@models_by_provider ||=
|
||||
{
|
||||
aws_bedrock: %w[claude-instant-1 claude-2],
|
||||
anthropic: %w[claude-instant-1 claude-2 claude-3-sonnet claude-3-opus],
|
||||
aws_bedrock: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet],
|
||||
anthropic: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus],
|
||||
vllm: %w[
|
||||
mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||
mistralai/Mistral-7B-Instruct-v0.2
|
||||
|
|
|
@ -1,87 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::ClaudeMessages do
|
||||
describe "#translate" do
|
||||
it "can insert OKs to make stuff interleve properly" do
|
||||
messages = [
|
||||
{ type: :user, id: "user1", content: "1" },
|
||||
{ type: :model, content: "2" },
|
||||
{ type: :user, id: "user1", content: "4" },
|
||||
{ type: :user, id: "user1", content: "5" },
|
||||
{ type: :model, content: "6" },
|
||||
]
|
||||
|
||||
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
|
||||
|
||||
dialectKlass = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
|
||||
dialect = dialectKlass.new(prompt, "claude-3-opus")
|
||||
translated = dialect.translate
|
||||
|
||||
expected_messages = [
|
||||
{ role: "user", content: "user1: 1" },
|
||||
{ role: "assistant", content: "2" },
|
||||
{ role: "user", content: "user1: 4" },
|
||||
{ role: "assistant", content: "OK" },
|
||||
{ role: "user", content: "user1: 5" },
|
||||
{ role: "assistant", content: "6" },
|
||||
]
|
||||
|
||||
expect(translated.messages).to eq(expected_messages)
|
||||
end
|
||||
|
||||
it "can properly translate a prompt" do
|
||||
dialect = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
|
||||
|
||||
tools = [
|
||||
{
|
||||
name: "echo",
|
||||
description: "echo a string",
|
||||
parameters: [
|
||||
{ name: "text", type: "string", description: "string to echo", required: true },
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
tool_call_prompt = { name: "echo", arguments: { text: "something" } }
|
||||
|
||||
messages = [
|
||||
{ type: :user, id: "user1", content: "echo something" },
|
||||
{ type: :tool_call, content: tool_call_prompt.to_json },
|
||||
{ type: :tool, id: "tool_id", content: "something".to_json },
|
||||
{ type: :model, content: "I did it" },
|
||||
{ type: :user, id: "user1", content: "echo something else" },
|
||||
]
|
||||
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are a helpful bot",
|
||||
messages: messages,
|
||||
tools: tools,
|
||||
)
|
||||
|
||||
dialect = dialect.new(prompt, "claude-3-opus")
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated.system_prompt).to start_with("You are a helpful bot")
|
||||
expect(translated.system_prompt).to include("echo a string")
|
||||
|
||||
expected = [
|
||||
{ role: "user", content: "user1: echo something" },
|
||||
{
|
||||
role: "assistant",
|
||||
content:
|
||||
"<function_calls>\n<invoke>\n<tool_name>echo</tool_name>\n<parameters>\n<text>something</text>\n</parameters>\n</invoke>\n</function_calls>",
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content:
|
||||
"<function_results>\n<result>\n<tool_name>tool_id</tool_name>\n<json>\n\"something\"\n</json>\n</result>\n</function_results>",
|
||||
},
|
||||
{ role: "assistant", content: "I did it" },
|
||||
{ role: "user", content: "user1: echo something else" },
|
||||
]
|
||||
|
||||
expect(translated.messages).to eq(expected)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -1,100 +1,87 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require_relative "dialect_context"
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
||||
let(:model_name) { "claude-2" }
|
||||
let(:context) { DialectContext.new(described_class, model_name) }
|
||||
|
||||
describe "#translate" do
|
||||
it "translates a prompt written in our generic format to Claude's format" do
|
||||
anthropic_version = (<<~TEXT).strip
|
||||
#{context.system_insts}
|
||||
#{described_class.tool_preamble}
|
||||
<tools>
|
||||
#{context.dialect_tools}</tools>
|
||||
it "can insert OKs to make stuff interleve properly" do
|
||||
messages = [
|
||||
{ type: :user, id: "user1", content: "1" },
|
||||
{ type: :model, content: "2" },
|
||||
{ type: :user, id: "user1", content: "4" },
|
||||
{ type: :user, id: "user1", content: "5" },
|
||||
{ type: :model, content: "6" },
|
||||
]
|
||||
|
||||
Human: #{context.simple_user_input}
|
||||
prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages)
|
||||
|
||||
Assistant:
|
||||
TEXT
|
||||
dialectKlass = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
|
||||
dialect = dialectKlass.new(prompt, "claude-3-opus")
|
||||
translated = dialect.translate
|
||||
|
||||
translated = context.system_user_scenario
|
||||
expected_messages = [
|
||||
{ role: "user", content: "user1: 1" },
|
||||
{ role: "assistant", content: "2" },
|
||||
{ role: "user", content: "user1: 4" },
|
||||
{ role: "assistant", content: "OK" },
|
||||
{ role: "user", content: "user1: 5" },
|
||||
{ role: "assistant", content: "6" },
|
||||
]
|
||||
|
||||
expect(translated).to eq(anthropic_version)
|
||||
expect(translated.messages).to eq(expected_messages)
|
||||
end
|
||||
|
||||
it "translates tool messages" do
|
||||
expected = +(<<~TEXT).strip
|
||||
#{context.system_insts}
|
||||
#{described_class.tool_preamble}
|
||||
<tools>
|
||||
#{context.dialect_tools}</tools>
|
||||
it "can properly translate a prompt" do
|
||||
dialect = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus")
|
||||
|
||||
Human: user1: This is a message by a user
|
||||
tools = [
|
||||
{
|
||||
name: "echo",
|
||||
description: "echo a string",
|
||||
parameters: [
|
||||
{ name: "text", type: "string", description: "string to echo", required: true },
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
Assistant: I'm a previous bot reply, that's why there's no user
|
||||
tool_call_prompt = { name: "echo", arguments: { text: "something" } }
|
||||
|
||||
Human: user1: This is a new message by a user
|
||||
messages = [
|
||||
{ type: :user, id: "user1", content: "echo something" },
|
||||
{ type: :tool_call, content: tool_call_prompt.to_json },
|
||||
{ type: :tool, id: "tool_id", content: "something".to_json },
|
||||
{ type: :model, content: "I did it" },
|
||||
{ type: :user, id: "user1", content: "echo something else" },
|
||||
]
|
||||
|
||||
Assistant: <function_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<parameters>
|
||||
<location>Sydney</location>
|
||||
<unit>c</unit>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
Human:
|
||||
<function_results>
|
||||
<result>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<json>
|
||||
"I'm a tool result"
|
||||
</json>
|
||||
</result>
|
||||
</function_results>
|
||||
|
||||
Assistant:
|
||||
TEXT
|
||||
|
||||
expect(context.multi_turn_scenario).to eq(expected)
|
||||
end
|
||||
|
||||
it "trims content if it's getting too long" do
|
||||
length = 19_000
|
||||
|
||||
translated = context.long_user_input_scenario(length: length)
|
||||
|
||||
expect(translated.length).to be < context.long_message_text(length: length).length
|
||||
end
|
||||
|
||||
it "retains usernames in generated prompt" do
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are a bot",
|
||||
messages: [
|
||||
{ id: "👻", type: :user, content: "Message1" },
|
||||
{ type: :model, content: "Ok" },
|
||||
{ id: "joe", type: :user, content: "Message2" },
|
||||
],
|
||||
"You are a helpful bot",
|
||||
messages: messages,
|
||||
tools: tools,
|
||||
)
|
||||
|
||||
translated = context.dialect(prompt).translate
|
||||
dialect = dialect.new(prompt, "claude-3-opus")
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated).to eq(<<~TEXT.strip)
|
||||
You are a bot
|
||||
expect(translated.system_prompt).to start_with("You are a helpful bot")
|
||||
expect(translated.system_prompt).to include("echo a string")
|
||||
|
||||
Human: 👻: Message1
|
||||
expected = [
|
||||
{ role: "user", content: "user1: echo something" },
|
||||
{
|
||||
role: "assistant",
|
||||
content:
|
||||
"<function_calls>\n<invoke>\n<tool_name>echo</tool_name>\n<parameters>\n<text>something</text>\n</parameters>\n</invoke>\n</function_calls>",
|
||||
},
|
||||
{
|
||||
role: "user",
|
||||
content:
|
||||
"<function_results>\n<result>\n<tool_name>tool_id</tool_name>\n<json>\n\"something\"\n</json>\n</result>\n</function_results>",
|
||||
},
|
||||
{ role: "assistant", content: "I did it" },
|
||||
{ role: "user", content: "user1: echo something else" },
|
||||
]
|
||||
|
||||
Assistant: Ok
|
||||
|
||||
Human: joe: Message2
|
||||
|
||||
Assistant:
|
||||
TEXT
|
||||
expect(translated.messages).to eq(expected)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,395 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::AnthropicMessages do
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-opus") }
|
||||
|
||||
let(:prompt) do
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are hello bot",
|
||||
messages: [type: :user, id: "user1", content: "hello"],
|
||||
)
|
||||
end
|
||||
|
||||
let(:echo_tool) do
|
||||
{
|
||||
name: "echo",
|
||||
description: "echo something",
|
||||
parameters: [{ name: "text", type: "string", description: "text to echo", required: true }],
|
||||
}
|
||||
end
|
||||
|
||||
let(:google_tool) do
|
||||
{
|
||||
name: "google",
|
||||
description: "google something",
|
||||
parameters: [
|
||||
{ name: "query", type: "string", description: "text to google", required: true },
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
let(:prompt_with_echo_tool) do
|
||||
prompt_with_tools = prompt
|
||||
prompt.tools = [echo_tool]
|
||||
prompt_with_tools
|
||||
end
|
||||
|
||||
let(:prompt_with_google_tool) do
|
||||
prompt_with_tools = prompt
|
||||
prompt.tools = [echo_tool]
|
||||
prompt_with_tools
|
||||
end
|
||||
|
||||
before { SiteSetting.ai_anthropic_api_key = "123" }
|
||||
|
||||
it "does not eat spaces with tool calls" do
|
||||
body = <<~STRING
|
||||
event: message_start
|
||||
data: {"type":"message_start","message":{"id":"msg_019kmW9Q3GqfWmuFJbePJTBR","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":347,"output_tokens":1}}}
|
||||
|
||||
event: content_block_start
|
||||
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
|
||||
|
||||
event: ping
|
||||
data: {"type": "ping"}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<function"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"calls"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<invoke"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<tool"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"name"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"google"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</tool"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"name"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<parameters"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<query"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"top"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"10"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"things"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" do"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" in"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" japan"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" for"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" tourists"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</query"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</parameters"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</invoke"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_stop
|
||||
data: {"type":"content_block_stop","index":0}
|
||||
|
||||
event: message_delta
|
||||
data: {"type":"message_delta","delta":{"stop_reason":"stop_sequence","stop_sequence":"</function_calls>"},"usage":{"output_tokens":57}}
|
||||
|
||||
event: message_stop
|
||||
data: {"type":"message_stop"}
|
||||
STRING
|
||||
|
||||
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
|
||||
|
||||
result = +""
|
||||
llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial|
|
||||
result << partial
|
||||
end
|
||||
|
||||
expected = (<<~TEXT).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>google</tool_name>
|
||||
<parameters>
|
||||
<query>top 10 things to do in japan for tourists</query>
|
||||
</parameters>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
|
||||
expect(result.strip).to eq(expected)
|
||||
end
|
||||
|
||||
it "can stream a response" do
|
||||
body = (<<~STRING).strip
|
||||
event: message_start
|
||||
data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}}
|
||||
|
||||
event: content_block_start
|
||||
data: {"type": "content_block_start", "index":0, "content_block": {"type": "text", "text": ""}}
|
||||
|
||||
event: ping
|
||||
data: {"type": "ping"}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}}
|
||||
|
||||
event: content_block_stop
|
||||
data: {"type": "content_block_stop", "index": 0}
|
||||
|
||||
event: message_delta
|
||||
data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}}
|
||||
|
||||
event: message_stop
|
||||
data: {"type": "message_stop"}
|
||||
STRING
|
||||
|
||||
parsed_body = nil
|
||||
|
||||
stub_request(:post, "https://api.anthropic.com/v1/messages").with(
|
||||
body:
|
||||
proc do |req_body|
|
||||
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
||||
true
|
||||
end,
|
||||
headers: {
|
||||
"Content-Type" => "application/json",
|
||||
"X-Api-Key" => "123",
|
||||
"Anthropic-Version" => "2023-06-01",
|
||||
},
|
||||
).to_return(status: 200, body: body)
|
||||
|
||||
result = +""
|
||||
llm.generate(prompt, user: Discourse.system_user) { |partial, cancel| result << partial }
|
||||
|
||||
expect(result).to eq("Hello!")
|
||||
|
||||
expected_body = {
|
||||
model: "claude-3-opus-20240229",
|
||||
max_tokens: 3000,
|
||||
messages: [{ role: "user", content: "user1: hello" }],
|
||||
system: "You are hello bot",
|
||||
stream: true,
|
||||
}
|
||||
expect(parsed_body).to eq(expected_body)
|
||||
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
|
||||
expect(log.request_tokens).to eq(25)
|
||||
expect(log.response_tokens).to eq(15)
|
||||
end
|
||||
|
||||
it "can return multiple function calls" do
|
||||
functions = <<~FUNCTIONS
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>echo</tool_name>
|
||||
<parameters>
|
||||
<text>something</text>
|
||||
</parameters>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>echo</tool_name>
|
||||
<parameters>
|
||||
<text>something else</text>
|
||||
</parameters>
|
||||
</invoke>
|
||||
FUNCTIONS
|
||||
|
||||
body = <<~STRING
|
||||
{
|
||||
"content": [
|
||||
{
|
||||
"text": "Hello!\n\n#{functions}\njunk",
|
||||
"type": "text"
|
||||
}
|
||||
],
|
||||
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||
"model": "claude-3-opus-20240229",
|
||||
"role": "assistant",
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": null,
|
||||
"type": "message",
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 25
|
||||
}
|
||||
}
|
||||
STRING
|
||||
|
||||
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
|
||||
|
||||
result = llm.generate(prompt_with_echo_tool, user: Discourse.system_user)
|
||||
|
||||
expected = (<<~EXPECTED).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>echo</tool_name>
|
||||
<parameters>
|
||||
<text>something</text>
|
||||
</parameters>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>echo</tool_name>
|
||||
<parameters>
|
||||
<text>something else</text>
|
||||
</parameters>
|
||||
<tool_id>tool_1</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
EXPECTED
|
||||
|
||||
expect(result.strip).to eq(expected)
|
||||
end
|
||||
|
||||
it "can operate in regular mode" do
|
||||
body = <<~STRING
|
||||
{
|
||||
"content": [
|
||||
{
|
||||
"text": "Hello!",
|
||||
"type": "text"
|
||||
}
|
||||
],
|
||||
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||
"model": "claude-3-opus-20240229",
|
||||
"role": "assistant",
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": null,
|
||||
"type": "message",
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 25
|
||||
}
|
||||
}
|
||||
STRING
|
||||
|
||||
parsed_body = nil
|
||||
stub_request(:post, "https://api.anthropic.com/v1/messages").with(
|
||||
body:
|
||||
proc do |req_body|
|
||||
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
||||
true
|
||||
end,
|
||||
headers: {
|
||||
"Content-Type" => "application/json",
|
||||
"X-Api-Key" => "123",
|
||||
"Anthropic-Version" => "2023-06-01",
|
||||
},
|
||||
).to_return(status: 200, body: body)
|
||||
|
||||
result = llm.generate(prompt, user: Discourse.system_user)
|
||||
expect(result).to eq("Hello!")
|
||||
|
||||
expected_body = {
|
||||
model: "claude-3-opus-20240229",
|
||||
max_tokens: 3000,
|
||||
messages: [{ role: "user", content: "user1: hello" }],
|
||||
system: "You are hello bot",
|
||||
}
|
||||
expect(parsed_body).to eq(expected_body)
|
||||
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
|
||||
expect(log.request_tokens).to eq(10)
|
||||
expect(log.response_tokens).to eq(25)
|
||||
end
|
||||
end
|
|
@ -1,96 +1,395 @@
|
|||
# frozen_String_literal: true
|
||||
# frozen_string_literal: true
|
||||
|
||||
require_relative "endpoint_compliance"
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||
let(:llm) { DiscourseAi::Completions::Llm.proxy("anthropic:claude-3-opus") }
|
||||
|
||||
class AnthropicMock < EndpointMock
|
||||
def response(content)
|
||||
let(:prompt) do
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are hello bot",
|
||||
messages: [type: :user, id: "user1", content: "hello"],
|
||||
)
|
||||
end
|
||||
|
||||
let(:echo_tool) do
|
||||
{
|
||||
completion: content,
|
||||
stop: "\n\nHuman:",
|
||||
stop_reason: "stop_sequence",
|
||||
truncated: false,
|
||||
log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
|
||||
model: "claude-2",
|
||||
exception: nil,
|
||||
name: "echo",
|
||||
description: "echo something",
|
||||
parameters: [{ name: "text", type: "string", description: "text to echo", required: true }],
|
||||
}
|
||||
end
|
||||
|
||||
def stub_response(prompt, response_text, tool_call: false)
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
||||
.with(body: model.default_options.merge(prompt: prompt).to_json)
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||
let(:google_tool) do
|
||||
{
|
||||
name: "google",
|
||||
description: "google something",
|
||||
parameters: [
|
||||
{ name: "query", type: "string", description: "text to google", required: true },
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
def stream_line(delta, finish_reason: nil)
|
||||
+"data: " << {
|
||||
completion: delta,
|
||||
stop: finish_reason ? "\n\nHuman:" : nil,
|
||||
stop_reason: finish_reason,
|
||||
truncated: false,
|
||||
log_id: "12b029451c6d18094d868bc04ce83f63",
|
||||
model: "claude-2",
|
||||
exception: nil,
|
||||
}.to_json
|
||||
let(:prompt_with_echo_tool) do
|
||||
prompt_with_tools = prompt
|
||||
prompt.tools = [echo_tool]
|
||||
prompt_with_tools
|
||||
end
|
||||
|
||||
def stub_streamed_response(prompt, deltas, tool_call: false)
|
||||
chunks =
|
||||
deltas.each_with_index.map do |_, index|
|
||||
if index == (deltas.length - 1)
|
||||
stream_line(deltas[index], finish_reason: "stop_sequence")
|
||||
else
|
||||
stream_line(deltas[index])
|
||||
end
|
||||
end
|
||||
|
||||
chunks = chunks.join("\n\n").split("")
|
||||
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
||||
.with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
|
||||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
end
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||
subject(:endpoint) { described_class.new("claude-2", DiscourseAi::Tokenizer::AnthropicTokenizer) }
|
||||
|
||||
fab!(:user)
|
||||
|
||||
let(:anthropic_mock) { AnthropicMock.new(endpoint) }
|
||||
|
||||
let(:compliance) do
|
||||
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user)
|
||||
let(:prompt_with_google_tool) do
|
||||
prompt_with_tools = prompt
|
||||
prompt.tools = [echo_tool]
|
||||
prompt_with_tools
|
||||
end
|
||||
|
||||
describe "#perform_completion!" do
|
||||
context "when using regular mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.regular_mode_simple_prompt(anthropic_mock)
|
||||
end
|
||||
end
|
||||
before { SiteSetting.ai_anthropic_api_key = "123" }
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.regular_mode_tools(anthropic_mock)
|
||||
end
|
||||
end
|
||||
it "does not eat spaces with tool calls" do
|
||||
body = <<~STRING
|
||||
event: message_start
|
||||
data: {"type":"message_start","message":{"id":"msg_019kmW9Q3GqfWmuFJbePJTBR","type":"message","role":"assistant","content":[],"model":"claude-3-opus-20240229","stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":347,"output_tokens":1}}}
|
||||
|
||||
event: content_block_start
|
||||
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
|
||||
|
||||
event: ping
|
||||
data: {"type": "ping"}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<function"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"calls"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<invoke"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<tool"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"name"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"google"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</tool"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"_"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"name"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<parameters"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"<query"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"top"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"10"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" "}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"things"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" do"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" in"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" japan"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" for"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" tourists"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</query"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</parameters"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"</invoke"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":">"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"\\n"}}
|
||||
|
||||
event: content_block_stop
|
||||
data: {"type":"content_block_stop","index":0}
|
||||
|
||||
event: message_delta
|
||||
data: {"type":"message_delta","delta":{"stop_reason":"stop_sequence","stop_sequence":"</function_calls>"},"usage":{"output_tokens":57}}
|
||||
|
||||
event: message_stop
|
||||
data: {"type":"message_stop"}
|
||||
STRING
|
||||
|
||||
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
|
||||
|
||||
result = +""
|
||||
llm.generate(prompt_with_google_tool, user: Discourse.system_user) do |partial|
|
||||
result << partial
|
||||
end
|
||||
|
||||
describe "when using streaming mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.streaming_mode_simple_prompt(anthropic_mock)
|
||||
end
|
||||
end
|
||||
expected = (<<~TEXT).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>google</tool_name>
|
||||
<parameters>
|
||||
<query>top 10 things to do in japan for tourists</query>
|
||||
</parameters>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.streaming_mode_tools(anthropic_mock)
|
||||
end
|
||||
end
|
||||
end
|
||||
expect(result.strip).to eq(expected)
|
||||
end
|
||||
|
||||
it "can stream a response" do
|
||||
body = (<<~STRING).strip
|
||||
event: message_start
|
||||
data: {"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-opus-20240229", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 25, "output_tokens": 1}}}
|
||||
|
||||
event: content_block_start
|
||||
data: {"type": "content_block_start", "index":0, "content_block": {"type": "text", "text": ""}}
|
||||
|
||||
event: ping
|
||||
data: {"type": "ping"}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "!"}}
|
||||
|
||||
event: content_block_stop
|
||||
data: {"type": "content_block_stop", "index": 0}
|
||||
|
||||
event: message_delta
|
||||
data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null, "usage":{"output_tokens": 15}}}
|
||||
|
||||
event: message_stop
|
||||
data: {"type": "message_stop"}
|
||||
STRING
|
||||
|
||||
parsed_body = nil
|
||||
|
||||
stub_request(:post, "https://api.anthropic.com/v1/messages").with(
|
||||
body:
|
||||
proc do |req_body|
|
||||
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
||||
true
|
||||
end,
|
||||
headers: {
|
||||
"Content-Type" => "application/json",
|
||||
"X-Api-Key" => "123",
|
||||
"Anthropic-Version" => "2023-06-01",
|
||||
},
|
||||
).to_return(status: 200, body: body)
|
||||
|
||||
result = +""
|
||||
llm.generate(prompt, user: Discourse.system_user) { |partial, cancel| result << partial }
|
||||
|
||||
expect(result).to eq("Hello!")
|
||||
|
||||
expected_body = {
|
||||
model: "claude-3-opus-20240229",
|
||||
max_tokens: 3000,
|
||||
messages: [{ role: "user", content: "user1: hello" }],
|
||||
system: "You are hello bot",
|
||||
stream: true,
|
||||
}
|
||||
expect(parsed_body).to eq(expected_body)
|
||||
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
|
||||
expect(log.request_tokens).to eq(25)
|
||||
expect(log.response_tokens).to eq(15)
|
||||
end
|
||||
|
||||
it "can return multiple function calls" do
|
||||
functions = <<~FUNCTIONS
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>echo</tool_name>
|
||||
<parameters>
|
||||
<text>something</text>
|
||||
</parameters>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>echo</tool_name>
|
||||
<parameters>
|
||||
<text>something else</text>
|
||||
</parameters>
|
||||
</invoke>
|
||||
FUNCTIONS
|
||||
|
||||
body = <<~STRING
|
||||
{
|
||||
"content": [
|
||||
{
|
||||
"text": "Hello!\n\n#{functions}\njunk",
|
||||
"type": "text"
|
||||
}
|
||||
],
|
||||
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||
"model": "claude-3-opus-20240229",
|
||||
"role": "assistant",
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": null,
|
||||
"type": "message",
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 25
|
||||
}
|
||||
}
|
||||
STRING
|
||||
|
||||
stub_request(:post, "https://api.anthropic.com/v1/messages").to_return(status: 200, body: body)
|
||||
|
||||
result = llm.generate(prompt_with_echo_tool, user: Discourse.system_user)
|
||||
|
||||
expected = (<<~EXPECTED).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>echo</tool_name>
|
||||
<parameters>
|
||||
<text>something</text>
|
||||
</parameters>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>echo</tool_name>
|
||||
<parameters>
|
||||
<text>something else</text>
|
||||
</parameters>
|
||||
<tool_id>tool_1</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
EXPECTED
|
||||
|
||||
expect(result.strip).to eq(expected)
|
||||
end
|
||||
|
||||
it "can operate in regular mode" do
|
||||
body = <<~STRING
|
||||
{
|
||||
"content": [
|
||||
{
|
||||
"text": "Hello!",
|
||||
"type": "text"
|
||||
}
|
||||
],
|
||||
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||
"model": "claude-3-opus-20240229",
|
||||
"role": "assistant",
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": null,
|
||||
"type": "message",
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 25
|
||||
}
|
||||
}
|
||||
STRING
|
||||
|
||||
parsed_body = nil
|
||||
stub_request(:post, "https://api.anthropic.com/v1/messages").with(
|
||||
body:
|
||||
proc do |req_body|
|
||||
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
||||
true
|
||||
end,
|
||||
headers: {
|
||||
"Content-Type" => "application/json",
|
||||
"X-Api-Key" => "123",
|
||||
"Anthropic-Version" => "2023-06-01",
|
||||
},
|
||||
).to_return(status: 200, body: body)
|
||||
|
||||
result = llm.generate(prompt, user: Discourse.system_user)
|
||||
expect(result).to eq("Hello!")
|
||||
|
||||
expected_body = {
|
||||
model: "claude-3-opus-20240229",
|
||||
max_tokens: 3000,
|
||||
messages: [{ role: "user", content: "user1: hello" }],
|
||||
system: "You are hello bot",
|
||||
}
|
||||
expect(parsed_body).to eq(expected_body)
|
||||
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
expect(log.provider_id).to eq(AiApiAuditLog::Provider::Anthropic)
|
||||
expect(log.request_tokens).to eq(10)
|
||||
expect(log.response_tokens).to eq(25)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -5,71 +5,6 @@ require "aws-eventstream"
|
|||
require "aws-sigv4"
|
||||
|
||||
class BedrockMock < EndpointMock
|
||||
def response(content)
|
||||
{
|
||||
completion: content,
|
||||
stop: "\n\nHuman:",
|
||||
stop_reason: "stop_sequence",
|
||||
truncated: false,
|
||||
log_id: "12dcc7feafbee4a394e0de9dffde3ac5",
|
||||
model: "claude",
|
||||
exception: nil,
|
||||
}
|
||||
end
|
||||
|
||||
def stub_response(prompt, response_content, tool_call: false)
|
||||
WebMock
|
||||
.stub_request(:post, "#{base_url}/invoke")
|
||||
.with(body: model.default_options.merge(prompt: prompt).to_json)
|
||||
.to_return(status: 200, body: JSON.dump(response(response_content)))
|
||||
end
|
||||
|
||||
def stream_line(delta, finish_reason: nil)
|
||||
encoder = Aws::EventStream::Encoder.new
|
||||
|
||||
message =
|
||||
Aws::EventStream::Message.new(
|
||||
payload:
|
||||
StringIO.new(
|
||||
{
|
||||
bytes:
|
||||
Base64.encode64(
|
||||
{
|
||||
completion: delta,
|
||||
stop: finish_reason ? "\n\nHuman:" : nil,
|
||||
stop_reason: finish_reason,
|
||||
truncated: false,
|
||||
log_id: "12b029451c6d18094d868bc04ce83f63",
|
||||
model: "claude-2.1",
|
||||
exception: nil,
|
||||
}.to_json,
|
||||
),
|
||||
}.to_json,
|
||||
),
|
||||
)
|
||||
|
||||
encoder.encode(message)
|
||||
end
|
||||
|
||||
def stub_streamed_response(prompt, deltas, tool_call: false)
|
||||
chunks =
|
||||
deltas.each_with_index.map do |_, index|
|
||||
if index == (deltas.length - 1)
|
||||
stream_line(deltas[index], finish_reason: "stop_sequence")
|
||||
else
|
||||
stream_line(deltas[index])
|
||||
end
|
||||
end
|
||||
|
||||
WebMock
|
||||
.stub_request(:post, "#{base_url}/invoke-with-response-stream")
|
||||
.with(body: model.default_options.merge(prompt: prompt).to_json)
|
||||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
|
||||
def base_url
|
||||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/anthropic.claude-v2:1"
|
||||
end
|
||||
end
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||
|
@ -89,32 +24,98 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
SiteSetting.ai_bedrock_region = "us-east-1"
|
||||
end
|
||||
|
||||
describe "#perform_completion!" do
|
||||
context "when using regular mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.regular_mode_simple_prompt(bedrock_mock)
|
||||
end
|
||||
end
|
||||
describe "Claude 3 Sonnet support" do
|
||||
it "supports the sonnet model" do
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.regular_mode_tools(bedrock_mock)
|
||||
request = nil
|
||||
|
||||
content = {
|
||||
content: [text: "hello sam"],
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 20,
|
||||
},
|
||||
}.to_json
|
||||
|
||||
stub_request(
|
||||
:post,
|
||||
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke",
|
||||
)
|
||||
.with do |inner_request|
|
||||
request = inner_request
|
||||
true
|
||||
end
|
||||
end
|
||||
.to_return(status: 200, body: content)
|
||||
|
||||
response = proxy.generate("hello world", user: user)
|
||||
|
||||
expect(request.headers["Authorization"]).to be_present
|
||||
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
||||
|
||||
expected = {
|
||||
"max_tokens" => 3000,
|
||||
"anthropic_version" => "bedrock-2023-05-31",
|
||||
"messages" => [{ "role" => "user", "content" => "hello world" }],
|
||||
"system" => "You are a helpful bot",
|
||||
}
|
||||
expect(JSON.parse(request.body)).to eq(expected)
|
||||
|
||||
expect(response).to eq("hello sam")
|
||||
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
expect(log.request_tokens).to eq(10)
|
||||
expect(log.response_tokens).to eq(20)
|
||||
end
|
||||
|
||||
describe "when using streaming mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.streaming_mode_simple_prompt(bedrock_mock)
|
||||
end
|
||||
end
|
||||
it "supports claude 3 sonnet streaming" do
|
||||
proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet")
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.streaming_mode_tools(bedrock_mock)
|
||||
request = nil
|
||||
|
||||
messages =
|
||||
[
|
||||
{ type: "message_start", message: { usage: { input_tokens: 9 } } },
|
||||
{ type: "content_block_delta", delta: { text: "hello " } },
|
||||
{ type: "content_block_delta", delta: { text: "sam" } },
|
||||
{ type: "message_delta", delta: { usage: { output_tokens: 25 } } },
|
||||
].map do |message|
|
||||
wrapped = { bytes: Base64.encode64(message.to_json) }.to_json
|
||||
io = StringIO.new(wrapped)
|
||||
aws_message = Aws::EventStream::Message.new(payload: io)
|
||||
Aws::EventStream::Encoder.new.encode(aws_message)
|
||||
end
|
||||
|
||||
bedrock_mock.with_chunk_array_support do
|
||||
stub_request(
|
||||
:post,
|
||||
"https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream",
|
||||
)
|
||||
.with do |inner_request|
|
||||
request = inner_request
|
||||
true
|
||||
end
|
||||
.to_return(status: 200, body: messages)
|
||||
|
||||
response = +""
|
||||
proxy.generate("hello world", user: user) { |partial| response << partial }
|
||||
|
||||
expect(request.headers["Authorization"]).to be_present
|
||||
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
||||
|
||||
expected = {
|
||||
"max_tokens" => 3000,
|
||||
"anthropic_version" => "bedrock-2023-05-31",
|
||||
"messages" => [{ "role" => "user", "content" => "hello world" }],
|
||||
"system" => "You are a helpful bot",
|
||||
}
|
||||
expect(JSON.parse(request.body)).to eq(expected)
|
||||
|
||||
expect(response).to eq("hello sam")
|
||||
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
expect(log.request_tokens).to eq(9)
|
||||
expect(log.response_tokens).to eq(25)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue