diff --git a/config/settings.yml b/config/settings.yml index 00315a47..0ad3e68f 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -111,6 +111,9 @@ plugins: - "stable-diffusion-v1-5" ai_hugging_face_api_url: default: "" + ai_hugging_face_api_key: + default: "" + secret: true ai_google_custom_search_api_key: diff --git a/lib/modules/summarization/models/llama2.rb b/lib/modules/summarization/models/llama2.rb index bc5c04f2..a6e5ef37 100644 --- a/lib/modules/summarization/models/llama2.rb +++ b/lib/modules/summarization/models/llama2.rb @@ -90,14 +90,20 @@ module DiscourseAi end def completion(prompt) - ::DiscourseAi::Inference::HuggingFaceTextGeneration.perform!(prompt, model).dig( - :generated_text, - ) + ::DiscourseAi::Inference::HuggingFaceTextGeneration.perform!( + prompt, + model, + token_limit: token_limit, + ).dig(:generated_text) end def tokenizer DiscourseAi::Tokenizer::Llama2Tokenizer end + + def token_limit + 4096 + end end end end diff --git a/lib/shared/inference/hugging_face_text_generation.rb b/lib/shared/inference/hugging_face_text_generation.rb index c3d69e12..52b8d174 100644 --- a/lib/shared/inference/hugging_face_text_generation.rb +++ b/lib/shared/inference/hugging_face_text_generation.rb @@ -15,7 +15,9 @@ module ::DiscourseAi typical_p: nil, max_tokens: 2000, repetition_penalty: 1.1, - user_id: nil + user_id: nil, + tokenizer: DiscourseAi::Tokenizer::Llama2Tokenizer, + token_limit: 4096 ) raise CompletionFailed if model.blank? @@ -27,13 +29,18 @@ module ::DiscourseAi end headers = { "Content-Type" => "application/json" } + if SiteSetting.ai_hugging_face_api_key.present? + headers["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_api_key}" + end + parameters = {} payload = { inputs: prompt, parameters: parameters } + prompt_size = tokenizer.size(prompt) parameters[:top_p] = top_p if top_p parameters[:top_k] = top_k if top_k parameters[:typical_p] = typical_p if typical_p - parameters[:max_new_tokens] = max_tokens if max_tokens + parameters[:max_new_tokens] = token_limit - prompt_size parameters[:temperature] = temperature if temperature parameters[:repetition_penalty] = repetition_penalty if repetition_penalty @@ -70,9 +77,8 @@ module ::DiscourseAi log.update!( raw_response_payload: response_body, - request_tokens: DiscourseAi::Tokenizer::Llama2Tokenizer.size(prompt), - response_tokens: - DiscourseAi::Tokenizer::Llama2Tokenizer.size(parsed_response[:generated_text]), + request_tokens: tokenizer.size(prompt), + response_tokens: tokenizer.size(parsed_response[:generated_text]), ) return parsed_response end @@ -118,8 +124,8 @@ module ::DiscourseAi ensure log.update!( raw_response_payload: response_raw, - request_tokens: DiscourseAi::Tokenizer::Llama2Tokenizer.size(prompt), - response_tokens: DiscourseAi::Tokenizer::Llama2Tokenizer.size(response_data), + request_tokens: tokenizer.size(prompt), + response_tokens: tokenizer.size(response_data), ) end end