mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-10-27 12:38:45 +00:00
If a module LLM model is set to claude-2 and the ai_bedrock variables are all present we will use AWS Bedrock instead of Antrhopic own APIs. This is quite hacky, but will allow us to test the waters with AWS Bedrock early access with every module. This situation of "same module, completely different API" is quite a bit far from what we had in the OpenAI/Azure separation, so it's more food for thought for when we start working on the LLM abstraction layer soon this year.
156 lines
4.8 KiB
Ruby
156 lines
4.8 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module ::DiscourseAi
|
|
module Inference
|
|
class AnthropicCompletions
|
|
CompletionFailed = Class.new(StandardError)
|
|
TIMEOUT = 60
|
|
|
|
def self.perform!(
|
|
prompt,
|
|
model = "claude-2",
|
|
temperature: nil,
|
|
top_p: nil,
|
|
max_tokens: nil,
|
|
user_id: nil,
|
|
stop_sequences: nil,
|
|
&blk
|
|
)
|
|
# HACK to get around the fact that they have different APIs
|
|
# we will introduce a proper LLM abstraction layer to handle this shenanigas later this year
|
|
if model == "claude-2" && SiteSetting.ai_bedrock_access_key_id.present? &&
|
|
SiteSetting.ai_bedrock_secret_access_key.present? &&
|
|
SiteSetting.ai_bedrock_region.present?
|
|
return(
|
|
AmazonBedrockInference.perform!(
|
|
prompt,
|
|
temperature: temperature,
|
|
top_p: top_p,
|
|
max_tokens: max_tokens,
|
|
user_id: user_id,
|
|
stop_sequences: stop_sequences,
|
|
&blk
|
|
)
|
|
)
|
|
end
|
|
|
|
log = nil
|
|
response_data = +""
|
|
response_raw = +""
|
|
|
|
url = URI("https://api.anthropic.com/v1/complete")
|
|
headers = {
|
|
"anthropic-version" => "2023-06-01",
|
|
"x-api-key" => SiteSetting.ai_anthropic_api_key,
|
|
"content-type" => "application/json",
|
|
}
|
|
|
|
payload = { model: model, prompt: prompt }
|
|
|
|
payload[:top_p] = top_p if top_p
|
|
payload[:max_tokens_to_sample] = max_tokens || 2000
|
|
payload[:temperature] = temperature if temperature
|
|
payload[:stream] = true if block_given?
|
|
payload[:stop_sequences] = stop_sequences if stop_sequences
|
|
|
|
Net::HTTP.start(
|
|
url.host,
|
|
url.port,
|
|
use_ssl: true,
|
|
read_timeout: TIMEOUT,
|
|
open_timeout: TIMEOUT,
|
|
write_timeout: TIMEOUT,
|
|
) do |http|
|
|
request = Net::HTTP::Post.new(url, headers)
|
|
request_body = payload.to_json
|
|
request.body = request_body
|
|
|
|
http.request(request) do |response|
|
|
if response.code.to_i != 200
|
|
Rails.logger.error(
|
|
"AnthropicCompletions: status: #{response.code.to_i} - body: #{response.body}",
|
|
)
|
|
raise CompletionFailed
|
|
end
|
|
|
|
log =
|
|
AiApiAuditLog.create!(
|
|
provider_id: AiApiAuditLog::Provider::Anthropic,
|
|
raw_request_payload: request_body,
|
|
user_id: user_id,
|
|
)
|
|
|
|
if !block_given?
|
|
response_body = response.read_body
|
|
parsed_response = JSON.parse(response_body, symbolize_names: true)
|
|
|
|
log.update!(
|
|
raw_response_payload: response_body,
|
|
request_tokens: DiscourseAi::Tokenizer::AnthropicTokenizer.size(prompt),
|
|
response_tokens:
|
|
DiscourseAi::Tokenizer::AnthropicTokenizer.size(parsed_response[:completion]),
|
|
)
|
|
return parsed_response
|
|
end
|
|
|
|
begin
|
|
cancelled = false
|
|
cancel = lambda { cancelled = true }
|
|
|
|
response.read_body do |chunk|
|
|
if cancelled
|
|
http.finish
|
|
return
|
|
end
|
|
|
|
response_raw << chunk
|
|
|
|
chunk
|
|
.split("\n")
|
|
.each do |line|
|
|
data = line.split("data: ", 2)[1]
|
|
next if !data
|
|
|
|
if !cancelled
|
|
begin
|
|
partial = JSON.parse(data, symbolize_names: true)
|
|
response_data << partial[:completion].to_s
|
|
|
|
# ping has no data... do not yeild it
|
|
yield partial, cancel if partial[:completion]
|
|
rescue JSON::ParserError
|
|
nil
|
|
# TODO leftover chunk carry over to next
|
|
end
|
|
end
|
|
end
|
|
rescue IOError
|
|
raise if !cancelled
|
|
end
|
|
end
|
|
|
|
return response_data
|
|
end
|
|
ensure
|
|
if block_given?
|
|
log.update!(
|
|
raw_response_payload: response_raw,
|
|
request_tokens: DiscourseAi::Tokenizer::AnthropicTokenizer.size(prompt),
|
|
response_tokens: DiscourseAi::Tokenizer::AnthropicTokenizer.size(response_data),
|
|
)
|
|
end
|
|
if Rails.env.development? && log
|
|
puts "AnthropicCompletions: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
|
|
end
|
|
end
|
|
|
|
def self.try_parse(data)
|
|
JSON.parse(data, symbolize_names: true)
|
|
rescue JSON::ParserError
|
|
nil
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|