Source code for josepy.jwk

"""JSON Web Key."""
import abc
import json
import logging
import math
from typing import (
    Any,
    Callable,
    Dict,
    Mapping,
    Optional,
    Sequence,
    Tuple,
    Type,
    Union,
)

import cryptography.exceptions
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec, rsa

import josepy.util
from josepy import errors, json_util, util

logger = logging.getLogger(__name__)


[docs]class JWK(json_util.TypedJSONObjectWithFields, metaclass=abc.ABCMeta): """JSON Web Key.""" type_field_name = 'kty' TYPES: Dict[str, Type['JWK']] = {} cryptography_key_types: Tuple[Type[Any], ...] = () """Subclasses should override.""" required: Sequence[str] = NotImplemented """Required members of public key's representation as defined by JWK/JWA.""" _thumbprint_json_dumps_params: Dict[str, Union[Optional[int], Sequence[str], bool]] = { # "no whitespace or line breaks before or after any syntactic # elements" 'indent': None, 'separators': (',', ':'), # "members ordered lexicographically by the Unicode [UNICODE] # code points of the member names" 'sort_keys': True, } key: Any
[docs] def thumbprint(self, hash_function: Callable[[], hashes.HashAlgorithm] = hashes.SHA256) -> bytes: """Compute JWK Thumbprint. https://tools.ietf.org/html/rfc7638 :returns: bytes """ digest = hashes.Hash(hash_function(), backend=default_backend()) digest.update(json.dumps( {k: v for k, v in self.to_json().items() if k in self.required}, **self._thumbprint_json_dumps_params).encode()) # type: ignore[arg-type] return digest.finalize()
[docs] @abc.abstractmethod def public_key(self) -> 'JWK': # pragma: no cover """Generate JWK with public key. For symmetric cryptosystems, this would return ``self``. """ raise NotImplementedError()
@classmethod def _load_cryptography_key(cls, data: bytes, password: Optional[bytes] = None, backend: Optional[Any] = None) -> Any: backend = default_backend() if backend is None else backend exceptions = {} # private key? loader_private: Any for loader_private in (serialization.load_pem_private_key, serialization.load_der_private_key): try: return loader_private(data, password, backend) except (ValueError, TypeError, cryptography.exceptions.UnsupportedAlgorithm) as error: exceptions[str(loader_private)] = error # public key? loader_public: Any for loader_public in (serialization.load_pem_public_key, serialization.load_der_public_key): try: return loader_public(data, backend) except (ValueError, cryptography.exceptions.UnsupportedAlgorithm) as error: exceptions[str(loader_public)] = error # no luck raise errors.Error('Unable to deserialize key: {0}'.format(exceptions))
[docs] @classmethod def load(cls, data: bytes, password: Optional[bytes] = None, backend: Optional[Any] = None) -> 'JWK': """Load serialized key as JWK. :param str data: Public or private key serialized as PEM or DER. :param str password: Optional password. :param backend: A `.PEMSerializationBackend` and `.DERSerializationBackend` provider. :raises errors.Error: if unable to deserialize, or unsupported JWK algorithm :returns: JWK of an appropriate type. :rtype: `JWK` """ try: key = cls._load_cryptography_key(data, password, backend) except errors.Error as error: logger.debug('Loading symmetric key, asymmetric failed: %s', error) return JWKOct(key=data) if cls.typ is not NotImplemented and not isinstance(key, cls.cryptography_key_types): raise errors.Error('Unable to deserialize {0} into {1}'.format( key.__class__, cls.__class__)) for jwk_cls in cls.TYPES.values(): if isinstance(key, jwk_cls.cryptography_key_types): return jwk_cls(key=key) raise errors.Error('Unsupported algorithm: {0}'.format(key.__class__))
[docs]@JWK.register class JWKOct(JWK): """Symmetric JWK.""" typ = 'oct' __slots__ = ('key',) required = ('k', JWK.type_field_name) key: bytes
[docs] def fields_to_partial_json(self) -> Dict[str, str]: # TODO: An "alg" member SHOULD also be present to identify the # algorithm intended to be used with the key, unless the # application uses another means or convention to determine # the algorithm used. return {'k': json_util.encode_b64jose(self.key)}
[docs] @classmethod def fields_from_json(cls, jobj: Mapping[str, Any]) -> 'JWKOct': return cls(key=json_util.decode_b64jose(jobj['k']))
[docs] def public_key(self) -> 'JWKOct': return self
[docs]@JWK.register class JWKRSA(JWK): """RSA JWK. :ivar key: :class:`~cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateKey` or :class:`~cryptography.hazmat.primitives.asymmetric.rsa.RSAPublicKey` wrapped in :class:`~josepy.util.ComparableRSAKey` """ typ = 'RSA' cryptography_key_types = (rsa.RSAPublicKey, rsa.RSAPrivateKey) __slots__ = ('key',) required = ('e', JWK.type_field_name, 'n') key: josepy.util.ComparableRSAKey def __init__(self, *args: Any, **kwargs: Any) -> None: if 'key' in kwargs and not isinstance( kwargs['key'], util.ComparableRSAKey): kwargs['key'] = util.ComparableRSAKey(kwargs['key']) super().__init__(*args, **kwargs) @classmethod def _encode_param(cls, data: int) -> str: """Encode Base64urlUInt. :type data: long :rtype: unicode """ length = max(data.bit_length(), 8) # decoding 0 length = math.ceil(length / 8) return json_util.encode_b64jose(data.to_bytes(byteorder="big", length=length)) @classmethod def _decode_param(cls, data: str) -> int: """Decode Base64urlUInt.""" try: binary = json_util.decode_b64jose(data) if not binary: raise errors.DeserializationError() return int.from_bytes(binary, byteorder="big") except ValueError: # invalid literal for long() with base 16 raise errors.DeserializationError()
[docs] def public_key(self) -> 'JWKRSA': return type(self)(key=self.key.public_key())
[docs] @classmethod def fields_from_json(cls, jobj: Mapping[str, Any]) -> 'JWKRSA': n, e = (cls._decode_param(jobj[x]) for x in ('n', 'e')) public_numbers = rsa.RSAPublicNumbers(e=e, n=n) # public key if 'd' not in jobj: return cls(key=public_numbers.public_key(default_backend())) # private key d = cls._decode_param(jobj['d']) if ('p' in jobj or 'q' in jobj or 'dp' in jobj or 'dq' in jobj or 'qi' in jobj or 'oth' in jobj): # "If the producer includes any of the other private # key parameters, then all of the others MUST be # present, with the exception of "oth", which MUST # only be present when more than two prime factors # were used." p, q, dp, dq, qi, = all_params = tuple( jobj.get(x) for x in ('p', 'q', 'dp', 'dq', 'qi')) if tuple(param for param in all_params if param is None): raise errors.Error( 'Some private parameters are missing: {0}'.format( all_params)) p, q, dp, dq, qi = tuple( cls._decode_param(str(x)) for x in all_params) # TODO: check for oth else: # cryptography>=0.8 p, q = rsa.rsa_recover_prime_factors(n, e, d) dp = rsa.rsa_crt_dmp1(d, p) dq = rsa.rsa_crt_dmq1(d, q) qi = rsa.rsa_crt_iqmp(p, q) key = rsa.RSAPrivateNumbers( p, q, d, dp, dq, qi, public_numbers).private_key( default_backend()) return cls(key=key)
[docs] def fields_to_partial_json(self) -> Dict[str, Any]: if isinstance(self.key._wrapped, rsa.RSAPublicKey): numbers = self.key.public_numbers() params = { 'n': numbers.n, 'e': numbers.e, } else: # rsa.RSAPrivateKey private = self.key.private_numbers() public = self.key.public_key().public_numbers() params = { 'n': public.n, 'e': public.e, 'd': private.d, 'p': private.p, 'q': private.q, 'dp': private.dmp1, 'dq': private.dmq1, 'qi': private.iqmp, } return {key: self._encode_param(value) for key, value in params.items()}
[docs]@JWK.register class JWKEC(JWK): """EC JWK. :ivar key: :class:`~cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey` or :class:`~cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey` wrapped in :class:`~josepy.util.ComparableECKey` """ typ = 'EC' __slots__ = ('key',) cryptography_key_types = ( ec.EllipticCurvePublicKey, ec.EllipticCurvePrivateKey) required = ('crv', JWK.type_field_name, 'x', 'y') key: josepy.util.ComparableECKey def __init__(self, *args: Any, **kwargs: Any) -> None: if 'key' in kwargs and not isinstance( kwargs['key'], util.ComparableECKey): kwargs['key'] = util.ComparableECKey(kwargs['key']) super().__init__(*args, **kwargs) @classmethod def _encode_param(cls, data: int, length: int) -> str: """Encode Base64urlUInt. :type data: long :type key_size: long :rtype: unicode """ return json_util.encode_b64jose(data.to_bytes(byteorder="big", length=length)) @classmethod def _decode_param(cls, data: str, name: str, valid_length: int) -> int: """Decode Base64urlUInt.""" try: binary = json_util.decode_b64jose(data) if len(binary) != valid_length: raise errors.DeserializationError( f'Expected parameter "{name}" to be {valid_length} bytes ' f'after base64-decoding; got {len(binary)} bytes instead' ) return int.from_bytes(binary, byteorder="big") except ValueError: # invalid literal for long() with base 16 raise errors.DeserializationError() @classmethod def _curve_name_to_crv(cls, curve_name: str) -> str: if curve_name == 'secp256r1': return 'P-256' if curve_name == 'secp384r1': return 'P-384' if curve_name == 'secp521r1': return 'P-521' raise errors.SerializationError() @classmethod def _crv_to_curve(cls, crv: str) -> ec.EllipticCurve: # crv is case-sensitive if crv == 'P-256': return ec.SECP256R1() if crv == 'P-384': return ec.SECP384R1() if crv == 'P-521': return ec.SECP521R1() raise errors.DeserializationError() @classmethod def expected_length_for_curve(cls, curve: ec.EllipticCurve) -> int: if isinstance(curve, ec.SECP256R1): return 32 elif isinstance(curve, ec.SECP384R1): return 48 elif isinstance(curve, ec.SECP521R1): return 66 raise ValueError(f'Unexpected curve: {curve}')
[docs] def fields_to_partial_json(self) -> Dict[str, Any]: params = {} if isinstance(self.key._wrapped, ec.EllipticCurvePublicKey): public = self.key.public_numbers() elif isinstance(self.key._wrapped, ec.EllipticCurvePrivateKey): private = self.key.private_numbers() public = self.key.public_key().public_numbers() params['d'] = private.private_value else: raise errors.SerializationError( 'Supplied key is neither of type EllipticCurvePublicKey ' 'nor EllipticCurvePrivateKey' ) params['x'] = public.x params['y'] = public.y params = { key: self._encode_param(value, self.expected_length_for_curve(public.curve)) for key, value in params.items() } params['crv'] = self._curve_name_to_crv(public.curve.name) return params
[docs] @classmethod def fields_from_json(cls, jobj: Mapping[str, Any]) -> 'JWKEC': curve = cls._crv_to_curve(jobj['crv']) expected_length = cls.expected_length_for_curve(curve) x, y = (cls._decode_param(jobj[n], n, expected_length) for n in ('x', 'y')) public_numbers = ec.EllipticCurvePublicNumbers(x=x, y=y, curve=curve) # private key if 'd' not in jobj: return cls(key=public_numbers.public_key(default_backend())) # private key d = cls._decode_param(jobj['d'], 'd', expected_length) key = ec.EllipticCurvePrivateNumbers(d, public_numbers).private_key(default_backend()) return cls(key=key)
[docs] def public_key(self) -> 'JWKEC': # Unlike RSAPrivateKey, EllipticCurvePrivateKey does not contain public_key() if hasattr(self.key, 'public_key'): key = self.key.public_key() else: key = self.key.public_numbers().public_key(default_backend()) return type(self)(key=key)