#!/usr/bin/python3
###############################################################################
#                                                                             #
# IPFire.org - A linux based firewall                                         #
# Copyright (C) 2016  Michael Tremer                                          #
#                                                                             #
# This program is free software: you can redistribute it and/or modify        #
# it under the terms of the GNU General Public License as published by        #
# the Free Software Foundation, either version 3 of the License, or           #
# (at your option) any later version.                                         #
#                                                                             #
# This program is distributed in the hope that it will be useful,             #
# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the               #
# GNU General Public License for more details.                                #
#                                                                             #
# You should have received a copy of the GNU General Public License           #
# along with this program.  If not, see <http://www.gnu.org/licenses/>.       #
#                                                                             #
###############################################################################

import argparse
import datetime
import daemon
import ipaddress
import logging
import logging.handlers
import os
import re
import signal
import stat
import subprocess
import tempfile

import inotify.adapters

LOCAL_TTL = 60

def setup_logging(loglevel=logging.INFO):
	log = logging.getLogger("dhcp")
	log.setLevel(loglevel)

	handler = logging.handlers.SysLogHandler(address="/dev/log", facility="daemon")
	handler.setLevel(loglevel)

	formatter = logging.Formatter("%(name)s[%(process)d]: %(message)s")
	handler.setFormatter(formatter)

	log.addHandler(handler)

	return log

log = logging.getLogger("dhcp")

def ip_address_to_reverse_pointer(address):
	parts = address.split(".")
	parts.reverse()

	return "%s.in-addr.arpa" % ".".join(parts)

def reverse_pointer_to_ip_address(rr):
	parts = rr.split(".")

	# Only take IP address part
	parts = reversed(parts[0:4])

	return ".".join(parts)

class UnboundDHCPLeasesBridge(object):
	def __init__(self, dhcp_leases_file, fix_leases_file, unbound_leases_file, hosts_file):
		self.leases_file = dhcp_leases_file
		self.fix_leases_file = fix_leases_file
		self.hosts_file = hosts_file

		self.unbound = UnboundConfigWriter(unbound_leases_file)
		self.running = False

	def run(self):
		log.info("Unbound DHCP Leases Bridge started on %s" % self.leases_file)
		self.running = True

		# Initial setup
		self.hosts = self.read_static_hosts()
		self.update_dhcp_leases()

		i = inotify.adapters.Inotify([
			self.leases_file,
			self.fix_leases_file,
			self.hosts_file,
		])

		for event in i.event_gen():
			# End if we are requested to terminate
			if not self.running:
				break

			if event is None:
				continue

			header, type_names, watch_path, filename = event

			# Update leases after leases file has been modified
			if "IN_MODIFY" in type_names:
				# Reload hosts
				if watch_path == self.hosts_file:
					self.hosts = self.read_static_hosts()

				self.update_dhcp_leases()

			# If the file is deleted, we re-add the watcher
			if "IN_IGNORED" in type_names:
				i.add_watch(watch_path)

		log.info("Unbound DHCP Leases Bridge terminated")

	def update_dhcp_leases(self):
		leases = []

		for lease in DHCPLeases(self.leases_file):
			# Don't bother with any leases that don't have a hostname
			if not lease.fqdn:
				continue

			leases.append(lease)

		for lease in FixLeases(self.fix_leases_file):
			leases.append(lease)

		# Skip any leases that also are a static host
		leases = [l for l in leases if not l.fqdn in self.hosts]

		# Remove any inactive or expired leases
		leases = [l for l in leases if l.active and not l.expired]

		# Dump leases
		if leases:
			log.debug("DHCP Leases:")
			for lease in leases:
				log.debug("  %s:" % lease.fqdn)
				log.debug("    State: %s" % lease.binding_state)
				log.debug("    Start: %s" % lease.time_starts)
				log.debug("    End  : %s" % lease.time_ends)
				if lease.expired:
					log.debug("    Expired")

		self.unbound.update_dhcp_leases(leases)

	def read_static_hosts(self):
		log.info("Reading static hosts from %s" % self.hosts_file)

		hosts = {}
		with open(self.hosts_file) as f:
			for line in f.readlines():
				line = line.rstrip()

				try:
					enabled, ipaddr, hostname, domainname, generateptr = line.split(",")
				except:
					log.warning("Could not parse line: %s" % line)
					continue

				# Skip any disabled entries
				if not enabled == "on":
					continue

				if hostname and domainname:
					fqdn = "%s.%s" % (hostname, domainname)
				elif hostname:
					fqdn = hostname
				elif domainname:
					fqdn = domainname

				try:
					hosts[fqdn].append(ipaddr)
					hosts[fqdn].sort()
				except KeyError:
					hosts[fqdn] = [ipaddr,]

		# Dump everything in the logs
		log.debug("Static hosts:")
		for name in hosts:
			log.debug("  %-20s : %s" % (name, ", ".join(hosts[name])))

		return hosts

	def terminate(self):
		self.running = False


