From 102f47c1c4e74cdf341e5ad5bc49488e49b7f7d4 Mon Sep 17 00:00:00 2001 From: Rafael dos Santos Silva Date: Mon, 2 Oct 2023 12:58:36 -0300 Subject: [PATCH] FEATURE: Allow Anthropic inference via AWS Bedrock (#235) If a module LLM model is set to claude-2 and the ai_bedrock variables are all present we will use AWS Bedrock instead of Antrhopic own APIs. This is quite hacky, but will allow us to test the waters with AWS Bedrock early access with every module. This situation of "same module, completely different API" is quite a bit far from what we had in the OpenAI/Azure separation, so it's more food for thought for when we start working on the LLM abstraction layer soon this year. --- config/settings.yml | 9 +- .../inference/amazon_bedrock_inference.rb | 174 ++++++++++++++++++ lib/shared/inference/anthropic_completions.rb | 21 ++- plugin.rb | 2 + 4 files changed, 204 insertions(+), 2 deletions(-) create mode 100644 lib/shared/inference/amazon_bedrock_inference.rb diff --git a/config/settings.yml b/config/settings.yml index 626efab4..b6aa05fa 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -120,12 +120,19 @@ discourse_ai: default: 4096 ai_hugging_face_model_display_name: default: "" - ai_google_custom_search_api_key: default: "" secret: true ai_google_custom_search_cx: default: "" + ai_bedrock_access_key_id: + default: "" + secret: true + ai_bedrock_secret_access_key: + default: "" + secret: true + ai_bedrock_region: + default: "us-east-1" composer_ai_helper_enabled: default: false diff --git a/lib/shared/inference/amazon_bedrock_inference.rb b/lib/shared/inference/amazon_bedrock_inference.rb new file mode 100644 index 00000000..54727eed --- /dev/null +++ b/lib/shared/inference/amazon_bedrock_inference.rb @@ -0,0 +1,174 @@ +# frozen_string_literal: true + +require "base64" +require "json" +require "aws-eventstream" +require "aws-sigv4" + +module ::DiscourseAi + module Inference + class AmazonBedrockInference + CompletionFailed = Class.new(StandardError) + TIMEOUT = 60 + + def self.perform!( + prompt, + model = "anthropic.claude-v2", + temperature: nil, + top_p: nil, + top_k: nil, + max_tokens: 20_000, + user_id: nil, + stop_sequences: nil, + tokenizer: Tokenizer::AnthropicTokenizer + ) + raise CompletionFailed if model.blank? + raise CompletionFailed if !SiteSetting.ai_bedrock_access_key_id.present? + raise CompletionFailed if !SiteSetting.ai_bedrock_secret_access_key.present? + raise CompletionFailed if !SiteSetting.ai_bedrock_region.present? + + url = + URI( + "https://bedrock.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{model}/invoke", + ) + url.path = url.path + "-with-response-stream" if block_given? + headers = { "Content-Type" => "application/json" } + + signer = + Aws::Sigv4::Signer.new( + access_key_id: SiteSetting.ai_bedrock_access_key_id, + region: SiteSetting.ai_bedrock_region, + secret_access_key: SiteSetting.ai_bedrock_secret_access_key, + service: "bedrock", + ) + + log = nil + response_data = +"" + response_raw = +"" + + url_api = "https://bedrock.us-east-1.amazonaws.com/model/#{model}/" + if block_given? + url_api = url_api + "invoke-with-response-stream" + else + url_api = url_api + "invoke" + end + + url = URI(url_api) + headers = { "content-type" => "application/json", "Accept" => "*/*" } + + payload = { prompt: prompt } + + payload[:top_p] = top_p if top_p + payload[:top_k] = top_k if top_k + payload[:max_tokens_to_sample] = max_tokens || 2000 + payload[:temperature] = temperature if temperature + payload[:stop_sequences] = stop_sequences if stop_sequences + + Net::HTTP.start( + url.host, + url.port, + use_ssl: true, + read_timeout: TIMEOUT, + open_timeout: TIMEOUT, + write_timeout: TIMEOUT, + ) do |http| + request = Net::HTTP::Post.new(url) + request_body = payload.to_json + request.body = request_body + + signed_request = + signer.sign_request( + req: request, + http_method: request.method, + url: url, + body: request.body, + ) + + request.initialize_http_header(headers.merge!(signed_request.headers)) + + http.request(request) do |response| + if response.code.to_i != 200 + Rails.logger.error( + "BedRockInference: status: #{response.code.to_i} - body: #{response.body}", + ) + raise CompletionFailed + end + + log = + AiApiAuditLog.create!( + provider_id: AiApiAuditLog::Provider::Anthropic, + raw_request_payload: request_body, + user_id: user_id, + ) + + if !block_given? + response_body = response.read_body + parsed_response = JSON.parse(response_body, symbolize_names: true) + + log.update!( + raw_response_payload: response_body, + request_tokens: tokenizer.size(prompt), + response_tokens: tokenizer.size(parsed_response[:completion]), + ) + return parsed_response + end + + begin + cancelled = false + cancel = lambda { cancelled = true } + decoder = Aws::EventStream::Decoder.new + + response.read_body do |chunk| + if cancelled + http.finish + return + end + + response_raw << chunk + + begin + message = decoder.decode_chunk(chunk) + + partial = + message + .first + .payload + .string + .then { JSON.parse(_1) } + .dig("bytes") + .then { Base64.decode64(_1) } + .then { JSON.parse(_1, symbolize_names: true) } + + next if !partial + + response_data << partial[:completion].to_s + + yield partial, cancel if partial[:completion] + rescue JSON::ParserError, + Aws::EventStream::Errors::MessageChecksumError, + Aws::EventStream::Errors::PreludeChecksumError => e + Rails.logger.error("BedrockInference: #{e}") + end + rescue IOError + raise if !cancelled + end + end + + return response_data + end + ensure + if block_given? + log.update!( + raw_response_payload: response_data, + request_tokens: tokenizer.size(prompt), + response_tokens: tokenizer.size(response_data), + ) + end + if Rails.env.development? && log + puts "BedrockInference: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}" + end + end + end + end + end +end diff --git a/lib/shared/inference/anthropic_completions.rb b/lib/shared/inference/anthropic_completions.rb index 43bfa2d1..329c69aa 100644 --- a/lib/shared/inference/anthropic_completions.rb +++ b/lib/shared/inference/anthropic_completions.rb @@ -13,8 +13,27 @@ module ::DiscourseAi top_p: nil, max_tokens: nil, user_id: nil, - stop_sequences: nil + stop_sequences: nil, + &blk ) + # HACK to get around the fact that they have different APIs + # we will introduce a proper LLM abstraction layer to handle this shenanigas later this year + if model == "claude-2" && SiteSetting.ai_bedrock_access_key_id.present? && + SiteSetting.ai_bedrock_secret_access_key.present? && + SiteSetting.ai_bedrock_region.present? + return( + AmazonBedrockInference.perform!( + prompt, + temperature: temperature, + top_p: top_p, + max_tokens: max_tokens, + user_id: user_id, + stop_sequences: stop_sequences, + &blk + ) + ) + end + log = nil response_data = +"" response_raw = +"" diff --git a/plugin.rb b/plugin.rb index 0ad43c07..2bf3b73b 100644 --- a/plugin.rb +++ b/plugin.rb @@ -9,6 +9,7 @@ gem "tokenizers", "0.3.3" gem "tiktoken_ruby", "0.0.5" +gem "aws-eventstream", "1.2.0" enabled_site_setting :discourse_ai_enabled @@ -33,6 +34,7 @@ after_initialize do require_relative "lib/shared/inference/anthropic_completions" require_relative "lib/shared/inference/stability_generator" require_relative "lib/shared/inference/hugging_face_text_generation" + require_relative "lib/shared/inference/amazon_bedrock_inference" require_relative "lib/shared/inference/function" require_relative "lib/shared/inference/function_list"