Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
from binascii import hexlify
from typing import cast, NamedTuple, Union, Dict, Any, Optional
from typing import NamedTuple, Union, Dict, Any, Optional

from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
Expand Down Expand Up @@ -131,7 +131,7 @@ def get_client_credential(
certificate_data: Optional[bytes] = None,
send_certificate_chain: bool = False,
**_: Any
) -> Dict:
) -> Dict[str, Union[str, bytes]]:
"""Load a certificate from a filesystem path or bytes, return it as a dict suitable for msal.ClientApplication.

:param str certificate_path: Path to a PEM or PKCS12 certificate file.
Expand All @@ -151,22 +151,26 @@ def get_client_credential(
elif not certificate_data:
raise ValueError('CertificateCredential requires a value for either "certificate_path" or "certificate_data"')

password_bytes: Optional[bytes] = None
if password:
# if password is already bytes, no need to encode.
if isinstance(password, str):
password = password.encode("utf-8")
password = cast("Optional[bytes]", password)
password_bytes = password.encode("utf-8")
else:
password_bytes = password

if b"-----BEGIN" in certificate_data:
cert = load_pem_certificate(certificate_data, password)
cert = load_pem_certificate(certificate_data, password_bytes)
else:
cert = load_pkcs12_certificate(certificate_data, password)
cert = load_pkcs12_certificate(certificate_data, password_bytes)
password = None # load_pkcs12_certificate returns cert.pem_bytes decrypted

if not isinstance(cert.private_key, RSAPrivateKey):
raise ValueError("The certificate must have an RSA private key because RS256 is used for signing")

client_credential = {"private_key": cert.pem_bytes, "thumbprint": hexlify(cert.fingerprint).decode("utf-8")}
client_credential: Dict[str, Union[str, bytes]] = {
"private_key": cert.pem_bytes.decode("utf-8"),
"thumbprint": hexlify(cert.fingerprint).decode("utf-8"),
}
if password:
client_credential["passphrase"] = password

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import base64
from typing import Optional
from typing import Optional, Union
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
Expand All @@ -14,11 +14,21 @@
class AadClientCertificate:
"""Wraps 'cryptography' to provide the crypto operations AadClient requires for certificate authentication.

:param bytes pem_bytes: bytes of a a PEM-encoded certificate including the (RSA) private key
:param bytes password: (optional) the certificate's password
:param pem_bytes: PEM-encoded certificate including the (RSA) private key. May be ``bytes`` or a ``str``;
``str`` values are encoded as UTF-8 before being handed to ``cryptography``.
:paramtype pem_bytes: bytes or str
:param password: (optional) the certificate's password. May be ``bytes`` or a ``str``; ``str`` values are
encoded as UTF-8.
:paramtype password: bytes or str or None
"""

def __init__(self, pem_bytes: bytes, password: Optional[bytes] = None) -> None:
def __init__(
self, pem_bytes: Union[bytes, str], password: Optional[Union[bytes, str]] = None
) -> None:
if isinstance(pem_bytes, str):
pem_bytes = pem_bytes.encode("utf-8")
if isinstance(password, str):
password = password.encode("utf-8")
private_key = serialization.load_pem_private_key(pem_bytes, password=password, backend=default_backend())
if not isinstance(private_key, RSAPrivateKey):
raise ValueError("The certificate must have an RSA private key because RS256 is used for signing")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from azure.identity import CertificateCredential, TokenCachePersistenceOptions
from azure.identity._enums import RegionalAuthority
from azure.identity._constants import EnvironmentVariables
from azure.identity._credentials.certificate import load_pkcs12_certificate
from azure.identity._credentials.certificate import load_pkcs12_certificate, get_client_credential
from azure.identity._internal.user_agent import USER_AGENT
from cryptography import x509
from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -58,6 +58,22 @@ def test_non_rsa_key():
CertificateCredential("tenant-id", "client-id", certificate_data=open(EC_CERT_PATH, "rb").read())


def test_get_client_credential_returns_str_private_key():
"""get_client_credential should return private_key as a string."""

cred_dict = get_client_credential(PEM_CERT_PATH)

assert isinstance(cred_dict["private_key"], str)
assert "-----BEGIN" in cred_dict["private_key"]
assert "passphrase" not in cred_dict

cred_dict = get_client_credential(PFX_CERT_WITH_PASSWORD_PATH, password=CERT_PASSWORD)

assert isinstance(cred_dict["private_key"], str)
assert "-----BEGIN" in cred_dict["private_key"]
assert "passphrase" not in cred_dict


def test_tenant_id_validation():
"""The credential should raise ValueError when given an invalid tenant_id"""

Expand Down
5 changes: 4 additions & 1 deletion sdk/identity/azure-identity/tests/test_obo.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,5 +307,8 @@ def test_client_certificate_with_params():
)

assert "passphrase" in credential._client_credential
assert credential._client_credential["passphrase"] == cert_password.encode("utf-8")
assert credential._client_credential["passphrase"] == cert_password
assert "private_key" in credential._client_credential
assert isinstance(credential._client_credential["private_key"], str)
assert "-----BEGIN" in credential._client_credential["private_key"]
assert "public_certificate" in credential._client_credential
Loading