diff --git a/src/aws_encryption_sdk/internal/crypto/elliptic_curve.py b/src/aws_encryption_sdk/internal/crypto/elliptic_curve.py index a7f1d1446..ce1e72224 100644 --- a/src/aws_encryption_sdk/internal/crypto/elliptic_curve.py +++ b/src/aws_encryption_sdk/internal/crypto/elliptic_curve.py @@ -171,6 +171,9 @@ def generate_ecc_signing_key(algorithm): :returns: Generated signing key :raises NotSupportedError: if signing algorithm is not supported on this platform """ - if not isinstance(algorithm.signing_algorithm_info, type(ec.EllipticCurve)): + try: + if not issubclass(algorithm.signing_algorithm_info, ec.EllipticCurve): + raise NotSupportedError("Unsupported signing algorithm info") + except TypeError: raise NotSupportedError("Unsupported signing algorithm info") return ec.generate_private_key(curve=algorithm.signing_algorithm_info(), backend=default_backend()) diff --git a/test/unit/test_crypto_elliptic_curve.py b/test/unit/test_crypto_elliptic_curve.py index 58a07557a..3e4048730 100644 --- a/test/unit/test_crypto_elliptic_curve.py +++ b/test/unit/test_crypto_elliptic_curve.py @@ -349,22 +349,17 @@ def test_ecc_public_numbers_from_compressed_point(patch_ec, patch_ecc_decode_com assert test == sentinel.public_numbers_instance -def test_generate_ecc_signing_key_supported(patch_default_backend, patch_ec): - patch_ec.generate_private_key.return_value = sentinel.raw_signing_key - mock_algorithm_info = MagicMock(return_value=sentinel.algorithm_info, spec=patch_ec.EllipticCurve) - mock_algorithm = MagicMock(signing_algorithm_info=mock_algorithm_info) +def test_generate_ecc_signing_key_supported(patch_default_backend): + patch_default_backend.return_value = sentinel.backend + mock_algorithm = MagicMock(signing_algorithm_info=ec.SECP384R1) test_signing_key = generate_ecc_signing_key(algorithm=mock_algorithm) - patch_ec.generate_private_key.assert_called_once_with( - curve=sentinel.algorithm_info, backend=patch_default_backend.return_value - ) - assert test_signing_key is sentinel.raw_signing_key + assert test_signing_key is not None def test_generate_ecc_signing_key_unsupported(patch_default_backend, patch_ec): - mock_algorithm_info = MagicMock(return_value=sentinel.algorithm_info) - mock_algorithm = MagicMock(signing_algorithm_info=mock_algorithm_info) + mock_algorithm = MagicMock(signing_algorithm_info="not_a_class") with pytest.raises(NotSupportedError) as excinfo: generate_ecc_signing_key(algorithm=mock_algorithm)