FEATURE: Compatibility with protected Hugging Face Endpoints (#123)
* FEATURE: Compatibility with protected Hugging Face Endpoints
This commit is contained in:
parent
58b96eda6c
commit
8b157feea5
|
@ -111,6 +111,9 @@ plugins:
|
|||
- "stable-diffusion-v1-5"
|
||||
ai_hugging_face_api_url:
|
||||
default: ""
|
||||
ai_hugging_face_api_key:
|
||||
default: ""
|
||||
secret: true
|
||||
|
||||
|
||||
ai_google_custom_search_api_key:
|
||||
|
|
|
@ -90,14 +90,20 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def completion(prompt)
|
||||
::DiscourseAi::Inference::HuggingFaceTextGeneration.perform!(prompt, model).dig(
|
||||
:generated_text,
|
||||
)
|
||||
::DiscourseAi::Inference::HuggingFaceTextGeneration.perform!(
|
||||
prompt,
|
||||
model,
|
||||
token_limit: token_limit,
|
||||
).dig(:generated_text)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::Llama2Tokenizer
|
||||
end
|
||||
|
||||
def token_limit
|
||||
4096
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -15,7 +15,9 @@ module ::DiscourseAi
|
|||
typical_p: nil,
|
||||
max_tokens: 2000,
|
||||
repetition_penalty: 1.1,
|
||||
user_id: nil
|
||||
user_id: nil,
|
||||
tokenizer: DiscourseAi::Tokenizer::Llama2Tokenizer,
|
||||
token_limit: 4096
|
||||
)
|
||||
raise CompletionFailed if model.blank?
|
||||
|
||||
|
@ -27,13 +29,18 @@ module ::DiscourseAi
|
|||
end
|
||||
headers = { "Content-Type" => "application/json" }
|
||||
|
||||
if SiteSetting.ai_hugging_face_api_key.present?
|
||||
headers["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_api_key}"
|
||||
end
|
||||
|
||||
parameters = {}
|
||||
payload = { inputs: prompt, parameters: parameters }
|
||||
prompt_size = tokenizer.size(prompt)
|
||||
|
||||
parameters[:top_p] = top_p if top_p
|
||||
parameters[:top_k] = top_k if top_k
|
||||
parameters[:typical_p] = typical_p if typical_p
|
||||
parameters[:max_new_tokens] = max_tokens if max_tokens
|
||||
parameters[:max_new_tokens] = token_limit - prompt_size
|
||||
parameters[:temperature] = temperature if temperature
|
||||
parameters[:repetition_penalty] = repetition_penalty if repetition_penalty
|
||||
|
||||
|
@ -70,9 +77,8 @@ module ::DiscourseAi
|
|||
|
||||
log.update!(
|
||||
raw_response_payload: response_body,
|
||||
request_tokens: DiscourseAi::Tokenizer::Llama2Tokenizer.size(prompt),
|
||||
response_tokens:
|
||||
DiscourseAi::Tokenizer::Llama2Tokenizer.size(parsed_response[:generated_text]),
|
||||
request_tokens: tokenizer.size(prompt),
|
||||
response_tokens: tokenizer.size(parsed_response[:generated_text]),
|
||||
)
|
||||
return parsed_response
|
||||
end
|
||||
|
@ -118,8 +124,8 @@ module ::DiscourseAi
|
|||
ensure
|
||||
log.update!(
|
||||
raw_response_payload: response_raw,
|
||||
request_tokens: DiscourseAi::Tokenizer::Llama2Tokenizer.size(prompt),
|
||||
response_tokens: DiscourseAi::Tokenizer::Llama2Tokenizer.size(response_data),
|
||||
request_tokens: tokenizer.size(prompt),
|
||||
response_tokens: tokenizer.size(response_data),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue