diff --git a/lib/middleware/request_tracker.rb b/lib/middleware/request_tracker.rb index a346b44112d..e27622547b1 100644 --- a/lib/middleware/request_tracker.rb +++ b/lib/middleware/request_tracker.rb @@ -241,7 +241,8 @@ class Middleware::RequestTracker "global_ip_limit_10_#{ip}", GlobalSetting.max_reqs_per_ip_per_10_seconds, 10, - global: true + global: true, + aggressive: true ) limiter60 = RateLimiter.new( @@ -249,7 +250,8 @@ class Middleware::RequestTracker "global_ip_limit_60_#{ip}", GlobalSetting.max_reqs_per_ip_per_minute, 60, - global: true + global: true, + aggressive: true ) limiter_assets10 = RateLimiter.new( diff --git a/lib/rate_limiter.rb b/lib/rate_limiter.rb index 7f3094e11a6..26225939866 100644 --- a/lib/rate_limiter.rb +++ b/lib/rate_limiter.rb @@ -37,13 +37,14 @@ class RateLimiter "#{RateLimiter.key_prefix}:#{@user && @user.id}:#{type}" end - def initialize(user, type, max, secs, global: false) + def initialize(user, type, max, secs, global: false, aggressive: false) @user = user @type = type @key = build_key(type) @max = max @secs = secs @global = global + @aggressive = aggressive end def clear! @@ -79,13 +80,38 @@ class RateLimiter PERFORM_LUA_SHA = Digest::SHA1.hexdigest(PERFORM_LUA) end + unless defined? PERFORM_LUA_AGGRESSIVE + PERFORM_LUA_AGGRESSIVE = <<~LUA + local now = tonumber(ARGV[1]) + local secs = tonumber(ARGV[2]) + local max = tonumber(ARGV[3]) + + local key = KEYS[1] + + local return_val = 0 + + if ((tonumber(redis.call("LLEN", key)) < max) or + (now - tonumber(redis.call("LRANGE", key, -1, -1)[1])) > secs) then + return_val = 1 + else + return_val = 0 + end + + redis.call("LPUSH", key, now) + redis.call("LTRIM", key, 0, max - 1) + redis.call("EXPIRE", key, secs * 2) + + return return_val + LUA + + PERFORM_LUA_AGGRESSIVE_SHA = Digest::SHA1.hexdigest(PERFORM_LUA_AGGRESSIVE) + end + def performed!(raise_error: true) return true if rate_unlimited? now = Time.now.to_i - if ((max || 0) <= 0) || - (eval_lua(PERFORM_LUA, PERFORM_LUA_SHA, [prefixed_key], [now, @secs, @max]) == 0) - + if ((max || 0) <= 0) || rate_limiter_allowed?(now) raise RateLimiter::LimitExceeded.new(seconds_to_wait, @type) if raise_error false else @@ -121,6 +147,20 @@ class RateLimiter private + def rate_limiter_allowed?(now) + + lua, lua_sha = nil + if @aggressive + lua = PERFORM_LUA_AGGRESSIVE + lua_sha = PERFORM_LUA_AGGRESSIVE_SHA + else + lua = PERFORM_LUA + lua_sha = PERFORM_LUA_SHA + end + + eval_lua(lua, lua_sha, [prefixed_key], [now, @secs, @max]) == 0 + end + def prefixed_key if @global "GLOBAL::#{key}" diff --git a/spec/components/rate_limiter_spec.rb b/spec/components/rate_limiter_spec.rb index 7e18fee7788..506fa13c531 100644 --- a/spec/components/rate_limiter_spec.rb +++ b/spec/components/rate_limiter_spec.rb @@ -38,6 +38,48 @@ describe RateLimiter do RateLimiter.disable end + context 'aggressive rate limiter' do + + it 'can operate correctly and totally stop limiting' do + + freeze_time + + # 2 requests every 30 seconds + limiter = RateLimiter.new(nil, "test", 2, 30, global: true, aggressive: true) + limiter.clear! + + limiter.performed! + limiter.performed! + + freeze_time 29.seconds.from_now + + expect do + limiter.performed! + end.to raise_error(RateLimiter::LimitExceeded) + + expect do + limiter.performed! + end.to raise_error(RateLimiter::LimitExceeded) + + # in aggressive mode both these ^^^ count as an attempt + freeze_time 29.seconds.from_now + + expect do + limiter.performed! + end.to raise_error(RateLimiter::LimitExceeded) + + expect do + limiter.performed! + end.to raise_error(RateLimiter::LimitExceeded) + + freeze_time 31.seconds.from_now + + limiter.performed! + limiter.performed! + + end + end + context 'global rate limiter' do it 'can operate in global mode' do