From 340187ccc9b30656d60f6018760543ea6c058dfe Mon Sep 17 00:00:00 2001 From: Ritwik Saha Date: Wed, 28 Jan 2026 19:58:29 -0500 Subject: [PATCH] poc Other alternate project owner and insecure connections --- cmlutils/base.py | 3 ++ cmlutils/cdswctl.py | 8 +-- cmlutils/constants.py | 3 +- cmlutils/project_entrypoint.py | 11 +++++ cmlutils/projects.py | 90 ++++++++++++++++++++++++++++------ cmlutils/utils.py | 40 ++++++++++++--- cmlutils/validator.py | 28 +++++++++-- 7 files changed, 151 insertions(+), 32 deletions(-) diff --git a/cmlutils/base.py b/cmlutils/base.py index 9d52973..ea2d5be 100644 --- a/cmlutils/base.py +++ b/cmlutils/base.py @@ -17,6 +17,7 @@ def __init__( api_key: str, ca_path: str, project_slug: str, + skip_tls_verification: bool = False, ) -> None: self.host = host self.username = username @@ -24,6 +25,7 @@ def __init__( self.api_key = api_key self.ca_path = ca_path self.project_slug = project_slug + self.skip_tls_verification = skip_tls_verification @property def apiv2_key(self) -> str: @@ -42,6 +44,7 @@ def apiv2_key(self) -> str: api_key=self.api_key, json_data=json_data, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) response_dict = response.json() _apiv2_key = response_dict["apiKey"] diff --git a/cmlutils/cdswctl.py b/cmlutils/cdswctl.py index fb2899e..9b6edb5 100644 --- a/cmlutils/cdswctl.py +++ b/cmlutils/cdswctl.py @@ -23,11 +23,11 @@ def _get_cdswctl_download_url(host: str) -> str: return final_url -def _download_and_extract(url: str, ca_path: str): +def _download_and_extract(url: str, ca_path: str, skip_tls_verification: bool = False): file_name = url.split("/")[-1] dir_path = _cdswctl_tmp_dir_path() file_path = os.path.join(dir_path, file_name) - download_file(url=url, filepath=file_path, ca_path=ca_path) + download_file(url=url, filepath=file_path, ca_path=ca_path, skip_tls_verification=skip_tls_verification) if Path(constants.BASE_PATH_CDSWCTL) in Path(dir_path).parents: tf = tarfile.open(file_path) tf.extractall(dir_path) @@ -51,9 +51,9 @@ def _cdswctl_tmp_dir_path() -> str: return dirpath -def obtain_cdswctl(host: str, ca_path: str) -> str: +def obtain_cdswctl(host: str, ca_path: str, skip_tls_verification: bool = False) -> str: file_url = _get_cdswctl_download_url(host) - expected_cdswctl_path = _download_and_extract(file_url, ca_path=ca_path) + expected_cdswctl_path = _download_and_extract(file_url, ca_path=ca_path, skip_tls_verification=skip_tls_verification) logging.info( "Expected cdsw path for cdswctl for file transfer %s", expected_cdswctl_path ) diff --git a/cmlutils/constants.py b/cmlutils/constants.py index b05ad9c..a03f71d 100644 --- a/cmlutils/constants.py +++ b/cmlutils/constants.py @@ -20,6 +20,7 @@ OUTPUT_DIR_KEY = "output_dir" PROJECT_NAME_KEY = "project_name" CA_PATH_KEY = "ca_path" +SKIP_TLS_VERIFICATION_KEY = "skip_tls_verification" MAX_API_PAGE_LENGTH = 30 @@ -56,7 +57,7 @@ class ApiV1Endpoints(Enum): API_KEY = "/api/v1/users/$username/apikey" RUNTIMES = "/api/v1/runtimes" USER_INFO = "/api/v1/users/$username" - PROJECTS_SUMMARY = "/api/v1/users/$username/projects-summary?all=true&context=$username&sortColumn=updated_at&projectName=$projectName&limit=$limit&offset=$offset" + PROJECTS_SUMMARY = "/api/v1/users/$username/projects-summary?all=true&scope=all&context=$username&sortColumn=updated_at&projectName=$projectName&limit=$limit&offset=$offset" """Mapping of old fields v1 to new fields of v2""" diff --git a/cmlutils/project_entrypoint.py b/cmlutils/project_entrypoint.py index a5e0ce6..7b8e593 100644 --- a/cmlutils/project_entrypoint.py +++ b/cmlutils/project_entrypoint.py @@ -72,6 +72,9 @@ def _read_config_file(file_path: str, project_name: str): print("Key %s is missing from config file." % (key)) raise output_config[CA_PATH_KEY] = config.get(project_name, CA_PATH_KEY, fallback="") + output_config[constants.SKIP_TLS_VERIFICATION_KEY] = config.getboolean( + project_name, constants.SKIP_TLS_VERIFICATION_KEY, fallback=False + ) return output_config else: print("Validation error: cannot find config file:", file_path) @@ -103,6 +106,7 @@ def project_export_cmd(project_name): apiv1_key = config[API_V1_KEY] output_dir = config[OUTPUT_DIR_KEY] ca_path = config[CA_PATH_KEY] + skip_tls_verification = config[constants.SKIP_TLS_VERIFICATION_KEY] output_dir = get_absolute_path(output_dir) ca_path = get_absolute_path(ca_path) @@ -121,6 +125,7 @@ def project_export_cmd(project_name): ca_path=ca_path, project_slug=project_name, owner_type="", + skip_tls_verification=skip_tls_verification, ) creator_username, project_slug, owner_type = pobj.get_creator_username() if creator_username is None: @@ -139,6 +144,7 @@ def project_export_cmd(project_name): apiv1_key=apiv1_key, ca_path=ca_path, project_slug=project_slug, + skip_tls_verification=skip_tls_verification, ) for v in validators: validation_response = v.validate() @@ -164,6 +170,7 @@ def project_export_cmd(project_name): ca_path=ca_path, project_slug=project_slug, owner_type=owner_type, + skip_tls_verification=skip_tls_verification, ) start_time = time.time() pexport.transfer_project_files(log_filedir=log_filedir) @@ -225,6 +232,7 @@ def project_import_cmd(project_name, verify): apiv1_key = config[API_V1_KEY] local_directory = config[OUTPUT_DIR_KEY] ca_path = config[CA_PATH_KEY] + skip_tls_verification = config[constants.SKIP_TLS_VERIFICATION_KEY] local_directory = get_absolute_path(local_directory) ca_path = get_absolute_path(ca_path) log_filedir = os.path.join(local_directory, project_name, "logs") @@ -238,6 +246,7 @@ def project_import_cmd(project_name, verify): top_level_dir=local_directory, ca_path=ca_path, project_slug=project_name, + skip_tls_verification=skip_tls_verification, ) logging.info("Started importing project: %s", project_name) try: @@ -248,6 +257,7 @@ def project_import_cmd(project_name, verify): top_level_directory=local_directory, apiv1_key=apiv1_key, ca_path=ca_path, + skip_tls_verification=skip_tls_verification, ) logging.info("Begin validating for import.") for v in validators: @@ -296,6 +306,7 @@ def project_import_cmd(project_name, verify): top_level_dir=local_directory, ca_path=ca_path, project_slug=project_slug, + skip_tls_verification=skip_tls_verification, ) start_time = time.time() if verify: diff --git a/cmlutils/projects.py b/cmlutils/projects.py index 11d83c1..2241885 100644 --- a/cmlutils/projects.py +++ b/cmlutils/projects.py @@ -44,12 +44,18 @@ def is_project_configured_with_runtimes( api_key: str, ca_path: str, project_slug: str, + skip_tls_verification: bool = False, ) -> bool: endpoint = Template(ApiV1Endpoints.PROJECT.value).substitute( username=username, project_name=project_slug ) response = call_api_v1( - host=host, endpoint=endpoint, method="GET", api_key=api_key, ca_path=ca_path + host=host, + endpoint=endpoint, + method="GET", + api_key=api_key, + ca_path=ca_path, + skip_tls_verification=skip_tls_verification, ) response_dict = response.json() return ( @@ -67,6 +73,7 @@ def get_ignore_files( ssh_port: str, project_slug: str, top_level_dir: str, + skip_tls_verification: bool = False, ) -> str: endpoint = Template(ApiV1Endpoints.PROJECT_FILE.value).substitute( username=username, project_name=project_slug, filename=constants.FILE_NAME @@ -78,7 +85,12 @@ def get_ignore_files( project_name, ) response = call_api_v1( - host=host, endpoint=endpoint, method="GET", api_key=api_key, ca_path=ca_path + host=host, + endpoint=endpoint, + method="GET", + api_key=api_key, + ca_path=ca_path, + skip_tls_verification=skip_tls_verification, ) a = response.text + "\n" + constants.FILE_NAME with open( @@ -130,8 +142,8 @@ def get_ignore_files( raise e -def get_rsync_enabled_runtime_id(host: str, api_key: str, ca_path: str) -> int: - runtime_list = get_cdsw_runtimes(host=host, api_key=api_key, ca_path=ca_path) +def get_rsync_enabled_runtime_id(host: str, api_key: str, ca_path: str, skip_tls_verification: bool = False) -> int: + runtime_list = get_cdsw_runtimes(host=host, api_key=api_key, ca_path=ca_path, skip_tls_verification=skip_tls_verification) for runtime in runtime_list: if "rsync" in runtime["edition"].lower(): logging.info("Rsync enabled runtime is available.") @@ -140,10 +152,15 @@ def get_rsync_enabled_runtime_id(host: str, api_key: str, ca_path: str) -> int: return -1 -def get_cdsw_runtimes(host: str, api_key: str, ca_path: str) -> list[dict[str, Any]]: +def get_cdsw_runtimes(host: str, api_key: str, ca_path: str, skip_tls_verification: bool = False) -> list[dict[str, Any]]: endpoint = "api/v1/runtimes" response = call_api_v1( - host=host, endpoint=endpoint, method="GET", api_key=api_key, ca_path=ca_path + host=host, + endpoint=endpoint, + method="GET", + api_key=api_key, + ca_path=ca_path, + skip_tls_verification=skip_tls_verification, ) response_dict = response.json() return response_dict["runtimes"] @@ -280,12 +297,13 @@ def __init__( ca_path: str, project_slug: str, owner_type: str, + skip_tls_verification: bool = False, ) -> None: self._ssh_subprocess = None self.top_level_dir = top_level_dir self.project_id = None self.owner_type = owner_type - super().__init__(host, username, project_name, api_key, ca_path, project_slug) + super().__init__(host, username, project_name, api_key, ca_path, project_slug, skip_tls_verification) self.metrics_data = dict() # Get CDSW project info using API v1 @@ -299,6 +317,7 @@ def get_project_infov1(self): method="GET", api_key=self.api_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return response.json() @@ -313,6 +332,7 @@ def get_project_env(self): method="GET", api_key=self.api_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return response.json() @@ -330,12 +350,14 @@ def get_creator_username(self): limit=constants.MAX_API_PAGE_LENGTH, offset=offset * constants.MAX_API_PAGE_LENGTH, ) + logging.info("Endpoint: %s", endpoint) response = call_api_v1( host=self.host, endpoint=endpoint, method="GET", api_key=self.api_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) """ @@ -388,6 +410,7 @@ def get_models_listv1(self, project_id: int): api_key=self.api_key, json_data=json_data, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return response.json() @@ -402,6 +425,7 @@ def get_jobs_listv1(self): method="GET", api_key=self.api_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return response.json() @@ -416,6 +440,7 @@ def get_app_listv1(self): method="GET", api_key=self.api_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return response.json() @@ -434,6 +459,7 @@ def get_model_infov1(self, model_id: str): api_key=self.api_key, json_data=json_data, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return response.json() @@ -448,6 +474,7 @@ def get_job_infov1(self, job_id: int): method="GET", api_key=self.api_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return response.json() @@ -462,6 +489,7 @@ def get_app_infov1(self, app_id: int): method="GET", api_key=self.api_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return response.json() @@ -474,6 +502,7 @@ def get_all_runtimes(self): method="GET", api_key=self.api_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return response.json() @@ -492,11 +521,12 @@ def transfer_project_files(self, log_filedir: str): api_key=self.api_key, ca_path=self.ca_path, project_slug=self.project_slug, + skip_tls_verification=self.skip_tls_verification, ): rsync_enabled_runtime_id = get_rsync_enabled_runtime_id( - host=self.host, api_key=self.api_key, ca_path=self.ca_path + host=self.host, api_key=self.api_key, ca_path=self.ca_path, skip_tls_verification=self.skip_tls_verification ) - cdswctl_path = obtain_cdswctl(host=self.host, ca_path=self.ca_path) + cdswctl_path = obtain_cdswctl(host=self.host, ca_path=self.ca_path, skip_tls_verification=self.skip_tls_verification) login_response = cdswctl_login( cdswctl_path=cdswctl_path, host=self.host, @@ -527,6 +557,7 @@ def transfer_project_files(self, log_filedir: str): ssh_port=port, project_slug=self.project_slug, top_level_dir=self.top_level_dir, + skip_tls_verification=self.skip_tls_verification, ) test_file_size( sshport=port, @@ -554,11 +585,12 @@ def verify_project_files(self, log_filedir: str): api_key=self.api_key, ca_path=self.ca_path, project_slug=self.project_slug, + skip_tls_verification=self.skip_tls_verification, ): rsync_enabled_runtime_id = get_rsync_enabled_runtime_id( - host=self.host, api_key=self.api_key, ca_path=self.ca_path + host=self.host, api_key=self.api_key, ca_path=self.ca_path, skip_tls_verification=self.skip_tls_verification ) - cdswctl_path = obtain_cdswctl(host=self.host, ca_path=self.ca_path) + cdswctl_path = obtain_cdswctl(host=self.host, ca_path=self.ca_path, skip_tls_verification=self.skip_tls_verification) login_response = cdswctl_login( cdswctl_path=cdswctl_path, host=self.host, @@ -586,6 +618,7 @@ def verify_project_files(self, log_filedir: str): ssh_port=port, project_slug=self.project_slug, top_level_dir=self.top_level_dir, + skip_tls_verification=self.skip_tls_verification, ) result = verify_files( sshport=port, @@ -913,10 +946,11 @@ def __init__( top_level_dir: str, ca_path: str, project_slug: str, + skip_tls_verification: bool = False, ) -> None: self._ssh_subprocess = None self.top_level_dir = top_level_dir - super().__init__(host, username, project_name, api_key, ca_path, project_slug) + super().__init__(host, username, project_name, api_key, ca_path, project_slug, skip_tls_verification) self.metrics_data = dict() def get_creator_username(self): @@ -939,6 +973,7 @@ def get_creator_username(self): method="GET", api_key=self.api_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) """ @@ -968,9 +1003,9 @@ def get_creator_username(self): def transfer_project(self, log_filedir: str, verify=False): result = None rsync_enabled_runtime_id = get_rsync_enabled_runtime_id( - host=self.host, api_key=self.apiv2_key, ca_path=self.ca_path + host=self.host, api_key=self.apiv2_key, ca_path=self.ca_path, skip_tls_verification=self.skip_tls_verification ) - cdswctl_path = obtain_cdswctl(host=self.host, ca_path=self.ca_path) + cdswctl_path = obtain_cdswctl(host=self.host, ca_path=self.ca_path, skip_tls_verification=self.skip_tls_verification) login_response = cdswctl_login( cdswctl_path=cdswctl_path, host=self.host, @@ -1019,9 +1054,9 @@ def transfer_project(self, log_filedir: str, verify=False): def verify_project(self, log_filedir: str): rsync_enabled_runtime_id = get_rsync_enabled_runtime_id( - host=self.host, api_key=self.apiv2_key, ca_path=self.ca_path + host=self.host, api_key=self.apiv2_key, ca_path=self.ca_path, skip_tls_verification=self.skip_tls_verification ) - cdswctl_path = obtain_cdswctl(host=self.host, ca_path=self.ca_path) + cdswctl_path = obtain_cdswctl(host=self.host, ca_path=self.ca_path, skip_tls_verification=self.skip_tls_verification) login_response = cdswctl_login( cdswctl_path=cdswctl_path, host=self.host, @@ -1070,6 +1105,7 @@ def create_project_v2(self, proj_metadata) -> str: user_token=self.apiv2_key, json_data=proj_metadata, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) json_resp = response.json() return json_resp["id"] @@ -1089,6 +1125,7 @@ def convert_project_to_engine_based(self, proj_patch_metadata) -> bool: api_key=self.api_key, json_data=proj_patch_metadata, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return True except KeyError as e: @@ -1107,6 +1144,7 @@ def create_model_v2(self, proj_id: str, model_metadata) -> str: user_token=self.apiv2_key, json_data=model_metadata, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) json_resp = response.json() return json_resp["id"] @@ -1127,6 +1165,7 @@ def create_model_build_v2( user_token=self.apiv2_key, json_data=model_metadata, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return @@ -1142,6 +1181,7 @@ def create_application_v2(self, proj_id: str, app_metadata) -> str: user_token=self.apiv2_key, json_data=app_metadata, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) json_resp = response.json() return json_resp["id"] @@ -1159,6 +1199,7 @@ def stop_application_v2(self, proj_id: str, app_id: str) -> None: method="POST", user_token=self.apiv2_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return @@ -1174,6 +1215,7 @@ def create_job_v2(self, proj_id: str, job_metadata) -> str: user_token=self.apiv2_key, json_data=job_metadata, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) json_resp = response.json() return json_resp["id"] @@ -1192,6 +1234,7 @@ def update_job_v2(self, proj_id: str, job_id: str, job_metadata) -> None: user_token=self.apiv2_key, json_data=job_metadata, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return @@ -1204,6 +1247,7 @@ def get_all_runtimes(self): method="GET", api_key=self.api_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return response.json() @@ -1220,6 +1264,7 @@ def get_spark_runtimeaddons(self): method="GET", user_token=self.apiv2_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) result_list = response.json()["runtime_addons"] if result_list: @@ -1237,6 +1282,7 @@ def get_all_runtimes_v2(self, page_token=""): method="GET", user_token=self.apiv2_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) result_list = response.json() if result_list: @@ -1258,6 +1304,7 @@ def check_project_exist(self, project_name: str) -> str: method="GET", user_token=self.apiv2_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) project_list = response.json()["projects"] if project_list: @@ -1284,6 +1331,7 @@ def check_model_exist(self, model_name: str, proj_id: str) -> bool: method="GET", user_token=self.apiv2_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) model_list = response.json()["models"] if model_list: @@ -1310,6 +1358,7 @@ def check_job_exist(self, job_name: str, script: str, proj_id: str) -> str: method="GET", user_token=self.apiv2_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) job_list = response.json()["jobs"] if job_list: @@ -1336,6 +1385,7 @@ def check_app_exist(self, subdomain: str, proj_id: str) -> bool: method="GET", user_token=self.apiv2_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) app_list = response.json()["applications"] if app_list: @@ -1357,6 +1407,7 @@ def get_models_listv2(self, proj_id: str): method="GET", user_token=self.apiv2_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return response.json() @@ -1370,6 +1421,7 @@ def get_models_detailv2(self, proj_id: str, model_id: str): method="GET", user_token=self.apiv2_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return response.json() @@ -1383,6 +1435,7 @@ def get_jobs_listv2(self, proj_id: str): method="GET", user_token=self.apiv2_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return response.json() @@ -1396,6 +1449,7 @@ def get_application_listv2(self, proj_id: str): method="GET", user_token=self.apiv2_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return response.json() @@ -1459,6 +1513,7 @@ def create_models(self, project_id: str, models_metadata_filepath: str): api_key=self.api_key, ca_path=self.ca_path, project_slug=self.project_slug, + skip_tls_verification=self.skip_tls_verification, ) model_metadata_list = read_json_file(models_metadata_filepath) if model_metadata_list != None: @@ -1536,6 +1591,7 @@ def create_stoppped_applications(self, project_id: str, app_metadata_filepath: s api_key=self.api_key, ca_path=self.ca_path, project_slug=self.project_slug, + skip_tls_verification=self.skip_tls_verification, ) app_metadata_list = read_json_file(app_metadata_filepath) if app_metadata_list != None: @@ -1599,6 +1655,7 @@ def create_paused_jobs(self, project_id: str, job_metadata_filepath: str): api_key=self.api_key, ca_path=self.ca_path, project_slug=self.project_slug, + skip_tls_verification=self.skip_tls_verification, ) job_metadata_list = read_json_file(job_metadata_filepath) src_tgt_job_mapping = {} @@ -1686,6 +1743,7 @@ def get_project_infov2(self, proj_id: str): method="GET", user_token=self.apiv2_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return response.json() diff --git a/cmlutils/utils.py b/cmlutils/utils.py index 2bc4c67..9183e79 100644 --- a/cmlutils/utils.py +++ b/cmlutils/utils.py @@ -19,6 +19,7 @@ def call_api_v1( api_key: str, json_data: dict = None, ca_path: str = "", + skip_tls_verification: bool = False, ) -> requests.Response: url = urllib.parse.urljoin(host, endpoint) s = requests.Session() @@ -31,6 +32,15 @@ def call_api_v1( s.mount("https://", HTTPAdapter(max_retries=retries)) headers = {"Content-Type": "application/json"} resp = None + + # Determine SSL verification setting + if skip_tls_verification: + verify_setting = False + elif ca_path != "": + verify_setting = ca_path + else: + verify_setting = True + try: if json_data != None: resp = s.request( @@ -39,7 +49,7 @@ def call_api_v1( auth=(api_key, ""), headers=headers, json=json_data, - verify=ca_path if ca_path != "" else True, + verify=verify_setting, ) else: resp = s.request( @@ -47,7 +57,7 @@ def call_api_v1( url=url, auth=(api_key, ""), headers=headers, - verify=ca_path if ca_path != "" else True, + verify=verify_setting, ) resp.raise_for_status() # Raise an exception for 4xx or 5xx errors return resp @@ -64,6 +74,7 @@ def call_api_v2( user_token: str, json_data: dict = None, ca_path: str = "", + skip_tls_verification: bool = False, ) -> requests.Response: url = urllib.parse.urljoin(host, endpoint) s = requests.Session() @@ -79,6 +90,15 @@ def call_api_v2( "Authorization": "Bearer {}".format(user_token), } resp = None + + # Determine SSL verification setting + if skip_tls_verification: + verify_setting = False + elif ca_path != "": + verify_setting = ca_path + else: + verify_setting = True + try: if json_data != None: resp = s.request( @@ -86,14 +106,14 @@ def call_api_v2( url=url, headers=headers, json=json_data, - verify=ca_path if ca_path != "" else True, + verify=verify_setting, ) else: resp = s.request( method=method.upper(), url=url, headers=headers, - verify=ca_path if ca_path != "" else True, + verify=verify_setting, ) resp.raise_for_status() # Raise an exception for 4xx or 5xx errors return resp @@ -104,8 +124,16 @@ def call_api_v2( raise -def download_file(url: str, filepath: str, ca_path: str = ""): - with requests.get(url, stream=True, verify=ca_path if ca_path != "" else True) as r: +def download_file(url: str, filepath: str, ca_path: str = "", skip_tls_verification: bool = False): + # Determine SSL verification setting + if skip_tls_verification: + verify_setting = False + elif ca_path != "": + verify_setting = ca_path + else: + verify_setting = True + + with requests.get(url, stream=True, verify=verify_setting) as r: with open(filepath, "wb") as f: shutil.copyfileobj(r.raw, f) diff --git a/cmlutils/validator.py b/cmlutils/validator.py index e86a7c1..37e8117 100644 --- a/cmlutils/validator.py +++ b/cmlutils/validator.py @@ -68,7 +68,7 @@ def validate(self) -> ValidationResponse: class UserNameImportValidator(ImportValidators): def __init__( - self, host: str, username: str, apiv1_key: str, project_name: str, ca_path: str + self, host: str, username: str, apiv1_key: str, project_name: str, ca_path: str, skip_tls_verification: bool = False ): self.validation_name = "check if user is present" self.host = host @@ -76,6 +76,7 @@ def __init__( self.apiv1_key = apiv1_key self.project_name = project_name self.ca_path = ca_path + self.skip_tls_verification = skip_tls_verification def validate(self) -> ValidationResponse: endpoint = Template(ApiV1Endpoints.USER_INFO.value).substitute( @@ -88,6 +89,7 @@ def validate(self) -> ValidationResponse: method="GET", api_key=self.apiv1_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return ValidationResponse( validation_name=self.validation_name, @@ -125,7 +127,7 @@ def validate(self) -> ValidationResponse: class RsyncRuntimeAddonExistsImportValidator(ImportValidators): def __init__( - self, host: str, username: str, apiv1_key: str, project_name: str, ca_path: str + self, host: str, username: str, apiv1_key: str, project_name: str, ca_path: str, skip_tls_verification: bool = False ): self.validation_name = "check if rsync is present" self.host = host @@ -133,11 +135,12 @@ def __init__( self.apiv1_key = apiv1_key self.project_name = project_name self.ca_path = ca_path + self.skip_tls_verification = skip_tls_verification def validate(self) -> ValidationResponse: rsync_enabled_runtime_id = -1 rsync_enabled_runtime_id = get_rsync_enabled_runtime_id( - host=self.host, api_key=self.apiv1_key, ca_path=self.ca_path + host=self.host, api_key=self.apiv1_key, ca_path=self.ca_path, skip_tls_verification=self.skip_tls_verification ) if rsync_enabled_runtime_id != -1: return ValidationResponse( @@ -160,7 +163,7 @@ def validate(self) -> ValidationResponse: class UsernameValidator(ExportValidators): def __init__( - self, host: str, username: str, apiv1_key: str, project_name: str, ca_path: str + self, host: str, username: str, apiv1_key: str, project_name: str, ca_path: str, skip_tls_verification: bool = False ): self.validation_name = "check if user is present" self.host = host @@ -168,6 +171,7 @@ def __init__( self.apiv1_key = apiv1_key self.project_name = project_name self.ca_path = ca_path + self.skip_tls_verification = skip_tls_verification def validate(self) -> ValidationResponse: endpoint = Template(ApiV1Endpoints.USER_INFO.value).substitute( @@ -180,6 +184,7 @@ def validate(self) -> ValidationResponse: method="GET", api_key=self.apiv1_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return ValidationResponse( validation_name=self.validation_name, @@ -224,6 +229,7 @@ def __init__( project_name: str, ca_path: str, project_slug: str, + skip_tls_verification: bool = False, ): self.validation_name = "Validate if the project {} belongs to user {}".format( project_name, username @@ -235,6 +241,7 @@ def __init__( self.project_name = project_name self.ca_path = ca_path self.project_slug = project_slug + self.skip_tls_verification = skip_tls_verification def validate(self) -> ValidationResponse: endpoint = Template(ApiV1Endpoints.PROJECT.value).substitute( @@ -247,6 +254,7 @@ def validate(self) -> ValidationResponse: method="GET", api_key=self.apiv1_key, ca_path=self.ca_path, + skip_tls_verification=self.skip_tls_verification, ) return ValidationResponse( validation_name=self.validation_name, @@ -292,6 +300,7 @@ def __init__( project_name: str, ca_path: str, project_slug: str, + skip_tls_verification: bool = False, ): self.validation_name = "check if rsync is present" self.host = host @@ -300,6 +309,7 @@ def __init__( self.project_name = project_name self.ca_path = ca_path self.project_slug = project_slug + self.skip_tls_verification = skip_tls_verification def validate(self) -> ValidationResponse: rsync_enabled_runtime_id = -1 @@ -310,9 +320,10 @@ def validate(self) -> ValidationResponse: api_key=self.apiv1_key, ca_path=self.ca_path, project_slug=self.project_slug, + skip_tls_verification=self.skip_tls_verification, ): rsync_enabled_runtime_id = get_rsync_enabled_runtime_id( - host=self.host, api_key=self.apiv1_key, ca_path=self.ca_path + host=self.host, api_key=self.apiv1_key, ca_path=self.ca_path, skip_tls_verification=self.skip_tls_verification ) if rsync_enabled_runtime_id != -1: return ValidationResponse( @@ -344,6 +355,7 @@ def initialize_import_validators( top_level_directory: str, apiv1_key: str, ca_path: str, + skip_tls_verification: bool = False, ) -> List[ImportValidators]: return [ DirectoriesAndFilesValidator( @@ -357,6 +369,7 @@ def initialize_import_validators( apiv1_key=apiv1_key, project_name=project_name, ca_path=ca_path, + skip_tls_verification=skip_tls_verification, ), RsyncRuntimeAddonExistsImportValidator( host=host, @@ -364,6 +377,7 @@ def initialize_import_validators( apiv1_key=apiv1_key, project_name=project_name, ca_path=ca_path, + skip_tls_verification=skip_tls_verification, ), ] @@ -376,6 +390,7 @@ def initialize_export_validators( apiv1_key: str, ca_path: str, project_slug: str, + skip_tls_verification: bool = False, ) -> List[ExportValidators]: return [ TopLevelDirectoryValidator(top_level_directory=top_level_directory), @@ -385,6 +400,7 @@ def initialize_export_validators( apiv1_key=apiv1_key, project_name=project_name, ca_path=ca_path, + skip_tls_verification=skip_tls_verification, ), ProjectBelongsToUserValidator( host=host, @@ -393,6 +409,7 @@ def initialize_export_validators( project_name=project_name, ca_path=ca_path, project_slug=project_slug, + skip_tls_verification=skip_tls_verification, ), RsyncRuntimeAddonExistsExportValidator( host=host, @@ -401,5 +418,6 @@ def initialize_export_validators( project_name=project_name, ca_path=ca_path, project_slug=project_slug, + skip_tls_verification=skip_tls_verification, ), ]