FEATURE: Allow Anthropic inference via AWS Bedrock (#235)
If a module LLM model is set to claude-2 and the ai_bedrock variables are all present we will use AWS Bedrock instead of Antrhopic own APIs. This is quite hacky, but will allow us to test the waters with AWS Bedrock early access with every module. This situation of "same module, completely different API" is quite a bit far from what we had in the OpenAI/Azure separation, so it's more food for thought for when we start working on the LLM abstraction layer soon this year.
This commit is contained in:
parent
ed7d1f06d1
commit
102f47c1c4
|
@ -120,12 +120,19 @@ discourse_ai:
|
|||
default: 4096
|
||||
ai_hugging_face_model_display_name:
|
||||
default: ""
|
||||
|
||||
ai_google_custom_search_api_key:
|
||||
default: ""
|
||||
secret: true
|
||||
ai_google_custom_search_cx:
|
||||
default: ""
|
||||
ai_bedrock_access_key_id:
|
||||
default: ""
|
||||
secret: true
|
||||
ai_bedrock_secret_access_key:
|
||||
default: ""
|
||||
secret: true
|
||||
ai_bedrock_region:
|
||||
default: "us-east-1"
|
||||
|
||||
composer_ai_helper_enabled:
|
||||
default: false
|
||||
|
|
|
@ -0,0 +1,174 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "base64"
|
||||
require "json"
|
||||
require "aws-eventstream"
|
||||
require "aws-sigv4"
|
||||
|
||||
module ::DiscourseAi
|
||||
module Inference
|
||||
class AmazonBedrockInference
|
||||
CompletionFailed = Class.new(StandardError)
|
||||
TIMEOUT = 60
|
||||
|
||||
def self.perform!(
|
||||
prompt,
|
||||
model = "anthropic.claude-v2",
|
||||
temperature: nil,
|
||||
top_p: nil,
|
||||
top_k: nil,
|
||||
max_tokens: 20_000,
|
||||
user_id: nil,
|
||||
stop_sequences: nil,
|
||||
tokenizer: Tokenizer::AnthropicTokenizer
|
||||
)
|
||||
raise CompletionFailed if model.blank?
|
||||
raise CompletionFailed if !SiteSetting.ai_bedrock_access_key_id.present?
|
||||
raise CompletionFailed if !SiteSetting.ai_bedrock_secret_access_key.present?
|
||||
raise CompletionFailed if !SiteSetting.ai_bedrock_region.present?
|
||||
|
||||
url =
|
||||
URI(
|
||||
"https://bedrock.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{model}/invoke",
|
||||
)
|
||||
url.path = url.path + "-with-response-stream" if block_given?
|
||||
headers = { "Content-Type" => "application/json" }
|
||||
|
||||
signer =
|
||||
Aws::Sigv4::Signer.new(
|
||||
access_key_id: SiteSetting.ai_bedrock_access_key_id,
|
||||
region: SiteSetting.ai_bedrock_region,
|
||||
secret_access_key: SiteSetting.ai_bedrock_secret_access_key,
|
||||
service: "bedrock",
|
||||
)
|
||||
|
||||
log = nil
|
||||
response_data = +""
|
||||
response_raw = +""
|
||||
|
||||
url_api = "https://bedrock.us-east-1.amazonaws.com/model/#{model}/"
|
||||
if block_given?
|
||||
url_api = url_api + "invoke-with-response-stream"
|
||||
else
|
||||
url_api = url_api + "invoke"
|
||||
end
|
||||
|
||||
url = URI(url_api)
|
||||
headers = { "content-type" => "application/json", "Accept" => "*/*" }
|
||||
|
||||
payload = { prompt: prompt }
|
||||
|
||||
payload[:top_p] = top_p if top_p
|
||||
payload[:top_k] = top_k if top_k
|
||||
payload[:max_tokens_to_sample] = max_tokens || 2000
|
||||
payload[:temperature] = temperature if temperature
|
||||
payload[:stop_sequences] = stop_sequences if stop_sequences
|
||||
|
||||
Net::HTTP.start(
|
||||
url.host,
|
||||
url.port,
|
||||
use_ssl: true,
|
||||
read_timeout: TIMEOUT,
|
||||
open_timeout: TIMEOUT,
|
||||
write_timeout: TIMEOUT,
|
||||
) do |http|
|
||||
request = Net::HTTP::Post.new(url)
|
||||
request_body = payload.to_json
|
||||
request.body = request_body
|
||||
|
||||
signed_request =
|
||||
signer.sign_request(
|
||||
req: request,
|
||||
http_method: request.method,
|
||||
url: url,
|
||||
body: request.body,
|
||||
)
|
||||
|
||||
request.initialize_http_header(headers.merge!(signed_request.headers))
|
||||
|
||||
http.request(request) do |response|
|
||||
if response.code.to_i != 200
|
||||
Rails.logger.error(
|
||||
"BedRockInference: status: #{response.code.to_i} - body: #{response.body}",
|
||||
)
|
||||
raise CompletionFailed
|
||||
end
|
||||
|
||||
log =
|
||||
AiApiAuditLog.create!(
|
||||
provider_id: AiApiAuditLog::Provider::Anthropic,
|
||||
raw_request_payload: request_body,
|
||||
user_id: user_id,
|
||||
)
|
||||
|
||||
if !block_given?
|
||||
response_body = response.read_body
|
||||
parsed_response = JSON.parse(response_body, symbolize_names: true)
|
||||
|
||||
log.update!(
|
||||
raw_response_payload: response_body,
|
||||
request_tokens: tokenizer.size(prompt),
|
||||
response_tokens: tokenizer.size(parsed_response[:completion]),
|
||||
)
|
||||
return parsed_response
|
||||
end
|
||||
|
||||
begin
|
||||
cancelled = false
|
||||
cancel = lambda { cancelled = true }
|
||||
decoder = Aws::EventStream::Decoder.new
|
||||
|
||||
response.read_body do |chunk|
|
||||
if cancelled
|
||||
http.finish
|
||||
return
|
||||
end
|
||||
|
||||
response_raw << chunk
|
||||
|
||||
begin
|
||||
message = decoder.decode_chunk(chunk)
|
||||
|
||||
partial =
|
||||
message
|
||||
.first
|
||||
.payload
|
||||
.string
|
||||
.then { JSON.parse(_1) }
|
||||
.dig("bytes")
|
||||
.then { Base64.decode64(_1) }
|
||||
.then { JSON.parse(_1, symbolize_names: true) }
|
||||
|
||||
next if !partial
|
||||
|
||||
response_data << partial[:completion].to_s
|
||||
|
||||
yield partial, cancel if partial[:completion]
|
||||
rescue JSON::ParserError,
|
||||
Aws::EventStream::Errors::MessageChecksumError,
|
||||
Aws::EventStream::Errors::PreludeChecksumError => e
|
||||
Rails.logger.error("BedrockInference: #{e}")
|
||||
end
|
||||
rescue IOError
|
||||
raise if !cancelled
|
||||
end
|
||||
end
|
||||
|
||||
return response_data
|
||||
end
|
||||
ensure
|
||||
if block_given?
|
||||
log.update!(
|
||||
raw_response_payload: response_data,
|
||||
request_tokens: tokenizer.size(prompt),
|
||||
response_tokens: tokenizer.size(response_data),
|
||||
)
|
||||
end
|
||||
if Rails.env.development? && log
|
||||
puts "BedrockInference: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -13,8 +13,27 @@ module ::DiscourseAi
|
|||
top_p: nil,
|
||||
max_tokens: nil,
|
||||
user_id: nil,
|
||||
stop_sequences: nil
|
||||
stop_sequences: nil,
|
||||
&blk
|
||||
)
|
||||
# HACK to get around the fact that they have different APIs
|
||||
# we will introduce a proper LLM abstraction layer to handle this shenanigas later this year
|
||||
if model == "claude-2" && SiteSetting.ai_bedrock_access_key_id.present? &&
|
||||
SiteSetting.ai_bedrock_secret_access_key.present? &&
|
||||
SiteSetting.ai_bedrock_region.present?
|
||||
return(
|
||||
AmazonBedrockInference.perform!(
|
||||
prompt,
|
||||
temperature: temperature,
|
||||
top_p: top_p,
|
||||
max_tokens: max_tokens,
|
||||
user_id: user_id,
|
||||
stop_sequences: stop_sequences,
|
||||
&blk
|
||||
)
|
||||
)
|
||||
end
|
||||
|
||||
log = nil
|
||||
response_data = +""
|
||||
response_raw = +""
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
gem "tokenizers", "0.3.3"
|
||||
gem "tiktoken_ruby", "0.0.5"
|
||||
gem "aws-eventstream", "1.2.0"
|
||||
|
||||
enabled_site_setting :discourse_ai_enabled
|
||||
|
||||
|
@ -33,6 +34,7 @@ after_initialize do
|
|||
require_relative "lib/shared/inference/anthropic_completions"
|
||||
require_relative "lib/shared/inference/stability_generator"
|
||||
require_relative "lib/shared/inference/hugging_face_text_generation"
|
||||
require_relative "lib/shared/inference/amazon_bedrock_inference"
|
||||
require_relative "lib/shared/inference/function"
|
||||
require_relative "lib/shared/inference/function_list"
|
||||
|
||||
|
|
Loading…
Reference in New Issue