FEATURE: Add support for StableBeluga and Upstage Llama2 instruct (#126)

* FEATURE: Add support for StableBeluga and Upstage Llama2 instruct

This means we support all models in the top3 of the Open LLM Leaderboard

Since some of those models have RoPE, we now have a setting so you can
customize the token limit depending which model you use.
This commit is contained in:
Rafael dos Santos Silva 2023-08-03 15:29:30 -03:00 committed by GitHub
parent 8b157feea5
commit eb7fff3a55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 85 additions and 13 deletions

View File

@ -114,7 +114,10 @@ plugins:
ai_hugging_face_api_key:
default: ""
secret: true
ai_hugging_face_token_limit:
default: 4096
ai_hugging_face_model_display_name:
default: ""
ai_google_custom_search_api_key:
default: ""

View File

@ -9,6 +9,7 @@ module DiscourseAi
require_relative "models/discourse"
require_relative "models/open_ai"
require_relative "models/llama2"
require_relative "models/llama2_fine_tuned_orca_style"
require_relative "strategies/fold_content"
require_relative "strategies/truncate_content"
@ -22,7 +23,11 @@ module DiscourseAi
Models::OpenAi.new("gpt-3.5-turbo-16k", max_tokens: 16_384),
Models::Discourse.new("long-t5-tglobal-base-16384-book-summary", max_tokens: 16_384),
Models::Anthropic.new("claude-2", max_tokens: 100_000),
Models::Llama2.new("Llama-2-7b-chat-hf", max_tokens: 4096),
Models::Llama2.new("Llama2-chat-hf", max_tokens: SiteSetting.ai_hugging_face_token_limit),
Models::Llama2FineTunedOrcaStyle.new(
"StableBeluga2",
max_tokens: SiteSetting.ai_hugging_face_token_limit,
),
]
foldable_models.each do |model|

View File

@ -5,7 +5,7 @@ module DiscourseAi
module Models
class Llama2 < Base
def display_name
"Llama2's #{model}"
"Llama2's #{SiteSetting.ai_hugging_face_model_display_name.presence || model}"
end
def correctly_configured?
@ -42,6 +42,7 @@ module DiscourseAi
Summarize the following in up to 400 words:
#{truncated_content} [/INST]
Here is a summary of the above topic:
TEXT
end
@ -66,6 +67,7 @@ module DiscourseAi
#{summary_instruction}
#{chunk_text} [/INST]
Here is a summary of the above topic:
TEXT
end
@ -90,20 +92,14 @@ module DiscourseAi
end
def completion(prompt)
::DiscourseAi::Inference::HuggingFaceTextGeneration.perform!(
prompt,
model,
token_limit: token_limit,
).dig(:generated_text)
::DiscourseAi::Inference::HuggingFaceTextGeneration.perform!(prompt, model).dig(
:generated_text,
)
end
def tokenizer
DiscourseAi::Tokenizer::Llama2Tokenizer
end
def token_limit
4096
end
end
end
end

View File

@ -0,0 +1,66 @@
# frozen_string_literal: true
module DiscourseAi
module Summarization
module Models
class Llama2FineTunedOrcaStyle < Llama2
def display_name
"Llama2FineTunedOrcaStyle's #{SiteSetting.ai_hugging_face_model_display_name.presence || model}"
end
def concatenate_summaries(summaries)
completion(<<~TEXT)
### System:
You are a helpful bot
### User:
Concatenate these disjoint summaries, creating a cohesive narrative:
#{summaries.join("\n")}
### Assistant:
TEXT
end
def summarize_with_truncation(contents, opts)
text_to_summarize = contents.map { |c| format_content_item(c) }.join
truncated_content = tokenizer.truncate(text_to_summarize, available_tokens)
completion(<<~TEXT)
### System:
#{build_base_prompt(opts)}
### User:
Summarize the following in up to 400 words:
#{truncated_content}
### Assistant:
Here is a summary of the above topic:
TEXT
end
private
def summarize_chunk(chunk_text, opts)
summary_instruction =
if opts[:single_chunk]
"Summarize the following forum discussion, creating a cohesive narrative:"
else
"Summarize the following in up to 400 words:"
end
completion(<<~TEXT)
### System:
#{build_base_prompt(opts)}
### User:
#{summary_instruction}
#{chunk_text}
### Assistant:
Here is a summary of the above topic:
TEXT
end
end
end
end
end

View File

@ -17,7 +17,7 @@ module ::DiscourseAi
repetition_penalty: 1.1,
user_id: nil,
tokenizer: DiscourseAi::Tokenizer::Llama2Tokenizer,
token_limit: 4096
token_limit: nil
)
raise CompletionFailed if model.blank?
@ -33,6 +33,8 @@ module ::DiscourseAi
headers["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_api_key}"
end
token_limit = token_limit || SiteSetting.ai_hugging_face_token_limit
parameters = {}
payload = { inputs: prompt, parameters: parameters }
prompt_size = tokenizer.size(prompt)