# Copyright Citrix Systems, Inc. All rights reserved.

"""RSA key encoding/decoding.

Encode keys using the encode_public/encode_private functions. They will return
a base64-encoded representation based on the .NET RSA Key XML Format.

Decode keys using the decode_public/decode_private functions. They will return
the original rsa.PublicKey/rsa.PrivateKey objects.
"""
from cwc.util import getbytes, frombytes, ENCODING
from xml.etree.ElementTree import Element, SubElement
from defusedxml.ElementTree import tostring, fromstring
import base64
import rsa

_XML_TAG_RSAKEY = 'RSAKeyValue'
_XML_TAG_MODULUS = 'Modulus'
_XML_TAG_EXPONENT = 'Exponent'
_XML_TAG_P = 'P'
_XML_TAG_Q = 'Q'
_XML_TAG_DP = 'DP'
_XML_TAG_DQ = 'DQ'
_XML_TAG_INVERSEQ = 'InverseQ'
_XML_TAG_D = 'D'

def encode_public(pub_key):
    """Encodes an rsa.PublicKey.

    :param pub_key: the public key to encode
    :returns:       a base64-encoded .NET RSA Key (XML Format) public key
    """
    root = Element(_XML_TAG_RSAKEY)

    _xmlwrite(root, _XML_TAG_MODULUS, pub_key.n)
    _xmlwrite(root, _XML_TAG_EXPONENT, pub_key.e, False)

    return _encodexml(root)

def decode_public(encoded_key):
    """Decodes a base64-encoded .NET RSA Key (XML Format) public key

    :param encoded_key: the encoded key to decode
    :returns:           an rsa.PublicKey object
    """
    root = _decodexml(encoded_key)

    n = _xmlread(root, _XML_TAG_MODULUS)
    e = _xmlread(root, _XML_TAG_EXPONENT)

    return rsa.PublicKey(n, e)

def encode_private(priv_key):
    """Encodes an rsa.PrivateKey.

    :param priv_key:    the private key to encode
    :returns:           a base64-encoded .NET RSA Key (XML Format) private key
    """
    root = Element(_XML_TAG_RSAKEY)
    
    _xmlwrite(root, _XML_TAG_MODULUS, priv_key.n)
    _xmlwrite(root, _XML_TAG_EXPONENT, priv_key.e, False)
    _xmlwrite(root, _XML_TAG_P, priv_key.p)
    _xmlwrite(root, _XML_TAG_Q, priv_key.q)
    _xmlwrite(root, _XML_TAG_DP, priv_key.exp1)
    _xmlwrite(root, _XML_TAG_DQ, priv_key.exp2)
    _xmlwrite(root, _XML_TAG_INVERSEQ, priv_key.coef)
    _xmlwrite(root, _XML_TAG_D, priv_key.d)

    return _encodexml(root)

def decode_private(encoded_key):
    """Decodes a base64-encoded .NET RSA Key (XML Format) private key

    :param encoded_key: the encoded key to decode
    :returns:           an rsa.PrivateKey object
    """
    root = _decodexml(encoded_key)

    n = _xmlread(root, _XML_TAG_MODULUS)
    e = _xmlread(root, _XML_TAG_EXPONENT)
    p = _xmlread(root, _XML_TAG_P)
    q = _xmlread(root, _XML_TAG_Q)
    exp1 = _xmlread(root, _XML_TAG_DP)
    exp2 = _xmlread(root, _XML_TAG_DQ)
    coef = _xmlread(root, _XML_TAG_INVERSEQ)
    d = _xmlread(root, _XML_TAG_D)

    return rsa.PrivateKey(n, e, d, p, q, exp1, exp2, coef)

def _xmlwrite(root, tag, value, pack_power_two = True):
    elem = SubElement(root, tag)
    bytes = _tobytes(value, pack_power_two)
    b64 = base64.b64encode(bytes)
    elem.text = frombytes(b64)
    return elem

def _xmlread(root, tag):
    elem = root.find(tag)
    b64 = getbytes(elem.text)
    bytes = base64.b64decode(b64)
    return _frombytes(bytes)

def _encodexml(root):
    xml = tostring(root, ENCODING)
    b64 = base64.b64encode(xml)
    return frombytes(b64)

def _decodexml(encoded_key):
    b64 = getbytes(encoded_key)
    bytes = base64.b64decode(b64)
    xml = frombytes(bytes)
    return fromstring(xml)

def _tobytes(value, pack_power_two):
    bitlen = value.bit_length()

    if pack_power_two:
        bytelen = max(1, _next_power_two(bitlen) // 8)
    else:
        bytelen = bitlen // 8 if bitlen % 8 == 0 else bitlen // 8 + 1

    return value.to_bytes(bytelen, byteorder='big', signed = False)

def _frombytes(bytes):
    try:
        return int.from_bytes(bytes, byteorder='big', signed = False)
    except AttributeError:
        return int(bytes.encode('hex'), 16)

def _next_power_two(n):
    return 1 << (n - 1).bit_length()
