FIX: never provide tools with invalid UTF-8 strings (#692)

Previous to this change, on truncation we could return invalid
UTF-8 strings to caller

This also allows tools to read up to 30 megs vs the old 4 megs.
This commit is contained in:
Sam 2024-06-27 14:06:52 +10:00 committed by GitHub
parent e26c5986f2
commit af4f871096
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 2 deletions

View File

@ -4,6 +4,12 @@ module DiscourseAi
module AiBot module AiBot
module Tools module Tools
class Tool class Tool
# Why 30 mega bytes?
# This general limit is mainly a security feature to avoid tools
# forcing infinite downloads or causing memory exhaustion.
# The limit is somewhat arbitrary and can be increased in future if needed.
MAX_RESPONSE_BODY_LENGTH = 30.megabyte
class << self class << self
def signature def signature
raise NotImplemented raise NotImplemented
@ -158,14 +164,24 @@ module DiscourseAi
end end
end end
def read_response_body(response, max_length: 4.megabyte) def self.read_response_body(response, max_length: nil)
max_length ||= MAX_RESPONSE_BODY_LENGTH
body = +"" body = +""
response.read_body do |chunk| response.read_body do |chunk|
body << chunk body << chunk
break if body.bytesize > max_length break if body.bytesize > max_length
end end
body[0..max_length] if body.bytesize > max_length
body[0...max_length].scrub
else
body.scrub
end
end
def read_response_body(response, max_length: nil)
self.class.read_response_body(response, max_length: max_length)
end end
def truncate(text, llm:, percent_length: nil, max_length: nil) def truncate(text, llm:, percent_length: nil, max_length: nil)

View File

@ -0,0 +1,44 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::Tools::Tool do
let :tool_class do
described_class
end
let :corrupt_string do
"\xC3\x28\xA0\xA1\xE2\x28\xA1\xE2\x82\x28\xF0\x28\x8C\xBC"
end
describe "#read_response_body" do
class FakeResponse
def initialize(chunk)
@chunk = chunk
end
def read_body
yield @chunk while true
end
end
it "never returns a corrupt string" do
response = FakeResponse.new(corrupt_string)
result = tool_class.read_response_body(response, max_length: 100.bytes)
expect(result.encoding).to eq(Encoding::UTF_8)
expect(result.valid_encoding?).to eq(true)
# scrubbing removes 7 chars
expect(result.length).to eq(93)
end
it "returns correctly truncated strings" do
response = FakeResponse.new("abc")
result = tool_class.read_response_body(response, max_length: 10.bytes)
expect(result.encoding).to eq(Encoding::UTF_8)
expect(result.valid_encoding?).to eq(true)
expect(result).to eq("abcabcabca")
end
end
end