import hashlib
import base64
import os
import time
import subprocess
import shlex
import binascii
import struct
import uuid
import requests
from Crypto.Signature import pkcs1_15
from Crypto.Hash import SHA256
from Crypto.PublicKey import RSA
import json
import argparse
import platform
import re
import sys
import logging
import logging.handlers

SECURE_ACCESS_PREFIX = "_spa"
SECURE_ACCESS_FILE_PREFIX = "ns-spa"
PRIVATE_KEY_PATH = f"/nsconfig/ssl/{SECURE_ACCESS_FILE_PREFIX}-gw.key"
PUBLIC_KEY_PATH = f"/nsconfig/ssl/{SECURE_ACCESS_FILE_PREFIX}-gw-public.key"
SPA_CONF_PATH = "/var/spa/.conf/_spa.conf"
SPA_SN_SAML_IDP_CERT_PATH = f"/nsconfig/ssl/{SECURE_ACCESS_FILE_PREFIX}-sn-saml-idp.cert"
SPA_GW_CERT_PATH = f"/nsconfig/ssl/{SECURE_ACCESS_FILE_PREFIX}-gw.cert"
SERVICE_NAME = "netappliance"
VPX_HA_STATE_PREFIX = 'Node:'
VPX_HA_STATE_STANDALONE = 'Standalone'
VPX_HA_STATE_PRIMARY = 'Primary'
VPX_HA_STATE_SECONDARY = 'Secondary'
LOG_FILE = '/var/log/spareg.log'
SECURE_ACCESS_CACHE_HEADER = "X-Citrix-SecureAccess-Cache"
AUTH_VS_NAME = f"{SECURE_ACCESS_PREFIX}_auth_vs"
VPN_VS_NAME = f"{SECURE_ACCESS_PREFIX}_vpn_vs"
CACHE_MEM_LIMIT = 500
CACHE_REL_EXPIRY = 300
FREEBSD_CA_CERTS_PATH = "/usr/local/share/certs/ca-root-nss.crt"
DIGICERT_ROOT_CA_PATH = f"/nsconfig/ssl/{SECURE_ACCESS_FILE_PREFIX}-digicert_global_root_ca.crt"
DIGICERT_INTERMEDIATE_CERT_NAME = "DigiCertTLSRSASHA2562020CA1.crt"
DIGICERT_INTERMEDIATE_CERT_PATH = f"/nsconfig/ssl/{SECURE_ACCESS_FILE_PREFIX}-{DIGICERT_INTERMEDIATE_CERT_NAME}"
DIGICERT_INTERMEDIATE_CERT_DOWNLOAD_URL = f"https://cacerts.digicert.com/{DIGICERT_INTERMEDIATE_CERT_NAME}"
DIGICERT_ROOT_CA_CERT_KEY = f"{SECURE_ACCESS_PREFIX}_digicert_root_ca_certkey"
DIGICERT_INTERMEDIATE_CERT_KEY = f"{SECURE_ACCESS_PREFIX}_digicert_intermediate_ca_certkey"

COLORS = {
    'RED': '\033[91m',
    'GREEN': '\033[92m',
    'YELLOW': '\033[93m',
    'DEFAULT': '\033[0m',
    'BOLD': '\033[1m',
    'UNDERLINE': '\033[4m',
}

TRUST_API_BASE_URLS = {
    "production": "https://trust.citrixnetworkapi.net",
    "staging": "https://trust.citrixnetworkapistaging.net",
    "dev": "https://trust.citrixnetworkapi.net"
}
WORKSPACES_API_BASE_URLS = {
    "production": "https://trust.citrixworkspacesapi.net",
    "staging": "https://trust.ctxwsstgapi.net",
    "dev": "https://trust.citrixworkspacesapi.net"
}
SECURE_PRIVATE_ACCESS_URLS = {
    "production": "https://policy.netscalergateway.net",
    "staging": "https://policy.netscalergatewaystaging.net",
    "dev": "https://policy.netscalergateway.net"
}

logging.basicConfig(filename=LOG_FILE, level=logging.INFO, format='%(asctime)s - %(message)s')
file_handler = logging.FileHandler(LOG_FILE, mode='w')
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
logging.getLogger().addHandler(file_handler)

