#!/usr/bin/env python
"""
Copyright 2000-2024 Citrix Systems, Inc. All rights reserved.
This software and documentation contain valuable trade secrets
and proprietary property belonging to Citrix Systems, Inc.
None of this software and documentation may be copied,
duplicated or disclosed without the express written permission
of Citrix Systems, Inc.
"""


import sys
import time
import json
import os
import psutil
import signal

import rainman_core.common.constants as CONST
from rainman_core.common.logger import RainLogger
from rainman_core.common import rain
from rainman_core.common.exception import *
from rainman_core.common.rain_globals import init_rain_globals
from rainman_core.common.ipc_lock import get_conf_lock
from rainman_core import autoheal


rlog = RainLogger(CONST.SCALE_LOG_FILE_NAME, CONST.DEFAULT_LOG_LEVEL)
log = rlog.logger
config, local, cloud = init_rain_globals()
conf_lock = get_conf_lock()
pid_file = ""
DRAINING_OR_DISABLED_STATE = ["OUT OF SERVICE", "GOING OUT OF SERVICE"]

def handle_signal(signum, frame):
    log.info(f"Received signal - {signum}")

    if signum == signal.SIGUSR1:
        log.info("No longer primary/CCO node, exiting")
        exit_daemon(11)

    elif signum == signal.SIGUSR2:
        log.info("SIGUSR2 received, ignoring")

    else:
        log.critical(f"Asked to handle an unexpected signal '{signum}'")


def sync_local_to_cloud_groups(groups, services):
    for grp in groups:
        try:
            cloud_servers = cloud.get_servers_in_group(grp)
        except:
            rlog.error_trace(f"failed to get server list for cloud->group {grp.name}")
            continue
        for svc_grp in services:
            try:
                if svc_grp.group_name != grp.name:
                    continue

                sg_servers = local.get_servers_in_group(svc_grp)

                add_servers = [x for x in cloud_servers if x.ip not in (
                    y.ip for y in sg_servers)]
                remove_servers = [x for x in sg_servers if x.ip not in (
                    y.ip for y in cloud_servers)]

                for server in add_servers:
                    log.info(
                        f"adding server {server.name!r} to servicegroup {svc_grp.name!r}")
                    local.add_server_to_group(server, svc_grp)

                for server in remove_servers:
                    log.info(
                        f"removing server {server.name!r} from servicegroup {svc_grp.name!r}")
                    local.remove_server_from_group(server, svc_grp)

                cloud_server_dict = {}
                sg_server_dict = {}
                for server in cloud_servers:
                    cloud_server_dict[server.name] = server
                for server in sg_servers:
                    sg_server_dict[server.name] = server

                servers_in_standby = [cloud_server_dict[x]
                    for x in cloud_server_dict
                    if cloud_server_dict[x].state == cloud.SERVER_STANDBY_STR
                    and x in sg_server_dict
                    and sg_server_dict[x].state not in DRAINING_OR_DISABLED_STATE]
                for server in servers_in_standby:
                    local.drain_server_in_group(server, svc_grp, 0, graceful = "YES")

                servers_in_inservice = [sg_server_dict[x]
                    for x in sg_server_dict
                    if sg_server_dict[x].state in DRAINING_OR_DISABLED_STATE
                    and x in cloud_server_dict
                    and cloud_server_dict[x].state == cloud.SERVER_ACTIVE_STR]
                for server in servers_in_inservice:
                    local.enable_server_in_group(server, svc_grp)

            except:
                rlog.error_trace(f"syncronize failed for cloud->group:ns->servicegroup {grp.name}:{svc_grp.name}")


def is_present(groupList, group):
    for grp in groupList:
        if grp.group == group:
            return True
    return False


def get_removed_server(group_name):
    remove_servers = []
    try:
        group = rain.group()
        group = group.get(config, group_name)
        configured_services = group.list_all_services_in_group(config)
        group_servers = local.get_servers_in_group(configured_services[0])
        cloud_servers = cloud.get_servers_in_group(group)
        remove_servers = [x for x in group_servers if x.ip not in (
            y.ip for y in cloud_servers)]

    except (config_not_present, config_failed) as e:
        log.warning("Failed to syncronize %s: %s" % (group_name, str(e)))
    return remove_servers


