"""
Copyright 2000-2024 Citrix Systems, Inc. All rights reserved.
This software and documentation contain valuable trade secrets
and proprietary property belonging to Citrix Systems, Inc.
None of this software and documentation may be copied,
duplicated or disclosed without the express written permission
of Citrix Systems, Inc.
"""


import json
import time
import sys
import os
import requests
import threading
import queue
import base64
import fcntl
import datetime
from random import SystemRandom
import shutil
import certifi
import os
import glob
import requests

from functools import wraps
import azurelinuxagent.common.utils.shellutil as shellutil
from http.server import BaseHTTPRequestHandler, HTTPServer
import socketserver

import adal
from azure.identity import DefaultAzureCredential
from azure.identity import ClientSecretCredential
from msrestazure.azure_exceptions import CloudError
from msrest.exceptions import AuthenticationError
from msrestazure.azure_active_directory import AdalAuthentication
from msrestazure.azure_cloud import AZURE_PUBLIC_CLOUD
from azure.mgmt.monitor import MonitorManagementClient
from azure.mgmt.compute import ComputeManagementClient
from azure.mgmt.network import NetworkManagementClient
from azure.mgmt.resource.resources import ResourceManagementClient
from azure.mgmt.resourcegraph import ResourceGraphClient
from azure.mgmt.resourcegraph.models import (QueryRequestOptions, QueryRequest)
# TODO: Update autoscale related code to be compatible with latest pkg
from azure.mgmt.monitor.v2015_04_01.aio.operations import AutoscaleSettingsOperations as autoscale
from azure.mgmt.monitor.models import (
    ScaleDirection,
    AutoscaleProfile,
    ScaleRule,
    ScaleAction,
    WebhookNotification,
    WebhookReceiver,
    ActionGroupResource,
    ActivityLogAlertResource,
    ActivityLogAlertLeafCondition,
    ActivityLogAlertAllOfCondition,
    ActivityLogAlertActionList,
    AutoscaleNotification,
    AutoscaleSettingResource)
from azurelinuxagent.common.osutil.nsvpx import NSVPXOSUtil

from rainman_core.common.logger import RainLogger
from rainman_core.common import rain
from rainman_core.common.base import base_cloud_driver
from rainman_core.common.rain import rainman_config, server, group, group_info, event
from rainman_core.common.exception import *
from rainman_core.common.stats import stats_config

#########################################################################
### XXX FreeBSD 11.x: requests prefers IPV6 by default on
###                   11.x. Force requests to use IPV4 instead
requests.packages.urllib3.util.connection.HAS_IPV6 = False
##########################################################################


azurestack_platform = True if NSVPXOSUtil.check_ns_on_azstack() else False


log = RainLogger.getLogger()

# Rainman Queue
rain_event_queue = queue.Queue(maxsize=0)

def rand_hex_8():
    return "%08x" % int(SystemRandom().getrandbits(32))


def rand_hex_32():
    return rand_hex_8()+rand_hex_8()+rand_hex_8()+rand_hex_8()


def json_loads_to_ascii(json_text):
    return _to_ascii(
        json.loads(json_text, object_hook=_to_ascii),
        ignore_dicts=True
    )

CA_CERT_DIR = "/nsconfig/.AZURE/ca_cert_dir/"
CA_CERT_PEM_PATH = "/nsconfig/.AZURE/ca_cert_dir/*.pem"
ENDPOINT_JSON_FILE = "/nsconfig/.AZURE/api_endpoints.json"
COMBINED_CA_CERT = "/nsconfig/.AZURE/combined.pem"

""" Format of ENDPOINT_JSON_FILE
{
"ad-end-point": "https://xxxxxx",
"azure-api-resource-manager": "https://xxx"
"active_directory_resource_id": "https://xxx"
}

"""

def get_requests_cert_path():
    return certifi.where()

def merge_files(outfile):
    with open(get_requests_cert_path(), 'rb') as readfile:
        shutil.copyfileobj(readfile, outfile)
    for filename in glob.glob(CA_CERT_PEM_PATH):
        with open(filename, 'rb') as readfile:
            shutil.copyfileobj(readfile, outfile)

def merge_certs():
    if os.path.exists(CA_CERT_DIR) and not os.path.isfile(CA_CERT_DIR):
        with open(COMBINED_CA_CERT, 'wb') as outfile:
            merge_files(outfile)
        os.environ['REQUESTS_CA_BUNDLE'] = COMBINED_CA_CERT
    else:
        if 'REQUESTS_CA_BUNDLE' in os.environ:
            del os.environ['REQUESTS_CA_BUNDLE']
        log.debug("No certificate directory found for override")

def get_endpoint(service):
    f_json = open(ENDPOINT_JSON_FILE)
    json_data = json.load(f_json)
    f_json.close()
    return json_data[service]

def get_active_directory_endpoint():
    try :
        end_point_url = get_endpoint("ad-end-point")
    except Exception as e:
        log.debug("Caught exception while reading api_endpoints.json : %s", e)
        end_point_url = AZURE_PUBLIC_CLOUD.endpoints.active_directory
    return end_point_url

def get_api_resource_manager_endpoint():
    try :
        end_point_url = get_endpoint("azure-api-resource-manager")
    except Exception as e:
        log.debug("Caught exception while reading api_endpoints.json : %s", e)
        end_point_url = AZURE_PUBLIC_CLOUD.endpoints.resource_manager
    return end_point_url

def get_active_directory_resource_id():
    try :
        end_point_url = get_endpoint("active_directory_resource_id")
    except Exception as e:
        log.debug("Caught exception while reading api_endpoints.json : %s", e)
        end_point_url = AZURE_PUBLIC_CLOUD.endpoints.active_directory_resource_id
    return end_point_url

def _to_ascii(data, ignore_dicts=False):
    if isinstance(data, str):
        return data
    if isinstance(data, list):
        return [_to_ascii(item, ignore_dicts=True) for item in data]
    if isinstance(data, dict) and not ignore_dicts:
        return {
            _to_ascii(key, ignore_dicts=True): _to_ascii(value, ignore_dicts=True)
            for key, value in list(data.items())
        }
    return data


