discourse-ai/lib/shared/inference/hugging_face_text_generation.rb
Rafael dos Santos Silva eb7fff3a55
FEATURE: Add support for StableBeluga and Upstage Llama2 instruct (#126)
* FEATURE: Add support for StableBeluga and Upstage Llama2 instruct

This means we support all models in the top3 of the Open LLM Leaderboard

Since some of those models have RoPE, we now have a setting so you can
customize the token limit depending which model you use.
2023-08-03 15:29:30 -03:00

146 lines
4.5 KiB
Ruby

# frozen_string_literal: true
module ::DiscourseAi
module Inference
class HuggingFaceTextGeneration
CompletionFailed = Class.new(StandardError)
TIMEOUT = 60
def self.perform!(
prompt,
model,
temperature: 0.7,
top_p: nil,
top_k: nil,
typical_p: nil,
max_tokens: 2000,
repetition_penalty: 1.1,
user_id: nil,
tokenizer: DiscourseAi::Tokenizer::Llama2Tokenizer,
token_limit: nil
)
raise CompletionFailed if model.blank?
url = URI(SiteSetting.ai_hugging_face_api_url)
if block_given?
url.path = "/generate_stream"
else
url.path = "/generate"
end
headers = { "Content-Type" => "application/json" }
if SiteSetting.ai_hugging_face_api_key.present?
headers["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_api_key}"
end
token_limit = token_limit || SiteSetting.ai_hugging_face_token_limit
parameters = {}
payload = { inputs: prompt, parameters: parameters }
prompt_size = tokenizer.size(prompt)
parameters[:top_p] = top_p if top_p
parameters[:top_k] = top_k if top_k
parameters[:typical_p] = typical_p if typical_p
parameters[:max_new_tokens] = token_limit - prompt_size
parameters[:temperature] = temperature if temperature
parameters[:repetition_penalty] = repetition_penalty if repetition_penalty
Net::HTTP.start(
url.host,
url.port,
use_ssl: url.scheme == "https",
read_timeout: TIMEOUT,
open_timeout: TIMEOUT,
write_timeout: TIMEOUT,
) do |http|
request = Net::HTTP::Post.new(url, headers)
request_body = payload.to_json
request.body = request_body
http.request(request) do |response|
if response.code.to_i != 200
Rails.logger.error(
"HuggingFaceTextGeneration: status: #{response.code.to_i} - body: #{response.body}",
)
raise CompletionFailed
end
log =
AiApiAuditLog.create!(
provider_id: AiApiAuditLog::Provider::HuggingFaceTextGeneration,
raw_request_payload: request_body,
user_id: user_id,
)
if !block_given?
response_body = response.read_body
parsed_response = JSON.parse(response_body, symbolize_names: true)
log.update!(
raw_response_payload: response_body,
request_tokens: tokenizer.size(prompt),
response_tokens: tokenizer.size(parsed_response[:generated_text]),
)
return parsed_response
end
begin
cancelled = false
cancel = lambda { cancelled = true }
response_data = +""
response_raw = +""
response.read_body do |chunk|
if cancelled
http.finish
return
end
response_raw << chunk
chunk
.split("\n")
.each do |line|
data = line.split("data: ", 2)[1]
next if !data || data.squish == "[DONE]"
if !cancelled
begin
# partial contains the entire payload till now
partial = JSON.parse(data, symbolize_names: true)
# this is the last chunk and contains the full response
next if partial[:token][:special] == true
response_data = partial[:token][:text].to_s
yield partial, cancel
rescue JSON::ParserError
nil
end
end
end
rescue IOError
raise if !cancelled
ensure
log.update!(
raw_response_payload: response_raw,
request_tokens: tokenizer.size(prompt),
response_tokens: tokenizer.size(response_data),
)
end
end
end
end
def self.try_parse(data)
JSON.parse(data, symbolize_names: true)
rescue JSON::ParserError
nil
end
end
end
end
end