def add_group_to_draining(draining, group, slow_server):
    services = group.list_all_services_in_group(config)
    if not services:
        log.warning("Stale group entry present, no services present for group")
    local_servers = local.get_servers_in_group(services[0])
    drain_servers = []

    try:
        if cloud.get_cloud_platform() in ('AZURE', 'AZURESTACK'):
            drain_servers = []
        elif cloud.get_cloud_platform() == "GCP":
            # We need to have at least as many servers left as we want to drain
            drain_servers = get_removed_server(group.name)
        else:
            local_servers[group.drain_count]
            if slow_server == None:
                log.info("slow_server NOT Detected")
                drain_servers = local_servers[group.drain_count:]
            else:
                log.info("slow_server %s detected" % (slow_server))
                for local_server in local_servers:
                    if local_server.name == slow_server:
                        drain_servers.append(local_server)
                        break
        drain = rain.drain(group.name, drain_servers, int(group.drain_time))
        drain.last_activity = time.time()
        draining.append(drain)

    except IndexError:
        log.warning(
            "Attempted to drain %s from %s, but not enough servers left: %s"
            % (group.drain_count, group.name, len(local_servers)))
        return

    for service in services:
        for server in drain_servers:
            try:
                log.info("Draining server: %s",str(server.name))
                local.drain_server_in_group(server, service, group.drain_time)
            except config_not_present:
                log.warning(
                    "Attempted to drain %s from %s, but it was not found."
                    % (group.drain_count, service.name))


def drain_progress(draining):

    for drain in draining:

        # Make sure the drain time has been reached before we bother checking on it
        if (time.time() - drain.drain_time) >= drain.last_activity:

            log.info("Drain time for group %s expired, checking servers" %
                     (drain.group))

            try:
                group = rain.group()
                group = group.get(config, drain.group)
                services = group.list_all_services_in_group(config)
            except config_not_present:
                log.warning(
                    "Group %s is missing from config"
                    % (drain.group))
                continue

            if cloud.get_cloud_platform() in ('AZURE', 'AZURESTACK', 'GCP'):
                drain.servers = get_removed_server(group.name)

            for server in drain.servers:

                log.info(
                    "Checking status of server %s"
                    % (server.name))

                if cloud.get_cloud_platform() in ('AZURE', 'AZURESTACK', 'GCP'):
                    for service in services:
                        try:
                            local.remove_server_from_group(server, service)
                        except:
                            log.warning("Exception while draining '%s' from '%s'."
                                        % (server.name, service.name))
                    drain.servers.remove(server)
                    log.info("removed server %s from draining list" %
                             server.name)
                    continue

                for service in services:
                    try:
                        if local.get_server_status_in_group(server, service) != "OUT OF SERVICE":
                            break
                    except config_failed as e:
                        log.warning("server '%s' is not found in local->'%s'"
                                    % (server.name, service.name))
                else:
                    log.info("Drain complete for '%s', removing it." %
                             server.name)
                    for service in services:
                        try:
                            local.remove_server_from_group(server, service)
                        except config_failed as e:
                            log.warning("server '%s' is not found in local->'%s'"
                                        % (server.name, service.name))
                    try:
                        cloud.remove_server_from_group(server, group)
                    except config_failed as e:
                        log.debug("server is not found in cloud")
                        pass

                    drain.servers.remove(server)

            if len(drain.servers) == 0:
                draining.remove(drain)

        else:
            continue


def is_server_draining_and_remove(draining, group, server):
    for drain in draining:
        if drain.group == group:
            if server in drain.servers:
                log.info("Removing server from draining list")
                drain.servers.remove(server)


def init_event_queue():
    queue_conf = rain.event_queue()
    try:
        queue_conf = rain.event_queue().get(config, "default")
    except (config_not_present, config_file_error):
        log.info("Adding queue_conf to rainman.conf...")
        queue_conf.name = "default"
        queue_conf.details = cloud.get_event_queue_details()
        if queue_conf.details is not None:
            queue_conf.add(config, queue_conf)
    finally:
        if cloud.get_cloud_platform() == 'AZURE':
            queue_conf.details = cloud.get_event_queue_details()
            if queue_conf.details is None:
                return None
            queue_conf.update(config, queue_conf)
        cloud.add_event_queue(queue_conf)
        return queue_conf


def get_configured_groups():
    try:
        return rain.group().get(config)
    except (config_not_present, config_file_error):
        log.warning("No configured groups found.")
        return []


def get_configured_services():
    try:
        return rain.service().get(config)
    except (config_not_present, config_file_error):
        log.warning("No configured services found.")
        return []

def get_configured_groups_diff(new_groups, old_groups):
    added_groups = [x.name for x in new_groups if x.name not in (
                    y.name for y in old_groups)]
    return (added_groups)