def log_message(message):
    logging.info(message)
    print(message)

def log_error(message):
    logging.error(message)
    print(message)

def run_command(command):
    try:
        logging.info(f"Command: {command}")
        args = shlex.split(command)
        with open(os.devnull, 'w') as devnull:
            result = subprocess.run(args, stdout=devnull, stderr=subprocess.PIPE)
        stderr = result.stderr.decode('utf-8').strip()
        if result.returncode != 0:
            logging.error(f"Error: {stderr}")
    except Exception as e:
        logging.error(f"Error running command. Exception {e}.\n")
        raise

def read_command_output(cmd):
    try:
        logging.info(f"Command: {cmd}")
        if isinstance(cmd, bytes):
            cmd = cmd.decode('ascii')
        args = shlex.split(cmd)
        with open(os.devnull, 'w') as FNULL:
            process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=FNULL)
            output = process.communicate()
    except Exception as e:
        logging.error(f"Error reading command output. Exception {e}.\n")
        raise
    output = output[0].decode('utf-8')
    logging.info("Command: %s \n", cmd)
    return output

def cli_exec(clicmd):
    try:
        cli_output = ''
        clicmd = shlex.quote(clicmd)
        logging.info("CLI_COMMAND: %s", clicmd)
        cmd = '/netscaler/nscli -U %%%%:.:. ' + clicmd
        if isinstance(cmd, str):
            cmd = shlex.split(cmd)
        cli_output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
        cli_output = cli_output.decode("utf-8").strip()
        if "Done" in cli_output:
            cli_output = cli_output.split("Done",1)
            cli_output = "".join(cli_output).strip()
            logging.info("Output: %s", cli_output)
            return cli_output
    except subprocess.CalledProcessError as e:
        error_output = e.output.decode("utf-8").strip()
        if "Done" in error_output:
            error_output = error_output.split("Done",1)
            error_output = clicmd.join(error_output).strip()
        logging.error(f"{error_output}")
        raise Exception(f"{error_output}")
    except Exception as e:
        logging.error(f"Error in executing command. Exception {e}.\n output {cli_output}")
        return ''

def cli_exec_ignoreerror(clicmd): 
    try:
        cli_output = ''
        clicmd = shlex.quote(clicmd)
        cmd = '/netscaler/nscli -U %%%%:.:. ' + clicmd
        logging.info("CLI_COMMAND: %s", clicmd)
        cli_output = subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT)
        cli_output = cli_output.decode("utf-8").strip()
        if "Done" in cli_output:
            cli_output = cli_output.split("Done",1)
            cli_output = "".join(cli_output).strip()
            logging.info("Output: %s", cli_output)
            return cli_output
    except subprocess.CalledProcessError as e:
        error_output = e.output.decode("utf-8").strip()
        logging.error(f"{error_output}")
        return ''
    except Exception as e:
        logging.error(f"Error in executing command. Exception {e}.\n output {cli_output}")
        return ''

