#!/opt/imunify360/venv/bin/python3
"""
The watchdog script that checks the webshield and restarts it if error found
"""

import json
import logging
import logging.handlers
import os
import requests
import subprocess
import sys
import uuid
from time import ctime, time, sleep
from argparse import ArgumentParser
from traceback import format_exception
from types import TracebackType


import sentry_sdk
from sentry_sdk import configure_scope


log = logging.getLogger(__name__)


def _syslog_exception(etype: type[Exception], evalue: Exception, etb: TracebackType):
    try:
        log.error('%s', ''.join(format_exception(etype, evalue, etb)).rstrip())
    except BaseException:
        try:
            sys.stderr.write('Error reporting error\n')
            sys.stderr.flush()
        except BaseException:
            pass


def _setup_logging(debug: bool, to_stderr: bool):
    logging.raiseExceptions = False
    logging.captureWarnings(True)
    sys.excepthook = _syslog_exception
    if to_stderr:
        handler = logging.StreamHandler()
        msgfmt = '%(asctime)s %(levelname)s: %(message)s'
    else:
        handler = logging.handlers.SysLogHandler('/dev/log')
        msgfmt = 'imunify360-watchdog: %(levelname)s: %(message)s'
    handler.setFormatter(logging.Formatter(msgfmt))
    logging.getLogger().addHandler(handler)
    logging.getLogger().setLevel(logging.DEBUG if debug else logging.INFO)