def is_primary_node():
    nodestate = local.get_node_config()
    if nodestate in ['Primary', 'CCO', 'StandAlone']:
        return True
    return False


def wait_for_config():
    signal.pause()
    try:
        return rain.event_queue().get(config, "default")
    except (config_not_present, config_file_error):
        log.info("No groups configured yet...")
        return None


def clean_wait_for_config(tk_refresh_th):
    queue_conf = None
    if cloud.get_cloud_platform() == "GCP":
        cloud.stop_pub_sub(queue_conf)
    else:
        cloud.stop_webhook(queue_conf)
    # Stoping token refresh thread
    if tk_refresh_th is not None:
        tk_refresh_th.cancel()

    queue_conf = wait_for_config()
    if queue_conf is None:
        return None

    if cloud.get_cloud_platform() == "GCP":
        cloud.start_pub_sub(queue_conf)
    else:
        cloud.start_webhook(queue_conf)
    # starting token refresh thread
    if tk_refresh_th is not None:
        tk_refresh_th.start()
    return queue_conf


def wait_for_event(tk_refresh_th):
    log.info("No groups configured, Sleeping")
    if cloud.get_cloud_platform() in ("AZURE", "GCP"):
        return clean_wait_for_config(tk_refresh_th)
    elif cloud.get_cloud_platform() == "AWS":
        return wait_for_config()


def check_events_from_queue():
    if cloud.get_cloud_platform() in ("AWS", "AZURE", "GCP"):
        return True
    return False


def check_running_process():
    ret = 0

    for proc in psutil.process_iter():
        try:
            pinfo = proc.as_dict(attrs=['pid', 'cmdline'])
            cmdline = pinfo['cmdline']
            if len(cmdline) > 1 and 'python' in cmdline[0] and 'rain_scale' in cmdline[1] and pinfo['pid'] != os.getpid():
                log.debug("Another rain_scale process %d is running, %d exiting..." % (
                    pinfo['pid'], os.getpid()))
                ret = 1
        except psutil.NoSuchProcess:
            pass

    return ret


def clear_pid_file():
    if pid_file:
        os.remove(pid_file)


def exit_daemon(ret_code):
    clear_pid_file()
    sys.exit(ret_code)


