diff --git a/config/settings.yml b/config/settings.yml index 626efab4..b6aa05fa 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -120,12 +120,19 @@ discourse_ai: default: 4096 ai_hugging_face_model_display_name: default: "" - ai_google_custom_search_api_key: default: "" secret: true ai_google_custom_search_cx: default: "" + ai_bedrock_access_key_id: + default: "" + secret: true + ai_bedrock_secret_access_key: + default: "" + secret: true + ai_bedrock_region: + default: "us-east-1" composer_ai_helper_enabled: default: false diff --git a/lib/shared/inference/amazon_bedrock_inference.rb b/lib/shared/inference/amazon_bedrock_inference.rb new file mode 100644 index 00000000..54727eed --- /dev/null +++ b/lib/shared/inference/amazon_bedrock_inference.rb @@ -0,0 +1,174 @@ +# frozen_string_literal: true + +require "base64" +require "json" +require "aws-eventstream" +require "aws-sigv4" + +module ::DiscourseAi + module Inference + class AmazonBedrockInference + CompletionFailed = Class.new(StandardError) + TIMEOUT = 60 + + def self.perform!( + prompt, + model = "anthropic.claude-v2", + temperature: nil, + top_p: nil, + top_k: nil, + max_tokens: 20_000, + user_id: nil, + stop_sequences: nil, + tokenizer: Tokenizer::AnthropicTokenizer + ) + raise CompletionFailed if model.blank? + raise CompletionFailed if !SiteSetting.ai_bedrock_access_key_id.present? + raise CompletionFailed if !SiteSetting.ai_bedrock_secret_access_key.present? + raise CompletionFailed if !SiteSetting.ai_bedrock_region.present? + + url = + URI( + "https://bedrock.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{model}/invoke", + ) + url.path = url.path + "-with-response-stream" if block_given? + headers = { "Content-Type" => "application/json" } + + signer = + Aws::Sigv4::Signer.new( + access_key_id: SiteSetting.ai_bedrock_access_key_id, + region: SiteSetting.ai_bedrock_region, + secret_access_key: SiteSetting.ai_bedrock_secret_access_key, + service: "bedrock", + ) + + log = nil + response_data = +"" + response_raw = +"" + + url_api = "https://bedrock.us-east-1.amazonaws.com/model/#{model}/" + if block_given? + url_api = url_api + "invoke-with-response-stream" + else + url_api = url_api + "invoke" + end + + url = URI(url_api) + headers = { "content-type" => "application/json", "Accept" => "*/*" } + + payload = { prompt: prompt } + + payload[:top_p] = top_p if top_p + payload[:top_k] = top_k if top_k + payload[:max_tokens_to_sample] = max_tokens || 2000 + payload[:temperature] = temperature if temperature + payload[:stop_sequences] = stop_sequences if stop_sequences + + Net::HTTP.start( + url.host, + url.port, + use_ssl: true, + read_timeout: TIMEOUT, + open_timeout: TIMEOUT, + write_timeout: TIMEOUT, + ) do |http| + request = Net::HTTP::Post.new(url) + request_body = payload.to_json + request.body = request_body + + signed_request = + signer.sign_request( + req: request, + http_method: request.method, + url: url, + body: request.body, + ) + + request.initialize_http_header(headers.merge!(signed_request.headers)) + + http.request(request) do |response| + if response.code.to_i != 200 + Rails.logger.error( + "BedRockInference: status: #{response.code.to_i} - body: #{response.body}", + ) + raise CompletionFailed + end + + log = + AiApiAuditLog.create!( + provider_id: AiApiAuditLog::Provider::Anthropic, + raw_request_payload: request_body, + user_id: user_id, + ) + + if !block_given? + response_body = response.read_body + parsed_response = JSON.parse(response_body, symbolize_names: true) + + log.update!( + raw_response_payload: response_body, + request_tokens: tokenizer.size(prompt), + response_tokens: tokenizer.size(parsed_response[:completion]), + ) + return parsed_response + end + + begin + cancelled = false + cancel = lambda { cancelled = true } + decoder = Aws::EventStream::Decoder.new + + response.read_body do |chunk| + if cancelled + http.finish + return + end + + response_raw << chunk + + begin + message = decoder.decode_chunk(chunk) + + partial = + message + .first + .payload + .string + .then { JSON.parse(_1) } + .dig("bytes") + .then { Base64.decode64(_1) } + .then { JSON.parse(_1, symbolize_names: true) } + + next if !partial + + response_data << partial[:completion].to_s + + yield partial, cancel if partial[:completion] + rescue JSON::ParserError, + Aws::EventStream::Errors::MessageChecksumError, + Aws::EventStream::Errors::PreludeChecksumError => e + Rails.logger.error("BedrockInference: #{e}") + end + rescue IOError + raise if !cancelled + end + end + + return response_data + end + ensure + if block_given? + log.update!( + raw_response_payload: response_data, + request_tokens: tokenizer.size(prompt), + response_tokens: tokenizer.size(response_data), + ) + end + if Rails.env.development? && log + puts "BedrockInference: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}" + end + end + end + end + end +end diff --git a/lib/shared/inference/anthropic_completions.rb b/lib/shared/inference/anthropic_completions.rb index 43bfa2d1..329c69aa 100644 --- a/lib/shared/inference/anthropic_completions.rb +++ b/lib/shared/inference/anthropic_completions.rb @@ -13,8 +13,27 @@ module ::DiscourseAi top_p: nil, max_tokens: nil, user_id: nil, - stop_sequences: nil + stop_sequences: nil, + &blk ) + # HACK to get around the fact that they have different APIs + # we will introduce a proper LLM abstraction layer to handle this shenanigas later this year + if model == "claude-2" && SiteSetting.ai_bedrock_access_key_id.present? && + SiteSetting.ai_bedrock_secret_access_key.present? && + SiteSetting.ai_bedrock_region.present? + return( + AmazonBedrockInference.perform!( + prompt, + temperature: temperature, + top_p: top_p, + max_tokens: max_tokens, + user_id: user_id, + stop_sequences: stop_sequences, + &blk + ) + ) + end + log = nil response_data = +"" response_raw = +"" diff --git a/plugin.rb b/plugin.rb index 0ad43c07..2bf3b73b 100644 --- a/plugin.rb +++ b/plugin.rb @@ -9,6 +9,7 @@ gem "tokenizers", "0.3.3" gem "tiktoken_ruby", "0.0.5" +gem "aws-eventstream", "1.2.0" enabled_site_setting :discourse_ai_enabled @@ -33,6 +34,7 @@ after_initialize do require_relative "lib/shared/inference/anthropic_completions" require_relative "lib/shared/inference/stability_generator" require_relative "lib/shared/inference/hugging_face_text_generation" + require_relative "lib/shared/inference/amazon_bedrock_inference" require_relative "lib/shared/inference/function" require_relative "lib/shared/inference/function_list"