class Watchdog:
    request_timeout = 4
    subprocess_timeout = 30
    user_agent = 'Webshield-watchdog-agent'
    sentry_dsn_path = '/usr/share/imunify360-webshield/sentry'
    package_name = 'imunify360-webshield-bundle'
    license_path = '/var/imunify360/license.json'
    flag_path = '/var/imunify360/webshield_broken'
    wafd_sock_path = '/var/run/imunify360/libiplists-daemon.sock'
    wafd_check_binary = 'i360_wafd_check'

    def __init__(self, status):
        platform = status.get('platform', {})
        self.should_proxy_be_running = platform.get('should_proxy_be_running', False)
        units = status.get('units', {})
        self.is_running = units.get('imunify360-webshield', {}).get('active', False)
        self.sentry_dsn = self._get_dsn()
        self.__show_flag_status()

    def __show_flag_status(self):
        if (ts := self._get_flag_timestamp()) is None:
            log.info("Flag '%s' is not set", self.flag_path)
        else:
            log.info("Flag '%s' is set on %s, now is %s", self.flag_path, ctime(ts), ctime(time()))

    @classmethod
    def _get_server_id(cls):
        try:
            with open(cls.license_path) as f:
                data = json.load(f)
        except Exception:
            return 'none'
        return data.get('id', 'none')

    @classmethod
    def _get_dsn(cls):
        try:
            with open(cls.sentry_dsn_path) as f:
                return f.read().strip()
        except Exception:
            return

    def _init_sentry(self):
        sentry_sdk.init(dsn=self.sentry_dsn, release=self._imunify360_version())
        with configure_scope() as scope:
            scope.user = {'id': self._get_server_id()}

    def _make_http_request(self, i: int, port: int):
        url = f'http://0.0.0.0:{port}/selfcheck?uuid={uuid.uuid4()}'
        curr_timeout = self.request_timeout * i
        try:
            requests.get(url, headers={'User-Agent': self.user_agent}, allow_redirects=False, timeout=curr_timeout)
        except Exception as error:
            log.debug("Webshield is not accessible on port %d, reason: '%s'", port, error)
            return False
        return True

    def _check_http_request(self):
        for port in (52224, 52228):  # Different ports for different modes, see gen_ports_conf.py.
            for i in range(1, 4):
                log.info('Checking webshield status: attempt %d, port %d', i, port)
                if self._make_http_request(i, port):
                    return True
                sleep(2)
        return False

    def _reload_units(self):
        # WARN: Never touch units directly!
        # Use imunify360-wsctl reload instead.
        log.info('Reloading units')
        subprocess.run(['imunify360-wsctl', 'reload'], check=False)

    @classmethod
    def collect_output(cls, cmd):
        try:
            cp = subprocess.run(
                cmd,
                stdin=subprocess.DEVNULL,
                stdout=subprocess.PIPE,
                stderr=subprocess.DEVNULL,
                timeout=cls.subprocess_timeout,
            )
        except (OSError, subprocess.TimeoutExpired):
            return ''
        if cp.returncode != 0:
            return ''
        return cp.stdout.decode()

    @classmethod
    def _get_rpm_version(cls):
        cmd = ['rpm', '-q', '--queryformat=%{VERSION}-%{RELEASE}', cls.package_name]
        return cls.collect_output(cmd)

    @classmethod
    def _get_dpkg_version(cls):
        cmd = ['dpkg', '--status', cls.package_name]
        out = cls.collect_output(cmd)
        if not out:
            return
        for line in out.splitlines():
            if line.startswith('Version:'):
                return line.strip().split()[1]

    @classmethod
    def _imunify360_version(cls):
        version = cls._get_rpm_version()
        if not version:
            version = cls._get_dpkg_version()
        return version

    @classmethod
    def _get_flag_timestamp(cls):
        try:
            with open(cls.flag_path) as o:
                return int(o.read().strip())
        except Exception:
            pass

    @classmethod
    def _put_flag_timestamp(cls):
        tms = int(time())
        try:
            with open(cls.flag_path, 'w') as w:
                w.write(f'{tms}')
        except Exception:
            pass

    @classmethod
    def _set_flag(cls):
        tms = cls._get_flag_timestamp()
        if not tms or time() - tms >= 60:  # 1 minute
            cls._put_flag_timestamp()
            return True
        return False

    @classmethod
    def _remove_flag_if_exists(cls):
        if not os.path.exists(cls.flag_path):
            return False
        try:
            os.unlink(cls.flag_path)
            return True
        except Exception:
            pass

    def ensure_webshield(self):
        if self.is_running:
            log.info('Webshield is running, checking if it is accessible')
            if self._check_http_request():
                log.info('Webshield is accessible')
                if self._remove_flag_if_exists():
                    log.info('Webshield has been resumed')
            else:
                log.info('Webshield is inaccessible')
                if self._set_flag():
                    log.error('Webshield has been marked as inaccessible')
                    self._reload_units()
        else:
            log.error('Webshield is not running. Restarting.')
            self._reload_units()

    def check_wafd(self):
        """
        The wafd is expected to be running by all means
        because not only the webshield is dependent on it.
        We call small wafd utility to check wafd is responsive.
        Otherwise we'll try to restart wafd.
        """
        log.info('Checking wafd')
        check_ip = 'fd3a:4d0d:a778:4bfa:4760:2825:b8bc:52e1'
        cmd = [self.wafd_check_binary, '-path', self.wafd_sock_path, check_ip]
        try:
            p = subprocess.run(cmd, check=True, timeout=2, capture_output=True)
        except Exception as error:
            # On any exception we just fall through to restart wafd
            log.info("Failed to check wafd, reason: '%s'", error)
        else:
            out = p.stdout.decode('utf-8')
            log.debug("Wafd check output: '%s'.", out)
            # status 0: ALLOW; status 4: ENGINE_OFF
            if 'Response' in out and ('status: 0' in out or 'status: 4' in out):
                # We got a sensible response, so wafd is running and responsible.
                # Nothing to do, return
                log.info('Wafd is running and responsible')
                return
        log.error('Wafd is not responsible, trying to restart it')
        try:
            subprocess.run(['systemctl', 'restart', 'imunify360-wafd'], check=True)
        except Exception as error:
            log.error("Failed to restart wafd, reason: '%s'", error)


def main():
    parser = ArgumentParser()
    parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging')
    parser.add_argument('-s', '--stderr', action='store_true', help='Log to stderr instead of syslog')
    args = parser.parse_args()
    _setup_logging(args.debug, args.stderr)

    out = Watchdog.collect_output(['imunify360-wsctl', 'status'])
    log.debug("Wafd status: '%s'.", out)
    try:
        status = json.loads(out) if out else {}
    except json.JSONDecodeError:
        status = {}
    w = Watchdog(status)
    w.check_wafd()
    if w.should_proxy_be_running:
        w.ensure_webshield()
    else:
        log.info('Webshield is not expected, skipping')


if __name__ == '__main__':
    main()
