mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-05-15 12:59:41 +00:00
* FEATURE: Add support for StableBeluga and Upstage Llama2 instruct This means we support all models in the top3 of the Open LLM Leaderboard Since some of those models have RoPE, we now have a setting so you can customize the token limit depending which model you use.
146 lines
4.5 KiB
Ruby
146 lines
4.5 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module ::DiscourseAi
|
|
module Inference
|
|
class HuggingFaceTextGeneration
|
|
CompletionFailed = Class.new(StandardError)
|
|
TIMEOUT = 60
|
|
|
|
def self.perform!(
|
|
prompt,
|
|
model,
|
|
temperature: 0.7,
|
|
top_p: nil,
|
|
top_k: nil,
|
|
typical_p: nil,
|
|
max_tokens: 2000,
|
|
repetition_penalty: 1.1,
|
|
user_id: nil,
|
|
tokenizer: DiscourseAi::Tokenizer::Llama2Tokenizer,
|
|
token_limit: nil
|
|
)
|
|
raise CompletionFailed if model.blank?
|
|
|
|
url = URI(SiteSetting.ai_hugging_face_api_url)
|
|
if block_given?
|
|
url.path = "/generate_stream"
|
|
else
|
|
url.path = "/generate"
|
|
end
|
|
headers = { "Content-Type" => "application/json" }
|
|
|
|
if SiteSetting.ai_hugging_face_api_key.present?
|
|
headers["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_api_key}"
|
|
end
|
|
|
|
token_limit = token_limit || SiteSetting.ai_hugging_face_token_limit
|
|
|
|
parameters = {}
|
|
payload = { inputs: prompt, parameters: parameters }
|
|
prompt_size = tokenizer.size(prompt)
|
|
|
|
parameters[:top_p] = top_p if top_p
|
|
parameters[:top_k] = top_k if top_k
|
|
parameters[:typical_p] = typical_p if typical_p
|
|
parameters[:max_new_tokens] = token_limit - prompt_size
|
|
parameters[:temperature] = temperature if temperature
|
|
parameters[:repetition_penalty] = repetition_penalty if repetition_penalty
|
|
|
|
Net::HTTP.start(
|
|
url.host,
|
|
url.port,
|
|
use_ssl: url.scheme == "https",
|
|
read_timeout: TIMEOUT,
|
|
open_timeout: TIMEOUT,
|
|
write_timeout: TIMEOUT,
|
|
) do |http|
|
|
request = Net::HTTP::Post.new(url, headers)
|
|
request_body = payload.to_json
|
|
request.body = request_body
|
|
|
|
http.request(request) do |response|
|
|
if response.code.to_i != 200
|
|
Rails.logger.error(
|
|
"HuggingFaceTextGeneration: status: #{response.code.to_i} - body: #{response.body}",
|
|
)
|
|
raise CompletionFailed
|
|
end
|
|
|
|
log =
|
|
AiApiAuditLog.create!(
|
|
provider_id: AiApiAuditLog::Provider::HuggingFaceTextGeneration,
|
|
raw_request_payload: request_body,
|
|
user_id: user_id,
|
|
)
|
|
|
|
if !block_given?
|
|
response_body = response.read_body
|
|
parsed_response = JSON.parse(response_body, symbolize_names: true)
|
|
|
|
log.update!(
|
|
raw_response_payload: response_body,
|
|
request_tokens: tokenizer.size(prompt),
|
|
response_tokens: tokenizer.size(parsed_response[:generated_text]),
|
|
)
|
|
return parsed_response
|
|
end
|
|
|
|
begin
|
|
cancelled = false
|
|
cancel = lambda { cancelled = true }
|
|
response_data = +""
|
|
response_raw = +""
|
|
|
|
response.read_body do |chunk|
|
|
if cancelled
|
|
http.finish
|
|
return
|
|
end
|
|
|
|
response_raw << chunk
|
|
|
|
chunk
|
|
.split("\n")
|
|
.each do |line|
|
|
data = line.split("data: ", 2)[1]
|
|
next if !data || data.squish == "[DONE]"
|
|
|
|
if !cancelled
|
|
begin
|
|
# partial contains the entire payload till now
|
|
partial = JSON.parse(data, symbolize_names: true)
|
|
|
|
# this is the last chunk and contains the full response
|
|
next if partial[:token][:special] == true
|
|
|
|
response_data = partial[:token][:text].to_s
|
|
|
|
yield partial, cancel
|
|
rescue JSON::ParserError
|
|
nil
|
|
end
|
|
end
|
|
end
|
|
rescue IOError
|
|
raise if !cancelled
|
|
ensure
|
|
log.update!(
|
|
raw_response_payload: response_raw,
|
|
request_tokens: tokenizer.size(prompt),
|
|
response_tokens: tokenizer.size(response_data),
|
|
)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
def self.try_parse(data)
|
|
JSON.parse(data, symbolize_names: true)
|
|
rescue JSON::ParserError
|
|
nil
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|