def execute_cli_commands(metadata):
    vpn_vs_ip = input("Enter Gateway vServer IP: ")
    vpn_vs_cert = input("Enter Gateway certKeyName: ")
    resource_location_id = metadata.get("resourceLocationId")
    instance_id = metadata.get("instanceId")
    customer_id = metadata.get("customerId")
    vpn_fqdn = metadata.get("gatewayFqdn")
    deployment_type = metadata.get("deployment").lower()
    spa_url = SECURE_PRIVATE_ACCESS_URLS.get(deployment_type)

    log_message("Configuring the Gateway...")
    #Configure cache policies
    cli_exec_ignoreerror("enable feature SSL SSLVPN AAA REWRITE IC")
    cli_exec_ignoreerror(f"unbind cache policylabel _reqBuiltinDefaults -policyName {SECURE_ACCESS_PREFIX}_cache_policy")
    cli_exec_ignoreerror(f"rm cache policy {SECURE_ACCESS_PREFIX}_cache_policy")
    cli_exec_ignoreerror(f"rm cache contentGroup {SECURE_ACCESS_PREFIX}_CG")
    cli_exec_ignoreerror(f"rm cache selector {SECURE_ACCESS_PREFIX}_selector")

    cli_exec_ignoreerror(f"set cache parameter -memLimit {CACHE_MEM_LIMIT}")
    cli_exec(f"add cache selector {SECURE_ACCESS_PREFIX}_selector \"HTTP.REQ.HEADER(\\\"{SECURE_ACCESS_CACHE_HEADER}\\\")\"")
    cli_exec(f"add cache contentGroup {SECURE_ACCESS_PREFIX}_CG -relExpiry {CACHE_REL_EXPIRY} -hitSelector {SECURE_ACCESS_PREFIX}_selector")
    cli_exec(f"add cache policy {SECURE_ACCESS_PREFIX}_cache_policy -rule \"HTTP.REQ.HEADER(\\\"{SECURE_ACCESS_CACHE_HEADER}\\\").EXISTS\" -action CACHE -storeInGroup {SECURE_ACCESS_PREFIX}_CG")
    cli_exec(f"bind cache policylabel _reqBuiltinDefaults -policyName {SECURE_ACCESS_PREFIX}_cache_policy -priority 10 -gotoPriorityExpression END")

    #Configure authentication vserver for SAML
    cli_exec_ignoreerror(f"unbind authentication vserver {AUTH_VS_NAME} -policy {SECURE_ACCESS_PREFIX}_saml_sp_pol -type REQUEST")
    cli_exec_ignoreerror(f"rm authentication Policy {SECURE_ACCESS_PREFIX}_saml_sp_pol")
    cli_exec_ignoreerror(f"rm authentication samlAction {SECURE_ACCESS_PREFIX}_sp_action")
    cli_exec_ignoreerror(f"rm authentication vserver {AUTH_VS_NAME}")
    cli_exec_ignoreerror(f"rm ssl certKey {SECURE_ACCESS_PREFIX}_gw_kp")
    cli_exec_ignoreerror(f"rm ssl certKey {SECURE_ACCESS_PREFIX}_saml_idp_cert")

    cli_exec(f"add authentication vserver {AUTH_VS_NAME} SSL 0.0.0.0")
    cli_exec(f"add ssl certKey {SECURE_ACCESS_PREFIX}_gw_kp -cert {SPA_GW_CERT_PATH} -key {PRIVATE_KEY_PATH} -inform PEM")
    cli_exec(f"add ssl certKey {SECURE_ACCESS_PREFIX}_saml_idp_cert -cert {SPA_SN_SAML_IDP_CERT_PATH}")
    cli_exec(f"add authentication samlAction {SECURE_ACCESS_PREFIX}_sp_action -samlIdPCertName {SECURE_ACCESS_PREFIX}_saml_idp_cert -samlRedirectUrl https://{vpn_fqdn} -Attribute1 encrypted_dsauthtoken -Attribute2 encrypted_idtoken")
    cli_exec(f"add authentication Policy {SECURE_ACCESS_PREFIX}_saml_sp_pol -rule true -action {SECURE_ACCESS_PREFIX}_sp_action")
    cli_exec(f"bind authentication vserver {AUTH_VS_NAME} -policy {SECURE_ACCESS_PREFIX}_saml_sp_pol -priority 100 -gotoPriorityExpression END -type REQUEST")
    cli_exec(f"bind ssl vserver {AUTH_VS_NAME} -certkeyName {vpn_vs_cert}")

    #Configure VPN vserver
    cli_exec_ignoreerror(f"unbind vpn vserver {VPN_VS_NAME} -policy {SECURE_ACCESS_PREFIX}_session_pol")
    cli_exec_ignoreerror(f"rm vpn sessionPolicy {SECURE_ACCESS_PREFIX}_session_pol")
    cli_exec_ignoreerror(f"rm vpn sessionAction {SECURE_ACCESS_PREFIX}_session_act")
    cli_exec_ignoreerror(f"rm vpn vserver {VPN_VS_NAME}")
    cli_exec_ignoreerror(f"rm authentication authnProfile {SECURE_ACCESS_PREFIX}_gw_authn_profile")

    cli_exec(f"add vpn vserver {VPN_VS_NAME} SSL {vpn_vs_ip} 443")
    cli_exec(f"bind ssl vserver {VPN_VS_NAME} -certkeyName {vpn_vs_cert}")
    cli_exec(f"add authentication authnProfile {SECURE_ACCESS_PREFIX}_gw_authn_profile -authnVsName {AUTH_VS_NAME}")
    cli_exec(f"set vpn vserver {VPN_VS_NAME} -authnProfile {SECURE_ACCESS_PREFIX}_gw_authn_profile -vserverFqdn {vpn_fqdn} -Listenpolicy NONE -tcpProfileName nstcp_default_XA_XD_profile -icaOnly OFF -dtls OFF -securePrivateAccess ENABLED")
    cli_exec(f"add vpn sessionAction {SECURE_ACCESS_PREFIX}_session_act -transparentInterception ON -clientlessVpnMode OFF -icaProxy OFF -clientChoices OFF -defaultAuthorizationAction DENY -useMIP NS -useIIP OFF -sessTimeout 60 -secureBrowse DISABLED")
    cli_exec(f"add vpn sessionPolicy {SECURE_ACCESS_PREFIX}_session_pol true {SECURE_ACCESS_PREFIX}_session_act")
    cli_exec(f"bind vpn vserver {VPN_VS_NAME} -policy {SECURE_ACCESS_PREFIX}_session_pol -priority 100 -gotoPriorityExpression END")
    cli_exec(f"bind vpn vserver {VPN_VS_NAME} -securePrivateAccessUrl {spa_url}")

    #Common configuration
    cli_exec_ignoreerror(f"rm policyexpression {SECURE_ACCESS_PREFIX}_instance_id")
    cli_exec_ignoreerror(f"rm policyexpression {SECURE_ACCESS_PREFIX}_resource_location_id")
    cli_exec_ignoreerror(f"rm policyexpression {SECURE_ACCESS_PREFIX}_customer_id")

    cli_exec(f"add policyexpression {SECURE_ACCESS_PREFIX}_instance_id \"\\\"{instance_id}\\\"\"")
    cli_exec(f"add policyexpression {SECURE_ACCESS_PREFIX}_resource_location_id \"\\\"{resource_location_id}\\\"\"")
    cli_exec(f"add policyexpression {SECURE_ACCESS_PREFIX}_customer_id \"\\\"{customer_id}\\\"\"")

    #Configure PSE CA Certs
    cli_exec_ignoreerror(f"unbind vpn global -cacert {DIGICERT_ROOT_CA_CERT_KEY}")
    cli_exec_ignoreerror(f"unbind vpn global -cacert {DIGICERT_INTERMEDIATE_CERT_KEY}")
    cli_exec_ignoreerror(f"rm ssl certKey {DIGICERT_ROOT_CA_CERT_KEY}")
    cli_exec_ignoreerror(f"rm ssl certKey {DIGICERT_INTERMEDIATE_CERT_KEY}")

    cli_exec(f"add ssl certKey {DIGICERT_ROOT_CA_CERT_KEY} -cert {DIGICERT_ROOT_CA_PATH}")
    cli_exec(f"add ssl certKey {DIGICERT_INTERMEDIATE_CERT_KEY} -cert {DIGICERT_INTERMEDIATE_CERT_PATH}")
    cli_exec(f"bind vpn global -cacert {DIGICERT_ROOT_CA_CERT_KEY}")
    cli_exec(f"bind vpn global -cacert {DIGICERT_INTERMEDIATE_CERT_KEY}")
    cli_exec("set vpn parameter -backendcertValidation ENABLED")
    
    #END of configuration
    cli_exec("save config")