class DHCPLeases(object):
	regex_leaseblock = re.compile(r"lease (?P<ipaddr>\d+\.\d+\.\d+\.\d+) {(?P<config>[\s\S]+?)\n}")

	def __init__(self, path):
		self.path = path

		self._leases = self._parse()

	def __iter__(self):
		return iter(self._leases)

	def _parse(self):
		log.info("Reading DHCP leases from %s" % self.path)

		leases = []

		with open(self.path) as f:
			# Read entire leases file
			data = f.read()

			for match in self.regex_leaseblock.finditer(data):
				block = match.groupdict()

				ipaddr = block.get("ipaddr")
				config = block.get("config")

				properties = self._parse_block(config)

				# Skip any abandoned leases
				if not "hardware" in properties:
					continue

				lease = Lease(ipaddr, properties)

				# Check if a lease for this Ethernet address already
				# exists in the list of known leases. If so replace
				# if with the most recent lease
				for i, l in enumerate(leases):
					if l.ipaddr == lease.ipaddr:
						leases[i] = max(lease, l)
						break

				else:
					leases.append(lease)

		return leases

	def _parse_block(self, block):
		properties = {}

		for line in block.splitlines():
			if not line:
				continue

			# Remove trailing ; from line
			if line.endswith(";"):
				line = line[:-1]

			# Invalid line if it doesn't end with ;
			else:
				continue

			# Remove any leading whitespace
			line = line.lstrip()

			# We skip all options and sets
			if line.startswith("option") or line.startswith("set"):
				continue

			# Split by first space
			key, val = line.split(" ", 1)
			properties[key] = val

		return properties


class FixLeases(object):
	cache = {}

	def __init__(self, path):
		self.path = path

		self._leases = self.cache[self.path] = self._parse()

	def __iter__(self):
		return iter(self._leases)

	def _parse(self):
		log.info("Reading fix leases from %s" % self.path)

		leases = []
		now = datetime.datetime.utcnow()

		with open(self.path) as f:
			for line in f.readlines():
				line = line.rstrip()

				try:
					hwaddr, ipaddr, enabled, a, b, c, hostname = line.split(",")
				except ValueError:
					log.warning("Could not parse line: %s" % line)
					continue

				# Skip any disabled leases
				if not enabled == "on":
					continue

				l = Lease(ipaddr, {
					"binding"         : "state active",
					"client-hostname" : hostname,
					"hardware"        : "ethernet %s" % hwaddr,
					"starts"          : now.strftime("%w %Y/%m/%d %H:%M:%S"),
					"ends"            : "never",
				})
				leases.append(l)

		# Try finding any deleted leases
		for lease in self.cache.get(self.path, []):
			if lease in leases:
				continue

			# Free the deleted lease
			lease.free()
			leases.append(lease)

		return leases


class Lease(object):
	def __init__(self, ipaddr, properties):
		self.ipaddr = ipaddr
		self._properties = properties

	def __repr__(self):
		return "<%s %s for %s (%s)>" % (self.__class__.__name__,
			self.ipaddr, self.hwaddr, self.hostname)

	def __eq__(self, other):
		return self.ipaddr == other.ipaddr and self.hwaddr == other.hwaddr

	def __gt__(self, other):
		if not self.ipaddr == other.ipaddr:
			return

		if not self.hwaddr == other.hwaddr:
			return

		return self.time_starts > other.time_starts

	@property
	def binding_state(self):
		state = self._properties.get("binding")

		if state:
			state = state.split(" ", 1)
			return state[1]

	def free(self):
		self._properties.update({
			"binding" : "state free",
		})

	@property
	def active(self):
		return self.binding_state == "active"

	@property
	def hwaddr(self):
		hardware = self._properties.get("hardware")

		if not hardware:
			return

		ethernet, address = hardware.split(" ", 1)

		return address

	@property
	def hostname(self):
		hostname = self._properties.get("client-hostname")

		if hostname is None:
			return

		# Remove any ""
		hostname = hostname.replace("\"", "")

		# Only return valid hostnames
		m = re.match(r"^[A-Z0-9\-]{1,63}$", hostname, re.I)
		if m:
			return hostname

	@property
	def domain(self):
		# Load ethernet settings
		ethernet_settings = self.read_settings("/var/ipfire/ethernet/settings")

		# Load DHCP settings
		dhcp_settings = self.read_settings("/var/ipfire/dhcp/settings")

		subnets = {}
		for zone in ("GREEN", "BLUE"):
			if not dhcp_settings.get("ENABLE_%s" % zone) == "on":
				continue

			netaddr = ethernet_settings.get("%s_NETADDRESS" % zone)
			submask = ethernet_settings.get("%s_NETMASK" % zone)

			subnet = ipaddress.ip_network("%s/%s" % (netaddr, submask))
			domain = dhcp_settings.get("DOMAIN_NAME_%s" % zone)

			subnets[subnet] = domain

		address = ipaddress.ip_address(self.ipaddr)

		for subnet in subnets:
			if address in subnet:
				return subnets[subnet]

		# Fall back to localdomain if no match could be found
		return "localdomain"

	@staticmethod
	def read_settings(filename):
		settings = {}

		with open(filename) as f:
			for line in f.readlines():
				# Remove line-breaks
				line = line.rstrip()

				k, v = line.split("=", 1)
				settings[k] = v

		return settings

	@property
	def fqdn(self):
		if self.hostname:
			return "%s.%s" % (self.hostname, self.domain)

	@staticmethod
	def _parse_time(s):
		return datetime.datetime.strptime(s, "%w %Y/%m/%d %H:%M:%S")

	@property
	def time_starts(self):
		starts = self._properties.get("starts")

		if starts:
			return self._parse_time(starts)

	@property
	def time_ends(self):
		ends = self._properties.get("ends")

		if not ends or ends == "never":
			return

		return self._parse_time(ends)

	@property
	def expired(self):
		if not self.time_ends:
			return self.time_starts > datetime.datetime.utcnow()

		return self.time_starts > datetime.datetime.utcnow() > self.time_ends

	@property
	def rrset(self):
		# If the lease does not have a valid FQDN, we cannot create any RRs
		if self.fqdn is None:
			return []

		return [
			# Forward record
			(self.fqdn, "%s" % LOCAL_TTL, "IN A", self.ipaddr),

			# Reverse record
			(ip_address_to_reverse_pointer(self.ipaddr), "%s" % LOCAL_TTL,
				"IN PTR", self.fqdn),
		]


