discourse-ai/lib/shared/tokenizer/tokenizer.rb

98 lines
2.6 KiB
Ruby

# frozen_string_literal: true
module DiscourseAi
module Tokenizer
class BasicTokenizer
class << self
def tokenizer
raise NotImplementedError
end
def tokenize(text)
tokenizer.encode(text).tokens
end
def size(text)
tokenize(text).size
end
def truncate(text, max_length)
# Fast track the common case where the text is already short enough.
return text if text.size < max_length
tokenizer.decode(tokenizer.encode(text).ids.take(max_length))
end
def can_expand_tokens?(text, addition, max_length)
return true if text.size + addition.size < max_length
tokenizer.encode(text).ids.length + tokenizer.encode(addition).ids.length < max_length
end
end
end
class BertTokenizer < BasicTokenizer
def self.tokenizer
@@tokenizer ||=
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/bert-base-uncased.json")
end
end
class AnthropicTokenizer < BasicTokenizer
def self.tokenizer
@@tokenizer ||=
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/claude-v1-tokenization.json")
end
end
class AllMpnetBaseV2Tokenizer < BasicTokenizer
def self.tokenizer
@@tokenizer ||=
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/all-mpnet-base-v2.json")
end
end
class Llama2Tokenizer < BasicTokenizer
def self.tokenizer
@@tokenizer ||=
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/llama-2-70b-chat-hf.json")
end
end
class MultilingualE5LargeTokenizer < BasicTokenizer
def self.tokenizer
@@tokenizer ||=
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/multilingual-e5-large.json")
end
end
class OpenAiTokenizer < BasicTokenizer
class << self
def tokenizer
@@tokenizer ||= Tiktoken.get_encoding("cl100k_base")
end
def tokenize(text)
tokenizer.encode(text)
end
def truncate(text, max_length)
# Fast track the common case where the text is already short enough.
return text if text.size < max_length
tokenizer.decode(tokenize(text).take(max_length))
rescue Tiktoken::UnicodeError
max_length = max_length - 1
retry
end
def can_expand_tokens?(text, addition, max_length)
return true if text.size + addition.size < max_length
tokenizer.encode(text).length + tokenizer.encode(addition).length < max_length
end
end
end
end
end