discourse/lib/presence_channel.rb

622 lines
20 KiB
Ruby

# frozen_string_literal: true
# The server-side implementation of PresenceChannels. See also {PresenceController}
# and +app/assets/javascripts/discourse/app/services/presence.js+
class PresenceChannel
class NotFound < StandardError
end
class InvalidAccess < StandardError
end
class ConfigNotLoaded < StandardError
end
class InvalidConfig < StandardError
end
class State
include ActiveModel::Serialization
attr_reader :message_bus_last_id
attr_reader :user_ids
attr_reader :count
def initialize(message_bus_last_id:, user_ids: nil, count: nil)
raise "user_ids or count required" if user_ids.nil? && count.nil?
@message_bus_last_id = message_bus_last_id
@user_ids = user_ids
@count = count || user_ids.count
end
def users
return nil if user_ids.nil?
User.where(id: user_ids)
end
end
# Class for managing config of PresenceChannel
# Three parameters can be provided on initialization:
# public: boolean value. If true, channel information is visible to all users (default false)
# allowed_user_ids: array of user_ids that can view, and become present in, the channel (default [])
# allowed_group_ids: array of group_ids that can view, and become present in, the channel (default [])
# count_only: boolean. If true, user identities are never revealed to clients. (default [])
class Config
NOT_FOUND ||= "notfound"
attr_accessor :public, :allowed_user_ids, :allowed_group_ids, :count_only, :timeout
def initialize(
public: false,
allowed_user_ids: nil,
allowed_group_ids: nil,
count_only: false,
timeout: nil
)
@public = public
@allowed_user_ids = allowed_user_ids
@allowed_group_ids = allowed_group_ids
@count_only = count_only
@timeout = timeout
end
def self.from_json(json)
data = JSON.parse(json, symbolize_names: true)
data = {} if !data.is_a? Hash
new(**data.slice(:public, :allowed_user_ids, :allowed_group_ids, :count_only, :timeout))
end
def to_json
data = { public: public }
data[:allowed_user_ids] = allowed_user_ids if allowed_user_ids
data[:allowed_group_ids] = allowed_group_ids if allowed_group_ids
data[:count_only] = count_only if count_only
data[:timeout] = timeout if timeout
data.to_json
end
end
DEFAULT_TIMEOUT ||= 60
CONFIG_CACHE_SECONDS ||= 10
GC_SECONDS ||= 24.hours.to_i
MUTEX_TIMEOUT_SECONDS ||= 10
MUTEX_LOCKED_ERROR ||= "PresenceChannel mutex is locked"
@@configuration_blocks ||= {}
attr_reader :name, :timeout, :message_bus_channel_name, :config
def initialize(name, raise_not_found: true, use_cache: true)
@name = name
@message_bus_channel_name = "/presence#{name}"
begin
@config = fetch_config(use_cache: use_cache)
rescue PresenceChannel::NotFound
raise if raise_not_found
@config = Config.new
end
@timeout = config.timeout || DEFAULT_TIMEOUT
end
# Is this user allowed to view this channel?
# Pass `nil` for anonymous viewers
def can_view?(user_id: nil, group_ids: nil)
return true if config.public
return true if user_id && config.allowed_user_ids&.include?(user_id)
if user_id && config.allowed_group_ids.present?
return true if config.allowed_group_ids.include?(Group::AUTO_GROUPS[:everyone])
group_ids ||= GroupUser.where(user_id: user_id).pluck("group_id")
return true if (group_ids & config.allowed_group_ids).present?
end
false
end
# Is a user allowed to enter this channel?
# Currently equal to the the can_view? permission
def can_enter?(user_id: nil, group_ids: nil)
return false if user_id.nil?
can_view?(user_id: user_id, group_ids: group_ids)
end
# Mark a user's client as present in this channel. The client_id should be unique per
# browser tab. This method should be called repeatedly (at least once every DEFAULT_TIMEOUT)
# while the user is present in the channel.
def present(user_id:, client_id:)
raise PresenceChannel::InvalidAccess if !can_enter?(user_id: user_id)
mutex_value = SecureRandom.hex
result =
retry_on_mutex_error do
PresenceChannel.redis_eval(
:present,
redis_keys,
[name, user_id, client_id, (Time.zone.now + timeout).to_i, mutex_value],
)
end
if result == 1
begin
publish_message(entering_user_ids: [user_id])
ensure
release_mutex(mutex_value)
end
end
end
# Immediately mark a user's client as leaving the channel
def leave(user_id:, client_id:)
mutex_value = SecureRandom.hex
result =
retry_on_mutex_error do
PresenceChannel.redis_eval(:leave, redis_keys, [name, user_id, client_id, nil, mutex_value])
end
if result == 1
begin
publish_message(leaving_user_ids: [user_id])
ensure
release_mutex(mutex_value)
end
end
end
# Fetch a {PresenceChannel::State} instance representing the current state of this
#
# @param [Boolean] count_only set true to skip fetching the list of user ids from redis
def state(count_only: config.count_only)
if count_only
last_id, count = retry_on_mutex_error { PresenceChannel.redis_eval(:count, redis_keys) }
else
last_id, ids = retry_on_mutex_error { PresenceChannel.redis_eval(:user_ids, redis_keys) }
end
count ||= ids&.count
last_id = nil if last_id == -1
if Rails.env.test? && MessageBus.backend == :memory
# Doing it this way is not atomic, but we have no other option when
# messagebus is not using the redis backend
last_id = MessageBus.last_id(message_bus_channel_name)
end
State.new(message_bus_last_id: last_id, user_ids: ids, count: count)
end
def user_ids
state.user_ids
end
def count
state(count_only: true).count
end
# Automatically expire all users which have not been 'present' for more than +DEFAULT_TIMEOUT+
def auto_leave
mutex_value = SecureRandom.hex
left_user_ids =
retry_on_mutex_error do
PresenceChannel.redis_eval(:auto_leave, redis_keys, [name, Time.zone.now.to_i, mutex_value])
end
if !left_user_ids.empty?
begin
publish_message(leaving_user_ids: left_user_ids)
ensure
release_mutex(mutex_value)
end
end
end
# Clear all members of the channel. This is intended for debugging/development only
def clear
PresenceChannel.redis.del(redis_key_zlist)
PresenceChannel.redis.del(redis_key_hash)
PresenceChannel.redis.del(redis_key_config)
PresenceChannel.redis.del(redis_key_mutex)
PresenceChannel.redis.zrem(self.class.redis_key_channel_list, name)
end
# Designed to be run periodically. Checks the channel list for channels with expired members,
# and runs auto_leave for each eligible channel
def self.auto_leave_all
channels_with_expiring_members =
PresenceChannel.redis.zrangebyscore(redis_key_channel_list, "-inf", Time.zone.now.to_i)
channels_with_expiring_members.each { |name| new(name, raise_not_found: false).auto_leave }
end
# Clear all known channels. This is intended for debugging/development only
def self.clear_all!
channels = PresenceChannel.redis.zrangebyscore(redis_key_channel_list, "-inf", "+inf")
channels.each { |name| new(name, raise_not_found: false).clear }
config_cache_keys =
PresenceChannel
.redis
.scan_each(match: Discourse.redis.namespace_key("_presence_*_config"))
.to_a
PresenceChannel.redis.del(*config_cache_keys) if config_cache_keys.present?
end
# Shortcut to access a redis client for all PresenceChannel activities.
# PresenceChannel must use the same Redis server as MessageBus, so that
# actions can be applied atomically. For the vast majority of Discourse
# installations, this is the same Redis server as `Discourse.redis`.
def self.redis
if MessageBus.backend == :redis
MessageBus.backend_instance.send(:pub_redis) # TODO: avoid a private API?
elsif Rails.env.test?
Discourse.redis.without_namespace
else
raise "PresenceChannel is unable to access MessageBus's Redis instance"
end
end
def self.redis_eval(key, *args)
LUA_SCRIPTS[key].eval(redis, *args)
end
# Register a callback to configure channels with a given prefix
# Prefix must match [a-zA-Z0-9_-]+
#
# For example, this registration will be used for
# all channels starting /topic-reply/...:
#
# register_prefix("topic-reply") do |channel_name|
# PresenceChannel::Config.new(public: true)
# end
#
# At runtime, the block will be passed a full channel name. If the channel
# should not exist, the block should return `nil`. If the channel should exist,
# the block should return a PresenceChannel::Config object.
#
# Return values may be cached for up to 10 seconds.
#
# Plugins should use the {Plugin::Instance.register_presence_channel_prefix} API instead
def self.register_prefix(prefix, &block)
unless prefix.match? /[a-zA-Z0-9_-]+/
raise "PresenceChannel prefix #{prefix} must match [a-zA-Z0-9_-]+"
end
if @@configuration_blocks&.[](prefix)
raise "PresenceChannel prefix #{prefix} already registered"
end
@@configuration_blocks[prefix] = block
end
# For use in a test environment only
def self.unregister_prefix(prefix)
raise "Only allowed in test environment" if !Rails.env.test?
@@configuration_blocks&.delete(prefix)
end
private
def fetch_config(use_cache: true)
cached_config = (PresenceChannel.redis.get(redis_key_config) if use_cache)
if cached_config == Config::NOT_FOUND
raise PresenceChannel::NotFound
elsif cached_config
Config.from_json(cached_config)
else
prefix = name[%r{/([a-zA-Z0-9_-]+)/.*}, 1]
raise PresenceChannel::NotFound if prefix.nil?
config_block = @@configuration_blocks[prefix]
config_block ||=
DiscoursePluginRegistry.presence_channel_prefixes.find { |t| t[0] == prefix }&.[](1)
raise PresenceChannel::NotFound if config_block.nil?
result = config_block.call(name)
to_cache =
if result.is_a? Config
result.to_json
elsif result.nil?
Config::NOT_FOUND
else
raise InvalidConfig.new "Expected PresenceChannel::Config or nil. Got a #{result.class.name}"
end
PresenceChannel.redis.set(redis_key_config, to_cache, ex: CONFIG_CACHE_SECONDS)
raise PresenceChannel::NotFound if result.nil?
result
end
end
def publish_message(entering_user_ids: nil, leaving_user_ids: nil)
message = {}
if config.count_only
message["count_delta"] = entering_user_ids&.count || 0
message["count_delta"] -= leaving_user_ids&.count || 0
return if message["count_delta"] == 0
else
message["leaving_user_ids"] = leaving_user_ids if leaving_user_ids.present?
if entering_user_ids.present?
users = User.where(id: entering_user_ids)
message["entering_users"] = ActiveModel::ArraySerializer.new(
users,
each_serializer: BasicUserSerializer,
)
end
end
params = {}
if config.public
# no params required
elsif config.allowed_user_ids || config.allowed_group_ids
params[:user_ids] = config.allowed_user_ids
params[:group_ids] = config.allowed_group_ids
else
# nobody is allowed... don't publish anything
return
end
MessageBus.publish(message_bus_channel_name, message.as_json, **params)
end
# Most atomic actions are achieved via lua scripts. However, when a lua action
# will result in publishing a messagebus message, the atomicity is broken.
#
# For example, if one process is handling a 'user enter' event, and another is
# handling a 'user leave' event, we need to make sure the messagebus messages
# are published in the same sequence that the PresenceChannel lua script are run.
#
# The present/leave/auto_leave lua scripts will automatically acquire this mutex
# if needed. If their return value indicates a change has occurred, the mutex
# should be released via #release_mutex after the messagebus message has been sent
#
# If they need a change, and the mutex is not available, they will raise an error
# and should be retried periodically
def redis_key_mutex
Discourse.redis.namespace_key("_presence_#{name}_mutex")
end
def release_mutex(mutex_value)
PresenceChannel.redis_eval(:release_mutex, [redis_key_mutex], [mutex_value])
end
def retry_on_mutex_error
attempts ||= 0
yield
rescue ::Redis::CommandError => e
if e.to_s =~ /#{MUTEX_LOCKED_ERROR}/ && attempts < 1000
attempts += 1
sleep 0.001
retry
else
raise
end
end
# The redis key which MessageBus uses to store the 'last_id' for the channel
# associated with this PresenceChannel.
def message_bus_last_id_key
return "" if Rails.env.test? && MessageBus.backend == :memory
# TODO: Avoid using private MessageBus methods here
encoded_channel_name = MessageBus.send(:encode_channel_name, message_bus_channel_name)
MessageBus.backend_instance.send(:backlog_id_key, encoded_channel_name)
end
def redis_keys
[
redis_key_zlist,
redis_key_hash,
self.class.redis_key_channel_list,
message_bus_last_id_key,
redis_key_mutex,
]
end
# The zlist is a list of client_ids, ranked by their expiration timestamp
# we periodically delete the 'lowest ranked' items in this list based on the `timeout` of the channel
def redis_key_zlist
Discourse.redis.namespace_key("_presence_#{name}_zlist")
end
# The hash contains a map of user_id => session_count
# when the count for a user reaches 0, the key is deleted
# We use this hash to efficiently count the number of present users
def redis_key_hash
Discourse.redis.namespace_key("_presence_#{name}_hash")
end
# The hash contains a map of user_id => session_count
# when the count for a user reaches 0, the key is deleted
# We use this hash to efficiently count the number of present users
def redis_key_config
Discourse.redis.namespace_key("_presence_#{name}_config")
end
# This list contains all active presence channels, ranked with the expiration timestamp of their least-recently-seen client_id
# We periodically check the 'lowest ranked' items in this list based on the `timeout` of the channel
def self.redis_key_channel_list
Discourse.redis.namespace_key("_presence_channels")
end
COMMON_PRESENT_LEAVE_LUA = <<~LUA
local channel = ARGV[1]
local user_id = ARGV[2]
local client_id = ARGV[3]
local expires = ARGV[4]
local mutex_value = ARGV[5]
local zlist_key = KEYS[1]
local hash_key = KEYS[2]
local channels_key = KEYS[3]
local message_bus_id_key = KEYS[4]
local mutex_key = KEYS[5]
local mutex_locked = redis.call('EXISTS', mutex_key) == 1
local zlist_elem = tostring(user_id) .. " " .. tostring(client_id)
LUA
UPDATE_GLOBAL_CHANNELS_LUA = <<~LUA
-- Update the global channels list with the timestamp of the oldest client
local oldest_client = redis.call('ZRANGE', zlist_key, 0, 0, 'WITHSCORES')
if table.getn(oldest_client) > 0 then
local oldest_client_expire_timestamp = oldest_client[2]
redis.call('ZADD', channels_key, tonumber(oldest_client_expire_timestamp), tostring(channel))
else
-- The channel is now empty, delete from global list
redis.call('ZREM', channels_key, tostring(channel))
end
LUA
LUA_SCRIPTS ||= {}
LUA_SCRIPTS[:present] = DiscourseRedis::EvalHelper.new <<~LUA
#{COMMON_PRESENT_LEAVE_LUA}
if mutex_locked then
local mutex_required = redis.call('HGET', hash_key, tostring(user_id)) == false
if mutex_required then
error("#{MUTEX_LOCKED_ERROR}")
end
end
local added_clients = redis.call('ZADD', zlist_key, expires, zlist_elem)
local added_users = 0
if tonumber(added_clients) > 0 then
local new_count = redis.call('HINCRBY', hash_key, tostring(user_id), 1)
if new_count == 1 then
added_users = 1
redis.call('SET', mutex_key, mutex_value, 'EX', #{MUTEX_TIMEOUT_SECONDS})
end
-- Add the channel to the global channel list. 'NX' means the value will
-- only be set if doesn't already exist
redis.call('ZADD', channels_key, "NX", expires, tostring(channel))
end
redis.call('EXPIREAT', hash_key, expires + #{GC_SECONDS})
redis.call('EXPIREAT', zlist_key, expires + #{GC_SECONDS})
return added_users
LUA
LUA_SCRIPTS[:leave] = DiscourseRedis::EvalHelper.new <<~LUA
#{COMMON_PRESENT_LEAVE_LUA}
if mutex_locked then
local user_session_count = redis.call('HGET', hash_key, tostring(user_id))
local mutex_required = user_session_count == 1 and redis.call('ZRANK', zlist_key, zlist_elem) ~= false
if mutex_required then
error("#{MUTEX_LOCKED_ERROR}")
end
end
-- Remove the user from the channel zlist
local removed_clients = redis.call('ZREM', zlist_key, zlist_elem)
local removed_users = 0
if tonumber(removed_clients) > 0 then
#{UPDATE_GLOBAL_CHANNELS_LUA}
-- Update the user session count in the channel hash
local val = redis.call('HINCRBY', hash_key, user_id, -1)
if val <= 0 then
redis.call('HDEL', hash_key, user_id)
removed_users = 1
redis.call('SET', mutex_key, mutex_value, 'EX', #{MUTEX_TIMEOUT_SECONDS})
end
end
return removed_users
LUA
LUA_SCRIPTS[:release_mutex] = DiscourseRedis::EvalHelper.new <<~LUA
local mutex_key = KEYS[1]
local expected_value = ARGV[1]
if redis.call("GET", mutex_key) == expected_value then
redis.call("DEL", mutex_key)
end
LUA
LUA_SCRIPTS[:user_ids] = DiscourseRedis::EvalHelper.new <<~LUA
local zlist_key = KEYS[1]
local hash_key = KEYS[2]
local message_bus_id_key = KEYS[4]
local mutex_key = KEYS[5]
if redis.call('EXISTS', mutex_key) > 0 then
error('#{MUTEX_LOCKED_ERROR}')
end
local user_ids = redis.call('HKEYS', hash_key)
table.foreach(user_ids, function(k,v) user_ids[k] = tonumber(v) end)
local message_bus_id = tonumber(redis.call('GET', message_bus_id_key))
if message_bus_id == nil then
message_bus_id = -1
end
return { message_bus_id, user_ids }
LUA
LUA_SCRIPTS[:count] = DiscourseRedis::EvalHelper.new <<~LUA
local zlist_key = KEYS[1]
local hash_key = KEYS[2]
local message_bus_id_key = KEYS[4]
local mutex_key = KEYS[5]
if redis.call('EXISTS', mutex_key) > 0 then
error('#{MUTEX_LOCKED_ERROR}')
end
local message_bus_id = tonumber(redis.call('GET', message_bus_id_key))
if message_bus_id == nil then
message_bus_id = -1
end
local count = redis.call('HLEN', hash_key)
return { message_bus_id, count }
LUA
LUA_SCRIPTS[:auto_leave] = DiscourseRedis::EvalHelper.new <<~LUA
local zlist_key = KEYS[1]
local hash_key = KEYS[2]
local channels_key = KEYS[3]
local mutex_key = KEYS[5]
local channel = ARGV[1]
local time = ARGV[2]
local mutex_value = ARGV[3]
local expire = redis.call('ZRANGEBYSCORE', zlist_key, '-inf', time)
local has_mutex = false
local get_mutex = function()
if redis.call('SETNX', mutex_key, mutex_value) == 0 then
error("#{MUTEX_LOCKED_ERROR}")
end
redis.call('EXPIRE', mutex_key, #{MUTEX_TIMEOUT_SECONDS})
has_mutex = true
end
local expired_user_ids = {}
local expireOld = function(k, v)
local user_id = v:match("[^ ]+")
if (not has_mutex) and (tonumber(redis.call('HGET', hash_key, user_id)) == 1) then
get_mutex()
end
local val = redis.call('HINCRBY', hash_key, user_id, -1)
if val <= 0 then
table.insert(expired_user_ids, tonumber(user_id))
redis.call('HDEL', hash_key, user_id)
end
redis.call('ZREM', zlist_key, v)
end
table.foreach(expire, expireOld)
#{UPDATE_GLOBAL_CHANNELS_LUA}
return expired_user_ids
LUA
end