def encode_to_base64(n):
    data = struct.pack('<Q', n).rstrip(b'\x00')
    if len(data) == 0:
        data = b'\x00'
    s = base64.urlsafe_b64encode(data).rstrip(b'=')
    return s.decode('utf-8')

def generate_nonce(instance_id,private_key_path):
    timestamp = str(int(time.time()))
    data = f"{instance_id}:{timestamp}"
    hash_obj = SHA256.new(data.encode('utf-8'))

    try:
        with open(private_key_path, 'r') as key_file:
            private_key_pem = key_file.read()
    except IOError as e:
        logging.error(f"Error reading private key file: {e}")
        raise

    private_key = RSA.import_key(private_key_pem)
    signature = pkcs1_15.new(private_key).sign(hash_obj)
    signature_base64 = base64.b64encode(signature).decode('utf-8')
    nonce = f"{instance_id}:{timestamp},{signature_base64}"
    return nonce

def create_public_key_xml():
    log_message("Creating keypair")
    run_command(f"rm -f {PRIVATE_KEY_PATH}")
    run_command(f"rm -f {PUBLIC_KEY_PATH}")
    run_command(f"openssl genrsa -out {PRIVATE_KEY_PATH} 2048")
    run_command(f"openssl rsa -in {PRIVATE_KEY_PATH} -outform PEM -pubout -out {PUBLIC_KEY_PATH}")

    modulus_out = read_command_output(f"openssl rsa -pubin -in {PUBLIC_KEY_PATH} -modulus -noout")
    detailed_output = read_command_output(f"openssl rsa -pubin -in {PUBLIC_KEY_PATH} -modulus -text")
    modulus = modulus_out[8:].strip()
    modulus = binascii.b2a_base64(binascii.unhexlify(modulus)).strip().decode('utf-8')

    exponent_out = detailed_output.split("Exponent: ")
    exponent = exponent_out[1].split('(')[0]
    exponent = encode_to_base64(int(exponent))

    public_key_xml = f"<RSAKeyValue><Modulus>{modulus}</Modulus><Exponent>{exponent}</Exponent></RSAKeyValue>"
    public_key_xml = base64.b64encode(public_key_xml.encode('utf-8')).decode('utf-8')
    return public_key_xml