class WebhookHandler(BaseHTTPRequestHandler):

    server_version = 'WebHook/1.0'
    sys_version = ''
    default_request_version = 'WebHook/1.0'
    log = log

    def _set_headers(self):
        self.send_response(200)
        self.send_header('Content-type', 'text/html')
        self.end_headers()

    def _get_params_dict(self):
        param_d = {}
        self.log.debug("GET params in URL path: %s" % self.path)
        try:
            params = self.path.split("?", 1)
            if len(params) == 2:
                params = params[-1].split("&")
                for par in params:
                    k_v_pair = par.split("=", 1)
                    if len(k_v_pair) == 2:
                        param_d[k_v_pair[0]] = k_v_pair[1]
        except:
            self.log.warning("Failed to parse GET elements")
        return param_d

    def _is_valid_token(self):
        if not azure_config.webhook_key:
            self.log.warning("No Token to validate")
            return True

        get_params = self._get_params_dict()
        self.log.debug("GET params: %s" % get_params)
        if "tokenid" in get_params:
            key = get_params["tokenid"]
            if key == azure_config.webhook_key:
                self.log.info("Token verified")
                return True
            if key and key == azure_config.previous_webhook_key:
                self.log.info("Token verified with previous token")
                return True
            self.log.warning("invalid Token %s" % key)
        return False

    def do_GET(self):
        # We should not be handling GET request
        pass

    def do_HEAD(self):
        # We should not be handling HEAD request
        self._set_headers()

    def do_POST(self):
        try:
            # Fetch and store post data in event queue
            # Gets the size of data
            content_length = int(self.headers['Content-Length'])

            OFLAGS = fcntl.fcntl(self.rfile, fcntl.F_GETFL)
            nflags = OFLAGS | os.O_NONBLOCK
            fcntl.fcntl(self.rfile, fcntl.F_SETFL, nflags)
            # sometimes a valid event is not available by now
            time.sleep(2)

            post_data = self.rfile.read(content_length)
            if self._is_valid_token() == True:
                rain_event_queue.put(post_data)
            else:
                self.log.warning(
                    "data dropped due to invalid token. Data: %s" % post_data)
                azure_config.invalid_webhook_message += 1
        except Exception as e:
            self.log.error("Exception in POST: %s" % str(e))
            azure_config.invalid_webhook_message += 1
        self._set_headers()


class azure_reporting():

    sys_statinfo = {
        "cpuusagepcnt": ["Percent", 80],
        "mgmtcpuusagepcnt": ["Percent", 80],
        "memsizemb": ["Megabytes", None],
        "memuseinmb": ["Megabytes", None],
        "memusagepcnt": ["Percent", 90]}

    server_statinfo = {
        "avgsvrttfb": ["Milliseconds", 100],
        "surgecount": ["Count", 30],
        "cursrvrconnections": ["Count", 30],
        "svrestablishedconn": ["Count", 30]}

    def __init__(self, cloud_config):
        self.log = log
        self.cloud_config = cloud_config
        self.authenticated = False
        self.token = None
        (self.tenantid, self.client, self.client_key) = cloud_config.get_credentials()
        self.instanceid = cloud_config.get_own_instanceid()
        self.region = cloud_config.get_own_region()
        self.credentials = cloud_config.authentication()
        self.subscription_id = cloud_config.get_subscription()
        self.instanceName = cloud_config.get_own_vmname()
        self.resource_groupe_name = cloud_config.get_resource_group_name()
        self.headers = {'Authorization': 'Bearer {}'.format(
            self.credentials), 'Content-Type': 'application/json'}
        self.vm_resource_id = f"subscriptions/{self.subscription_id}/resourceGroups/{self.resource_groupe_name}/providers/Microsoft.Compute/virtualMachines/{self.instanceName}"
        self.url = f"https://{self.region}.monitoring.azure.com/{self.vm_resource_id}/metrics"
        self.stats = stats_config()
        self.set_log_level()
        pass

    def set_log_level(self):
        try:
            self.stats.reload_config()
            loglevel = self.stats.config.get('loglevel', 'INFO')
            self.log.setLevel(loglevel)
            self.log.debug("Setting log level to \"%s\"" % (loglevel))
        except Exception as e:
            self.log.error(
                "Failure seting log level to \"%s\": \"%s\"" % (loglevel, e))

    def get_subscription_id(self):
        return self.cloud_config.get_subscription()

    def get_instance_id(self):
        return self.cloud_config.get_own_instanceid()

    def formulate_metric(self, metric, stat_value, stats):
        resource = metric["resource"]
        property = metric["property"]
        value = float(stat_value[resource][property])
        namespace = stats.config.get("namespace", "Citrix ADC")
        self.log.debug("namespace = {}".format(namespace))
        data = {
            "time": datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'),
            "data": {
                "baseData": {
                    "metric": metric["name"]+" "+metric["unit"],
                    "namespace": namespace,
                    "dimNames": [
                        metric["dimensions"][0]["Name"]
                    ],
                    "series": [
                        {
                            "dimValues": [
                                metric["dimensions"][0]["Name"]
                            ],
                            "min": value,
                            "max": value,
                            "sum": value,
                            "count": 1
                        }
                    ]
                }
            }
        }
        self.log.debug("New metric sample for %s (count=%s, sum=%s, ts=%s)" % (
            data.get('data').get('baseData').get('metric'), data.get('data').get('baseData').get(
                'series')[0].get("count"), data.get('data').get('baseData').get('series')[0].get("sum"),
            data.get('time')))
        return data

    def azure_manage_connection(self):
        if self.authenticated is False:
            LOGIN_ENDPOINT = get_active_directory_endpoint()
            RESOURCE = "https://monitoring.azure.com/"
            # Authenticate to Azure Cloud
            context = adal.AuthenticationContext(
                LOGIN_ENDPOINT + '/' + self.tenantid)
            auth_response = context.acquire_token_with_client_credentials(
                RESOURCE, self.client, self.client_key)
            self.token = auth_response.get("accessToken")
            self.log.debug("Expires In : {}".format(
                auth_response.get("expiresIn")))
            self.authenticated = True
            self.log.debug("self.cloud_config.authenticated: %s" %
                           self.authenticated)
            self.log.debug("Authorization Bearer {}".format(self.token))
        return

    def publish_metrics(self, metric_data, name_space):
        self.azure_manage_connection()
        self.headers = {'Authorization': 'Bearer {}'.format(
            self.token), 'Content-Type': 'application/json'}
        # publish metrics to cloudwatch
        for metric in metric_data:
            try:
                resp = requests.post(self.url, data=json.dumps(
                    metric), headers=self.headers)
                self.log.debug("Publishing Request headers: %s" %
                               resp.request.headers)
                self.log.debug("Publishing Request url: %s" % resp.request.url)
                self.log.debug("Publishing Request body: %s" %
                               resp.request.body)
                if resp.status_code != 200:
                    self.log.debug(
                        "Publishing Response Status Code: %s" % resp.status_code)
                if resp.status_code == 401:
                    self.authenticated = False
                    self.azure_manage_connection()
                    self.headers = {'Authorization': 'Bearer {}'.format(
                        self.token), 'Content-Type': 'application/json'}
                    resp = requests.post(self.url, data=json.dumps(
                        metric), headers=self.headers)
            except Exception as e:
                self.log.error("Publishing Error: %s" % e)

    def configure_polling_for_groups(self, group_names):
        raise NotImplementedError(
            '%s: configure_polling_for_groups' % (self.__class__.__name__))


