diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py b/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py index e7ce646e9987..e1c2664001c5 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/certificate.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ from binascii import hexlify -from typing import cast, NamedTuple, Union, Dict, Any, Optional +from typing import NamedTuple, Union, Dict, Any, Optional from cryptography import x509 from cryptography.hazmat.primitives import hashes, serialization @@ -131,7 +131,7 @@ def get_client_credential( certificate_data: Optional[bytes] = None, send_certificate_chain: bool = False, **_: Any -) -> Dict: +) -> Dict[str, Union[str, bytes]]: """Load a certificate from a filesystem path or bytes, return it as a dict suitable for msal.ClientApplication. :param str certificate_path: Path to a PEM or PKCS12 certificate file. @@ -151,22 +151,26 @@ def get_client_credential( elif not certificate_data: raise ValueError('CertificateCredential requires a value for either "certificate_path" or "certificate_data"') + password_bytes: Optional[bytes] = None if password: - # if password is already bytes, no need to encode. if isinstance(password, str): - password = password.encode("utf-8") - password = cast("Optional[bytes]", password) + password_bytes = password.encode("utf-8") + else: + password_bytes = password if b"-----BEGIN" in certificate_data: - cert = load_pem_certificate(certificate_data, password) + cert = load_pem_certificate(certificate_data, password_bytes) else: - cert = load_pkcs12_certificate(certificate_data, password) + cert = load_pkcs12_certificate(certificate_data, password_bytes) password = None # load_pkcs12_certificate returns cert.pem_bytes decrypted if not isinstance(cert.private_key, RSAPrivateKey): raise ValueError("The certificate must have an RSA private key because RS256 is used for signing") - client_credential = {"private_key": cert.pem_bytes, "thumbprint": hexlify(cert.fingerprint).decode("utf-8")} + client_credential: Dict[str, Union[str, bytes]] = { + "private_key": cert.pem_bytes.decode("utf-8"), + "thumbprint": hexlify(cert.fingerprint).decode("utf-8"), + } if password: client_credential["passphrase"] = password diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aadclient_certificate.py b/sdk/identity/azure-identity/azure/identity/_internal/aadclient_certificate.py index 1a5a8a20d973..f59ff8de02f8 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aadclient_certificate.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aadclient_certificate.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import base64 -from typing import Optional +from typing import Optional, Union from cryptography import x509 from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import padding @@ -14,11 +14,21 @@ class AadClientCertificate: """Wraps 'cryptography' to provide the crypto operations AadClient requires for certificate authentication. - :param bytes pem_bytes: bytes of a a PEM-encoded certificate including the (RSA) private key - :param bytes password: (optional) the certificate's password + :param pem_bytes: PEM-encoded certificate including the (RSA) private key. May be ``bytes`` or a ``str``; + ``str`` values are encoded as UTF-8 before being handed to ``cryptography``. + :paramtype pem_bytes: bytes or str + :param password: (optional) the certificate's password. May be ``bytes`` or a ``str``; ``str`` values are + encoded as UTF-8. + :paramtype password: bytes or str or None """ - def __init__(self, pem_bytes: bytes, password: Optional[bytes] = None) -> None: + def __init__( + self, pem_bytes: Union[bytes, str], password: Optional[Union[bytes, str]] = None + ) -> None: + if isinstance(pem_bytes, str): + pem_bytes = pem_bytes.encode("utf-8") + if isinstance(password, str): + password = password.encode("utf-8") private_key = serialization.load_pem_private_key(pem_bytes, password=password, backend=default_backend()) if not isinstance(private_key, RSAPrivateKey): raise ValueError("The certificate must have an RSA private key because RS256 is used for signing") diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential.py b/sdk/identity/azure-identity/tests/test_certificate_credential.py index 24015b4da414..b72144b879a6 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential.py @@ -10,7 +10,7 @@ from azure.identity import CertificateCredential, TokenCachePersistenceOptions from azure.identity._enums import RegionalAuthority from azure.identity._constants import EnvironmentVariables -from azure.identity._credentials.certificate import load_pkcs12_certificate +from azure.identity._credentials.certificate import load_pkcs12_certificate, get_client_credential from azure.identity._internal.user_agent import USER_AGENT from cryptography import x509 from cryptography.hazmat.backends import default_backend @@ -58,6 +58,22 @@ def test_non_rsa_key(): CertificateCredential("tenant-id", "client-id", certificate_data=open(EC_CERT_PATH, "rb").read()) +def test_get_client_credential_returns_str_private_key(): + """get_client_credential should return private_key as a string.""" + + cred_dict = get_client_credential(PEM_CERT_PATH) + + assert isinstance(cred_dict["private_key"], str) + assert "-----BEGIN" in cred_dict["private_key"] + assert "passphrase" not in cred_dict + + cred_dict = get_client_credential(PFX_CERT_WITH_PASSWORD_PATH, password=CERT_PASSWORD) + + assert isinstance(cred_dict["private_key"], str) + assert "-----BEGIN" in cred_dict["private_key"] + assert "passphrase" not in cred_dict + + def test_tenant_id_validation(): """The credential should raise ValueError when given an invalid tenant_id""" diff --git a/sdk/identity/azure-identity/tests/test_obo.py b/sdk/identity/azure-identity/tests/test_obo.py index bbcd50e6c603..9d03c40ce363 100644 --- a/sdk/identity/azure-identity/tests/test_obo.py +++ b/sdk/identity/azure-identity/tests/test_obo.py @@ -307,5 +307,8 @@ def test_client_certificate_with_params(): ) assert "passphrase" in credential._client_credential - assert credential._client_credential["passphrase"] == cert_password.encode("utf-8") + assert credential._client_credential["passphrase"] == cert_password + assert "private_key" in credential._client_credential + assert isinstance(credential._client_credential["private_key"], str) + assert "-----BEGIN" in credential._client_credential["private_key"] assert "public_certificate" in credential._client_credential