diff --git a/lib/ai_bot/tools/tool.rb b/lib/ai_bot/tools/tool.rb index f5112c3e..a695036d 100644 --- a/lib/ai_bot/tools/tool.rb +++ b/lib/ai_bot/tools/tool.rb @@ -4,6 +4,12 @@ module DiscourseAi module AiBot module Tools class Tool + # Why 30 mega bytes? + # This general limit is mainly a security feature to avoid tools + # forcing infinite downloads or causing memory exhaustion. + # The limit is somewhat arbitrary and can be increased in future if needed. + MAX_RESPONSE_BODY_LENGTH = 30.megabyte + class << self def signature raise NotImplemented @@ -158,14 +164,24 @@ module DiscourseAi end end - def read_response_body(response, max_length: 4.megabyte) + def self.read_response_body(response, max_length: nil) + max_length ||= MAX_RESPONSE_BODY_LENGTH + body = +"" response.read_body do |chunk| body << chunk break if body.bytesize > max_length end - body[0..max_length] + if body.bytesize > max_length + body[0...max_length].scrub + else + body.scrub + end + end + + def read_response_body(response, max_length: nil) + self.class.read_response_body(response, max_length: max_length) end def truncate(text, llm:, percent_length: nil, max_length: nil) diff --git a/spec/lib/modules/ai_bot/tools/tool_spec.rb b/spec/lib/modules/ai_bot/tools/tool_spec.rb new file mode 100644 index 00000000..e7eca272 --- /dev/null +++ b/spec/lib/modules/ai_bot/tools/tool_spec.rb @@ -0,0 +1,44 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::AiBot::Tools::Tool do + let :tool_class do + described_class + end + + let :corrupt_string do + "\xC3\x28\xA0\xA1\xE2\x28\xA1\xE2\x82\x28\xF0\x28\x8C\xBC" + end + + describe "#read_response_body" do + class FakeResponse + def initialize(chunk) + @chunk = chunk + end + + def read_body + yield @chunk while true + end + end + + it "never returns a corrupt string" do + response = FakeResponse.new(corrupt_string) + result = tool_class.read_response_body(response, max_length: 100.bytes) + + expect(result.encoding).to eq(Encoding::UTF_8) + expect(result.valid_encoding?).to eq(true) + + # scrubbing removes 7 chars + expect(result.length).to eq(93) + end + + it "returns correctly truncated strings" do + response = FakeResponse.new("abc") + result = tool_class.read_response_body(response, max_length: 10.bytes) + + expect(result.encoding).to eq(Encoding::UTF_8) + expect(result.valid_encoding?).to eq(true) + + expect(result).to eq("abcabcabca") + end + end +end