import argparse
import re
import json
import encryption as crypto
import logging 
from include import *


def get_arguments():
    '''
    The function fetches the arguments from CLI and return the dictionary of all the CLI arguments
    '''
    parser = argparse.ArgumentParser(prog = "OTP Secret encryption tool",description="This tool lets the admin convert the plaintext OTP secret to encrypted format and reveert back to the plaintext format as desired. It also has the functionality to upgrade the certificate that has been ustored in the AD", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-Host', action="store", dest="SERVER", required = True, help = "The IP or the hostname of the AD server where the OTPSecret is stored")
    parser.add_argument('-Port', action="store", dest="PORT", type = int, default = 636, help = "use port 389 for plaintext encryption and 636 for SSL", required = False)
    parser.add_argument('-username', action="store", dest = "USER", required = True, help = "The LDAP bind DN of the admin of the active directory "  )
    parser.add_argument('-filter_by_user_name', action="store", dest = "AD_user", required = False,  help = "The samAccountName of the user whose operations are to be performed.")
    parser.add_argument('-search_base', action = 'store', dest = 'search_base', required = True, help = "The base argument for the search in LDAP directory")
    parser.add_argument('-search_filter', action = "store", dest = 'search_filter', required = False, default = '(cn=*)',help = "The filter to be appplied to search the ldap directory, if any")
    parser.add_argument('-operation', action = 'store', dest = 'operation', choices=["0","1","2","encryption","decryption","cert_upgrade"],required = True, help = "The operation being requested 0 or encryption, 1 or decryption, 2 or cert_upgrade") 
    parser.add_argument('-source_attribute', action = 'store', dest='source_attribute', required = True, help = "The attribute where the OTPSecret is stored currently")
    parser.add_argument('-target_attribute', action = 'store', dest='target_attribute', required = False, help = "The attribute where the OTPSecret is to be stored")
    parser.add_argument('-cert_path', action = 'store', dest='cert_path', nargs = '*', required = True, help = "The absolute path of the certificate in x.506 standard to be used for public key validation")
    parser.add_argument('-new_cert_path', action = 'store', dest = 'new_cert_path',  required = False, help = "The absolute path of the new certificate in x.506 standard to be used for public key validation")

    
    args = parser.parse_args()
    return vars(args)

def connect_to_server( username, password, host, port ):
        '''
        connect to the server using stored credentials
        input - from the user provided credentials
        output - connection instance handle
        '''
        try:
            server = Server(host,port,True) 
            connection = Connection(server, 
                username,password,
                auto_bind=True)
        except exceptions.LDAPSocketOpenError:
            print('ERROR:Unable to reach the server, please check the server credentials and port number ') 
            sys.exit()   
        except exceptions.LDAPExceptionError as er:
            print("ERROR:Could not connect to the server due to the following error %s",er)
            sys.exit()
        except:
            print("ERROR:Could not connect to the server due to an error\n")
            sys.exit()
        print('INFO: Connection Successful..')
        return True



def verify_arguments(args_dict, operation, target_attribute, new_cert_path, source_attribute):
    '''
    utility function to verify arguments are in right format and the user has provided all the required attributes

    Checks to be done: 
    If operation is encryption -- Then the user must provide a target attribute. 
    If operation is decryption -- The target must be different from the source attribute.
    If operation is cert_upgrade -- The value for new_cert path must be provided.

    The other checks are done in CLI itself.
    '''
    #extract the variables from the argparser
    output = { 'result': True, 'value': "" }
    
    if(operation == '0' or operation == 'encryption'):
        if(target_attribute is None):
            output['result'] = False
            output['value'] = "Please provide a -target_attribute value for storing the decrypted data."
    elif(operation == '1' or operation == 'decryption'):
        if(target_attribute is None or target_attribute == source_attribute):
            output['result'] = False
            output['value'] = "Please provide a new attribute value for storing the decrypted data"
        
    elif(operation == '2' or operation == 'cert_upgrade'):
        if(new_cert_path is None):
            output['result'] = False
            output['value'] = "Please provide a valid new certificate path"
        
    else:
        output['result'] = False
        output['value'] = "Please provide the correct operation value"
        
    return output

def is_json(myjson):
    '''
    verify if the object being passed is a json
    '''
    try:
        json_object = json.loads(myjson)
    except(ValueError):
        return False
    return True

def OTPSecret_regex_verify(input):
    '''
    input  - string
    output - True : If the string conforms to the format of a plaintext OTP
             False : wrong format
    '''
    pattern = '^\[?\'?#@([a-zA-Z0-9\._()\-%!\*\'~`+]*=[a-zA-z0-9&=-]*&,\'?)*\]?'
    if(re.fullmatch(pattern, input.strip()) != None):
        return True
    return False

def traverse_entries(
        ldap,
        entries
        ):
    '''
    traverse the AD for all the retrieved users,
    access the attribute and encrypt the attribute.

    fetch each entry, encrypt them and store it in json.

    raise appropriate exception in case of any error.

    input - list of all entries and the ldap object

    '''
    ldap.crypt = crypto.crypto()
    entered = False
    ldap.number_of_entries = 0
    ldap.total_entries = 0
    ldap.number_of_entries_not_modified = 0
    ldap.not_modified =  [] 
    #modify the parameter
    try:
        for entry in entries:
            try:
                logging.info("\n\n")
                modified = False
                entered = True
                logging.info("%s",entry)
                if 'dn' not in entry:
                    continue
                DN = entry['dn']
                #logging.info("%s",entry['attributes'])
                #logging.info("target attr %s",  entry['attributes'][ldap.target_attribute])
                attribute = str(entry['attributes'][ldap.source_attribute])
                ldap.target_value = entry['attributes'][ldap.target_attribute]
                ldap.current_cert_path = ldap.cert_path
                #logging.info("source attr %s %s",attribute, type(attribute))
                if( OTPSecret_regex_verify(attribute)):
                    ldap.total_entries += 1
                    attribute_list = attribute.split('&,')
                    attribute_list[0] = attribute_list[0][2:]
                    if(ldap.encryption == True):
                        if(not(ldap.encrypt_and_update(attribute_list,DN))):
                            logging.error("encrypt and update failed for %s",DN)
                        else:
                            modified = True
                            ldap.number_of_entries += 1 
                    else:
                        # convert the data in json format and update
                        values_dict = {}
                        for device in attribute_list:
                            device_list = device.split('=', 1)
                            values_dict[device_list[0]] = device_list[1]
                        if(not(ldap.convert_to_json_and_store(values_dict, DN))):
                            logging.error("error in storing for %s", DN)
                        else:
                            ldap.number_of_entries += 1
                            modified = True
                
                else:
                    logging.error("error in the format of the attribute")
                
                if(not(modified)):
                    ldap.number_of_entries_not_modified += 1 
                    ldap.not_modified.append(DN)
            except Exception as er:
                logging.error("Error in traversing the entry due to following error %s ",type(er).__name__)
        if(not(entered)):
            logging.error("No Entries to be updated")
    except Exception as er:
        logging.error("Error in traversing the entry due to following error %s ",type(er).__name__) 

def store_plaintext_in_json(DN, attribute, ldap):
    '''
    Convert the attribute to plaintext json format and store the same in AD, no encryption done for the OTPSecret
    '''
    values_dict = {}
    if(attribute[:2] == "#@"):
        ldap.total_entries += 1
        attribute_list = attribute.split('&,')
        attribute_list[0] = attribute_list[0][2:]
        #logging.info("%s",attribute_list)
        for device in attribute_list:
            if(len(device.split('=',1)) == 2):
                device, secret = device.split('=', 1)
                values_dict[device] = secret
        
    else:
        logging.error("error in the format of attribute for %s", DN)
        return False

    if(not(ldap.convert_to_json_and_store(values_dict, DN))):
        logging.error("error in storing for %s", DN)
        return False
    else:
        logging.info("Plaintext Update successful for %s", DN)
        ldap.number_of_entries += 1
    return True


def verify_encryption(
    ldap, attribute, DN
    ):
    '''
    verify that the encryption is intact by checking if you can decrypt the same.
    if succesful -  return the decrypted plaintext
    else - Return False
    '''
    try:
        #logging.info("in verify %s", attribute)
        attribute = json.loads(attribute)
        #logging.info('verifying if the encryption is intact')
        attribute = attribute['otpdata']
        decrypted_data = {}
        devices = attribute["devices"]
        
        #extract the encryption parameters
        
        for device,secret in devices.items():
            kid, iv,  ciphertext = secret.split('.')
            kid = urlsafe_b64decode(kid)
            ciphertext = urlsafe_b64decode(ciphertext)
            iv = urlsafe_b64decode(iv)
            #logging.info("iv %s",urlsafe_b64encode(iv))
            #verify the certificate
            if(ldap.crypt.verify_certificate( kid, ldap) != True):
                logging.error('error in verifying the certificate')
                raise ValueError

            #generate the symmetric key
            symmetric_key = ldap.generate_symmetric_key()

            #logging.info("symmetric_key is %s ,  %s,  %s", b64encode(symmetric_key), iv, ciphertext)
            #decrypt the secret
            decrypted_secret = ldap.crypt.decrypt(symmetric_key, iv, ciphertext) 
            
            #logging.info('the decrypted secret is %s ', decrypted_secret) 
            if(decrypted_secret is None):
                raise ValueError
            decrypted_data[device] = decrypted_secret
        return decrypted_data
    except Exception as er:
        logging.error("error in processing the encrypted data due to this error %s, skipping the entry %s ", type(er).__name__, DN)
        return False
        


def restore_plaintext(DN, ldap, decrypted_secret):
    '''
    given a list of devices and their secrets 
    store them in the attribute in the prescribed format
    write the attribute to AD in the prescribed format

    input - the plaintext attribute to be stored

    ''' 
    #logging.info("The decrypted secret is %s",decrypted_secret)
    output_string = "#@"   
    for device, message in decrypted_secret.items():
        message = json.loads(message)
        output_string += device + "=" + message['secret']  
        message.pop('secret', None)
        for key in message.keys():
            output_string += "&" + str(key) + "=" + message[key]  
        output_string += "&,"
    #logging.info("Output string being written is %s", output_string)
    ldap.update(DN, ldap.target_attribute, output_string )


def decrypt_and_restore(ldap, entries):
    '''
    decrypt the user parameters and restore back in the legacy format
    input - encrypted attribute entries
    '''
    ldap.crypt = crypto.crypto()
    entered = False
    ldap.number_of_entries = 0
    ldap.total_entries = 0
    ldap.number_of_entries_not_modified = 0
    ldap.not_modified = []
   
    #modify the parameter
    for entry in entries:
        try:
            ldap.total_entries += 1
            entered = True
            if 'dn' not in entry:
                continue
            DN = entry['dn']
            #debug(entry['attributes'])
            
            attribute = entry['attributes'][ldap.source_attribute]
            #logging.info("%s", attribute)
            ldap.target_value = entry['attributes'][ldap.target_attribute]
            if(isinstance(attribute,str)):
                logging.info("Proccessing a Single valued attribute")
                attribute_value = attribute
            elif(isinstance(attribute,list)):
                logging.info("Processing a multi valued attribute")
                attribute_value = attribute[0]
            else:
                logging.info("Operation failed due to unsupported attribute type")
                return False
            if(is_json(attribute_value)):
                decrypted_secret = verify_encryption(ldap, attribute_value, DN)
                if(decrypted_secret == False):
                    logging.error("error in decryption for %s", DN)
                else:
                    restore_plaintext(DN, ldap, decrypted_secret)
                    ldap.number_of_entries += 1
                    logging.info("Restore successful..")

            elif(is_json(urlsafe_b64decode(attribute_value))):
                existing_attribute = urlsafe_b64decode(attribute_value)
                decrypted_secret = verify_encryption(ldap, existing_attribute, DN)
                if(decrypted_secret == False):
                    logging.error("error in decryption for %s", DN)
                else:
                    restore_plaintext(DN, ldap, decrypted_secret)
                    ldap.number_of_entries += 1
                    logging.info("Restore successful..")
            else:
                logging.error("Error in the format of the source attribute for DN %s",DN)
                ldap.not_modified.append(DN)
                ldap.number_of_entries_not_modified += 1

        except:
            ldap.not_modified.append(DN)
            ldap.number_of_entries_not_modified += 1
            logging.error("error in restoring the plaintext back for %s",DN)

    if(not(entered)):
        logging.error("No Entries to be updated")
    
