#!/usr/bin/env ruby
# frozen_string_literal: true

require 'optparse'

# cache_critical_dns is intended to be used for performing DNS lookups against
# the services critical for Discourse to run - PostgreSQL and Redis. The
# cache mechanism is storing the resolved host addresses in /etc/hosts. This can
# protect against DNS lookup failures _after_ the resolved addresses have been
# written to /etc/hosts at least once. Example lookup failures may be NXDOMAIN
# or SERVFAIL responses from DNS requests.
#
# Before a resolved address is cached, a protocol-aware healthcheck is
# performed against the host with the authentication details found for that
# service in the process environment. Any hosts that fail the healthcheck will
# never be cached.
#
# The list of environment variables that cache_critical_dns will read for
# critical service hostnames can be extended at process execution time by
# specifying environment variable names within the
# DISCOURSE_DNS_CACHE_ADDITIONAL_SERVICE_NAMES environment variable. This is a
# comma-delimited string of extra environment variables to be added to the list
# defined in the static CRITICAL_HOST_ENV_VARS hash.
#
# DISCOURSE_DNS_CACHE_ADDITIONAL_SERVICE_NAMES serves as a kind of lookup table
# for extra services for caching. Any environment variable names within this
# list are treated with the same rules as the DISCOURSE_DB_HOST (and co.)
# variables, as described below.
#
# This is as far as you need to read if you are using CNAME or A records for
# your services.
#
# The extended behaviour of cache_critical_dns is to add SRV RR lookup support
# for DNS Service Discovery (see http://www.dns-sd.org/). For any of the critical
# service environment variables (see CRITICAL_HOST_ENV_VARS), if a corresponding
# SRV environment variable is found (suffixed with _SRV), cache_critical_dns
# will assume that SRV RRs should exist and will begin to lookup SRV targets
# for resolving the host addresses for caching, and ignore the original service
# name variable. Healthy target addresses are cached against the original service
# environment variable, as the Discourse application expects. For example a
# healthy target found from the SRV lookup for DISCOURSE_DB_HOST_SRV will be
# cached against the name specified by the DISCOURSE_DB_HOST.
#
# Example environment variables for SRV lookups are:
#   DISCOURSE_DB_HOST_SRV
#   DISCOURSE_DB_REPLICA_HOST_SRV
#   DISCOURSE_REDIS_HOST_SRV
#   DISCOURSE_REDIS_REPLICA_HOST_SRV
#
# cache_critical_dns will keep an internal record of all names resolved within
# the last 30 minutes. This internal cache is to give a priority order to new
# SRV targets that have appeared during the program runtime (SRV records
# contain zero or more targets, which may appear or disappear at any time).
# If a target has not been seen for more than 30 minutes it will be evicted from
# the internal cache. The internal cache of healthy targets is a fallback for
# when errors occur during DNS lookups.
#
# Targets that are resolved and found healthy usually find themselves in the host
# cache, depending on if they are the newest or not. Targets that are resolved
# but never found healthy will never be cached or even stored in the internal
# cache. Targets that _begin_ healthy and are cached, and _become_ unhealthy
# will only be removed from the host cache if another newer target is resolved
# and found to be healthy. This is because we never write a resolved target to
# the hosts cache unless it is both the newest and healthy. We assume that
# cached hosts are healthy until they are superseded by a newer healthy target.
#
# An SRV RR specifies a priority value for each of the SRV targets that
# are present, ranging from 0 - 65535. When caching SRV records we may want to
# filter out any targets above or below a particular threshold. The LE (less
# than or equal to) and GE (greater than or equal to) environment variables
# (suffixed with _PRIORITY_LE or PRIORITY_GE) for a corresponding SRV variable
# will ignore targets above or below the threshold, respectively.
#
# This mechanism may be used for SRV RRs that share a single name and utilise
# the priority value for signalling to cache_critical_dns which targets are
# relevant to a given name. Any target found outside of the threshold is ignored.
# The host and internal caching behavior are otherwise the same.
#
# Example environment variables for SRV priority thresholds are:
#   DISCOURSE_DB_HOST_SRV_PRIORITY_LE
#   DISCOURSE_DB_REPLICA_HOST_SRV_PRIORITY_GE

# Specifying this env var ensures ruby can load gems installed via the Discourse
# project Gemfile (e.g. pg, redis).
ENV['BUNDLE_GEMFILE'] ||= '/var/www/discourse/Gemfile'
require 'bundler/setup'

require 'ipaddr'
require 'pg'
require 'redis'
require 'resolv'
require 'socket'
require 'time'