def is_ns_primary():
    cmd_str = "show ns config | grep 'Node:' "
    node_ha_state = cli_exec(cmd_str)
    ha_state = node_ha_state.split(VPX_HA_STATE_PREFIX)[1]
    if ha_state.strip().startswith(VPX_HA_STATE_PRIMARY):
        logging.info("Node is primary")
        return True
    if ha_state.strip().startswith(VPX_HA_STATE_STANDALONE):
        logging.info("Node is standalone")
        return True
    logging.error(f"Node is {ha_state}")
    return False

def read_and_parse_metadata():
    try:
        parser = argparse.ArgumentParser(description='Register SPA Instance')
        parser.add_argument("base64_data", nargs='?')
        args = parser.parse_args()
        
        if args.base64_data:
            logging.info("Base64 data provided as argument.")
            base64_data = args.base64_data.strip()
            os.makedirs(os.path.dirname(SPA_CONF_PATH), exist_ok=True)
            with open(SPA_CONF_PATH, 'w') as conf_file:
                conf_file.write(base64_data)
        elif os.path.exists(SPA_CONF_PATH):
            logging.info("Base64 data read from file.")
            with open(SPA_CONF_PATH, 'r') as conf_file:
                base64_data = conf_file.read().strip()
        else:
            log_error("Error: No base64 data provided.")
            return None
        decoded_data = base64.b64decode(base64_data).decode('utf-8')
    except Exception as e:
        log_error(f"Error reading and parsing metadata: {e}")
        raise
    return json.loads(decoded_data)

def is_metadata_valid(metadata):
    required_fields = ["deployment", "resourceLocationId", "instanceId", "customerId", "gatewayFqdn"]
    for field in required_fields:
        if field not in metadata or metadata[field] is None:
            log_error(f"Error: {field} is missing or None in metadata.")
            return False

    if not is_fqdn_valid(metadata):
        log_error(f"Error: Gateway FQDN {metadata['gatewayFqdn']} is not in a valid FQDN format.")
        return False

    return True

def is_fqdn_valid(metadata):
    fqdn = metadata['gatewayFqdn']
    fqdn = re.sub(r':\d+$', '', fqdn)
    fqdn_pattern = r'^(?=.{1,255}$)(?!-)[A-Za-z0-9-]{1,63}(?<!-)(\.(?!-)[A-Za-z0-9-]{1,63}(?<!-))*$'
    fqdn_regex = re.compile(fqdn_pattern, re.IGNORECASE)
    if fqdn_regex.match(fqdn):
        metadata['gatewayFqdn'] = fqdn
    return fqdn_regex.match(fqdn)