class UnboundConfigWriter(object):
	def __init__(self, path):
		self.path = path

		self._cached_leases = []

	def update_dhcp_leases(self, leases):
		# Find any leases that have expired or do not exist any more
		# but are still in the unbound local data
		removed_leases = [l for l in self._cached_leases if not l in leases]

		# Find any leases that have been added
		new_leases = [l for l in leases if l not in self._cached_leases]

		# End here if nothing has changed
		if not new_leases and not removed_leases:
			return

		# Write out all leases
		self.write_dhcp_leases(leases)

		# Update unbound about changes
		for l in removed_leases:
			try:
				for name, ttl, type, content in l.rrset:
					log.debug("Removing records for %s" % name)
					self._control("local_data_remove", name)

			# If the lease cannot be removed we will try the next one
			except:
				continue

			# If the removal was successful, we will remove it from the cache
			else:
				self._cached_leases.remove(l)

		for l in new_leases:
			try:
				for rr in l.rrset:
					log.debug("Adding new record %s" % " ".join(rr))
					self._control("local_data", *rr)

			# If the lease cannot be added we will try the next one
			except:
				continue

			# Add lease to cache when successfully added
			else:
				self._cached_leases.append(l)

	def write_dhcp_leases(self, leases):
		with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
			filename = f.name

			for l in leases:
				for rr in l.rrset:
					f.write("local-data: \"%s\"\n" % " ".join(rr))

			# Make file readable for everyone
			os.fchmod(f.fileno(), stat.S_IRUSR|stat.S_IWUSR|stat.S_IRGRP|stat.S_IROTH)

		os.rename(filename, self.path)

	def _control(self, *args):
		command = ["unbound-control"]
		command.extend(args)

		try:
			subprocess.check_output(command)

		# Log any errors
		except subprocess.CalledProcessError as e:
			log.critical("Could not run %s, error code: %s: %s" % (
				" ".join(command), e.returncode, e.output))

			raise


if __name__ == "__main__":
	parser = argparse.ArgumentParser(description="Bridge for DHCP Leases and Unbound DNS")

	# Daemon Stuff
	parser.add_argument("--daemon", "-d", action="store_true",
		help="Launch as daemon in background")
	parser.add_argument("--verbose", "-v", action="count", help="Be more verbose")

	# Paths
	parser.add_argument("--dhcp-leases", default="/var/state/dhcp/dhcpd.leases",
		metavar="PATH", help="Path to the DHCPd leases file")
	parser.add_argument("--unbound-leases", default="/etc/unbound/dhcp-leases.conf",
		metavar="PATH", help="Path to the unbound configuration file")
	parser.add_argument("--fix-leases", default="/var/ipfire/dhcp/fixleases",
		metavar="PATH", help="Path to the fix leases file")
	parser.add_argument("--hosts", default="/var/ipfire/main/hosts",
		metavar="PATH", help="Path to static hosts file")

	# Parse command line arguments
	args = parser.parse_args()

	# Setup logging
	loglevel = logging.WARN

	if args.verbose:
		if args.verbose == 1:
			loglevel = logging.INFO
		elif args.verbose >= 2:
			loglevel = logging.DEBUG

	setup_logging(loglevel)

	bridge = UnboundDHCPLeasesBridge(args.dhcp_leases, args.fix_leases,
		args.unbound_leases, args.hosts)

	ctx = daemon.DaemonContext(detach_process=args.daemon)
	ctx.signal_map = {
		signal.SIGHUP  : bridge.update_dhcp_leases,
		signal.SIGTERM : bridge.terminate,
	}

	with ctx:
		bridge.run()