class WebHookListener(threading.Thread):
    def __init__(self, ip, port):
        threading.Thread.__init__(self)
        self.log = log
        self.ip = ip
        self.port = port
        self.daemon = True
        self.webhook = None
        self.start()

    def run(self):
        server_address = (self.ip, self.port)
        self.webhook = HTTPServer(server_address, WebhookHandler)
        self.log.info('Starting Webhook listener...')
        self.webhook.serve_forever()
        self.webhook.server_close()

    def stop(self):
        self.webhook.shutdown()

    def stopped(self):
        return not self.is_alive()


class azure_config(base_cloud_driver):
    cur_configured_groups = []
    invalid_webhook_message = 0
    log = log
    webhook_key = ""
    previous_webhook_key = ""
    webhook_expire_duration = 5 * 60

    def __init__(self):
        self.monitorMgmt = None
        self.computeMgmt = None
        self.networkMgmt = None
        self.resourceMgmt = None
        self.resourceGraph = None
        self.authenticated = False
        self.meta_data_url = "http://169.254.169.254/metadata/instance?api-version=2017-08-01"
        self.port = 9001
        self.web_hook_listenner = None
        self.subscription_id = self.get_subscription()
        self.credentials = self.authentication()
        self.timer_duration = azure_config.webhook_expire_duration
        self.log_folder = "/flash/nsconfig/.AZURE/"
        self.SERVER_STANDBY_STR = "TODOTODO"
        self.SERVER_ACTIVE_STR = "TODOTODO"

    @classmethod
    def renew_webhook_key(cls):
        cls.previous_webhook_key = cls.webhook_key
        cls.webhook_key = rand_hex_32()

    def get_credentials(self):
        cred = []
        count = 0
        cred_file_path = '/nsconfig/.AZURE/cred'
        if os.path.exists(cred_file_path):
            with open(cred_file_path, 'r') as fp:
                while True:
                    line = fp.readline()
                    if not line:
                        break
                    count += 1
                    dec_cont = base64.b64decode(line)
                    cred.append(dec_cont)
                fp.close()
        else:
            self.log.critical(
                'Authentication failed: %s not present' % (cred_file_path))
            raise azure_credentials_not_present(
                'Authentication failed: /nsconfig/.AZURE/cred not present')
        return cred

    def get_own_instanceid(self):
        cmd = "sysctl netscaler.sysuuid |awk '{print $2}'"
        ret, output = shellutil.run_get_output(cmd)
        instance_id = output.split('\n', 1)[0]
        if ret:
            self.log.error("Unable to get sys uuid\n")
            return ""
        return instance_id.strip()

    def is_managed_identity_assigned(self):
        params = {"api-version": "2018-02-01", "resource": "https://management.azure.com/"}

        # Get access token
        response = requests.get("http://169.254.169.254/metadata/identity/oauth2/token",
                                 headers={"Metadata": "true"}, params=params)
        if response.status_code == 200:
            return True
        else:
            return False

    def _init_api_clients(self,credentials):
        self.authenticated = True
        self.monitorMgmt = MonitorManagementClient(credentials, self.subscription_id, base_url=get_api_resource_manager_endpoint())
        self.computeMgmt = ComputeManagementClient(credentials, self.subscription_id, base_url=get_api_resource_manager_endpoint())
        self.networkMgmt = NetworkManagementClient(credentials, self.subscription_id, base_url=get_api_resource_manager_endpoint())
        self.resourceMgmt = ResourceManagementClient(credentials, self.subscription_id, base_url=get_api_resource_manager_endpoint())
        self.resourceGraph = ResourceGraphClient(credentials, base_url=get_api_resource_manager_endpoint())
        self.log.debug("Successfully authenticated")

    def authentication(self):
        if not azurestack_platform and self.is_managed_identity_assigned():
            try:
                credentials = DefaultAzureCredential()
                self._init_api_clients(credentials)
                return credentials
            except Exception as e:
                self.log.warning("Authentication failed using Managed identity: %s", (str(e)))
                self.log.info('Trying Legacy authentication ')
                return self.authentication_legacy()
        else :
            return self.authentication_legacy()

    def authentication_legacy(self):
        base_url = ""
        try:
            (TENANT_ID, CLIENT, KEY) = self.get_credentials()

            merge_certs()

            LOGIN_ENDPOINT = get_active_directory_endpoint()
            RESOURCE = get_active_directory_resource_id()

            if azurestack_platform:
                client = ns_restapi.get_nsrestapi_client()
                client.run()
                mystack_cloud = client.getendpoint_restapi()
                base_url = mystack_cloud.endpoints.management
                credentials, subid = client.get_credentials()
            else:
                credentials = ClientSecretCredential(
                    tenant_id=TENANT_ID.decode("utf-8"),
                    client_id=CLIENT.decode("utf-8"),
                    client_secret=KEY.decode("utf-8")
                )

            if azurestack_platform:
                self.log.debug("Successfully authenticated")
                self.authenticated = True
                self.monitorMgmt = MonitorManagementClient(
                    credentials, self.subscription_id, base_url=base_url)
                self.computeMgmt = ComputeManagementClient(
                    credentials, self.subscription_id, api_version='2017-03-30', base_url=base_url)
                self.networkMgmt = NetworkManagementClient(
                    credentials, self.subscription_id, api_version='2016-09-01', base_url=base_url)
                self.resourceMgmt = ResourceManagementClient(
                    credentials, self.subscription_id, base_url=base_url)
            else:
                self._init_api_clients(credentials)
            return credentials
        except azure_credentials_not_present as e:
            self.log.critical(
                'Authentication failed: azure_credentials_not_present')
            pass
        except Exception as e:
            self.log.critical('Authentication failed: %s' % (str(e)))
            raise azure_authentication_failure(str(e))

    def is_authenticated(self):
        return self.authenticated

    def azure_managed_connection(func):
        @wraps(func)
        def wrapper(*args, **kargs):
            self = args[0]
            if self.authenticated is False:
                self.authentication()
            return func(*args, **kargs)
        return wrapper

    @azure_managed_connection
    def prepare_resource_graph_request(self, _query, page_token=None):
        # Configure query options
        queryoption = QueryRequestOptions(
            skip_token=page_token
        )

        # Configure query request
        queryrequest = QueryRequest(
            subscriptions=[self.subscription_id],
            query=_query,
            options=queryoption
        )
        return queryrequest