CRITICAL_HOST_ENV_VARS = %w{
  DISCOURSE_DB_HOST
  DISCOURSE_DB_REPLICA_HOST
  DISCOURSE_REDIS_HOST
  DISCOURSE_REDIS_SLAVE_HOST
  DISCOURSE_REDIS_REPLICA_HOST
  DISCOURSE_MESSAGE_BUS_REDIS_HOST
  DISCOURSE_MESSAGE_BUS_REDIS_REPLICA_HOST
}.union(
  ENV.fetch('DISCOURSE_DNS_CACHE_ADDITIONAL_SERVICE_NAMES', '')
    .split(',')
    .map(&:strip)
)

DEFAULT_DB_NAME = "discourse"
DEFAULT_DB_PORT = 5432
DEFAULT_REDIS_PORT = 6379

HOST_RESOLVER_CACHE = {}
HOST_HEALTHY_CACHE = {}
HOSTS_PATH = ENV['DISCOURSE_DNS_CACHE_HOSTS_FILE'] || "/etc/hosts"

PrioFilter = Struct.new(:min, :max) do
  # min and max must be integers and relate to the minimum or maximum accepted
  # priority of an SRV RR target.
  # The return value from within_threshold? indicates if the priority is less
  # than or equal to the upper threshold, or greater than or equal to the
  # lower threshold.
  def within_threshold?(p)
    p >= min && p <= max
  end
end
SRV_PRIORITY_THRESHOLD_MIN = 0
SRV_PRIORITY_THRESHOLD_MAX = 65535
SRV_PRIORITY_FILTERS = Hash.new(
  PrioFilter.new(SRV_PRIORITY_THRESHOLD_MIN, SRV_PRIORITY_THRESHOLD_MAX))

REFRESH_SECONDS = 30

module DNSClient
  def dns_client_with_timeout
    Resolv::DNS.open do |dns_client|
      dns_client.timeouts = 2
      yield dns_client
    end
  end
end

class Name
  include DNSClient

  def initialize(hostname)
    @name = hostname
  end

  def resolve
    dns_client_with_timeout do |dns_client|
      [].tap do |addresses|
        addresses.concat(dns_client.getresources(@name, Resolv::DNS::Resource::IN::A).map(&:address))
        addresses.concat(dns_client.getresources(@name, Resolv::DNS::Resource::IN::AAAA).map(&:address))
      end.map(&:to_s)
    end
  end
end

class SRVName
  include DNSClient

  def initialize(srv_hostname, prio_filter)
    @name = srv_hostname
    @prio_filter = prio_filter
  end

  def resolve
    dns_client_with_timeout do |dns_client|
      [].tap do |addresses|
        targets = dns_client.getresources(@name, Resolv::DNS::Resource::IN::SRV)
        targets.delete_if { |t| !@prio_filter.within_threshold?(t.priority) }
        addresses.concat(targets.map { |t| Name.new(t.target.to_s).resolve }.flatten)
      end
    end
  end
end

CacheMeta = Struct.new(:first_seen, :last_seen)

class EmptyCacheError < StandardError; end

class ResolverCache
  def initialize(name)
    # instance of Name|SRVName
    @name = name

    # {IPv4/IPv6 address: CacheMeta}
    @cached = {}
  end

  # Returns a list of resolved addresses ordered by first seen time.  Most
  # recently seen address is first.
  # If an exception occurs during DNS resolution we return whatever addresses
  # are cached.
  # Addresses last seen more than 30 minutes ago are evicted from the cache.
  # Raises EmptyCacheError if the cache is empty after DNS resolution and cache
  # eviction is performed.
  def resolve
    begin
      @name.resolve.each do |address|
        if @cached[address]
          @cached[address].last_seen = Time.now.utc
        else
          @cached[address] = CacheMeta.new(Time.now.utc, Time.now.utc)
        end
      end
    rescue Resolv::ResolvError, Resolv::ResolvTimeout
    end

    @cached = @cached.delete_if { |_, meta| Time.now.utc - 30 * 60 > meta.last_seen }
    if @cached.empty?
      raise EmptyCacheError, "DNS resolver found no usable addresses"
    end
    @cached.sort_by { |_, meta| meta.first_seen }.reverse.map(&:first)
  end
end

class HealthyCache
  def initialize(resolver_cache, check)
    @resolver_cache = resolver_cache  # instance of ResolverCache
    @check = check  # lambda function to perform for health checks
    @cached = nil  # a single IP address that was most recently found to be healthy
  end

  # Returns the first healthy server found in the list of resolved addresses.
  # Returns the last known healthy server if all servers disappear from the
  # DNS.
  # Raises EmptyCacheError if no healthy servers have been cached.
  def first_healthy
    begin
      addresses = @resolver_cache.resolve
    rescue EmptyCacheError
      return @cached if @cached
      raise
    end

    if (address = addresses.lazy.select { |addr| @check.call(addr) }.first)
      @cached = address
    end
    if @cached.nil?
      raise EmptyCacheError, "no healthy servers found amongst #{addresses}"
    end

    @cached
  end
