DEV: Consolidate Redis evalsha logic into DiscourseRedis::EvalHelper (#15957)

This commit is contained in:
David Taylor 2022-02-15 16:06:12 +00:00 committed by GitHub
parent dd5373cc4c
commit 11c93342dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 79 additions and 49 deletions

View File

@ -161,4 +161,20 @@ class DiscourseRedis
key[(namespace.length + 1)..-1] key[(namespace.length + 1)..-1]
end end
class EvalHelper
def initialize(script)
@script = script
@sha1 = Digest::SHA1.hexdigest(script)
end
def eval(redis, *args, **kwargs)
redis.evalsha @sha1, *args, **kwargs
rescue ::Redis::CommandError => e
if e.to_s =~ /^NOSCRIPT/
redis.eval @script, *args, **kwargs
else
raise
end
end
end
end end

View File

@ -49,6 +49,12 @@ module Middleware
ACCEPT_ENCODING = "HTTP_ACCEPT_ENCODING" ACCEPT_ENCODING = "HTTP_ACCEPT_ENCODING"
DISCOURSE_RENDER = "HTTP_DISCOURSE_RENDER" DISCOURSE_RENDER = "HTTP_DISCOURSE_RENDER"
REDIS_STORE_SCRIPT = DiscourseRedis::EvalHelper.new <<~LUA
local current = redis.call("incr", KEYS[1])
redis.call("expire",KEYS[1],ARGV[1])
return current
LUA
def initialize(env, request = nil) def initialize(env, request = nil)
@env = env @env = env
@request = request || Rack::Request.new(@env) @request = request || Rack::Request.new(@env)
@ -259,11 +265,7 @@ module Middleware
if status == 200 && cache_duration if status == 200 && cache_duration
if GlobalSetting.anon_cache_store_threshold > 1 if GlobalSetting.anon_cache_store_threshold > 1
count = Discourse.redis.eval(<<~REDIS, [cache_key_count], [cache_duration]) count = REDIS_STORE_SCRIPT.eval(Discourse.redis, [cache_key_count], [cache_duration])
local current = redis.call("incr", KEYS[1])
redis.call("expire",KEYS[1],ARGV[1])
return current
REDIS
# technically lua will cast for us, but might as well be # technically lua will cast for us, but might as well be
# prudent here, hence the to_i # prudent here, hence the to_i

View File

@ -260,15 +260,7 @@ class PresenceChannel
end end
def self.redis_eval(key, *args) def self.redis_eval(key, *args)
script_sha1 = LUA_SCRIPTS_SHA1[key] LUA_SCRIPTS[key].eval(redis, *args)
raise ArgumentError.new("No script for #{key}") if script_sha1.nil?
redis.evalsha script_sha1, *args
rescue ::Redis::CommandError => e
if e.to_s =~ /^NOSCRIPT/
redis.eval LUA_SCRIPTS[key], *args
else
raise
end
end end
# Register a callback to configure channels with a given prefix # Register a callback to configure channels with a given prefix
@ -473,7 +465,7 @@ class PresenceChannel
LUA_SCRIPTS ||= {} LUA_SCRIPTS ||= {}
LUA_SCRIPTS[:present] = <<~LUA LUA_SCRIPTS[:present] = DiscourseRedis::EvalHelper.new <<~LUA
#{COMMON_PRESENT_LEAVE_LUA} #{COMMON_PRESENT_LEAVE_LUA}
if mutex_locked then if mutex_locked then
@ -502,7 +494,7 @@ class PresenceChannel
return added_users return added_users
LUA LUA
LUA_SCRIPTS[:leave] = <<~LUA LUA_SCRIPTS[:leave] = DiscourseRedis::EvalHelper.new <<~LUA
#{COMMON_PRESENT_LEAVE_LUA} #{COMMON_PRESENT_LEAVE_LUA}
if mutex_locked then if mutex_locked then
@ -532,7 +524,7 @@ class PresenceChannel
return removed_users return removed_users
LUA LUA
LUA_SCRIPTS[:release_mutex] = <<~LUA LUA_SCRIPTS[:release_mutex] = DiscourseRedis::EvalHelper.new <<~LUA
local mutex_key = KEYS[1] local mutex_key = KEYS[1]
local expected_value = ARGV[1] local expected_value = ARGV[1]
@ -541,7 +533,7 @@ class PresenceChannel
end end
LUA LUA
LUA_SCRIPTS[:user_ids] = <<~LUA LUA_SCRIPTS[:user_ids] = DiscourseRedis::EvalHelper.new <<~LUA
local zlist_key = KEYS[1] local zlist_key = KEYS[1]
local hash_key = KEYS[2] local hash_key = KEYS[2]
local message_bus_id_key = KEYS[4] local message_bus_id_key = KEYS[4]
@ -562,7 +554,7 @@ class PresenceChannel
return { message_bus_id, user_ids } return { message_bus_id, user_ids }
LUA LUA
LUA_SCRIPTS[:count] = <<~LUA LUA_SCRIPTS[:count] = DiscourseRedis::EvalHelper.new <<~LUA
local zlist_key = KEYS[1] local zlist_key = KEYS[1]
local hash_key = KEYS[2] local hash_key = KEYS[2]
local message_bus_id_key = KEYS[4] local message_bus_id_key = KEYS[4]
@ -582,7 +574,7 @@ class PresenceChannel
return { message_bus_id, count } return { message_bus_id, count }
LUA LUA
LUA_SCRIPTS[:auto_leave] = <<~LUA LUA_SCRIPTS[:auto_leave] = DiscourseRedis::EvalHelper.new <<~LUA
local zlist_key = KEYS[1] local zlist_key = KEYS[1]
local hash_key = KEYS[2] local hash_key = KEYS[2]
local channels_key = KEYS[3] local channels_key = KEYS[3]
@ -626,9 +618,4 @@ class PresenceChannel
return expired_user_ids return expired_user_ids
LUA LUA
LUA_SCRIPTS.freeze
LUA_SCRIPTS_SHA1 = LUA_SCRIPTS.transform_values do |script|
Digest::SHA1.hexdigest(script)
end.freeze
end end

View File

@ -70,7 +70,7 @@ class RateLimiter
# reloader friendly # reloader friendly
unless defined? PERFORM_LUA unless defined? PERFORM_LUA
PERFORM_LUA = <<~LUA PERFORM_LUA = DiscourseRedis::EvalHelper.new <<~LUA
local now = tonumber(ARGV[1]) local now = tonumber(ARGV[1])
local secs = tonumber(ARGV[2]) local secs = tonumber(ARGV[2])
local max = tonumber(ARGV[3]) local max = tonumber(ARGV[3])
@ -89,12 +89,10 @@ class RateLimiter
return 0 return 0
end end
LUA LUA
PERFORM_LUA_SHA = Digest::SHA1.hexdigest(PERFORM_LUA)
end end
unless defined? PERFORM_LUA_AGGRESSIVE unless defined? PERFORM_LUA_AGGRESSIVE
PERFORM_LUA_AGGRESSIVE = <<~LUA PERFORM_LUA_AGGRESSIVE = DiscourseRedis::EvalHelper.new <<~LUA
local now = tonumber(ARGV[1]) local now = tonumber(ARGV[1])
local secs = tonumber(ARGV[2]) local secs = tonumber(ARGV[2])
local max = tonumber(ARGV[3]) local max = tonumber(ARGV[3])
@ -116,8 +114,6 @@ class RateLimiter
return return_val return return_val
LUA LUA
PERFORM_LUA_AGGRESSIVE_SHA = Digest::SHA1.hexdigest(PERFORM_LUA_AGGRESSIVE)
end end
def performed!(raise_error: true) def performed!(raise_error: true)
@ -161,15 +157,9 @@ class RateLimiter
def rate_limiter_allowed?(now) def rate_limiter_allowed?(now)
lua, lua_sha = nil lua, lua_sha = nil
if @aggressive eval_helper = @aggressive ? PERFORM_LUA_AGGRESSIVE : PERFORM_LUA
lua = PERFORM_LUA_AGGRESSIVE
lua_sha = PERFORM_LUA_AGGRESSIVE_SHA
else
lua = PERFORM_LUA
lua_sha = PERFORM_LUA_SHA
end
eval_lua(lua, lua_sha, [prefixed_key], [now, @secs, @max]) == 0 eval_helper.eval(redis, [prefixed_key], [now, @secs, @max]) == 0
end end
def prefixed_key def prefixed_key
@ -201,14 +191,4 @@ class RateLimiter
def rate_unlimited? def rate_unlimited?
!!(RateLimiter.disabled? || (@user&.staff? && !@apply_limit_to_staff && @staff_limit[:max].nil?)) !!(RateLimiter.disabled? || (@user&.staff? && !@apply_limit_to_staff && @staff_limit[:max].nil?))
end end
def eval_lua(lua, sha, keys, args)
redis.evalsha(sha, keys, args)
rescue Redis::CommandError => e
if e.to_s =~ /^NOSCRIPT/
redis.eval(lua, keys, args)
else
raise
end
end
end end

View File

@ -132,4 +132,49 @@ describe DiscourseRedis do
end end
end end
end end
describe DiscourseRedis::EvalHelper do
it "works" do
helper = DiscourseRedis::EvalHelper.new <<~LUA
return 'hello world'
LUA
expect(helper.eval(Discourse.redis)).to eq('hello world')
end
it "works with arguments" do
helper = DiscourseRedis::EvalHelper.new <<~LUA
return ARGV[1]..ARGV[2]..KEYS[1]..KEYS[2]
LUA
expect(helper.eval(Discourse.redis, ['key1', 'key2'], ['arg1', 'arg2'])).to eq("arg1arg2key1key2")
end
it "works with arguments" do
helper = DiscourseRedis::EvalHelper.new <<~LUA
return ARGV[1]..ARGV[2]..KEYS[1]..KEYS[2]
LUA
expect(helper.eval(Discourse.redis, ['key1', 'key2'], ['arg1', 'arg2'])).to eq("arg1arg2key1key2")
end
it "uses evalsha correctly" do
redis_proxy = Class.new do
attr_reader :calls
def method_missing(meth, *args, **kwargs, &block)
@calls ||= []
@calls.push(meth)
Discourse.redis.public_send(meth, *args, **kwargs, &block)
end
end.new
Discourse.redis.call("SCRIPT", "FLUSH", "SYNC")
helper = DiscourseRedis::EvalHelper.new <<~LUA
return 'hello world'
LUA
expect(helper.eval(redis_proxy)).to eq("hello world")
expect(helper.eval(redis_proxy)).to eq("hello world")
expect(helper.eval(redis_proxy)).to eq("hello world")
expect(redis_proxy.calls).to eq([:evalsha, :eval, :evalsha, :evalsha])
end
end
end end