#   Imporant: Using this function assumes that the passed query generates a table with only one column
    @azure_managed_connection
    def resource_graph_request(self, query):
        resources = self.resourceGraph.resources(self.prepare_resource_graph_request(query))
        results = [item for sublist in resources.data.rows for item in sublist ]

        # Check for paged results
        while resources.skip_token is not None:
            self.log.debug("Retrieving " +
                           str(resources.count) + " paged records")
            resources = self.resourceGraph.resources(
                self.prepare_resource_graph_request(query, resources.skip_token))
            results.append(
                [item for sublist in resources.data.rows for item in sublist])

        self.log.debug("Azure Graph - Query: " + str(query) +
                       " Results: " + str(results))
        return results

    @azure_managed_connection
    def get_group(self, group_names=None):
        try:
            result = []
            if group_names is None:
                autoscaling_groups = self._get_autoscaling_groups()
            else:
                autoscaling_groups = self._get_autoscaling_group(group=group_names)
            for autoscaling_group in autoscaling_groups:
                this_group = group()
                this_group.name = autoscaling_group.name
                this_group.region = autoscaling_group.location
                this_group.locations = self._get_resource_group_from_id(autoscaling_group.id)
                result.append(this_group)
                self.log.debug("get_group : adding autoscale group %s" %
                               (autoscaling_group.name))
        except AuthenticationError as e:
            self.log.critical('Authentication failed: %s' % (str(e)))
            self.authenticated = False
            raise azure_authentication_failure(str(e))
        except:
            raise
        return result

    @azure_managed_connection
    def get_group_info(self, group_names=None):
        try:
            result = []
            if group_names is None:
                autoscaling_groups = self._get_autoscaling_groups()
            else:
                autoscaling_groups = self._get_autoscaling_group(group=group_names)

            for autoscaling_group in autoscaling_groups:

                this_group = group_info()
                this_group.name = autoscaling_group.name
                this_group.region = autoscaling_group.location
                this_group.locations = self._get_resource_group_from_id(autoscaling_group.id)
                this_group.drain = False
                result.append(this_group)
                self.log.debug("get_group_info : adding autoscale group %s" % (
                    autoscaling_group.name))
        except AuthenticationError as e:
            self.log.critical('Authentication failed: %s' % (str(e)))
            self.authenticated = False
            raise azure_authentication_failure(str(e))
        except:
            raise
        return result

    '''
    Min Server constraint in group
    '''
    @azure_managed_connection
    def get_min_servers_in_group(self, group_name):
        '''Returning 0 as we do not manage deleting servers in Azure'''
        return 0
    '''
    Servers in groups
    '''
    @azure_managed_connection
    def get_servers_in_group(self, asggroup):
        result = []
        config = rainman_config()
        asg_info = group()

        try:
            if asggroup is None:
                autoscale_groups = self._get_autoscaling_groups()
            else:
                if asggroup.rg is None:
                    asggroup.rg = self.get_resource_group_name()
                autoscale_groups = self._get_autoscaling_group(group=asggroup.name,rg=asggroup.rg)

            for autoscale_group in autoscale_groups:
                self.log.debug("Processing autoscale_group: %s " %
                               (autoscale_group.name))

                asg_info = asg_info.get(config, autoscale_group.name)
                if azurestack_platform:
                    '''
                    In order to get the vmss name we use the provider Microsoft.Compute/virtualMachineScaleSets.
                    The output of this endpoint is only the name of the VMSS and not the autoscalegroup name
                    '''
                    vmss_name = autoscale_group.name
                    rg_name = self._get_resource_group_from_id(autoscale_group.id)
                else:
                    vmss_name, rg_name = self._get_vmscaleset_name(autoscale_group)

                self.log.debug("get_servers_in_group : RG name = %s vmscalet name = %s" % (
                    rg_name, vmss_name))

                vmSSVMPaged = self._get_vms_in_vmss(rg_name, vmss_name)
                if vmSSVMPaged is None:
                    self.log.debug("Cloud group %s is empty." % (vmss_name))
                for vmSSVM in vmSSVMPaged:
                    priv_ip = None
                    self.log.debug(
                        "get_servers_in_group : instance_id = %s" % (vmSSVM.instance_id))
                    try:
                        vm = self._get_vm_in_vmss(
                            rg_name, vmss_name, vmSSVM.instance_id)
                        if vm is None:
                            continue
                        interfaces = self._get_vm_nw_intfs_in_vmss(
                            rg_name, vmss_name, vmSSVM.instance_id)
                        if interfaces is None:
                            continue

                        for interface in interfaces:
                            name = vm.name
                            try:
                                interface.ip_configurations[0]
                            except IndexError:
                                pass
                            else:
                                priv_ip = interface.ip_configurations[0].private_ip_address

                    except CloudError as e:
                        self.log.debug('getting vm failed : %s' % (str(e)))
                    if priv_ip is None:
                        continue
                    this_server = server()
                    this_server.name = name
                    this_server.ip = priv_ip
                    result.append(this_server)
                    self.log.debug("Cloud group %s contains: %s %s" %
                                   (vmss_name, name, priv_ip))

        except AuthenticationError as e:
            self.log.critical('Authentication failed: %s' % (str(e)))
            self.authenticated = False
            raise azure_authentication_failure(str(e))
        except CloudError as e:
            self.log.info(
                'Unable to get servers : Cloud Error : %s' % (str(e)))
            raise
        except Exception as e:
            self.log.info('Unable to get servers :  %s' % (str(e)))
            raise
        return result

    @azure_managed_connection
    def add_server_to_group(self, server, group):
        '''
        Scale up for AWS
        Might ignore server details? Need to look into amazon API docs.
        '''
        pass

    @azure_managed_connection
    def remove_server_from_group(self, server, asggroup):
        result = []
        config = rainman_config()
        asg_info = group()
        try:
            if asggroup is None:
                autoscale_groups = self._get_autoscaling_groups()
            else:
                autoscale_groups = self._get_autoscaling_group(group=asggroup.name,rg=asggroup.rg)

            for autoscale_group in autoscale_groups:
                self.log.debug("Processing autoscale_group: %s" %
                               (autoscale_group.name))
                asg_info = asg_info.get(config, autoscale_group.name)

                vmss_name, rg_name = self._get_vmscaleset_name(autoscale_group)
                self.log.debug("remove_server_from_group : RG name = %s vmscalet name = %s" % (rg_name, vmss_name))

                vmSSVMPaged = self._get_vms_in_vmss(rg_name, vmss_name)
                if vmSSVMPaged is None:
                    self.log.debug("Cloud group %s is empty." % (vmss_name))
                try:
                    vmSSVM = [
                        vm for vm in vmSSVMPaged if vm.name in server.name]
                    if vmSSVM is not None:
                        self._delete_vm_in_vmss(
                            rg_name, vmss_name, vmSSVM.instance_id)
                        self.log.debug("server %s is deleted from vmss %s." % (
                            server.name, vmss_name))
                        break
                    else:
                        self.log.debug("server %s is not present in vmss %s." % (
                            server.name, vmss_name))
                except CloudError as e:
                    self.log.debug('removing  vm failed : %s' % (str(e)))
                except Exception as e:
                    self.log.debug('VM is already deleted : %s' % (str(e)))
        except AuthenticationError as e:
            self.log.critical('Authentication failed: %s' % (str(e)))
            self.authenticated = False
            raise azure_authentication_failure(str(e))
        except Exception as e:
            raise

    @azure_managed_connection
    def get_instance_freeips(self, mac=None):
        if self.get_instance_count() < 2:
            self.log.info("Only one interface is present")
            return None
        try:
            ip_lists = self._get_lb_public_ip(mac)
            if len(ip_lists) > 0:
                return ip_lists
            else:
                return self.get_nw_intf_ips(mac)
        except AuthenticationError as e:
            self.log.critical('Authentication failed: %s' % (str(e)))
            self.authenticated = False
            raise azure_authentication_failure(str(e))
        except Exception as e:
            self.log.info("Unable to get free ips : %s" % (str(e)))
            return None

    def get_event_queue_details(self):
        try:
            nsip = self.get_own_publicip()
        except Exception as e:
            self.log.info("Public IP for interface 0 is not assigned.")
            return None
        queue_name = "AZURE_RAINMAN_EVENT_QUEUE"
        # renew webhook_key for each webhook json
        azure_config.renew_webhook_key()
        webhook = 'https://{}:{}?tokenid={}'.format(
            nsip, self.port, azure_config.webhook_key)
        return json.dumps({'queuename': queue_name, 'webhookname': 'CitrixAutoScaleWebhook', 'webhookuri': webhook})

    @azure_managed_connection
    def add_event_queue(self, queue_conf):
        return

    @azure_managed_connection
    def start_webhook(self, queue_conf):
        queue_name = json.loads(queue_conf.details)['queuename']
        webhook = json.loads(queue_conf.details)['webhookname']

        # Starting Webhook listenner
        if self.web_hook_listenner is None:
            self.web_hook_listenner = WebHookListener(
                self.get_own_privateip(), self.port)
        return

    @azure_managed_connection
    def remove_event_queue(self, queue_conf):
        return

    @azure_managed_connection
    def stop_webhook(self, queue_conf):
        try:
            self.web_hook_listenner.stop()
            self.web_hook_listenner = None
        except:
            pass

    @azure_managed_connection
    def get_events_from_queue(self, queue_conf):
        messages = []
        while True:
            try:
                message = rain_event_queue.get(False)
            except queue.Empty:
                self.log.debug("Event queue is empty")
                break
            else:
                messages.append(message)
        # rain_event_queue.task_done()
        events = []
        if len(messages):
            for message in messages:
                try:
                    event = self.message_to_event(message)
                    if event != None:
                        events.append(event)
                except:
                    self.log.debug("Not able to get event")
                    continue
        return events

    def do_timer_events(self, queue_conf):
        self.log.debug("Timer event starting")

        try:
            self.__renew_webhook(queue_conf)
        except Exception as e:
            self.log.error(
                "Timer event: failed to renew webhook: %s" % str(e))

        self.log.debug("Timer event finished")

    def do_rain_tags_timer_events(self):
        self.log.debug("Timer event starting")
        try:
            time.sleep(1)
        except Exception as e:
            self.log.error("Timer event: error: %s" % str(e))

        self.log.debug("Timer event finished")

    def __renew_webhook(self, queue_conf):
        queue_conf.details = self.get_event_queue_details()
        if queue_conf.details:
            config = rainman_config()
            queue_conf.update(config, queue_conf)
        try:
            conf_groups = rain.group().get(config, None)
            for groups in conf_groups:
                self.configure_events_for_group(queue_conf, groups)
        except (config_not_present, config_file_error):
            self.log.debug("No group configured")

    @azure_managed_connection
    def configure_events_for_groups(self, queue_conf, group_names):
        added_configured_groups = []
        removed_configured_groups = []
        config = rainman_config()

        for group_name in group_names:
            if group_name in self.cur_configured_groups:
                pass
            else:
                added_configured_groups.append(group_name)
                self.configure_events_for_group(queue_conf, group().get(config, group_name))

        for group_name in self.cur_configured_groups:
            if group_name in group_names:
                pass
            else:
                removed_configured_groups.append(group_name)
                self.remove_notification_config_from_group(
                    queue_conf, group_name)

        for group_name in removed_configured_groups:
            self.cur_configured_groups.remove(group_name)

        for group_name in added_configured_groups:
            self.cur_configured_groups.append(group_name)

    def remove_notification_config_from_group(self, queue_conf, group_name):
        try:
            self.log.debug(
                "Removing notification configuration to asg: %s" % group_name)
            webhookuri = json.loads(queue_conf.details)['webhookuri']
            webhookname = json.loads(queue_conf.details)['webhookname']

            asgs = self._get_autoscaling_group(group=group_name)
            # Assuming only one group is returned
            asg = asgs[0]

            down_policy = self.get_scale_down_policy(asg)
            if down_policy is None:
                self.log.info(
                    "Did not find a scale down policy for %s" % asg.name)
                raise config_failed(
                    "Did not find a scale down policy for %s" % asg.name)

            if asg.notifications is not None and len(asg.notifications) > 0:
                found = False
                for ntfy in asg.notifications:
                    if ntfy.operation not in ("Scale", "scale"):
                        continue
                    if ntfy.webhooks is not None:
                        for webh in ntfy.webhooks:
                            foundkey = False
                            for key in webh.properties:
                                if key == 'webhookname':
                                    foundkey = True
                            if webh.properties and foundkey == True and webh.properties['webhookname'] == webhookname:
                                self.log.debug("Found Webhook configured")
                                found = True
                                break
                        if found is True:
                            ntfy.webhooks.remove(webh)
                            break
                if found is True:
                    self._set_autoscaling_group(group_name, asg)
        except AuthenticationError as e:
            self.log.critical('Authentication failed: %s' % (str(e)))
            self.authenticated = False
            raise config_failed("Authentication failed: %s" % (str(e)))
        except Exception as e:
            self.log.debug(
                "remove_notification_config_from_group failed : %s" % (str(e)))
            raise config_failed(
                "remove_notification_config_from_group failed : %s" % (str(e)))

    def cleanup_on_group_remove(self, group):
        pass

    def configure_events_for_group(self, queue_conf, group):
        try:

            config = rainman_config()

            asg_info = group
            self.log.debug(
                "Putting notification configuration to asg: %s" % group.name)
            webhookuri = json.loads(queue_conf.details)['webhookuri']
            webhookname = json.loads(queue_conf.details)['webhookname']
            self.log.debug("Fetching Webhook %s"  % (webhookuri))

            asgs = self._get_autoscaling_group(group=group.name)
            # Assuming only one group is returned
            asg = asgs[0]

            down_policy = self.get_scale_down_policy(asg)
            if down_policy is None:
                self.log.info(
                    "Did not find a scale down policy for %s" % asg.name)
                raise config_failed(
                    "Did not find a scaling  policy for %s" % asg.name)

            prop = {'webhookname': webhookname}
            webhook_notify = WebhookNotification(service_uri=webhookuri, properties=prop)

            if asg.notifications is not None and len(asg.notifications) > 0:
                found = False
                update = True
                for ntfy in asg.notifications:
                    if ntfy.operation not in ("Scale", "scale"):
                        continue
                    if ntfy.webhooks is not None:
                        for webh in ntfy.webhooks:
                            foundkey = False
                            for key in webh.properties:
                                if key == 'webhookname':
                                    foundkey = True
                            if webh.properties and foundkey == True and webh.properties['webhookname'] == webhookname:
                                self.log.debug("Webhook already configured")
                                if webh.service_uri != webhookuri:
                                    webh.service_uri = webhookuri
                                else:
                                    update = False
                                found = True
                                break
                        if found is True:
                            break

                if found is False:
                    # asg.notifications.remove(ntfy)
                    ntfy.webhooks.insert(0, webhook_notify)
                if update is True:
                    try:
                        self._set_autoscaling_group(group.name, asg)
                    except Exception as e:
                        self.log.info("Webhook not configured : %s" % (str(e)))
            else:
                webh = [webhook_notify]
                asn = AutoscaleNotification(webhooks=webh)
                asg.notifications = [asn]
                try:
                    self._set_autoscaling_group(group.name, asg)
                except Exception as e:
                    self.log.info("Webhook not configured : %s" % (str(e)))

            asg_info.drain_count = int(down_policy.scale_action.value)
            asg_info.update(config, asg_info)
        except AuthenticationError as e:
            self.log.critical('Authentication failed: %s' % (str(e)))
            self.authenticated = False
            raise config_failed("Authentication failed: %s" % (str(e)))
        except Exception as e:
            self.log.debug("configure_events_for_group failed : %s" % (str(e)))
            raise config_failed(
                "configure_events_for_group failed : %s" % (str(e)))

    def get_scale_down_policy(self, asg):
        for autoprof in asg.profiles:
            for sr in autoprof.rules:
                if sr.scale_action.direction == 'Decrease' or sr.scale_action.direction == ScaleDirection.decrease:
                    return sr
        return None

    def get_scale_up_policy(self, asg):
        for autoprof in asg.profiles:
            for sr in autoprof.rules:
                if sr.scale_action.direction == 'Increase' or sr.scale_action.direction == ScaleDirection.increase:
                    return sr
        return None

    def message_to_event(self, message):
        try:
            body = json_loads_to_ascii(message)
        except Exception:
            azure_config.invalid_webhook_message += 1
            return None

        if (not 'operation' in body) or (
                not 'context' in body) or (
                    not 'name' in body['context']) or (
                        not 'subscriptionId' in body['context']) or (
                            self.subscription_id != body['context']['subscriptionId']):
            azure_config.invalid_webhook_message += 1
            return None

        self.log.info("Received message: %s" % (body))
        asg_name = body['context']['name']
        comparison = body['operation']
        config = rainman_config()
        asg_info = group()
        try:
            asg_info = asg_info.get(config, asg_name)
        except Exception as e:
            self.log.info("%s groups is not configured" % (asg_info))
            return None

        if 'Scale In' in comparison:
            self.log.info("Received a alarm for scale down. Processing...")
            this_event = event('ALARM', asg_name, None, 'drain')
        else:
            self.log.info("Received a alarm for scale up. Processing...")
            this_event = event('LAUNCH', asg_name, None, 'sync')

        return this_event

    def _get_lb_public_ip(self, mac):
        lb_id = None
        result = []
        found_nic = None
        rg_name = self.get_resource_group_name()
        vm_name = self.get_own_vmname()
        vm = self.computeMgmt.virtual_machines.get(rg_name, vm_name)
        for intf in vm.network_profile.network_interfaces:
            nic_id = intf.id
            nic_name = nic_id.split("/")[-1]
            nic = self.networkMgmt.network_interfaces.get(rg_name, nic_name)
            nicmac = nic.mac_address
            nicmac = nicmac.replace('-', '')
            if nicmac.lower() == mac.lower():
                found_nic = nic
                break
        if found_nic is None:
            return result
        for ip in found_nic.ip_configurations:
            if ip.load_balancer_backend_address_pools is not None:
                self.log.info('Found Load balancer configuration')
                for lb_bp in ip.load_balancer_backend_address_pools:
                    id_list = lb_bp.id.split('/')
                    lb_bp_name = id_list[-1]
                    lb_name = id_list[-3]
                    lb = self.networkMgmt.load_balancers.get(rg_name, lb_name)
                    if lb.frontend_ip_configurations is None:
                        continue
                    this_rule = False
                    for ip_conf in lb.frontend_ip_configurations:
                        if ip_conf.load_balancing_rules is None:
                            continue
                        for lb_r in ip_conf.load_balancing_rules:
                            lb_rule_name = lb_r.id.split('/')[-1]
                            this_rule = False
                            for lb_rl in lb.load_balancing_rules:
                                if lb_rule_name != lb_rl.name:
                                    continue
                                be_pool = lb_rl.backend_address_pool.id.split(
                                    '/')[-1]
                                if be_pool == lb_bp_name:
                                    this_rule = True
                                    break
                            if this_rule is True:
                                break
                        if this_rule is False:
                            continue
                        lb_public_ip = ip_conf.public_ip_address.id
                        lb_public_ip_name = lb_public_ip.split('/')[-1]
                        pub_ip_address = self.networkMgmt.public_ip_addresses.get(
                            rg_name, lb_public_ip_name)
                        if pub_ip_address.ip_address not in result:
                            result.append(pub_ip_address.ip_address)
                        this_rule = False
        return result

    def _get_vmscaleset_name(self, asg=None):
        uri = asg.target_resource_uri
        if uri is None:
            return None, None
        array = uri.split('/')
        vmssname = array[-1]
        rg_name = array[4]
        return vmssname, rg_name

    def _get_vmscaleset(self, rg_name, asg):
        vmssname = self._get_vmscaleset_name(asg)
        return self.computeMgmt.virtual_machine_scale_sets.get(rg_name, vmssname)

    def _get_autoscaling_group(self, group=None,rg=None):
        ass = []
        if group is None:
            return self._get_autoscaling_groups()
        try:
            if azurestack_platform:
                if rg:
                    as_obj = self.computeMgmt.virtual_machine_scale_sets.get(rg, group)
                else:
                    as_obj = self._get_vmss_by_name(group)
            else:
                if rg:
                    as_obj = self.monitorMgmt.autoscale_settings.get(rg,group)
                else:
                    as_obj = self._get_autoscalesetting_by_name(group)
            ass.append(as_obj)
        except Exception as e:
            raise
        return ass

    def _set_autoscaling_group(self, group=None, asr=None):
        ass = []
        if group is None or asr is None:
            return None
        rgName = self._get_resource_group_from_id(asr.id)
        as_obj = self.monitorMgmt.autoscale_settings.create_or_update(rgName, group, asr)
        return 'Success'

    def _get_autoscaling_groups(self):
        ass = []
        if azurestack_platform:
            ass, vmssids = self._get_connected_vmss()
        else:
            asss = self._get_asettings_in_subscription()
            vmss, vmssids = self._get_connected_vmss()
            for a in asss:
                if a.target_resource_uri.lower() in vmssids:
                    ass.append(a)
        return ass

    def _get_vms_in_vmss(self, rgname, vmssname):
        return self.computeMgmt.virtual_machine_scale_set_vms.list(rgname, vmssname)

    def _get_vm_in_vmss(self, rgname, vmssname, id):
        return self.computeMgmt.virtual_machine_scale_set_vms.get(rgname, vmssname, id)

    def _get_vm_nw_intfs_in_vmss(self, rgname, vmssname, id):
        return self.networkMgmt.network_interfaces.list_virtual_machine_scale_set_vm_network_interfaces(rgname, vmssname, id)

    def _delete_vm_in_vmss(self, rgname, vmssname, id):
        self.computeMgmt.virtual_machine_scale_set_vms.delete(
            rgname, vmssname, id)
        return 'Success'

    def get_instance_metadata(self):
        jsonObj = {}
        try:
            header = {'Metadata': 'True'}
            response = requests.get(
                self.meta_data_url, headers=header, timeout=10)
            if response and response.status_code == 200:
                jsonObj = response.json()
                return jsonObj
        except Exception as e:
            self.log.warning(
                "Unable to retrieve instance metadata (%s)" % str(e))

        # We could not get metadata. If running on azurestack then use ns_restapi.
        if azurestack_platform:
            nsrest = ns_restapi.get_nsrestapi_client()
            nsrest.run()
            jsonObj['compute'] = nsrest.get_compute_dict()
            jsonObj['network'] = {}
            jsonObj['network']['interface'] = nsrest.get_network()
        return jsonObj

    def get_subscription(self):
        return self.get_instance_metadata()['compute']['subscriptionId']

    def get_resource_group_name(self):
        return self.get_instance_metadata()['compute']['resourceGroupName']

    def get_own_vmname(self):
        return self.get_instance_metadata()['compute']['name']

    def get_nw_intf_ip(self, id):
        return self.get_instance_metadata()['network']['interface'][id]['ipv4']['ipAddress'][0]['privateIpAddress']

    def get_nic_ip(self, nic):
        name = " ".join(nic.id.split('/')[-1:])
        sub = "".join(nic.id.split('/')[4])
        return self.networkMgmt.network_interfaces.get(sub, name).ip_configurations

    def get_vm_nics(self, vmId):
        l1 = vmId.split("/")
        vm_name = l1[8]
        rg_name = self.get_resource_group_name()
        vm = self.computeMgmt.virtual_machines.get(rg_name, vm_name)
        return ([vm_name, vm.network_profile.network_interfaces])

    def _get_resource_group_from_id(self, resource_id):
        return resource_id.lstrip("/").split("/")[3]

    def _get_own_vnet(self):
        rg_name = self.get_resource_group_name()
        vm_name = self.get_own_vmname()
        vm = self.computeMgmt.virtual_machines.get(rg_name, vm_name)
        nic_name = vm.network_profile.network_interfaces[0].id.split('/')[-1]
        nic_group = vm.network_profile.network_interfaces[0].id.split('/')[-5]
        nic = self.networkMgmt.network_interfaces.get(nic_group, nic_name)
        vnet_name = nic.ip_configurations[0].subnet.id.split('/')[-3]
        vnet_group = nic.ip_configurations[0].subnet.id.split('/')[-7]
        vnet = self.networkMgmt.virtual_networks.get(vnet_group, vnet_name)
        return vnet

    def _get_peer_vnets(self, vnet):
        peers = []
        vnet_rg = self._get_resource_group_from_id(vnet.id)
        peer = self.networkMgmt.virtual_network_peerings.list(vnet_rg,vnet.name)
        for p in peer:
            peers.append(p.remote_virtual_network.id)
        return peers

    def _get_asettings_in_subscription(self):
        return self.monitorMgmt.autoscale_settings.list_by_subscription()

    def _get_autoscalesetting_by_name(self, group):
        aslist = []
        asglist = self._get_autoscaling_groups()
        for a in asglist:
            if a.name == group:
                aslist.append(a)
        if len(aslist) > 1:
            rg_name = self.get_resource_group_name()
            for ass in aslist:
                if (self._get_resource_group_from_id(ass.id)).lower() == rg_name.lower():
                    return ass
        if aslist:
            return aslist[0]
        return aslist

    def _get_vmss_by_name(self, group):
        vms = []
        vmsslist, vmssids = self._get_connected_vmss()
        for vm in vmsslist:
            if vm.name == group:
                vms.append(vm)
        if len(vms) > 1:
            rg_name = self.get_resource_group_name()
            for vmss in vms:
                if self._get_resource_group_from_id(vmss.id).lower() == rg_name.lower():
                    return vmss
        if vms:
            return vms[0]
        return vms

    def _get_connected_vmss(self):
        vmss = []
        vmssids = []
        vnet = self._get_own_vnet()
        peers = self._get_peer_vnets(vnet)
        vnetlist = [vnet.id] + peers
        vm_list = self.computeMgmt.virtual_machine_scale_sets.list_all()
        for vm in vm_list:
            array = vm.id.split("/")
            resource_group = array[4]
            vm_name = array[-1]
            vm = self.computeMgmt.virtual_machine_scale_sets.get(resource_group, vm_name)
            vnet_name = vm.virtual_machine_profile.network_profile.network_interface_configurations[0].ip_configurations[0].subnet.id.split('/')[-3]
            vnet_group = vm.virtual_machine_profile.network_profile.network_interface_configurations[0].ip_configurations[0].subnet.id.split('/')[-7]
            vnetofvm = self.networkMgmt.virtual_networks.get(vnet_group, vnet_name)
            if vnetofvm.id in vnetlist:
                vmss.append(vm)
                vmssids.append(vm.id.lower())
        return vmss, vmssids

    def get_client_facing_interface(self, mac):
        index = 0
        intfs = self.get_instance_metadata()['network']['interface']
        for intf in intfs:
            imd_mac = intf['macAddress'] if not azurestack_platform else intf['properties']['macAddress']
            if imd_mac.lower() == mac.lower():
                return index
            index += 1
        return -1

    def get_nw_intf_ips(self, mac):
        result = []
        intf_idx = self.get_client_facing_interface(mac)
        if intf_idx == -1:
            self.log.info("Not able to find Client facing interface ")
            return result
        ips = self.get_instance_metadata(
        )['network']['interface'][intf_idx]['ipv4']['ipAddress']
        for ip in ips:
            try:
                ip['privateIpAddress']
            except:
                self.log.info("get_nw_intf_ips : private IP Address not found")
                break
            else:
                result.append(ip['privateIpAddress'])
        return result

    def get_own_privateip(self):
        return self.get_nw_intf_ip(0)

    def get_own_publicip(self):
        return self.get_instance_metadata()['network']['interface'][0]['ipv4']['ipAddress'][0]['publicIpAddress']

    def get_own_region(self):
        return self.get_instance_metadata()['compute']['location']

    def get_instance_count(self):
        interfaces = self.get_instance_metadata()['network']['interface']
        return len(interfaces)

    def get_cloud_platform(self):
        if azurestack_platform:
            return "AZURESTACK"
        else:
            return "AZURE"

    def validate_intf_count(self, intf_count):
        return (intf_count < 3)

    def check_event_queue(self):
        if self.get_cloud_platform() == "AZURESTACK":
            return False
        else:
            return True

    def get_ftu_filename(self):
        return '/flash/nsconfig/.AZURE/ftumode'

    def get_daemon_pid_file(self):
        return '/flash/nsconfig/.AZURE/rain_scale.pid'

    def get_rain_tags_daemon_pid_file(self):
        return '/flash/nsconfig/.AZURE/rain_tags.pid'

    def get_rain_stats_daemon_pid_file(self):
        return '/flash/nsconfig/.AZURE/rain_stats.pid'

    def get_azure_reporting(self):
        return azure_reporting(self)

    def check_privileges(self, feature):
        pass
