FIX: Handle unicode on tokenizer (#515)

* FIX: Handle unicode on tokenizer

Our fast track code broke when strings had characters who are longer in tokens than
in UTF-8.

Admins can set `DISCOURSE_AI_STRICT_TOKEN_COUNTING: true` in app.yml to ensure token counting is strict, even if slower.


Co-authored-by: wozulong <sidle.pax_0e@icloud.com>
This commit is contained in:
Rafael dos Santos Silva 2024-03-14 17:33:30 -03:00 committed by GitHub
parent b327313115
commit 3b8f900486
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 44 additions and 6 deletions

View File

@ -177,6 +177,9 @@ discourse_ai:
default: ""
hidden: true
ai_llava_api_key: ""
ai_strict_token_counting:
default: false
hidden: true
composer_ai_helper_enabled:
default: false

View File

@ -17,14 +17,19 @@ module DiscourseAi
end
def truncate(text, max_length)
# Fast track the common case where the text is already short enough.
return text if text.size < max_length
# fast track common case, /2 to handle unicode chars
# than can take more than 1 token per char
return text if !SiteSetting.ai_strict_token_counting && text.size < max_length / 2
tokenizer.decode(tokenizer.encode(text).ids.take(max_length))
end
def can_expand_tokens?(text, addition, max_length)
return true if text.size + addition.size < max_length
# fast track common case, /2 to handle unicode chars
# than can take more than 1 token per char
if !SiteSetting.ai_strict_token_counting && text.size + addition.size < max_length / 2
return true
end
tokenizer.encode(text).ids.length + tokenizer.encode(addition).ids.length < max_length
end

View File

@ -13,8 +13,9 @@ module DiscourseAi
end
def truncate(text, max_length)
# Fast track the common case where the text is already short enough.
return text if text.size < max_length
# fast track common case, /2 to handle unicode chars
# than can take more than 1 token per char
return text if !SiteSetting.ai_strict_token_counting && text.size < max_length / 2
tokenizer.decode(tokenize(text).take(max_length))
rescue Tiktoken::UnicodeError
@ -23,7 +24,11 @@ module DiscourseAi
end
def can_expand_tokens?(text, addition, max_length)
return true if text.size + addition.size < max_length
# fast track common case, /2 to handle unicode chars
# than can take more than 1 token per char
if !SiteSetting.ai_strict_token_counting && text.size + addition.size < max_length / 2
return true
end
tokenizer.encode(text).length + tokenizer.encode(addition).length < max_length
end

View File

@ -81,6 +81,31 @@ describe DiscourseAi::Tokenizer::OpenAiTokenizer do
sentence = "foo bar 👨🏿‍👩🏿‍👧🏿‍👧🏿 baz qux quux corge grault garply waldo fred plugh xyzzy thud"
expect(described_class.truncate(sentence, 7)).to eq("foo bar 👨🏿")
end
it "truncates unicode characters properly when they use more than one token per char" do
sentence = "我喜欢吃比萨"
original_size = described_class.size(sentence)
expect(described_class.size(described_class.truncate(sentence, original_size - 1))).to be <
original_size
end
end
describe "#can_expand_tokens?" do
it "returns true when the tokens can be expanded" do
expect(described_class.can_expand_tokens?("foo bar", "baz qux", 6)).to eq(true)
end
it "returns false when the tokens cannot be expanded" do
expect(described_class.can_expand_tokens?("foo bar", "baz qux", 3)).to eq(false)
end
it "returns false when the tokens cannot be expanded due to multibyte unicode characters" do
expect(described_class.can_expand_tokens?("foo bar 👨🏿", "baz qux", 6)).to eq(false)
end
it "handles unicode characters properly when they use more than one token per char" do
expect(described_class.can_expand_tokens?("我喜欢吃比萨", "", 10)).to eq(false)
end
end
end