FEATURE: Llama2 for summarization (#116)
This commit is contained in:
parent
4b0c077ce5
commit
b25daed60b
|
@ -4,6 +4,7 @@ class AiApiAuditLog < ActiveRecord::Base
|
||||||
module Provider
|
module Provider
|
||||||
OpenAI = 1
|
OpenAI = 1
|
||||||
Anthropic = 2
|
Anthropic = 2
|
||||||
|
HuggingFaceTextGeneration = 3
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,7 @@ en:
|
||||||
ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)"
|
ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)"
|
||||||
ai_openai_api_key: "API key for OpenAI API"
|
ai_openai_api_key: "API key for OpenAI API"
|
||||||
ai_anthropic_api_key: "API key for Anthropic API"
|
ai_anthropic_api_key: "API key for Anthropic API"
|
||||||
|
ai_hugging_face_api_url: "Custom URL used for OpenSource LLM inference. Compatible with https://github.com/huggingface/text-generation-inference"
|
||||||
|
|
||||||
composer_ai_helper_enabled: "Enable the Composer's AI helper."
|
composer_ai_helper_enabled: "Enable the Composer's AI helper."
|
||||||
ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer."
|
ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer."
|
||||||
|
|
|
@ -108,6 +108,9 @@ plugins:
|
||||||
choices:
|
choices:
|
||||||
- "stable-diffusion-xl-beta-v2-2-2"
|
- "stable-diffusion-xl-beta-v2-2-2"
|
||||||
- "stable-diffusion-v1-5"
|
- "stable-diffusion-v1-5"
|
||||||
|
ai_hugging_face_api_url:
|
||||||
|
default: ""
|
||||||
|
|
||||||
|
|
||||||
ai_google_custom_search_api_key:
|
ai_google_custom_search_api_key:
|
||||||
default: ""
|
default: ""
|
||||||
|
|
|
@ -8,6 +8,7 @@ module DiscourseAi
|
||||||
require_relative "models/anthropic"
|
require_relative "models/anthropic"
|
||||||
require_relative "models/discourse"
|
require_relative "models/discourse"
|
||||||
require_relative "models/open_ai"
|
require_relative "models/open_ai"
|
||||||
|
require_relative "models/llama2"
|
||||||
|
|
||||||
require_relative "strategies/fold_content"
|
require_relative "strategies/fold_content"
|
||||||
require_relative "strategies/truncate_content"
|
require_relative "strategies/truncate_content"
|
||||||
|
@ -21,6 +22,7 @@ module DiscourseAi
|
||||||
Models::OpenAi.new("gpt-3.5-turbo-16k", max_tokens: 16_384),
|
Models::OpenAi.new("gpt-3.5-turbo-16k", max_tokens: 16_384),
|
||||||
Models::Discourse.new("long-t5-tglobal-base-16384-book-summary", max_tokens: 16_384),
|
Models::Discourse.new("long-t5-tglobal-base-16384-book-summary", max_tokens: 16_384),
|
||||||
Models::Anthropic.new("claude-2", max_tokens: 100_000),
|
Models::Anthropic.new("claude-2", max_tokens: 100_000),
|
||||||
|
Models::Llama2.new("Llama-2-7b-chat-hf", max_tokens: 4096),
|
||||||
]
|
]
|
||||||
|
|
||||||
foldable_models.each do |model|
|
foldable_models.each do |model|
|
||||||
|
|
|
@ -0,0 +1,104 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Summarization
|
||||||
|
module Models
|
||||||
|
class Llama2 < Base
|
||||||
|
def display_name
|
||||||
|
"Llama2's #{model}"
|
||||||
|
end
|
||||||
|
|
||||||
|
def correctly_configured?
|
||||||
|
SiteSetting.ai_hugging_face_api_url.present?
|
||||||
|
end
|
||||||
|
|
||||||
|
def configuration_hint
|
||||||
|
I18n.t(
|
||||||
|
"discourse_ai.summarization.configuration_hint",
|
||||||
|
count: 1,
|
||||||
|
setting: "ai_hugging_face_api_url",
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def concatenate_summaries(summaries)
|
||||||
|
completion(<<~TEXT)
|
||||||
|
[INST] <<SYS>>
|
||||||
|
You are a helpful bot
|
||||||
|
<</SYS>>
|
||||||
|
|
||||||
|
Concatenate these disjoint summaries, creating a cohesive narrative:
|
||||||
|
#{summaries.join("\n")} [/INST]
|
||||||
|
TEXT
|
||||||
|
end
|
||||||
|
|
||||||
|
def summarize_with_truncation(contents, opts)
|
||||||
|
text_to_summarize = contents.map { |c| format_content_item(c) }.join
|
||||||
|
truncated_content = tokenizer.truncate(text_to_summarize, available_tokens)
|
||||||
|
|
||||||
|
completion(<<~TEXT)
|
||||||
|
[INST] <<SYS>>
|
||||||
|
#{build_base_prompt(opts)}
|
||||||
|
<</SYS>>
|
||||||
|
|
||||||
|
Summarize the following in up to 400 words:
|
||||||
|
#{truncated_content} [/INST]
|
||||||
|
TEXT
|
||||||
|
end
|
||||||
|
|
||||||
|
def summarize_single(chunk_text, opts)
|
||||||
|
summarize_chunk(chunk_text, opts.merge(single_chunk: true))
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def summarize_chunk(chunk_text, opts)
|
||||||
|
summary_instruction =
|
||||||
|
if opts[:single_chunk]
|
||||||
|
"Summarize the following forum discussion, creating a cohesive narrative:"
|
||||||
|
else
|
||||||
|
"Summarize the following in up to 400 words:"
|
||||||
|
end
|
||||||
|
|
||||||
|
completion(<<~TEXT)
|
||||||
|
[INST] <<SYS>>
|
||||||
|
#{build_base_prompt(opts)}
|
||||||
|
<</SYS>>
|
||||||
|
|
||||||
|
#{summary_instruction}
|
||||||
|
#{chunk_text} [/INST]
|
||||||
|
TEXT
|
||||||
|
end
|
||||||
|
|
||||||
|
def build_base_prompt(opts)
|
||||||
|
base_prompt = <<~TEXT
|
||||||
|
You are a summarization bot.
|
||||||
|
You effectively summarise any text and reply ONLY with ONLY the summarized text.
|
||||||
|
You condense it into a shorter version.
|
||||||
|
You understand and generate Discourse forum Markdown.
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
if opts[:resource_path]
|
||||||
|
base_prompt +=
|
||||||
|
"Try generating links as well the format is #{opts[:resource_path]}. eg: [ref](#{opts[:resource_path]}/77)\n"
|
||||||
|
end
|
||||||
|
|
||||||
|
base_prompt += "The discussion title is: #{opts[:content_title]}.\n" if opts[
|
||||||
|
:content_title
|
||||||
|
]
|
||||||
|
|
||||||
|
base_prompt
|
||||||
|
end
|
||||||
|
|
||||||
|
def completion(prompt)
|
||||||
|
::DiscourseAi::Inference::HuggingFaceTextGeneration.perform!(prompt, model).dig(
|
||||||
|
:generated_text,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
DiscourseAi::Tokenizer::Llama2Tokenizer
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,137 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module ::DiscourseAi
|
||||||
|
module Inference
|
||||||
|
class HuggingFaceTextGeneration
|
||||||
|
CompletionFailed = Class.new(StandardError)
|
||||||
|
TIMEOUT = 60
|
||||||
|
|
||||||
|
def self.perform!(
|
||||||
|
prompt,
|
||||||
|
model,
|
||||||
|
temperature: 0.7,
|
||||||
|
top_p: nil,
|
||||||
|
top_k: nil,
|
||||||
|
typical_p: nil,
|
||||||
|
max_tokens: 2000,
|
||||||
|
repetition_penalty: 1.1,
|
||||||
|
user_id: nil
|
||||||
|
)
|
||||||
|
raise CompletionFailed if model.blank?
|
||||||
|
|
||||||
|
url = URI(SiteSetting.ai_hugging_face_api_url)
|
||||||
|
if block_given?
|
||||||
|
url.path = "/generate_stream"
|
||||||
|
else
|
||||||
|
url.path = "/generate"
|
||||||
|
end
|
||||||
|
headers = { "Content-Type" => "application/json" }
|
||||||
|
|
||||||
|
parameters = {}
|
||||||
|
payload = { inputs: prompt, parameters: parameters }
|
||||||
|
|
||||||
|
parameters[:top_p] = top_p if top_p
|
||||||
|
parameters[:top_k] = top_k if top_k
|
||||||
|
parameters[:typical_p] = typical_p if typical_p
|
||||||
|
parameters[:max_new_tokens] = max_tokens if max_tokens
|
||||||
|
parameters[:temperature] = temperature if temperature
|
||||||
|
parameters[:repetition_penalty] = repetition_penalty if repetition_penalty
|
||||||
|
|
||||||
|
Net::HTTP.start(
|
||||||
|
url.host,
|
||||||
|
url.port,
|
||||||
|
use_ssl: url.scheme == "https",
|
||||||
|
read_timeout: TIMEOUT,
|
||||||
|
open_timeout: TIMEOUT,
|
||||||
|
write_timeout: TIMEOUT,
|
||||||
|
) do |http|
|
||||||
|
request = Net::HTTP::Post.new(url, headers)
|
||||||
|
request_body = payload.to_json
|
||||||
|
request.body = request_body
|
||||||
|
|
||||||
|
http.request(request) do |response|
|
||||||
|
if response.code.to_i != 200
|
||||||
|
Rails.logger.error(
|
||||||
|
"HuggingFaceTextGeneration: status: #{response.code.to_i} - body: #{response.body}",
|
||||||
|
)
|
||||||
|
raise CompletionFailed
|
||||||
|
end
|
||||||
|
|
||||||
|
log =
|
||||||
|
AiApiAuditLog.create!(
|
||||||
|
provider_id: AiApiAuditLog::Provider::HuggingFaceTextGeneration,
|
||||||
|
raw_request_payload: request_body,
|
||||||
|
user_id: user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if !block_given?
|
||||||
|
response_body = response.read_body
|
||||||
|
parsed_response = JSON.parse(response_body, symbolize_names: true)
|
||||||
|
|
||||||
|
log.update!(
|
||||||
|
raw_response_payload: response_body,
|
||||||
|
request_tokens: DiscourseAi::Tokenizer::Llama2Tokenizer.size(prompt),
|
||||||
|
response_tokens:
|
||||||
|
DiscourseAi::Tokenizer::Llama2Tokenizer.size(parsed_response[:generated_text]),
|
||||||
|
)
|
||||||
|
return parsed_response
|
||||||
|
end
|
||||||
|
|
||||||
|
begin
|
||||||
|
cancelled = false
|
||||||
|
cancel = lambda { cancelled = true }
|
||||||
|
response_data = +""
|
||||||
|
response_raw = +""
|
||||||
|
|
||||||
|
response.read_body do |chunk|
|
||||||
|
if cancelled
|
||||||
|
http.finish
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
response_raw << chunk
|
||||||
|
|
||||||
|
chunk
|
||||||
|
.split("\n")
|
||||||
|
.each do |line|
|
||||||
|
data = line.split("data: ", 2)[1]
|
||||||
|
next if !data || data.squish == "[DONE]"
|
||||||
|
|
||||||
|
if !cancelled
|
||||||
|
begin
|
||||||
|
# partial contains the entire payload till now
|
||||||
|
partial = JSON.parse(data, symbolize_names: true)
|
||||||
|
|
||||||
|
# this is the last chunk and contains the full response
|
||||||
|
next if partial[:token][:special] == true
|
||||||
|
|
||||||
|
response_data = partial[:token][:text].to_s
|
||||||
|
|
||||||
|
yield partial, cancel
|
||||||
|
rescue JSON::ParserError
|
||||||
|
nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
rescue IOError
|
||||||
|
raise if !cancelled
|
||||||
|
ensure
|
||||||
|
log.update!(
|
||||||
|
raw_response_payload: response_raw,
|
||||||
|
request_tokens: DiscourseAi::Tokenizer::Llama2Tokenizer.size(prompt),
|
||||||
|
response_tokens: DiscourseAi::Tokenizer::Llama2Tokenizer.size(response_data),
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def self.try_parse(data)
|
||||||
|
JSON.parse(data, symbolize_names: true)
|
||||||
|
rescue JSON::ParserError
|
||||||
|
nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -52,6 +52,13 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
class Llama2Tokenizer < BasicTokenizer
|
||||||
|
def self.tokenizer
|
||||||
|
@@tokenizer ||=
|
||||||
|
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/llama-2-70b-chat-hf.json")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
class OpenAiTokenizer < BasicTokenizer
|
class OpenAiTokenizer < BasicTokenizer
|
||||||
class << self
|
class << self
|
||||||
def tokenizer
|
def tokenizer
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
# url: https://meta.discourse.org/t/discourse-ai/259214
|
# url: https://meta.discourse.org/t/discourse-ai/259214
|
||||||
# required_version: 2.7.0
|
# required_version: 2.7.0
|
||||||
|
|
||||||
gem "tokenizers", "0.3.2"
|
gem "tokenizers", "0.3.3"
|
||||||
gem "tiktoken_ruby", "0.0.5"
|
gem "tiktoken_ruby", "0.0.5"
|
||||||
|
|
||||||
enabled_site_setting :discourse_ai_enabled
|
enabled_site_setting :discourse_ai_enabled
|
||||||
|
@ -31,6 +31,7 @@ after_initialize do
|
||||||
require_relative "lib/shared/inference/openai_embeddings"
|
require_relative "lib/shared/inference/openai_embeddings"
|
||||||
require_relative "lib/shared/inference/anthropic_completions"
|
require_relative "lib/shared/inference/anthropic_completions"
|
||||||
require_relative "lib/shared/inference/stability_generator"
|
require_relative "lib/shared/inference/stability_generator"
|
||||||
|
require_relative "lib/shared/inference/hugging_face_text_generation"
|
||||||
|
|
||||||
require_relative "lib/shared/classificator"
|
require_relative "lib/shared/classificator"
|
||||||
require_relative "lib/shared/post_classificator"
|
require_relative "lib/shared/post_classificator"
|
||||||
|
|
|
@ -100,3 +100,20 @@ describe DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer do
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe DiscourseAi::Tokenizer::Llama2Tokenizer do
|
||||||
|
describe "#size" do
|
||||||
|
describe "returns a token count" do
|
||||||
|
it "for a sentence with punctuation and capitalization and numbers" do
|
||||||
|
expect(described_class.size("Hello, World! 123")).to eq(9)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#truncate" do
|
||||||
|
it "truncates a sentence" do
|
||||||
|
sentence = "foo bar baz qux quux corge grault garply waldo fred plugh xyzzy thud"
|
||||||
|
expect(described_class.truncate(sentence, 3)).to eq("foo bar")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
|
@ -9,3 +9,7 @@ Licensed under MIT License
|
||||||
## all-mpnet-base-v2.json
|
## all-mpnet-base-v2.json
|
||||||
|
|
||||||
Licensed under Apache License
|
Licensed under Apache License
|
||||||
|
|
||||||
|
## llama-2-70b-chat-hf
|
||||||
|
|
||||||
|
Licensed under LLAMA 2 COMMUNITY LICENSE AGREEMENT
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue