from upgrade_cert import cert_upgrade
from ldapaccess import ldapaccess as ld
from utils import *
import logging
import time
import logging
import getpass
import os


if __name__ == "__main__":
    
    #fetch the arguments
    start_time = time.time()
    args = get_arguments()

    #set up logging
    FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"

    logging.basicConfig(filename='app.log', level=logging.DEBUG, filemode='a', format='%(asctime)s - %(name)s - %(levelname)s [%(filename)s:%(lineno)s-%(funcName)s() ] - %(message)s', datefmt='%d-%b-%y %H:%M:%S')
    logging.info("\n\n")
    
    try:
        #extract the variables from the argparser
        for key,val in args.items():
            exec(key + '=val')

        #obtain the password for the AD admin
        pw = getpass.getpass("INPUT: Please enter the LDAP AD Admin password - ")
        PASSWORD = pw

        #ensure that authentication pf the admin to the LDAP server succeeds or exit if it fails
        connect_to_server(USER, PASSWORD, SERVER, PORT)


        #verify the arguments
        verification_result = verify_arguments(args, operation, target_attribute, new_cert_path, source_attribute)

        cert_upgradation = False
        encryption = False 
        decryption = False 



        #prepare the search filter if the user name is specified
        if(AD_user is not None):
            search_filter = "(name="+AD_user+")"

        if(verification_result['result'] == False):
            print("ERROR: There seems to be an error, ", verification_result['value'])
        else: 
            if(operation == '0' or operation == 'encryption'):
                encryption = True
                print("WARN: You are choosing to encrypt the OTPSecret which requires encryption feature to enabled for OTP config, please make sure the encryption parameter is turned ON in otpparameters")
            elif(operation == '1' or operation == 'decryption'):
                decryption = True
            else:
                cert_upgradation = True
            #if its encryption or decryption 
            if(not(cert_upgradation)):
                #if encryption take the first certificate
                if(encryption):
                    cert_path = cert_path[0]
                    print("INFO: You seem to have provided multiple certificates,using the first certificate for encryption")
                inst = ld(SERVER, USER, PASSWORD, 
                        PORT, source_attribute, 
                        target_attribute, cert_path, 
                         encryption)
                #connect to the server
                inst.connect() 
                #query for users     
                search_result = inst.retrieve_entries(search_base,
                    search_filter = search_filter
                    )
                if(search_result == False):
                    logging.error("Error in searching the directory")
                else:
                    if(decryption == True):
                        if(source_attribute == target_attribute):
                            print('WARN: Any data other than OTPSecret in the source attribute might lost as it is overwritten, please select a different target attribute')
                            print('exiting..')
                        else:
                            decrypt_and_restore(inst,search_result)    
                    else:
                        traverse_entries(inst,search_result)
            else:
                inst = cert_upgrade(SERVER, USER, PASSWORD, 
                PORT, source_attribute, 
                cert_path, new_cert_path, )
                #connect to the server
                inst.connect() 
                #query for users     
                logging.info("connection status %s",inst.connection.result)

                search_result = inst.retrieve_entries(search_base,
                    search_filter = search_filter
                    )
                if(search_result == False):
                    logging.error("Error in searching")
                else:
                    inst.traverse_entries(search_result)

            print('INFO: Number of entries fetched',inst.total_entries)
            print('INFO: Number of entries modified',inst.number_of_entries)
            print('INFO: The number of entries that were not modified are',inst.number_of_entries_not_modified)
            print('INFO: Total time of execution of script',time.time()-start_time,'s')
            if(os.path.exists('unmodified_users.txt') and not(os.path.islink('unmodified_users.txt'))):
                with open('unmodified_users.txt', 'w') as f:
                    for i,item in enumerate(inst.not_modified):
                        f.write( str(i) + "  %s\n" % item)

    except Exception as er:
        logging.info("Error in processing - %s ",type(er).__name__)
        #unbind the connection
        inst.connection.unbind()



    