"""
Copyright 2000-2023 Citrix Systems, Inc. All rights reserved.
This software and documentation contain valuable trade secrets
and proprietary property belonging to Citrix Systems, Inc.
None of this software and documentation may be copied,
duplicated or disclosed without the express written permission
of Citrix Systems, Inc.
"""

import time
import traceback
from rainman_core.common.logger import RainLogger
from rainman_core.common import rain
from rainman_core.common.rain_globals import init_rain_globals
from rainman_core.common.exception import config_file_error
from rainman_core.common.ipc_lock import get_conf_lock


log = RainLogger.getLogger()
config, local, cloud = init_rain_globals()


def wrap_check(func):
    def wrapper(*args, **kwargs) -> int:
        try:
            return 0xF & func(*args, **kwargs)
        except Exception as err:
            log.warning("Exception seen in %s : %s", str(func), err)
        return 0xF
    return wrapper


class AutoHeal:
    '''
    Class hosting Rainman self heal logic and functions specifics to
    repair.

    call run method to start autoheal from rain_scale process. run will
    call _check_* methods for operation and log summary.

    all repair has to happen under functions with name _check_*.
    these functions must return non negative int value such that:
        0 :  no issues found                                [0x0]
        1 :  issue found and fixed                          [0x1]
        n :  found issue, fix attempt failed. (2<=n<=13)    [0x2:0xD]
        14:  found issue but cannot fix via autoheal.       [0xE]
        15:  Exception seen via decorator except_handler    [0xF]

    _check_*(...) methods rely on returning unique value per exit path for
    debugging rather than logging text messages.
    '''
    __instance = None

    def __new__(cls, *args, **kwargs):
        if not cls.__instance:
            cls.__instance = super().__new__(cls, *args, **kwargs)
        return cls.__instance

    def __init__(self) -> None:
        self.latest_start_ts = 0
        self.latest_run_ts = 0
        self.retcode = []
        self.prev_retcode = []
        self._retcode_str = ""
        self._prev_retcode_str = ""
        self.run_stats = {}
        self.prev_run_stats = {}
        self.ns_running_conf = ""
        self.ipc_conf_lock = get_conf_lock()

    def get_retcode_str(self, previous=False) -> str:
        retcode = self.prev_retcode if previous else self.retcode
        if not retcode:
            return ''
        return ",".join(
            ["".join([f"{el:x}" for el in line]) for line in retcode])

    @wrap_check
    def _check_cloud_asg(self, asg_name: str) -> int:
        try:
            if cloud.get_group_info(asg_name):
                return 0
        except BaseException:
            # not modifying anything in cloud, log critical alert.
            log.warning("autoheal: group %s not found", asg_name)
            return 0xE
        return 2

    @wrap_check
    def _check_asg(self, asg_name: str, sg_name: str, cp_name="") -> int:
        try:
            asg = rain.group().get(config, asg_name)
            if asg:
                return 0
            return 2
        except BaseException:
            try:
                sg = rain.service().get(config, sg_name)
                if sg:
                    log.warning(
                        "autoheal: cp/sg exist without asg %s", asg_name)
            except BaseException:
                return 3
        if cp_name:
            log.critical(
                "BROKEN CLOUD PROFILE '%s'. please remove and add cloudprofile again.",
                cp_name)
        else:
            log.critical(
                "BROKEN SERVICEGROUP '%s', please remove and add servicegroup again.",
                sg_name)
        return 0xE

    @wrap_check
    def _check_ns_sg(self, sg_name: str) -> int:
        if local.get_group(sg_name):
            return 0
        try:
            sg = rain.service().get(config, sg_name)
            if not sg:
                return 2
        except BaseException:
            return 3
        try:
            local.add_group(sg)
        except BaseException:
            return 4
        log.info("autoheal: created missing servicegroup %s", sg_name)
        return 1

    @wrap_check
    def _check_ns_lb(self, lb_name: str) -> int:
        if local.get_lb_exist(lb_name):
            return 0
        try:
            lb = rain.loadbalancer().get(config, lb_name)
            if not lb:
                return 2
        except BaseException:
            return 3
        try:
            local.add_lb(lb)
        except BaseException:
            return 4
        log.info("autoheal: created missing loadbalancer %s", lb_name)
        return 1

    @wrap_check
    def _check_ns_binding(self, lb_name: str, sg_name: str) -> int:
        if local.get_lb_sg_b_exist(lb_name, sg_name):
            return 0
        try:
            sg = rain.service().get(config, sg_name)
        except BaseException:
            return 2
        try:
            lb = rain.loadbalancer().get(config, lb_name)
        except BaseException:
            return 3
        try:
            local.add_group_to_lb(sg, lb)
        except BaseException:
            return 4
        log.info("autoheal: created missing binding %s:%s", sg_name, lb_name)
        return 1

    @wrap_check
    def _check_ns_conf_cp(self, cp: rain.group_lb_alarm_binding) -> int:
        if not self.ns_running_conf:
            log.warning("autoheal: ns config is not known, check skipped.")
            return 0xE
        m_0 = [l.split(" ") for l in self.ns_running_conf if cp.name in l]
        if m_0:
            m_1 = [l for l in m_0 if l[0].lower() == "add"]
            if m_1:
                m_2 = [l for l in m_1 if l[1].lower() == "cloud" and l[2].lower()
                       == "profile" and l[3] == cp.name]
                if not m_2:
                    m_2 = [l for l in m_1 if l[1].lower(
                    ) == "cloudprofile" and l[2] == cp.name]
                if m_2:
                    if len(m_2) > 1:
                        return 2
                    m_3 = m_2[0]
                    if not (cp.group in m_3 or cp.service in m_3):
                        return 3
                    if cp.lb_name not in m_3:
                        return 4
                    return 0
        log.warning(
            "cloud profile '%s': missing from ns config, please remove and add cloudprofile again",
            cp.name)
        return 0xE

    @wrap_check
    def _check_ns_conf_sg(self, sg: rain.service) -> int:
        m_0 = [l.split(" ") for l in self.ns_running_conf if sg.name in l]
        if m_0:
            m_1 = [l for l in m_0 if l[0].lower() == "add"]
            if m_1:
                m_2 = [l for l in m_1 if l[1].lower(
                ) == "servicegroup" and l[2] == sg.name]
                if m_2:
                    if len(m_2) > 1:
                        return 2
                    return 0
        log.warning(
            "servicegroup '%s': missing from ns config, please remove and add servicegroup again",
            sg.name)
        return 0xE

    @wrap_check
    def _check_remove_stale_asg(self, asg: rain.group) -> int:
        try:
            cloud.cleanup_on_group_remove(asg)
        except Exception as err:
            log.warning("autoheal: failed cleanup: group %s: %s",
                        asg.name, str(err))
        try:
            if cloud.check_event_queue():
                queue_conf = rain.event_queue().get(config, "default")
                cloud.remove_notification_config_from_group(
                    queue_conf, asg.name)
        except Exception as err:
            log.warning("autoheal: failed to clear notification: %s",
                        str(err))
        asg.remove(config, asg)
        log.info("autoheal: no longer watching autoscale %s", asg.name)
        return 1

    @wrap_check
    def _check_remove_stale_lb(self, lb: rain.loadbalancer) -> int:
        try:
            local.remove_lb(lb)
        except Exception as err:
            log.warning("autoheal: failed to remove lb %s: %s",
                        "lb.name", str(err))
        lb.remove(config, lb)
        log.info("autoheal: no longer watching lb %s", lb.name)
        return 1

    def run(self) -> None:
        log.debug("started")
        self.latest_start_ts = time.time()
        self.prev_retcode, self.retcode = self.retcode, []
        try:
            try:
                self.ns_running_conf = local.get_nsrunningconfig()
            except Exception as err:
                # Autoheal will use existing value of running conf.
                log.warning("autoheal: failed to get ns conf: %s", err)

            self.ipc_conf_lock.acquire()
            try:
                cps = rain.group_lb_alarm_binding().get(config)
            except BaseException:
                cps = []
            try:
                asgs = rain.group().get(config)
            except BaseException:
                asgs = []
            try:
                sgs = rain.service().get(config)
            except BaseException:
                sgs = []
            try:
                lbs = rain.loadbalancer().get(config)
            except BaseException:
                lbs = []

            asgs_sg = set(sg.group_name for sg in sgs)
            sgs_cp = set(cp.service for cp in cps)
            sgs_non_cp = set(sg for sg in sgs if sg.name not in sgs_cp)
            asgs_stale = set(asg for asg in asgs if asg.name not in asgs_sg)
            lbs_cp = set(cp.lb_name for cp in cps)
            lbs_stale = set(lb for lb in lbs if lb.name not in lbs_cp)

            log.debug("asgs_sg: %s", asgs_sg)
            log.debug("sgs_cp: %s", sgs_cp)
            log.debug("sgs_non_cp: %s", sgs_non_cp)
            log.debug("asgs_stale: %s", asgs_stale)
            log.debug("lbs_cp: %s", lbs_cp)
            log.debug("lbs_stale: %s", lbs_stale)

            # autoscalgroups (asgs): includes all
            log.debug("asg in cloud")
            self.retcode.append([self._check_cloud_asg(asg)
                                for asg in asgs_sg])

            # cloud profile entries
            log.debug("cloud profile(s) -> asg in rain")
            self.retcode.append(
                [self._check_asg(cp.group, cp.service, cp.name) for cp in cps])
            log.debug("cloud profile(s) -> sg in ns")
            self.retcode.append([self._check_ns_sg(cp.service) for cp in cps])
            log.debug("cloud profile(s) -> lb in ns")
            self.retcode.append([self._check_ns_lb(cp.lb_name) for cp in cps])
            log.debug("cloud profile(s) -> binding in ns")
            self.retcode.append(
                [self._check_ns_binding(cp.lb_name, cp.service) for cp in cps])
            log.debug("cloud profile(s) in ns.conf")
            self.retcode.append([self._check_ns_conf_cp(cp) for cp in cps])

            # non cloud profile entries
            log.debug("servicegroup(s) -> asg in rain")
            self.retcode.append(
                [self._check_asg(sg.group_name, sg.name) for sg in sgs_non_cp])
            log.debug("servicegroup(s) -> sg in ns")
            self.retcode.append([self._check_ns_sg(sg.name)
                                for sg in sgs_non_cp])
            log.debug("servicegroup(s) in ns.conf")
            self.retcode.append([self._check_ns_conf_sg(sg)
                                for sg in sgs_non_cp])

            # remove stale entries: asg
            log.debug("autoheal: stale asg")
            self.retcode.append([self._check_remove_stale_asg(asg)
                                for asg in asgs_stale])

            # remove stale entries: lb
            log.debug("autoheal: stale lb")
            self.retcode.append([self._check_remove_stale_lb(lb)
                                for lb in lbs_stale])

            # alarm entries ???

            log.debug("autoheal: finished")
        except config_file_error as err:
            log.warning("autoheal: %s", err)
        except Exception as err:
            log.error(
                "Exception in autoheal::run: %s\n%s",
                err,
                traceback.format_exc())
        finally:
            self.ipc_conf_lock.release()
        self.latest_run_ts = time.time()

    def log_autoheal_run_stats(self) -> None:
        """
        run_stats keys:
            p : passed scenarios, no autoheal needed.
            i : fixed scenarios.
            f : issues found and not fixed.
            m : issue found and critical alert logged for manual fix.
            x : Exception seen with autoheal script.
            n : new issue.
            r : recurring non passing issue
            u : unknown state
            t : time taken by autoheal overall
        """
        self.prev_run_stats = self.run_stats
        self.run_stats = {"p": 0, "i": 0, "f": 0, "m": 0, "x": 0,
                          "n": 0, "r": 0, "u": 0, "t": 0}

        if not self.prev_retcode:
            # not enough data points to stat autoheal run issues.
            return

        self._prev_retcode_str = self._retcode_str
        self._retcode_str = self.get_retcode_str()
        log.info("AH:RC: %s", self._retcode_str)
        log.debug("AH:PC: %s", self._prev_retcode_str)

        try:
            for (cur, pre) in zip(self.retcode, self.prev_retcode):
                for (cur_el, pre_el) in zip(cur, pre):
                    if cur_el == 0:
                        self.run_stats["p"] += 1
                    elif cur_el == 1:
                        self.run_stats["i"] += 1
                    elif 1 < cur_el < 0xE:
                        self.run_stats["f"] += 1
                    elif cur_el == 0xE:
                        self.run_stats["m"] += 1
                    elif cur_el == 0xF:
                        self.run_stats["x"] += 1
                    else:
                        self.run_stats["u"] += 1

                    if cur_el != 0 and pre_el == 0:
                        self.run_stats["n"] += 1
                    elif cur_el > 1 and pre_el > 1 and cur_el == pre_el:
                        self.run_stats["r"] += 1
        except Exception as err:
            log.error("Exception in run_stats: %s\n%s",
                      err, traceback.format_exc())
        try:
            self.run_stats["t"] = round(
                self.latest_run_ts - self.latest_start_ts, 3)
            log.info("AH:ST: %s", self.run_stats)
            log.debug("AH:PT: %s", self.prev_run_stats)
        except Exception as err:
            log.error("Exception in print run_stats: %s\n%s",
                      err, traceback.format_exc())
