FEATURE: Compatibility with protected Hugging Face Endpoints (#123)

* FEATURE: Compatibility with protected Hugging Face Endpoints
This commit is contained in:
Rafael dos Santos Silva 2023-08-02 17:00:00 -03:00 committed by GitHub
parent 58b96eda6c
commit 8b157feea5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 10 deletions

View File

@ -111,6 +111,9 @@ plugins:
- "stable-diffusion-v1-5" - "stable-diffusion-v1-5"
ai_hugging_face_api_url: ai_hugging_face_api_url:
default: "" default: ""
ai_hugging_face_api_key:
default: ""
secret: true
ai_google_custom_search_api_key: ai_google_custom_search_api_key:

View File

@ -90,14 +90,20 @@ module DiscourseAi
end end
def completion(prompt) def completion(prompt)
::DiscourseAi::Inference::HuggingFaceTextGeneration.perform!(prompt, model).dig( ::DiscourseAi::Inference::HuggingFaceTextGeneration.perform!(
:generated_text, prompt,
) model,
token_limit: token_limit,
).dig(:generated_text)
end end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::Llama2Tokenizer DiscourseAi::Tokenizer::Llama2Tokenizer
end end
def token_limit
4096
end
end end
end end
end end

View File

@ -15,7 +15,9 @@ module ::DiscourseAi
typical_p: nil, typical_p: nil,
max_tokens: 2000, max_tokens: 2000,
repetition_penalty: 1.1, repetition_penalty: 1.1,
user_id: nil user_id: nil,
tokenizer: DiscourseAi::Tokenizer::Llama2Tokenizer,
token_limit: 4096
) )
raise CompletionFailed if model.blank? raise CompletionFailed if model.blank?
@ -27,13 +29,18 @@ module ::DiscourseAi
end end
headers = { "Content-Type" => "application/json" } headers = { "Content-Type" => "application/json" }
if SiteSetting.ai_hugging_face_api_key.present?
headers["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_api_key}"
end
parameters = {} parameters = {}
payload = { inputs: prompt, parameters: parameters } payload = { inputs: prompt, parameters: parameters }
prompt_size = tokenizer.size(prompt)
parameters[:top_p] = top_p if top_p parameters[:top_p] = top_p if top_p
parameters[:top_k] = top_k if top_k parameters[:top_k] = top_k if top_k
parameters[:typical_p] = typical_p if typical_p parameters[:typical_p] = typical_p if typical_p
parameters[:max_new_tokens] = max_tokens if max_tokens parameters[:max_new_tokens] = token_limit - prompt_size
parameters[:temperature] = temperature if temperature parameters[:temperature] = temperature if temperature
parameters[:repetition_penalty] = repetition_penalty if repetition_penalty parameters[:repetition_penalty] = repetition_penalty if repetition_penalty
@ -70,9 +77,8 @@ module ::DiscourseAi
log.update!( log.update!(
raw_response_payload: response_body, raw_response_payload: response_body,
request_tokens: DiscourseAi::Tokenizer::Llama2Tokenizer.size(prompt), request_tokens: tokenizer.size(prompt),
response_tokens: response_tokens: tokenizer.size(parsed_response[:generated_text]),
DiscourseAi::Tokenizer::Llama2Tokenizer.size(parsed_response[:generated_text]),
) )
return parsed_response return parsed_response
end end
@ -118,8 +124,8 @@ module ::DiscourseAi
ensure ensure
log.update!( log.update!(
raw_response_payload: response_raw, raw_response_payload: response_raw,
request_tokens: DiscourseAi::Tokenizer::Llama2Tokenizer.size(prompt), request_tokens: tokenizer.size(prompt),
response_tokens: DiscourseAi::Tokenizer::Llama2Tokenizer.size(response_data), response_tokens: tokenizer.size(response_data),
) )
end end
end end