discourse/script/cache_critical_dns

390 lines
9.5 KiB
Ruby
Executable File

#!/usr/bin/env ruby
# frozen_string_literal: true
# Specifying this env var ensures ruby can load gems installed via the Discourse
# project Gemfile (e.g. pg, redis).
ENV['BUNDLE_GEMFILE'] ||= '/var/www/discourse/Gemfile'
require 'bundler/setup'
require 'ipaddr'
require 'pg'
require 'redis'
require 'resolv'
require 'socket'
require 'time'
CRITICAL_HOST_ENV_VARS = %w{
DISCOURSE_DB_HOST
DISCOURSE_DB_REPLICA_HOST
DISCOURSE_REDIS_HOST
DISCOURSE_REDIS_SLAVE_HOST
DISCOURSE_REDIS_REPLICA_HOST
}
DEFAULT_DB_NAME = "discourse"
HOST_RESOLVER_CACHE = {}
HOST_HEALTHY_CACHE = {}
HOSTS_PATH = ENV['DISCOURSE_DNS_CACHE_HOSTS_FILE'] || "/etc/hosts"
PrioFilter = Struct.new(:min, :max) do
# min and max must be integers and relate to the minimum or maximum accepted
# priority of an SRV RR target.
# The return value from within_threshold? indicates if the priority is less
# than or equal to the upper threshold, or greater than or equal to the
# lower threshold.
def within_threshold?(p)
p >= min && p <= max
end
end
SRV_PRIORITY_THRESHOLD_MIN = 0
SRV_PRIORITY_THRESHOLD_MAX = 65535
SRV_PRIORITY_FILTERS = Hash.new(
PrioFilter.new(SRV_PRIORITY_THRESHOLD_MIN, SRV_PRIORITY_THRESHOLD_MAX))
REFRESH_SECONDS = 30
module DNSClient
def dns_client_with_timeout
Resolv::DNS.open do |dns_client|
dns_client.timeouts = 2
yield dns_client
end
end
end
class Name
include DNSClient
def initialize(hostname)
@name = hostname
end
def resolve
dns_client_with_timeout do |dns_client|
[].tap do |addresses|
addresses.concat(dns_client.getresources(@name, Resolv::DNS::Resource::IN::A).map(&:address))
addresses.concat(dns_client.getresources(@name, Resolv::DNS::Resource::IN::AAAA).map(&:address))
end.map(&:to_s)
end
end
end
class SRVName
include DNSClient
def initialize(srv_hostname, prio_filter)
@name = srv_hostname
@prio_filter = prio_filter
end
def resolve
dns_client_with_timeout do |dns_client|
[].tap do |addresses|
targets = dns_client.getresources(@name, Resolv::DNS::Resource::IN::SRV)
targets.delete_if { |t| !@prio_filter.within_threshold?(t.priority) }
addresses.concat(targets.map { |t| Name.new(t.target.to_s).resolve }.flatten)
end
end
end
end
CacheMeta = Struct.new(:first_seen, :last_seen)
class ResolverCache
def initialize(name)
# instance of Name|SRVName
@name = name
# {IPv4/IPv6 address: CacheMeta}
@cached = {}
end
# resolve returns a list of resolved addresses ordered by the time first seen,
# most recently seen at the head of the list.
# Addresses last seen more than 30 minutes ago are evicted from the cache on
# a call to resolve().
# If an exception occurs during DNS resolution we return whatever addresses are
# cached.
def resolve
@name.resolve.each do |address|
if @cached[address]
@cached[address].last_seen = Time.now.utc
else
@cached[address] = CacheMeta.new(Time.now.utc, Time.now.utc)
end
end
ensure
@cached = @cached.delete_if { |_, meta| Time.now.utc - 30 * 60 > meta.last_seen }
@cached.sort_by { |_, meta| meta.first_seen }.reverse.map(&:first)
end
end
class HealthyCache
def initialize(resolver_cache, check)
@resolver_cache = resolver_cache # instance of ResolverCache
@check = check # lambda function to perform for health checks
@cached = nil # a single IP address that was most recently found to be healthy
end
def first_healthy
address = @resolver_cache.resolve.lazy.select { |addr| @check.call(addr) }.first
if !nilempty(address).nil?
@cached = address
end
@cached
end
end
def redis_healthcheck(host:, password:)
client_opts = {
host: host,
password: password,
timeout: 1,
}
if !nilempty(ENV['DISCOURSE_REDIS_USE_SSL']).nil? then
client_opts[:ssl] = true
client_opts[:ssl_params] = {
verify_mode: OpenSSL::SSL::VERIFY_NONE,
}
end
client = Redis.new(**client_opts)
response = client.ping
response == "PONG"
rescue
false
ensure
client.close if client
end
def postgres_healthcheck(host:, user:, password:, dbname:)
client = PG::Connection.new(
host: host,
user: user,
password: password,
dbname: dbname,
connect_timeout: 2, # minimum
)
client.exec(';').none?
rescue
false
ensure
client.close if client
end
HEALTH_CHECKS = {
"DISCOURSE_DB_HOST": lambda { |addr|
postgres_healthcheck(
host: addr,
user: ENV["DISCOURSE_DB_USERNAME"] || DEFAULT_DB_NAME,
dbname: ENV["DISCOURSE_DB_NAME"] || DEFAULT_DB_NAME,
password: ENV["DISCOURSE_DB_PASSWORD"])},
"DISCOURSE_DB_REPLICA_HOST": lambda { |addr|
postgres_healthcheck(
host: addr,
user: ENV["DISCOURSE_DB_USERNAME"] || DEFAULT_DB_NAME,
dbname: ENV["DISCOURSE_DB_NAME"] || DEFAULT_DB_NAME,
password: ENV["DISCOURSE_DB_PASSWORD"])},
"DISCOURSE_REDIS_HOST": lambda { |addr|
redis_healthcheck(
host: addr,
password: ENV["DISCOURSE_REDIS_PASSWORD"])},
"DISCOURSE_REDIS_REPLICA_HOST": lambda { |addr|
redis_healthcheck(
host: addr,
password: ENV["DISCOURSE_REDIS_PASSWORD"])},
}
def log(msg)
STDERR.puts "#{Time.now.utc.iso8601}: #{msg}"
end
def error(msg)
log(msg)
end
def swap_address(hosts, name, ips)
new_file = []
hosts.split("\n").each do |line|
line.strip!
if line[0] != '#'
_, hostname = line.split(/\s+/)
next if hostname == name
end
new_file << line
new_file << "\n"
end
ips.each do |ip|
new_file << "#{ip} #{name} # AUTO GENERATED: #{Time.now.utc.iso8601}\n"
end
new_file.join
end
def send_counter(name, description, labels, value)
host = "localhost"
port = ENV.fetch("DISCOURSE_PROMETHEUS_COLLECTOR_PORT", 9405).to_i
if labels
labels = labels.map do |k, v|
"\"#{k}\": \"#{v}\""
end.join(",")
else
labels = ""
end
json = <<~JSON
{
"_type": "Custom",
"type": "Counter",
"name": "#{name}",
"description": "#{description}",
"labels": { #{labels} },
"value": #{value}
}
JSON
payload = +"POST /send-metrics HTTP/1.1\r\n"
payload << "Host: #{host}\r\n"
payload << "Connection: Close\r\n"
payload << "Content-Type: application/json\r\n"
payload << "Content-Length: #{json.bytesize}\r\n"
payload << "\r\n"
payload << json
socket = TCPSocket.new host, port
socket.write payload
socket.flush
result = socket.read
first_line = result.split("\n")[0]
if first_line.strip != "HTTP/1.1 200 OK"
error("Failed to report metric #{result}")
end
socket.close
rescue => e
error("Failed to send metric to Prometheus #{e}")
end
def report_success
send_counter('critical_dns_successes_total', 'critical DNS resolution success', nil, 1)
end
def report_failure(errors)
errors.each do |host, count|
send_counter('critical_dns_failures_total', 'critical DNS resolution failures', host ? { host: host } : nil, count)
end
end
def nilempty(v)
if v.nil?
nil
elsif v.respond_to?(:empty?) && v.empty?
nil
else
v
end
end
def env_srv_var(env_name)
"#{env_name}_SRV"
end
def env_srv_name(env_name)
nilempty(ENV[env_srv_var(env_name)])
end
def run(hostname_vars)
# HOSTNAME: [IP_ADDRESS, ...]
# this will usually be a single address
resolved = {}
errors = Hash.new(0)
hostname_vars.each do |var|
name = ENV[var]
HOST_RESOLVER_CACHE[var] ||= ResolverCache.new(
if (srv_name = env_srv_name(var))
SRVName.new(srv_name, SRV_PRIORITY_FILTERS[env_srv_var(var)])
else
Name.new(name)
end
)
HOST_HEALTHY_CACHE[var] ||= HealthyCache.new(HOST_RESOLVER_CACHE[var], HEALTH_CHECKS[var.to_sym])
begin
if (address = HOST_HEALTHY_CACHE[var].first_healthy)
resolved[name] = [address]
else
error("#{var}: #{name}: no address")
errors[name] += 1
end
rescue => e
error("#{var}: #{name}: #{e}")
errors[name] += 1
end
end
hosts_content = File.read(HOSTS_PATH)
hosts = Resolv::Hosts.new(HOSTS_PATH)
changed = false
resolved.each do |hostname, ips|
if hosts.getaddresses(hostname).map(&:to_s).sort != ips.sort
log("IP addresses for #{hostname} changed to #{ips}")
hosts_content = swap_address(hosts_content, hostname, ips)
changed = true
end
end
if changed
File.write(HOSTS_PATH, hosts_content)
end
rescue => e
error("DNS lookup failed: #{e}")
errors[nil] = 1
ensure
if errors == {}
report_success
else
report_failure(errors)
end
end
# If any of the host variables are an explicit IP we will not attempt to cache
# them.
all_hostname_vars = CRITICAL_HOST_ENV_VARS.select do |name|
begin
host = ENV[name]
next if nilempty(host).nil?
IPAddr.new(host)
false
rescue IPAddr::InvalidAddressError, IPAddr::AddressFamilyError
true
end
end
# Populate the SRV_PRIORITY_FILTERS for any name that has a priority present in
# the environment. If no priority thresholds are found for the name, the default
# is that no filtering based on priority will be performed.
CRITICAL_HOST_ENV_VARS.each do |v|
if (name = env_srv_name(v))
max = ENV.fetch("#{env_srv_var(v)}_PRIORITY_LE", SRV_PRIORITY_THRESHOLD_MAX).to_i
min = ENV.fetch("#{env_srv_var(v)}_PRIORITY_GE", SRV_PRIORITY_THRESHOLD_MIN).to_i
if max > SRV_PRIORITY_THRESHOLD_MAX ||
min < SRV_PRIORITY_THRESHOLD_MIN ||
min > max
raise "invalid priority threshold set for #{v}"
end
SRV_PRIORITY_FILTERS[env_srv_var(v)] = PrioFilter.new(min, max)
end
end
while true
run(all_hostname_vars)
sleep REFRESH_SECONDS
end