FIX: Handle unicode on tokenizer (#515)
* FIX: Handle unicode on tokenizer Our fast track code broke when strings had characters who are longer in tokens than in UTF-8. Admins can set `DISCOURSE_AI_STRICT_TOKEN_COUNTING: true` in app.yml to ensure token counting is strict, even if slower. Co-authored-by: wozulong <sidle.pax_0e@icloud.com>
This commit is contained in:
parent
b327313115
commit
3b8f900486
|
@ -177,6 +177,9 @@ discourse_ai:
|
||||||
default: ""
|
default: ""
|
||||||
hidden: true
|
hidden: true
|
||||||
ai_llava_api_key: ""
|
ai_llava_api_key: ""
|
||||||
|
ai_strict_token_counting:
|
||||||
|
default: false
|
||||||
|
hidden: true
|
||||||
|
|
||||||
composer_ai_helper_enabled:
|
composer_ai_helper_enabled:
|
||||||
default: false
|
default: false
|
||||||
|
|
|
@ -17,14 +17,19 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def truncate(text, max_length)
|
def truncate(text, max_length)
|
||||||
# Fast track the common case where the text is already short enough.
|
# fast track common case, /2 to handle unicode chars
|
||||||
return text if text.size < max_length
|
# than can take more than 1 token per char
|
||||||
|
return text if !SiteSetting.ai_strict_token_counting && text.size < max_length / 2
|
||||||
|
|
||||||
tokenizer.decode(tokenizer.encode(text).ids.take(max_length))
|
tokenizer.decode(tokenizer.encode(text).ids.take(max_length))
|
||||||
end
|
end
|
||||||
|
|
||||||
def can_expand_tokens?(text, addition, max_length)
|
def can_expand_tokens?(text, addition, max_length)
|
||||||
return true if text.size + addition.size < max_length
|
# fast track common case, /2 to handle unicode chars
|
||||||
|
# than can take more than 1 token per char
|
||||||
|
if !SiteSetting.ai_strict_token_counting && text.size + addition.size < max_length / 2
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
|
||||||
tokenizer.encode(text).ids.length + tokenizer.encode(addition).ids.length < max_length
|
tokenizer.encode(text).ids.length + tokenizer.encode(addition).ids.length < max_length
|
||||||
end
|
end
|
||||||
|
|
|
@ -13,8 +13,9 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def truncate(text, max_length)
|
def truncate(text, max_length)
|
||||||
# Fast track the common case where the text is already short enough.
|
# fast track common case, /2 to handle unicode chars
|
||||||
return text if text.size < max_length
|
# than can take more than 1 token per char
|
||||||
|
return text if !SiteSetting.ai_strict_token_counting && text.size < max_length / 2
|
||||||
|
|
||||||
tokenizer.decode(tokenize(text).take(max_length))
|
tokenizer.decode(tokenize(text).take(max_length))
|
||||||
rescue Tiktoken::UnicodeError
|
rescue Tiktoken::UnicodeError
|
||||||
|
@ -23,7 +24,11 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def can_expand_tokens?(text, addition, max_length)
|
def can_expand_tokens?(text, addition, max_length)
|
||||||
return true if text.size + addition.size < max_length
|
# fast track common case, /2 to handle unicode chars
|
||||||
|
# than can take more than 1 token per char
|
||||||
|
if !SiteSetting.ai_strict_token_counting && text.size + addition.size < max_length / 2
|
||||||
|
return true
|
||||||
|
end
|
||||||
|
|
||||||
tokenizer.encode(text).length + tokenizer.encode(addition).length < max_length
|
tokenizer.encode(text).length + tokenizer.encode(addition).length < max_length
|
||||||
end
|
end
|
||||||
|
|
|
@ -81,6 +81,31 @@ describe DiscourseAi::Tokenizer::OpenAiTokenizer do
|
||||||
sentence = "foo bar 👨🏿👩🏿👧🏿👧🏿 baz qux quux corge grault garply waldo fred plugh xyzzy thud"
|
sentence = "foo bar 👨🏿👩🏿👧🏿👧🏿 baz qux quux corge grault garply waldo fred plugh xyzzy thud"
|
||||||
expect(described_class.truncate(sentence, 7)).to eq("foo bar 👨🏿")
|
expect(described_class.truncate(sentence, 7)).to eq("foo bar 👨🏿")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "truncates unicode characters properly when they use more than one token per char" do
|
||||||
|
sentence = "我喜欢吃比萨"
|
||||||
|
original_size = described_class.size(sentence)
|
||||||
|
expect(described_class.size(described_class.truncate(sentence, original_size - 1))).to be <
|
||||||
|
original_size
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#can_expand_tokens?" do
|
||||||
|
it "returns true when the tokens can be expanded" do
|
||||||
|
expect(described_class.can_expand_tokens?("foo bar", "baz qux", 6)).to eq(true)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "returns false when the tokens cannot be expanded" do
|
||||||
|
expect(described_class.can_expand_tokens?("foo bar", "baz qux", 3)).to eq(false)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "returns false when the tokens cannot be expanded due to multibyte unicode characters" do
|
||||||
|
expect(described_class.can_expand_tokens?("foo bar 👨🏿", "baz qux", 6)).to eq(false)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "handles unicode characters properly when they use more than one token per char" do
|
||||||
|
expect(described_class.can_expand_tokens?("我喜欢吃比萨", "萨", 10)).to eq(false)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue