2023-04-19 11:55:59 -03:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
|
|
|
module DiscourseAi
|
2023-05-15 15:10:42 -03:00
|
|
|
module Tokenizer
|
|
|
|
class BasicTokenizer
|
2023-06-27 12:26:33 -03:00
|
|
|
class << self
|
|
|
|
def tokenizer
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
2023-05-15 15:10:42 -03:00
|
|
|
|
2023-06-27 12:26:33 -03:00
|
|
|
def tokenize(text)
|
|
|
|
tokenizer.encode(text).tokens
|
|
|
|
end
|
|
|
|
|
|
|
|
def size(text)
|
|
|
|
tokenize(text).size
|
|
|
|
end
|
|
|
|
|
|
|
|
def truncate(text, max_length)
|
|
|
|
# Fast track the common case where the text is already short enough.
|
|
|
|
return text if text.size < max_length
|
|
|
|
|
|
|
|
tokenizer.decode(tokenizer.encode(text).ids.take(max_length))
|
|
|
|
end
|
2023-05-17 20:21:28 -03:00
|
|
|
|
2023-06-27 12:26:33 -03:00
|
|
|
def can_expand_tokens?(text, addition, max_length)
|
|
|
|
return true if text.size + addition.size < max_length
|
|
|
|
|
|
|
|
tokenizer.encode(text).ids.length + tokenizer.encode(addition).ids.length < max_length
|
|
|
|
end
|
2023-05-15 15:10:42 -03:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
class BertTokenizer < BasicTokenizer
|
|
|
|
def self.tokenizer
|
|
|
|
@@tokenizer ||=
|
|
|
|
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/bert-base-uncased.json")
|
|
|
|
end
|
2023-04-19 11:55:59 -03:00
|
|
|
end
|
|
|
|
|
2023-05-15 15:10:42 -03:00
|
|
|
class AnthropicTokenizer < BasicTokenizer
|
|
|
|
def self.tokenizer
|
|
|
|
@@tokenizer ||=
|
|
|
|
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/claude-v1-tokenization.json")
|
|
|
|
end
|
2023-05-06 20:31:53 +10:00
|
|
|
end
|
2023-05-15 15:10:42 -03:00
|
|
|
|
2023-07-13 12:41:36 -03:00
|
|
|
class AllMpnetBaseV2Tokenizer < BasicTokenizer
|
|
|
|
def self.tokenizer
|
|
|
|
@@tokenizer ||=
|
|
|
|
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/all-mpnet-base-v2.json")
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2023-07-27 13:55:32 -03:00
|
|
|
class Llama2Tokenizer < BasicTokenizer
|
|
|
|
def self.tokenizer
|
|
|
|
@@tokenizer ||=
|
|
|
|
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/llama-2-70b-chat-hf.json")
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2023-07-27 15:50:03 -03:00
|
|
|
class MultilingualE5LargeTokenizer < BasicTokenizer
|
|
|
|
def self.tokenizer
|
|
|
|
@@tokenizer ||=
|
|
|
|
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/multilingual-e5-large.json")
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2023-05-15 15:10:42 -03:00
|
|
|
class OpenAiTokenizer < BasicTokenizer
|
2023-06-27 12:26:33 -03:00
|
|
|
class << self
|
|
|
|
def tokenizer
|
|
|
|
@@tokenizer ||= Tiktoken.get_encoding("cl100k_base")
|
|
|
|
end
|
2023-05-15 15:10:42 -03:00
|
|
|
|
2023-06-27 12:26:33 -03:00
|
|
|
def tokenize(text)
|
|
|
|
tokenizer.encode(text)
|
|
|
|
end
|
|
|
|
|
|
|
|
def truncate(text, max_length)
|
|
|
|
# Fast track the common case where the text is already short enough.
|
|
|
|
return text if text.size < max_length
|
|
|
|
|
|
|
|
tokenizer.decode(tokenize(text).take(max_length))
|
|
|
|
rescue Tiktoken::UnicodeError
|
|
|
|
max_length = max_length - 1
|
|
|
|
retry
|
|
|
|
end
|
2023-05-15 15:10:42 -03:00
|
|
|
|
2023-06-27 12:26:33 -03:00
|
|
|
def can_expand_tokens?(text, addition, max_length)
|
|
|
|
return true if text.size + addition.size < max_length
|
2023-05-17 20:21:28 -03:00
|
|
|
|
2023-06-27 12:26:33 -03:00
|
|
|
tokenizer.encode(text).length + tokenizer.encode(addition).length < max_length
|
|
|
|
end
|
2023-05-15 15:10:42 -03:00
|
|
|
end
|
2023-04-19 11:55:59 -03:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|