From 8993318551ad2fc70175b0a1c056399e29715d4f Mon Sep 17 00:00:00 2001 From: Eu Jing Chua Date: Wed, 24 Jul 2024 23:33:45 +0000 Subject: [PATCH 1/6] Initial fixes for token-based fileshare auth --- .../services/remote/azure/azure_auth.py | 10 ++++++-- .../services/remote/azure/azure_fileshare.py | 6 ++--- .../mlos_bench/services/types/__init__.py | 2 ++ .../types/azure_authenticator_type.py | 24 +++++++++++++++++++ 4 files changed, 37 insertions(+), 5 deletions(-) create mode 100644 mlos_bench/mlos_bench/services/types/azure_authenticator_type.py diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py index e674a943256..2055845815d 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py @@ -10,17 +10,18 @@ from typing import Any, Callable, Dict, List, Optional, Union import azure.identity as azure_id +import azure.core.credentials as azure_cred from azure.keyvault.secrets import SecretClient from pytz import UTC from mlos_bench.services.base_service import Service -from mlos_bench.services.types.authenticator_type import SupportsAuth +from mlos_bench.services.types.azure_authenticator_type import SupportsAzureAuth from mlos_bench.util import check_required_params _LOG = logging.getLogger(__name__) -class AzureAuthService(Service, SupportsAuth): +class AzureAuthService(Service, SupportsAzureAuth): """Helper methods to get access to Azure services.""" _REQ_INTERVAL = 300 # = 5 min @@ -56,6 +57,7 @@ def __init__( [ self.get_access_token, self.get_auth_headers, + self.get_credential, ], ), ) @@ -133,3 +135,7 @@ def get_access_token(self) -> str: def get_auth_headers(self) -> dict: """Get the authorization part of HTTP headers for REST API calls.""" return {"Authorization": "Bearer " + self.get_access_token()} + + def get_credential(self) -> azure_cred.TokenCredential: + """Return the Azure SDK credential object.""" + return self._cred diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py index 5ff1b638a32..763c781acdb 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py @@ -13,7 +13,7 @@ from mlos_bench.services.base_fileshare import FileShareService from mlos_bench.services.base_service import Service -from mlos_bench.services.types.authenticator_type import SupportsAuth +from mlos_bench.services.types.azure_authenticator_type import SupportsAzureAuth from mlos_bench.util import check_required_params _LOG = logging.getLogger(__name__) @@ -66,14 +66,14 @@ def _get_share_client(self) -> ShareClient: """Get the Azure file share client object.""" if self._share_client is None: assert self._parent is not None and isinstance( - self._parent, SupportsAuth + self._parent, SupportsAzureAuth ), "Authorization service not provided. Include service-auth.jsonc?" self._share_client = ShareClient.from_share_url( self._SHARE_URL.format( account_name=self.config["storageAccountName"], fs_name=self.config["storageFileShareName"], ), - credential=self._parent.get_access_token(), + credential=self._parent.get_credential(), token_intent="backup", ) return self._share_client diff --git a/mlos_bench/mlos_bench/services/types/__init__.py b/mlos_bench/mlos_bench/services/types/__init__.py index e2d0cb55b5a..14e720dec13 100644 --- a/mlos_bench/mlos_bench/services/types/__init__.py +++ b/mlos_bench/mlos_bench/services/types/__init__.py @@ -7,6 +7,7 @@ """ from mlos_bench.services.types.authenticator_type import SupportsAuth +from mlos_bench.services.types.azure_authenticator_type import SupportsAzureAuth from mlos_bench.services.types.config_loader_type import SupportsConfigLoading from mlos_bench.services.types.fileshare_type import SupportsFileShareOps from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning @@ -19,6 +20,7 @@ __all__ = [ "SupportsAuth", + "SupportsAzureAuth", "SupportsConfigLoading", "SupportsFileShareOps", "SupportsHostProvisioning", diff --git a/mlos_bench/mlos_bench/services/types/azure_authenticator_type.py b/mlos_bench/mlos_bench/services/types/azure_authenticator_type.py new file mode 100644 index 00000000000..ccef2b77fdd --- /dev/null +++ b/mlos_bench/mlos_bench/services/types/azure_authenticator_type.py @@ -0,0 +1,24 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +"""Protocol interface for authentication for the Azure services.""" + +from typing import Protocol, runtime_checkable +from mlos_bench.services.types.authenticator_type import SupportsAuth +import azure.core.credentials as azure_cred + + +@runtime_checkable +class SupportsAzureAuth(SupportsAuth, Protocol): + """Protocol interface for authentication for the Azure services.""" + + def get_credential(self) -> azure_cred.TokenCredential: + """ + Get the credential object for Azure services. + + Returns + ------- + credential : azure_cred.TokenCredential + TokenCredential object. + """ From 4a964eec6b6c0a95c3422a23c708ac05ecaea313 Mon Sep 17 00:00:00 2001 From: Eu Jing Chua Date: Mon, 29 Jul 2024 21:43:37 +0000 Subject: [PATCH 2/6] Fix typing and mocks --- .../services/remote/azure/azure_auth.py | 4 ++-- .../services/remote/azure/azure_fileshare.py | 16 +++++++++---- .../mlos_bench/services/types/__init__.py | 2 -- .../services/types/authenticator_type.py | 16 +++++++++++-- .../types/azure_authenticator_type.py | 24 ------------------- .../test_load_service_config_examples.py | 19 ++++++++++++++- .../services/remote/mock/mock_auth_service.py | 6 ++++- 7 files changed, 50 insertions(+), 37 deletions(-) delete mode 100644 mlos_bench/mlos_bench/services/types/azure_authenticator_type.py diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py index 2055845815d..f0f37ade705 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py @@ -15,13 +15,13 @@ from pytz import UTC from mlos_bench.services.base_service import Service -from mlos_bench.services.types.azure_authenticator_type import SupportsAzureAuth +from mlos_bench.services.types.authenticator_type import SupportsAuth from mlos_bench.util import check_required_params _LOG = logging.getLogger(__name__) -class AzureAuthService(Service, SupportsAzureAuth): +class AzureAuthService(Service, SupportsAuth[azure_cred.TokenCredential]): """Helper methods to get access to Azure services.""" _REQ_INTERVAL = 300 # = 5 min diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py index 763c781acdb..0b63e219750 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py @@ -8,12 +8,13 @@ import os from typing import Any, Callable, Dict, List, Optional, Set, Union +import azure.core.credentials as azure_cred from azure.core.exceptions import ResourceNotFoundError from azure.storage.fileshare import ShareClient from mlos_bench.services.base_fileshare import FileShareService from mlos_bench.services.base_service import Service -from mlos_bench.services.types.azure_authenticator_type import SupportsAzureAuth +from mlos_bench.services.types.authenticator_type import SupportsAuth from mlos_bench.util import check_required_params _LOG = logging.getLogger(__name__) @@ -60,20 +61,25 @@ def __init__( "storageFileShareName", }, ) + assert self._parent is not None and isinstance( + self._parent, SupportsAuth + ), "Authorization service not provided. Include service-auth.jsonc?" + self._auth_service: SupportsAuth[azure_cred.TokenCredential] = self._parent self._share_client: Optional[ShareClient] = None def _get_share_client(self) -> ShareClient: """Get the Azure file share client object.""" if self._share_client is None: - assert self._parent is not None and isinstance( - self._parent, SupportsAzureAuth - ), "Authorization service not provided. Include service-auth.jsonc?" + credential = self._auth_service.get_credential() + assert isinstance( + credential, azure_cred.TokenCredential + ), f"Expected a TokenCredential, but got {type(credential)} instead." self._share_client = ShareClient.from_share_url( self._SHARE_URL.format( account_name=self.config["storageAccountName"], fs_name=self.config["storageFileShareName"], ), - credential=self._parent.get_credential(), + credential=credential, token_intent="backup", ) return self._share_client diff --git a/mlos_bench/mlos_bench/services/types/__init__.py b/mlos_bench/mlos_bench/services/types/__init__.py index 14e720dec13..e2d0cb55b5a 100644 --- a/mlos_bench/mlos_bench/services/types/__init__.py +++ b/mlos_bench/mlos_bench/services/types/__init__.py @@ -7,7 +7,6 @@ """ from mlos_bench.services.types.authenticator_type import SupportsAuth -from mlos_bench.services.types.azure_authenticator_type import SupportsAzureAuth from mlos_bench.services.types.config_loader_type import SupportsConfigLoading from mlos_bench.services.types.fileshare_type import SupportsFileShareOps from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning @@ -20,7 +19,6 @@ __all__ = [ "SupportsAuth", - "SupportsAzureAuth", "SupportsConfigLoading", "SupportsFileShareOps", "SupportsHostProvisioning", diff --git a/mlos_bench/mlos_bench/services/types/authenticator_type.py b/mlos_bench/mlos_bench/services/types/authenticator_type.py index 6f99dd6bce3..b01c30d42de 100644 --- a/mlos_bench/mlos_bench/services/types/authenticator_type.py +++ b/mlos_bench/mlos_bench/services/types/authenticator_type.py @@ -4,11 +4,13 @@ # """Protocol interface for authentication for the cloud services.""" -from typing import Protocol, runtime_checkable +from typing import Protocol, TypeVar, runtime_checkable + +T_co = TypeVar("T_co", covariant=True) @runtime_checkable -class SupportsAuth(Protocol): +class SupportsAuth(Protocol[T_co]): """Protocol interface for authentication for the cloud services.""" def get_access_token(self) -> str: @@ -30,3 +32,13 @@ def get_auth_headers(self) -> dict: access_header : dict HTTP header containing the access token. """ + + def get_credential(self) -> T_co: + """ + Get the credential object for cloud services. + + Returns + ------- + credential : T + Cloud-specific credential object. + """ diff --git a/mlos_bench/mlos_bench/services/types/azure_authenticator_type.py b/mlos_bench/mlos_bench/services/types/azure_authenticator_type.py deleted file mode 100644 index ccef2b77fdd..00000000000 --- a/mlos_bench/mlos_bench/services/types/azure_authenticator_type.py +++ /dev/null @@ -1,24 +0,0 @@ -# -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# -"""Protocol interface for authentication for the Azure services.""" - -from typing import Protocol, runtime_checkable -from mlos_bench.services.types.authenticator_type import SupportsAuth -import azure.core.credentials as azure_cred - - -@runtime_checkable -class SupportsAzureAuth(SupportsAuth, Protocol): - """Protocol interface for authentication for the Azure services.""" - - def get_credential(self) -> azure_cred.TokenCredential: - """ - Get the credential object for Azure services. - - Returns - ------- - credential : azure_cred.TokenCredential - TokenCredential object. - """ diff --git a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py index 55453270808..e010fd140b9 100644 --- a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py @@ -48,11 +48,28 @@ def test_load_service_config_examples( config_path: str, ) -> None: """Tests loading a config example.""" + parent: Service = config_loader_service config = config_loader_service.load_config(config_path, ConfigSchema.SERVICE) + # Add other services that require a SupportsAuth parent service as necessary. + requires_auth_service_parent = { + "AzureFileShareService", + } + config_class_name = str(config.get("class", "MISSING CLASS")).rsplit(".", maxsplit=1)[-1] + if config_class_name in requires_auth_service_parent: + # AzureFileShareService requires an auth service to be loaded as well. + auth_service_config = config_loader_service.load_config( + "services/remote/mock/mock_auth_service.jsonc", + ConfigSchema.SERVICE, + ) + auth_service = config_loader_service.build_service( + config=auth_service_config, + parent=config_loader_service, + ) + parent = auth_service # Make an instance of the class based on the config. service_inst = config_loader_service.build_service( config=config, - parent=config_loader_service, + parent=parent, ) assert service_inst is not None assert isinstance(service_inst, Service) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py index 482f9ee2a9b..b1228217a52 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py @@ -13,7 +13,7 @@ _LOG = logging.getLogger(__name__) -class MockAuthService(Service, SupportsAuth): +class MockAuthService(Service, SupportsAuth[str]): """A collection Service functions for mocking authentication ops.""" def __init__( @@ -32,6 +32,7 @@ def __init__( [ self.get_access_token, self.get_auth_headers, + self.get_credential, ], ), ) @@ -41,3 +42,6 @@ def get_access_token(self) -> str: def get_auth_headers(self) -> dict: return {"Authorization": "Bearer " + self.get_access_token()} + + def get_credential(self) -> str: + return "MOCK CREDENTIAL" From efd3a0812e3cb3223c5e06aac241ae3e664c2de0 Mon Sep 17 00:00:00 2001 From: Eu Jing Chua Date: Mon, 29 Jul 2024 21:49:44 +0000 Subject: [PATCH 3/6] Fix linting --- mlos_bench/mlos_bench/services/remote/azure/azure_auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py index f0f37ade705..caa98cb8b37 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py @@ -9,8 +9,8 @@ from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Union -import azure.identity as azure_id import azure.core.credentials as azure_cred +import azure.identity as azure_id from azure.keyvault.secrets import SecretClient from pytz import UTC From 1eae60727c0034db28f41aa2717b4f45fa50488d Mon Sep 17 00:00:00 2001 From: Sergiy Matusevych Date: Fri, 2 Aug 2024 15:02:46 -0700 Subject: [PATCH 4/6] make sure we initialize before invoking get_credential(); minor cosmetic fixes --- .../services/remote/azure/azure_auth.py | 25 ++++++++----------- .../services/remote/azure/azure_fileshare.py | 6 ++--- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py index caa98cb8b37..271fd9b667c 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py @@ -9,8 +9,8 @@ from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Union -import azure.core.credentials as azure_cred -import azure.identity as azure_id +from azure.core.credentials import TokenCredential +from azure.identity import DefaultAzureCredential, CertificateCredential from azure.keyvault.secrets import SecretClient from pytz import UTC @@ -21,7 +21,7 @@ _LOG = logging.getLogger(__name__) -class AzureAuthService(Service, SupportsAuth[azure_cred.TokenCredential]): +class AzureAuthService(Service, SupportsAuth[TokenCredential]): """Helper methods to get access to Azure services.""" _REQ_INTERVAL = 300 # = 5 min @@ -69,8 +69,7 @@ def __init__( self._token_expiration_ts = datetime.now(UTC) # Typically, some future timestamp. # Login as the first identity available, usually ourselves or a managed identity - self._cred: Union[azure_id.DefaultAzureCredential, azure_id.CertificateCredential] - self._cred = azure_id.DefaultAzureCredential() + self._cred: Union[DefaultAzureCredential, CertificateCredential] = DefaultAzureCredential() # Verify info required for SP auth early if "spClientId" in self.config: @@ -84,13 +83,14 @@ def __init__( }, ) - def _init_sp(self) -> None: + def get_credential(self) -> TokenCredential: + """Return the Azure SDK credential object.""" # Perform this initialization outside of __init__ so that environment loading tests # don't need to specifically mock keyvault interactions out # Already logged in as SP - if isinstance(self._cred, azure_id.CertificateCredential): - return + if isinstance(self._cred, CertificateCredential): + return self._cred sp_client_id = self.config["spClientId"] keyvault_name = self.config["keyVaultName"] @@ -110,17 +110,18 @@ def _init_sp(self) -> None: cert_bytes = b64decode(secret.value) # Reauthenticate as the service principal. - self._cred = azure_id.CertificateCredential( + self._cred = CertificateCredential( tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes, ) + return self._cred def get_access_token(self) -> str: """Get the access token from Azure CLI, if expired.""" # Ensure we are logged as the Service Principal, if provided if "spClientId" in self.config: - self._init_sp() + self.get_credential() ts_diff = (self._token_expiration_ts - datetime.now(UTC)).total_seconds() _LOG.debug("Time to renew the token: %.2f sec.", ts_diff) @@ -135,7 +136,3 @@ def get_access_token(self) -> str: def get_auth_headers(self) -> dict: """Get the authorization part of HTTP headers for REST API calls.""" return {"Authorization": "Bearer " + self.get_access_token()} - - def get_credential(self) -> azure_cred.TokenCredential: - """Return the Azure SDK credential object.""" - return self._cred diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py index 0b63e219750..6fa447da225 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py @@ -8,7 +8,7 @@ import os from typing import Any, Callable, Dict, List, Optional, Set, Union -import azure.core.credentials as azure_cred +from azure.core.credentials import TokenCredential from azure.core.exceptions import ResourceNotFoundError from azure.storage.fileshare import ShareClient @@ -64,7 +64,7 @@ def __init__( assert self._parent is not None and isinstance( self._parent, SupportsAuth ), "Authorization service not provided. Include service-auth.jsonc?" - self._auth_service: SupportsAuth[azure_cred.TokenCredential] = self._parent + self._auth_service: SupportsAuth[TokenCredential] = self._parent self._share_client: Optional[ShareClient] = None def _get_share_client(self) -> ShareClient: @@ -72,7 +72,7 @@ def _get_share_client(self) -> ShareClient: if self._share_client is None: credential = self._auth_service.get_credential() assert isinstance( - credential, azure_cred.TokenCredential + credential, TokenCredential ), f"Expected a TokenCredential, but got {type(credential)} instead." self._share_client = ShareClient.from_share_url( self._SHARE_URL.format( From 826db25f181d46f8c0c91d9643bdf3c6a9454ea5 Mon Sep 17 00:00:00 2001 From: Sergiy Matusevych Date: Fri, 2 Aug 2024 15:59:22 -0700 Subject: [PATCH 5/6] make sure we are consistent in using the credentials in the entire auth service --- .../services/remote/azure/azure_auth.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py index 271fd9b667c..96a1deb2090 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py @@ -67,9 +67,7 @@ def __init__( self._access_token = "RENEW *NOW*" self._token_expiration_ts = datetime.now(UTC) # Typically, some future timestamp. - - # Login as the first identity available, usually ourselves or a managed identity - self._cred: Union[DefaultAzureCredential, CertificateCredential] = DefaultAzureCredential() + self._cred: Optional[TokenCredential] = None # Verify info required for SP auth early if "spClientId" in self.config: @@ -87,15 +85,18 @@ def get_credential(self) -> TokenCredential: """Return the Azure SDK credential object.""" # Perform this initialization outside of __init__ so that environment loading tests # don't need to specifically mock keyvault interactions out + if self._cred is not None: + return self._cred - # Already logged in as SP - if isinstance(self._cred, CertificateCredential): + self._cred = DefaultAzureCredential() + if "spClientId" not in self.config: return self._cred sp_client_id = self.config["spClientId"] keyvault_name = self.config["keyVaultName"] cert_name = self.config["certName"] tenant_id = self.config["tenant"] + _LOG.debug("Log in with Azure Service Principal %s", sp_client_id) # Get a client for fetching cert info keyvault_secrets_client = SecretClient( @@ -119,15 +120,11 @@ def get_credential(self) -> TokenCredential: def get_access_token(self) -> str: """Get the access token from Azure CLI, if expired.""" - # Ensure we are logged as the Service Principal, if provided - if "spClientId" in self.config: - self.get_credential() - ts_diff = (self._token_expiration_ts - datetime.now(UTC)).total_seconds() _LOG.debug("Time to renew the token: %.2f sec.", ts_diff) if ts_diff < self._req_interval: _LOG.debug("Request new accessToken") - res = self._cred.get_token("https://management.azure.com/.default") + res = self.get_credential().get_token("https://management.azure.com/.default") self._token_expiration_ts = datetime.fromtimestamp(res.expires_on, tz=UTC) self._access_token = res.token _LOG.info("Got new accessToken. Expiration time: %s", self._token_expiration_ts) From 10588fceae4383b96c82aee842a4582ea8e1bb0b Mon Sep 17 00:00:00 2001 From: Eu Jing Chua Date: Fri, 2 Aug 2024 23:16:07 +0000 Subject: [PATCH 6/6] Fix linting --- mlos_bench/mlos_bench/services/remote/azure/azure_auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py index 96a1deb2090..dccb5740ce2 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py @@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, List, Optional, Union from azure.core.credentials import TokenCredential -from azure.identity import DefaultAzureCredential, CertificateCredential +from azure.identity import CertificateCredential, DefaultAzureCredential from azure.keyvault.secrets import SecretClient from pytz import UTC