def write_saml_cert(data):
    try:
        saml_signing_cert_base64 = data.get("samlSigningCert")
        logging.info("Writing SAML IDP certificate...")
        if not saml_signing_cert_base64.startswith("-----BEGIN CERTIFICATE-----"):
            saml_signing_cert_base64 = "-----BEGIN CERTIFICATE-----\n" + saml_signing_cert_base64 + "\n-----END CERTIFICATE-----"
        
        with open(SPA_SN_SAML_IDP_CERT_PATH, 'w') as cert_file:
            cert_file.write(saml_signing_cert_base64)
        
        gw_cert_cmd = (
            f"openssl req -new -x509 -key {PRIVATE_KEY_PATH} -outform PEM "
            f"-out {SPA_GW_CERT_PATH} -days 10000 -config <(cat <<'EOF'\n"
            f"[req]\n"
            f"distinguished_name=req_distinguished_name\n"
            f"prompt=no\n"
            f"[req_distinguished_name]\n"
            f"C=US\n"
            f"ST=California\n"
            f"L=San Francisco\n"
            f"O=Citrix Gateway\n"
            f"OU=SPA Gateway\n"
            f"CN=SPA Gateway default\n"
            f"EOF\n)"
        )
        logging.info("Generating Gateway certificate...")
        subprocess.run(f"bash -c \"{gw_cert_cmd}\"", shell=True, check=True)
    except Exception as e:
        log_error(f"Error writing certificates: {e}")
        raise

def show_reverse_countdown(total_duration, polling_interval):
    spinner = ['|', '/', '-', '\\']
    spinner_index = 0

    for remaining in range(total_duration, total_duration - polling_interval, -1):
        mins, secs = divmod(remaining, 60)
        timeformat = f'{mins:02d}:{secs:02d}'
        if mins < 2:
            time_color = COLORS['RED']
        else:
            time_color = COLORS['YELLOW']
        print(f"\rRegistration code expires in {time_color}{timeformat}{COLORS['DEFAULT']}{spinner[spinner_index]}", end="")
        sys.stdout.flush()
        time.sleep(1)
        spinner_index = (spinner_index + 1) % len(spinner)

def fetch_status_and_show_countdown(verification_url, total_duration, polling_interval):
    while total_duration > 0:
        response = requests.get(verification_url)
        if response.status_code == 200:
            log_message("\nGateway Registration successful.")
            return True

        show_reverse_countdown(total_duration, polling_interval)
        total_duration -= polling_interval

    log_error("Registration code expired. Re-initiate registration process.")
    return False

def register_key(metadata, krs_data):
    try:
        deployment_type = metadata.get("deployment").lower()    
        customer_id = metadata.get("customerId")
        instance_id = metadata.get("instanceId")

        krs_url = TRUST_API_BASE_URLS.get(deployment_type)
        log_message("Registering your device ...")
        logging.info(f"KRS Data : {krs_data}")
        response = requests.post(f"{krs_url}/root/trust/v1/token", json=krs_data)
        logging.info(f"Response : {response}")
        if response.status_code == 404:
            return False
        if response.status_code == 409:
            log_error("Instance already registered")
            return False
        if response.status_code != 200:
            return False

        response_data = response.json()
        code = response_data.get("code")
        expiry_timestamp = response_data.get("expiryTimestamp")
        polling_interval = response_data.get("pollingInterval")
        current_timestamp = response_data.get("timestamp")

        total_duration = expiry_timestamp - current_timestamp
        bold_green = COLORS['GREEN'] + COLORS['BOLD']
        print(f"{COLORS['YELLOW']}Registration Code: {bold_green}{code}{COLORS['DEFAULT']}")
        logging.info(f"Registration Code: {code}")
        print("Waiting for registration to complete...", end="")

        stats_url = WORKSPACES_API_BASE_URLS.get(deployment_type)
        status_url = f"{stats_url}/{customer_id}/publickeys/{SERVICE_NAME}/{instance_id}"
        return fetch_status_and_show_countdown(status_url, total_duration, polling_interval)
    except Exception as e:
        log_error(f"Error registering key: {e}")
        raise