end

def redis_healthcheck(host:, password:, port: DEFAULT_REDIS_PORT)
  client_opts = {
    host: host,
    password: password,
    port: port,
    timeout: 1,
  }
  if !nilempty(ENV['DISCOURSE_REDIS_USE_SSL']).nil? then
    client_opts[:ssl] = true
    client_opts[:ssl_params] = {
      verify_mode: OpenSSL::SSL::VERIFY_NONE,
    }
  end
  client = Redis.new(**client_opts)
  response = client.ping
  response == "PONG"
rescue
  false
ensure
  client.close if client
end

def postgres_healthcheck(host:, user:, password:, dbname:, port: DEFAULT_DB_PORT)
  client = PG::Connection.new(
    host: host,
    user: user,
    password: password,
    dbname: dbname,
    port: port,
    connect_timeout: 2,  # minimum
  )
  client.exec(';').none?
rescue
  false
ensure
  client.close if client
end

HEALTH_CHECKS = Hash.new(
  # unknown keys (like services defined at runtime) are assumed to be healthy
  lambda { |addr| true }
).merge!({
  "DISCOURSE_DB_HOST": lambda { |addr|
    postgres_healthcheck(
      host: addr,
      user: ENV["DISCOURSE_DB_USERNAME"] || DEFAULT_DB_NAME,
      dbname: ENV["DISCOURSE_DB_NAME"] || DEFAULT_DB_NAME,
      port: ENV["DISCOURSE_DB_PORT"] || DEFAULT_DB_PORT,
      password: ENV["DISCOURSE_DB_PASSWORD"])},
  "DISCOURSE_DB_REPLICA_HOST": lambda { |addr|
    postgres_healthcheck(
      host: addr,
      user: ENV["DISCOURSE_DB_USERNAME"] || DEFAULT_DB_NAME,
      dbname: ENV["DISCOURSE_DB_NAME"] || DEFAULT_DB_NAME,
      port: ENV["DISCOURSE_DB_REPLICA_PORT"] || DEFAULT_DB_PORT,
      password: ENV["DISCOURSE_DB_PASSWORD"])},
  "DISCOURSE_REDIS_HOST": lambda { |addr|
    redis_healthcheck(
      host: addr,
      password: ENV["DISCOURSE_REDIS_PASSWORD"])},
  "DISCOURSE_REDIS_REPLICA_HOST": lambda { |addr|
    redis_healthcheck(
      host: addr,
      password: ENV["DISCOURSE_REDIS_PASSWORD"])},
  "DISCOURSE_MESSAGE_BUS_REDIS_HOST": lambda { |addr|
    redis_healthcheck(
      host: addr,
      port: env_as_int("DISCOURSE_MESSAGE_BUS_REDIS_PORT", DEFAULT_REDIS_PORT),
      password: ENV["DISCOURSE_MESSAGE_BUS_REDIS_PASSWORD"])},
  "DISCOURSE_MESSAGE_BUS_REDIS_REPLICA_HOST": lambda { |addr|
    redis_healthcheck(
      host: addr,
      port: env_as_int("DISCOURSE_MESSAGE_BUS_REDIS_REPLICA_PORT", DEFAULT_REDIS_PORT),
      password: ENV["DISCOURSE_MESSAGE_BUS_REDIS_PASSWORD"])},
})

def log(msg)
  STDERR.puts "#{Time.now.utc.iso8601}: #{msg}"
end

def error(msg)
  log(msg)
end

def swap_address(hosts, name, ips)
  new_file = []

  hosts.split("\n").each do |line|
    line.strip!
    if line[0] != '#'
      _, hostname = line.split(/\s+/)
      next if hostname == name
    end
    new_file << line
    new_file << "\n"
  end

  ips.each do |ip|
    new_file << "#{ip} #{name} # AUTO GENERATED: #{Time.now.utc.iso8601}\n"
  end

  new_file.join
end

