FEATURE: Allow Anthropic inference via AWS Bedrock (#235)

If a module LLM model is set to claude-2 and the ai_bedrock variables are all present we will use AWS Bedrock instead of Antrhopic own APIs.

This is quite hacky, but will allow us to test the waters with AWS Bedrock early access with every module.

This situation of "same module, completely different API" is quite a bit far from what we had in the OpenAI/Azure separation, so it's more food for thought for when we start working on the LLM abstraction layer soon this year.
This commit is contained in:
Rafael dos Santos Silva 2023-10-02 12:58:36 -03:00 committed by GitHub
parent ed7d1f06d1
commit 102f47c1c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 204 additions and 2 deletions

View File

@ -120,12 +120,19 @@ discourse_ai:
default: 4096
ai_hugging_face_model_display_name:
default: ""
ai_google_custom_search_api_key:
default: ""
secret: true
ai_google_custom_search_cx:
default: ""
ai_bedrock_access_key_id:
default: ""
secret: true
ai_bedrock_secret_access_key:
default: ""
secret: true
ai_bedrock_region:
default: "us-east-1"
composer_ai_helper_enabled:
default: false

View File

@ -0,0 +1,174 @@
# frozen_string_literal: true
require "base64"
require "json"
require "aws-eventstream"
require "aws-sigv4"
module ::DiscourseAi
module Inference
class AmazonBedrockInference
CompletionFailed = Class.new(StandardError)
TIMEOUT = 60
def self.perform!(
prompt,
model = "anthropic.claude-v2",
temperature: nil,
top_p: nil,
top_k: nil,
max_tokens: 20_000,
user_id: nil,
stop_sequences: nil,
tokenizer: Tokenizer::AnthropicTokenizer
)
raise CompletionFailed if model.blank?
raise CompletionFailed if !SiteSetting.ai_bedrock_access_key_id.present?
raise CompletionFailed if !SiteSetting.ai_bedrock_secret_access_key.present?
raise CompletionFailed if !SiteSetting.ai_bedrock_region.present?
url =
URI(
"https://bedrock.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{model}/invoke",
)
url.path = url.path + "-with-response-stream" if block_given?
headers = { "Content-Type" => "application/json" }
signer =
Aws::Sigv4::Signer.new(
access_key_id: SiteSetting.ai_bedrock_access_key_id,
region: SiteSetting.ai_bedrock_region,
secret_access_key: SiteSetting.ai_bedrock_secret_access_key,
service: "bedrock",
)
log = nil
response_data = +""
response_raw = +""
url_api = "https://bedrock.us-east-1.amazonaws.com/model/#{model}/"
if block_given?
url_api = url_api + "invoke-with-response-stream"
else
url_api = url_api + "invoke"
end
url = URI(url_api)
headers = { "content-type" => "application/json", "Accept" => "*/*" }
payload = { prompt: prompt }
payload[:top_p] = top_p if top_p
payload[:top_k] = top_k if top_k
payload[:max_tokens_to_sample] = max_tokens || 2000
payload[:temperature] = temperature if temperature
payload[:stop_sequences] = stop_sequences if stop_sequences
Net::HTTP.start(
url.host,
url.port,
use_ssl: true,
read_timeout: TIMEOUT,
open_timeout: TIMEOUT,
write_timeout: TIMEOUT,
) do |http|
request = Net::HTTP::Post.new(url)
request_body = payload.to_json
request.body = request_body
signed_request =
signer.sign_request(
req: request,
http_method: request.method,
url: url,
body: request.body,
)
request.initialize_http_header(headers.merge!(signed_request.headers))
http.request(request) do |response|
if response.code.to_i != 200
Rails.logger.error(
"BedRockInference: status: #{response.code.to_i} - body: #{response.body}",
)
raise CompletionFailed
end
log =
AiApiAuditLog.create!(
provider_id: AiApiAuditLog::Provider::Anthropic,
raw_request_payload: request_body,
user_id: user_id,
)
if !block_given?
response_body = response.read_body
parsed_response = JSON.parse(response_body, symbolize_names: true)
log.update!(
raw_response_payload: response_body,
request_tokens: tokenizer.size(prompt),
response_tokens: tokenizer.size(parsed_response[:completion]),
)
return parsed_response
end
begin
cancelled = false
cancel = lambda { cancelled = true }
decoder = Aws::EventStream::Decoder.new
response.read_body do |chunk|
if cancelled
http.finish
return
end
response_raw << chunk
begin
message = decoder.decode_chunk(chunk)
partial =
message
.first
.payload
.string
.then { JSON.parse(_1) }
.dig("bytes")
.then { Base64.decode64(_1) }
.then { JSON.parse(_1, symbolize_names: true) }
next if !partial
response_data << partial[:completion].to_s
yield partial, cancel if partial[:completion]
rescue JSON::ParserError,
Aws::EventStream::Errors::MessageChecksumError,
Aws::EventStream::Errors::PreludeChecksumError => e
Rails.logger.error("BedrockInference: #{e}")
end
rescue IOError
raise if !cancelled
end
end
return response_data
end
ensure
if block_given?
log.update!(
raw_response_payload: response_data,
request_tokens: tokenizer.size(prompt),
response_tokens: tokenizer.size(response_data),
)
end
if Rails.env.development? && log
puts "BedrockInference: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
end
end
end
end
end
end

View File

@ -13,8 +13,27 @@ module ::DiscourseAi
top_p: nil,
max_tokens: nil,
user_id: nil,
stop_sequences: nil
stop_sequences: nil,
&blk
)
# HACK to get around the fact that they have different APIs
# we will introduce a proper LLM abstraction layer to handle this shenanigas later this year
if model == "claude-2" && SiteSetting.ai_bedrock_access_key_id.present? &&
SiteSetting.ai_bedrock_secret_access_key.present? &&
SiteSetting.ai_bedrock_region.present?
return(
AmazonBedrockInference.perform!(
prompt,
temperature: temperature,
top_p: top_p,
max_tokens: max_tokens,
user_id: user_id,
stop_sequences: stop_sequences,
&blk
)
)
end
log = nil
response_data = +""
response_raw = +""

View File

@ -9,6 +9,7 @@
gem "tokenizers", "0.3.3"
gem "tiktoken_ruby", "0.0.5"
gem "aws-eventstream", "1.2.0"
enabled_site_setting :discourse_ai_enabled
@ -33,6 +34,7 @@ after_initialize do
require_relative "lib/shared/inference/anthropic_completions"
require_relative "lib/shared/inference/stability_generator"
require_relative "lib/shared/inference/hugging_face_text_generation"
require_relative "lib/shared/inference/amazon_bedrock_inference"
require_relative "lib/shared/inference/function"
require_relative "lib/shared/inference/function_list"