def prepare_krs_data(data):
    instance_id = data.get("instanceId")
    gw_fqdn = data.get("gatewayFqdn")
    public_key_xml = create_public_key_xml()
    nonce = generate_nonce(instance_id,PRIVATE_KEY_PATH)
    if nonce is None:
        logging.error("Error generating nonce.")
        return None

    krs_data = {
        "service": SERVICE_NAME,
        "instanceID": instance_id,
        "publicKey": public_key_xml,
        "nonce": nonce,
        "deviceInfo": {
            "properties": [
                {
                    "property": "hostname",
                    "value": gw_fqdn,
                    "displayName": [{"name": "Hostname", "locale": "en-US"}]
                },
                {
                    "property": "product",
                    "value": "Gateway for SPA mixed mode",
                    "displayName": [{"name": "Product Name", "locale": "en-US"}]
                },
                {
                    "property": "instanceID",
                    "value": instance_id,
                    "displayName": [{"name": "Gateway ID", "locale": "en-US"}]
                }
            ]
        },
        "gwDeploymentType": "SPA"
    }

    return krs_data

def check_if_registered(metadata):
    try:
        deployment_type = metadata.get("deployment").lower()
        customer_id = metadata.get("customerId")
        instance_id = metadata.get("instanceId")
        stats_url = WORKSPACES_API_BASE_URLS.get(deployment_type)
        status_url = f"{stats_url}/{customer_id}/publickeys/{SERVICE_NAME}/{instance_id}"
        response = requests.get(status_url)
        if response.status_code == 200:
            if os.path.exists(PRIVATE_KEY_PATH):
                log_message("Gateway is already registered.")
            else:
                log_error("Gateway is already registered but private key is missing. Regenerate metadata and try again.")
                sys.exit(1)
            return True
        return False
    except Exception as e:
        log_error(f"Error checking registration status: {e}")
        raise

def main():
    try:
        is_primary = is_ns_primary()
        if not is_primary:
            log_error("Error: This script can only be run on the primary node.")
            return None

        metadata = read_and_parse_metadata()
        if not metadata or not is_metadata_valid(metadata):
            log_error("Error: Failed to read and parse metadata.")
            return
        is_registered = check_if_registered(metadata)
        if is_registered:
            response = input("Do you want to reconfigure Gateway? (y/n): ")
            if response.lower() != 'y':
                log_message("Exiting...")
                return
            else:
                log_message("Reconfiguring Gateway with existing keys...")
        else:
            krs_data = prepare_krs_data(metadata)
            if not krs_data:
                log_error("Error generating key registration data")
                return

            status = register_key(metadata, krs_data)
            if not status:
                log_error("Error: Key registration failed.")
                return
            
        write_saml_cert(metadata)
        prepare_pse_ca_certs()
        execute_cli_commands(metadata)
        log_message("Gateway successfully configured.")

    except KeyboardInterrupt:
        log_error("Exiting...")
    except Exception as e:
        log_error(f"Error registering Gateway\n {e}")

def prepare_pse_ca_certs():
    try:
        logging.info("Extracting DigiCert Global Root CA...")
        extract_digicert_root_ca_cmd = f"awk '/DigiCert Global Root CA/,/END CERTIFICATE/' {FREEBSD_CA_CERTS_PATH} > {DIGICERT_ROOT_CA_PATH}"
        subprocess.run(f"bash -c \"{extract_digicert_root_ca_cmd}\"", shell=True, check=True)
        logging.info(f"DigiCert Global Root CA extracted to {DIGICERT_ROOT_CA_PATH}")

        logging.info(f"Downloading DigiCert intermediate certificate: {DIGICERT_INTERMEDIATE_CERT_DOWNLOAD_URL}")
        download_cert_cmd = f"curl -o {DIGICERT_INTERMEDIATE_CERT_PATH} \"{DIGICERT_INTERMEDIATE_CERT_DOWNLOAD_URL}\""
        run_command(download_cert_cmd)
        logging.info(f"DigiCert Intermediate certificate downloaded to {DIGICERT_INTERMEDIATE_CERT_PATH}")

        logging.info("Verifying the downloaded certificate...")
        verify_cert_cmd = f"openssl verify -CAfile {DIGICERT_ROOT_CA_PATH} {DIGICERT_INTERMEDIATE_CERT_PATH}"
        output = read_command_output(verify_cert_cmd)
        if "OK" in output:
            logging.info("Intermediate certificate successfully verified.")
        else:
            log_error("Intermediate certificate verification failed.")
            raise ValueError("Certificate verification failed.")

    except Exception as e:
        log_error(f"Error configuring CA certificates: {e}")
        raise

if __name__ == "__main__":
    main()