def send_counter(name, description, labels, value)
  host = "localhost"
  port = env_as_int("DISCOURSE_PROMETHEUS_COLLECTOR_PORT", 9405)

  if labels
    labels = labels.map do |k, v|
      "\"#{k}\": \"#{v}\""
    end.join(",")
  else
    labels = ""
  end

  json = <<~JSON
  {
    "_type": "Custom",
    "type": "Counter",
    "name": "#{name}",
    "description": "#{description}",
    "labels": { #{labels} },
    "value": #{value}
  }
  JSON

  payload = +"POST /send-metrics HTTP/1.1\r\n"
  payload << "Host: #{host}\r\n"
  payload << "Connection: Close\r\n"
  payload << "Content-Type: application/json\r\n"
  payload << "Content-Length: #{json.bytesize}\r\n"
  payload << "\r\n"
  payload << json

  socket = TCPSocket.new host, port
  socket.write payload
  socket.flush
  result = socket.read
  first_line = result.split("\n")[0]
  if first_line.strip != "HTTP/1.1 200 OK"
    error("Failed to report metric #{result}")
  end
  socket.close
rescue => e
  error("Failed to send metric to Prometheus #{e}")
end

def report_success
  send_counter('critical_dns_successes_total', 'critical DNS resolution success', nil, 1)
end

def report_failure(errors)
  errors.each do |host, count|
    send_counter('critical_dns_failures_total', 'critical DNS resolution failures', host ? { host: host } : nil, count)
  end
end

def nilempty(v)
  if v.nil?
    nil
  elsif v.respond_to?(:empty?) && v.empty?
    nil
  else
    v
  end
end

def env_srv_var(env_name)
  "#{env_name}_SRV"
end

def env_srv_name(env_name)
  nilempty(ENV[env_srv_var(env_name)])
end

def env_as_int(env_name, default = nil)
  val = ENV.fetch(env_name, default)
  if nilempty(val).nil?
    return val
  end
  val.to_i
end

def run_and_report(hostname_vars)
  errors = run(hostname_vars)
  if errors.empty?
    report_success
  else
    report_failure(errors)
  end
end

def run(hostname_vars)
  resolved = {}
  errors = Hash.new(0)

  hostname_vars.each do |var|
    name = ENV[var]
    HOST_RESOLVER_CACHE[var] ||= ResolverCache.new(
      if (srv_name = env_srv_name(var))
        SRVName.new(srv_name, SRV_PRIORITY_FILTERS[env_srv_var(var)])
      else
        Name.new(name)
      end
    )

    HOST_HEALTHY_CACHE[var] ||= HealthyCache.new(HOST_RESOLVER_CACHE[var], HEALTH_CHECKS[var.to_sym])

    begin
      address = HOST_HEALTHY_CACHE[var].first_healthy
      resolved[name] = [address]
    rescue EmptyCacheError => e
      error("#{var}: #{name}: #{e}")
      errors[name] += 1
    end
  end

  hosts_content = File.read(HOSTS_PATH)
  hosts = Resolv::Hosts.new(HOSTS_PATH)

  changed = false
  resolved.each do |hostname, ips|
    if hosts.getaddresses(hostname).map(&:to_s).sort != ips.sort
      log("IP addresses for #{hostname} changed to #{ips}")
      hosts_content = swap_address(hosts_content, hostname, ips)
      changed = true
    end
  end

  if changed
    File.write(HOSTS_PATH, hosts_content)
  end
rescue => e
  error("unhandled exception during run: #{e}")
  errors[nil] = 1
ensure
  return errors
end

# If any of the host variables are an explicit IP we will not attempt to cache
# them.
all_hostname_vars = CRITICAL_HOST_ENV_VARS.select do |name|
  host = ENV[name]
  if nilempty(host).nil?
    # don't attempt to cache host vars that aren't present in the environment
    false
  else
    begin
      IPAddr.new(host)
      # host is an IPv4 / IPv6 address
      false
    rescue IPAddr::InvalidAddressError, IPAddr::AddressFamilyError
      true
    end
  end
end

# Populate the SRV_PRIORITY_FILTERS for any name that has a priority present in
# the environment. If no priority thresholds are found for the name, the default
# is that no filtering based on priority will be performed.
CRITICAL_HOST_ENV_VARS.each do |v|
  if (name = env_srv_name(v))
    max = env_as_int("#{env_srv_var(v)}_PRIORITY_LE", SRV_PRIORITY_THRESHOLD_MAX)
    min = env_as_int("#{env_srv_var(v)}_PRIORITY_GE", SRV_PRIORITY_THRESHOLD_MIN)
    if max > SRV_PRIORITY_THRESHOLD_MAX ||
        min < SRV_PRIORITY_THRESHOLD_MIN ||
        min > max
      raise "invalid priority threshold set for #{v}"
    end

    SRV_PRIORITY_FILTERS[env_srv_var(v)] = PrioFilter.new(min, max)
  end
end

options = {
  once: false,
}
OptionParser.new do |opts|
  opts.on("--once", "run script once instead of indefinitely") do |o|
    options[:once] = true
  end
end.parse!

if options[:once]
  errors = run(all_hostname_vars)
  exit errors.empty? ? 0 : 1
end

while true
  run_and_report(all_hostname_vars)
  sleep REFRESH_SECONDS
end