def main():
    '''
    This daemon is designed to be restarted when the config changes.
    Using this scheme allows us to minimize the number of times we
    scrape the local and cloud configuration and keep unecessary API
    calls to a minimum when we might not be doing any work as a result
    of those API calls. For example, every configured cloud group needs to
    send notifications to us. This only needs to be done once, but you
    need to make API calls to check in the first place. Avoiding checking
    each poll cycle.
    At startup check if Event queue exists else exit and retry again.
    '''


    my_pid = os.getpid()
    if not is_primary_node():
        log.info(
            f"Not on a primary node. process {my_pid!r} exiting on start...")
        exit(9)
    if check_running_process():
        log.info(
            f"process already running. Duplicate process {my_pid!r} exiting on start...")
        exit(10)
    log.info(f"process {my_pid!r} started")

    try:
        global pid_file
        pid_file = cloud.get_daemon_pid_file()
        with open(pid_file, "w") as fp:
            fp.write("%d" % os.getpid())
    except IOError as e:
        log.error("Not able to create pid file %s error:" + str(e) %
                  (pid_file))

    signal.signal(signal.SIGUSR1, handle_signal)
    signal.signal(signal.SIGUSR2, handle_signal)

    draining = []
    syncing = []
    queue_conf = None
    configured_groups = None
    configured_services = None

    config.upgrade_conf()

    queue_conf = rain.event_queue()
    if cloud.check_event_queue():
        try:
            queue_conf = rain.event_queue().get(config, "default")
        except (config_not_present, config_file_error):
            log.info("No groups configured, Sleeping")
            queue_conf = wait_for_config()
        finally:
            if cloud.get_cloud_platform() in ("AZURE", "GCP"):
                queue_conf.details = cloud.get_event_queue_details()
                if queue_conf.details is None:
                    log.critical("Not able to update queue")
                    exit_daemon(12)
                queue_conf.update(config, queue_conf)
                if len(queue_conf.details) != 0:
                    if cloud.get_cloud_platform() == "GCP":
                        cloud.start_pub_sub(queue_conf)
                    else:
                        cloud.start_webhook(queue_conf)

    tk_refresh_th = None
    if cloud.get_cloud_platform() == "AZURE":
        try:
            tk_refresh_th = rain.looped_timer(
                cloud.timer_duration, cloud.do_timer_events, queue_conf)
            tk_refresh_th.start()
        except:
            rlog.error_trace("timer failed to initiate, running without timer.")

    first_iteration = True
    previous = time.time() - 20
    ah = autoheal.AutoHeal()
    configured_groups_previous = []
    log.info("init success, entering daemon loop.")
    while (True):
        current = time.time()
        time_diff = current - previous
        if first_iteration or time_diff >= 20:
            conf_lock.acquire_ro()
            configured_groups = get_configured_groups()
            configured_groups_diff = get_configured_groups_diff(configured_groups, configured_groups_previous)
            configured_groups_previous = configured_groups
            configured_services = get_configured_services()
            conf_lock.release_ro()

            if check_events_from_queue():
                if not configured_groups:
                    queue_conf = wait_for_event(tk_refresh_th)
                    if queue_conf is None:
                        continue

            log.debug(f"configured groups: {configured_groups!r}")
            log.debug(f"configured services: {configured_services!r}")
            sync_groups = []
            sync_services = []

            if first_iteration:
                sync_groups.extend(configured_groups)
                sync_services.extend(configured_services)
                first_iteration = False

            for group in configured_groups:
                if group not in sync_groups and not is_present(draining, group.name):
                    log.debug(f"sync group planned: {group.name!r}")
                    sync_groups.append(group)

            sync_group_names = [g.name for g in sync_groups]
            for service in configured_services:
                if service not in sync_services and service.group_name in sync_group_names:
                    log.debug(f"sync service planned: {service.name!r}")
                    sync_services.append(service)

            if cloud.check_event_queue():
                try:
                    cloud.configure_events_for_groups(
                        queue_conf, sync_group_names)
                except config_failed as e:
                    log.warning("Configuration failed: %s", str(e))
            # On GCP sync is required only once when new ASG is added or if Pub/Sub
            # initialization fails.
            if cloud.get_cloud_platform() == "GCP":
                if cloud.ps_initialized == False or (configured_groups_diff):
                    sync_local_to_cloud_groups(sync_groups, sync_services)
            else:
                sync_local_to_cloud_groups(sync_groups, sync_services)
            previous = current

        if check_events_from_queue():
            events = []
            sync_groups = []
            if configured_groups:
                #log.debug("Long polling for events on %s" % (queue_conf.name))
                events = [event for event in cloud.get_events_from_queue(
                    queue_conf) if event.group in [g.name for g in configured_groups]]

            for event in events:
                log.info(f"handling event: {event!s}")
                if event.action == 'sync':
                    try:
                        sync_groups.append(
                            rain.group().get(config, event.group))
                    except:
                        rlog.error_trace("failed to add group for sync from event.")
                    if event.event == 'EC2_INSTANCE_TERMINATE':
                        is_server_draining_and_remove(
                            draining, event.group, event.server)
                elif event.action == 'drain':
                    if cloud.get_cloud_platform() == 'GCP':
                        group = group.get(config, event.group)
                        if group.drain == "true":
                            add_group_to_draining(
                                draining, group, event.server)
                        else:
                            sync_groups.append(group)
                        continue

                    if event.event == 'ALARM':
                        log.debug("Draining for group %s" % event.group)
                        group = rain.group().get(config, event.group)
                        if is_present(draining, event.group):
                            log.info(
                                "%s group is in draining list, Alarm is ignored" % event.group)
                            continue
                        cloud_servers = cloud.get_servers_in_group(group)
                        server_count = len(cloud_servers)
                        if server_count <= cloud.get_min_servers_in_group(event.group) or server_count <= 1:
                            log.info("There is only %d server in ASG, %s. Alarm is ignored" % (
                                server_count, event.group))
                        else:
                            group = group.get(config, event.group)
                            add_group_to_draining(
                                draining, group, event.server)
                    else:
                        log.debug("Skipping non ALARM event for %s" %
                                  (event.group))

            if sync_groups:
                sync_local_to_cloud_groups(sync_groups, configured_services)

        drain_progress(draining)

        if previous > ah.latest_run_ts + CONST.AUTOHEAL_INTERVAL:
            ah.run()
            ah.log_autoheal_run_stats()

        # Sleeping for 1 second
        time.sleep(1)

    exit_daemon(0)


if __name__ == "__main__":
    try:
        main()
    except:
        try:
            conf_lock.release_ro()
        except:
            pass
        if log:
            rlog.error_trace()
        exit_daemon(-1)
