FEATURE: Compatibility with protected Hugging Face Endpoints (#123)

* FEATURE: Compatibility with protected Hugging Face Endpoints
This commit is contained in:
Rafael dos Santos Silva 2023-08-02 17:00:00 -03:00 committed by GitHub
parent 58b96eda6c
commit 8b157feea5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 10 deletions

View File

@ -111,6 +111,9 @@ plugins:
- "stable-diffusion-v1-5"
ai_hugging_face_api_url:
default: ""
ai_hugging_face_api_key:
default: ""
secret: true
ai_google_custom_search_api_key:

View File

@ -90,14 +90,20 @@ module DiscourseAi
end
def completion(prompt)
::DiscourseAi::Inference::HuggingFaceTextGeneration.perform!(prompt, model).dig(
:generated_text,
)
::DiscourseAi::Inference::HuggingFaceTextGeneration.perform!(
prompt,
model,
token_limit: token_limit,
).dig(:generated_text)
end
def tokenizer
DiscourseAi::Tokenizer::Llama2Tokenizer
end
def token_limit
4096
end
end
end
end

View File

@ -15,7 +15,9 @@ module ::DiscourseAi
typical_p: nil,
max_tokens: 2000,
repetition_penalty: 1.1,
user_id: nil
user_id: nil,
tokenizer: DiscourseAi::Tokenizer::Llama2Tokenizer,
token_limit: 4096
)
raise CompletionFailed if model.blank?
@ -27,13 +29,18 @@ module ::DiscourseAi
end
headers = { "Content-Type" => "application/json" }
if SiteSetting.ai_hugging_face_api_key.present?
headers["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_api_key}"
end
parameters = {}
payload = { inputs: prompt, parameters: parameters }
prompt_size = tokenizer.size(prompt)
parameters[:top_p] = top_p if top_p
parameters[:top_k] = top_k if top_k
parameters[:typical_p] = typical_p if typical_p
parameters[:max_new_tokens] = max_tokens if max_tokens
parameters[:max_new_tokens] = token_limit - prompt_size
parameters[:temperature] = temperature if temperature
parameters[:repetition_penalty] = repetition_penalty if repetition_penalty
@ -70,9 +77,8 @@ module ::DiscourseAi
log.update!(
raw_response_payload: response_body,
request_tokens: DiscourseAi::Tokenizer::Llama2Tokenizer.size(prompt),
response_tokens:
DiscourseAi::Tokenizer::Llama2Tokenizer.size(parsed_response[:generated_text]),
request_tokens: tokenizer.size(prompt),
response_tokens: tokenizer.size(parsed_response[:generated_text]),
)
return parsed_response
end
@ -118,8 +124,8 @@ module ::DiscourseAi
ensure
log.update!(
raw_response_payload: response_raw,
request_tokens: DiscourseAi::Tokenizer::Llama2Tokenizer.size(prompt),
response_tokens: DiscourseAi::Tokenizer::Llama2Tokenizer.size(response_data),
request_tokens: tokenizer.size(prompt),
response_tokens: tokenizer.size(response_data),
)
end
end