diff --git a/.gitignore b/.gitignore index 4060180..60e3481 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +ILSVRC2012_Images # PyInstaller # Usually these files are written by a python script from a template @@ -81,6 +82,7 @@ target/ # IPython profile_default/ ipython_config.py +tutorial.ipynb # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ diff --git a/README.md b/README.md index 79f83e9..dd25e9d 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ Commands: ### Version ```cmd $ odl version -odl version, current: 0.0.2, svc: 1.8 +odl version, current: 0.0.6, svc: 1.8 ``` ### Login @@ -64,11 +64,11 @@ Login with opendatalab username and password. If you haven't an opendatalab acco ```cmd $ odl login -Username []: wangrui@pjlab.org.cn +Username []: someone@example.com Password []: -Login successfully as wangrui@pjlab.org.cn +Login successfully as someone@example.com or -$ odl login -u wangrui@pjlab.org.cn +$ odl login -u someone@example.com Password[]: ``` @@ -77,7 +77,7 @@ Logout current opendatalab account ```cmd $ odl logout Do you want to logout? [y/N]: y -wangrui@pjlab.org.cn.com logout +someone@example.com logout ``` @@ -178,4 +178,4 @@ if __name__ == '__main__': ``` ## Documentation -More information can be found on the [documentation site](https://opendatalab.org.cn/docs) +More information can be found on the [documentation site](https://opendatalab.org.cn/docs) \ No newline at end of file diff --git a/opendatalab/__init__.py b/opendatalab/__init__.py index bc338b4..4b545f1 100644 --- a/opendatalab/__init__.py +++ b/opendatalab/__init__.py @@ -6,8 +6,7 @@ """OpenDataLab python SDK.""" -from opendatalab.__version__ import __version__ -from opendatalab.__version__ import __url__ +from opendatalab.__version__ import __url__, __version__ from opendatalab.client.client import Client -__all__ = ["__url__", "__version__", "Client"] \ No newline at end of file +__all__ = ["__url__", "__version__", "Client"] diff --git a/opendatalab/__version__.py b/opendatalab/__version__.py index 9e1f177..a24df51 100644 --- a/opendatalab/__version__.py +++ b/opendatalab/__version__.py @@ -8,7 +8,8 @@ """OpenDataLab python SDK version info.""" __url__ = "https://opendatalab.org.cn" -__version__ = "0.0.2" -__svc__ = '1.8' +__version__ = "0.0.10" +__svc__ = '2.0' odl_clientId = "kmz3bkwzlaa3wrq8pvwa" uaa_url_prefix = "https://sso.openxlab.org.cn/gw/uaa-be" +# clientSecret: 97gdrvwwzob86q2rneq2x95w6bnxkpqj5oak1ype diff --git a/opendatalab/cli/cmd.py b/opendatalab/cli/cmd.py index 6aa07c0..3984cbe 100644 --- a/opendatalab/cli/cmd.py +++ b/opendatalab/cli/cmd.py @@ -8,7 +8,7 @@ import click -from opendatalab.__version__ import __version__, __url__, __svc__ +from opendatalab.__version__ import __svc__, __url__, __version__ from opendatalab.cli.custom import CustomCommand from opendatalab.cli.utility import ContextInfo @@ -91,9 +91,7 @@ def login(obj: ContextInfo, username: str, password: str): implement_login(obj, username, password) -@command(synopsis=( - "$ odl ls dataset # list dataset files", - "$ odl ls dataset/sub_dir # list dataset/sub_dir files",)) +@command(synopsis=("$ odl ls dataset # list dataset files",)) @click.argument("name", nargs=1) @click.pass_obj def ls(obj: ContextInfo, name: str) -> None: @@ -121,7 +119,7 @@ def search(obj: ContextInfo, keywords): implement_search(obj, keywords) -@command(synopsis=("$ odl info dataset_name # show dataset info.",)) +@command(synopsis=("$ odl info dataset_name # show dataset info",)) @click.argument("name", nargs=1) @click.pass_obj def info(obj: ContextInfo, name): @@ -137,33 +135,30 @@ def info(obj: ContextInfo, name): @command(synopsis=("$ odl get dataset_name # get dataset files into local",)) @click.argument("name", nargs=1) @click.option( - "--thread", - "-t", - default=8, - help="Number of thread for download", - show_default=True, + "--dest", + "-d", + default='', + help="Desired dataset store path", + show_default=True ) @click.option( - "--limit_speed", - "-l", - default=0, - help="Download limit speed: KB/s, 0 is unlimited", - show_default=True, + "--workers", + "-w", + default = 8, + help= "number of workers", + show_default = True ) @click.pass_obj -def get(obj: ContextInfo, name, thread, limit_speed): +def get(obj: ContextInfo, name, dest, workers): """Get(Download) dataset files into local path.\f - Args: obj (ContextInfo): context info\f name (str): dataset name\f - thread (int): multi-thread number\f - limit_speed (int): limit download speed, for not limit set value to 0 + destination(str): desired dataset store path\f + wokers(str): number of workers\f """ - + from opendatalab.cli.get import implement_get - implement_get(obj, name, thread, limit_speed) - - + implement_get(obj, name, dest, workers) if __name__ == "__main__": cli() diff --git a/opendatalab/cli/get.py b/opendatalab/cli/get.py index ede4fae..c273026 100644 --- a/opendatalab/cli/get.py +++ b/opendatalab/cli/get.py @@ -2,26 +2,26 @@ # # Copyright 2022 Shanghai AI Lab. Licensed under MIT License. # -import logging import os import sys -import threading -import time -from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from typing import List import click -import oss2 from tqdm import tqdm -from opendatalab.cli.policy import service_agreement_url, private_policy_url +from opendatalab.cli.policy import private_policy_url, service_agreement_url from opendatalab.cli.utility import ContextInfo, exception_handler +from opendatalab.client import downloader from opendatalab.exception import OdlDataNotExistsError -oss2.set_stream_logger(level=logging.CRITICAL) -key_to_get_size_map = {} - +STATUS_DICT = { + # "noAuthRequired": "No Authorization needed", + "pendingRequirement": "In order to download this dataset, please fill in an application form via our website.", + "waiting": "Authorization submitted, please wait for the application result", + "rejected": "Authorization submitted, but rejected. Please contact us for more information", + # "accepted": "Authorization submitted, download available." +} def handler(dwCtrlType): if dwCtrlType == 0: # CTRL_C_EVENT @@ -32,177 +32,133 @@ def handler(dwCtrlType): if sys.platform == "win32": import win32api win32api.SetConsoleCtrlHandler(handler, True) - - -def get_oss_traffic_limit(limit_speed): - if limit_speed <= 0: - return 0 - if limit_speed < 245760: - return 245760 - if limit_speed > 838860800: - return 838860800 - return limit_speed - - -def download_object( - bucket: oss2.Bucket, - obj_key: str, - lock: threading.RLock, - root: str, - pbar: tqdm, - limit_speed: int, -): - def progress_callback(bytes_consumed, _): - with lock: - global key_to_get_size_map - if obj_key not in key_to_get_size_map: - key_to_get_size_map[obj_key] = 0 - - # sys.stdout.flush() - pbar.update(bytes_consumed - key_to_get_size_map[obj_key]) - key_to_get_size_map[obj_key] = bytes_consumed - - try: - headers = dict() - if limit_speed > 0: - headers[oss2.models.OSS_TRAFFIC_LIMIT] = str(limit_speed) - - filename = os.path.join(root, obj_key.split("/")[-1]) - - oss2.resumable_download( - bucket, - obj_key, - filename, - multiget_threshold=50 * 1024 * 1024, # 50M -> 500G(cdn) - part_size=10 * 1024 * 1024, # 10M - progress_callback=progress_callback, - num_threads=1, - headers=headers, - ) - return True, None - except oss2.exceptions.InconsistentError as e: - return False, e - except oss2.exceptions.ServerError as e: - return False, e - except Exception as e: - return False, e - - + + @exception_handler -def implement_get(obj: ContextInfo, name: str, thread: int, limit_speed: int, compressed: bool = True) -> None: +def implement_get(obj: ContextInfo, name: str, destination:str, num_workers:int): """ implementation for getting dataset files Args: obj (ContextInfo): name (str): thread (int): - limit_speed (int): compressed (bool): Returns: """ + # process dataset_name and split + # print(name) ds_split = name.split("/") - if len(ds_split) > 1: - dataset_name = ds_split[0] - sub_dir = "/".join(ds_split[1:]) - else: - dataset_name = name - sub_dir = "" - + dataset_name = ds_split[0] + if ds_split[-1] == '': + ds_split.pop() + single_file_flag = False + sub_dir = "/".join(ds_split[1:]) + + if len(ds_split) >= 2 and ('.' in ds_split[-1]): + single_file_flag = True + if len(ds_split) == 2: + # indicate README.md + file_name = ds_split[-1] + sub_dir = '' + elif len(ds_split) > 2: + sub_dir = "/".join(ds_split[1:-1]) + file_name = sub_dir + '/' + ds_split[-1] + # client init client = obj.get_client() - info_data_name = client.get_api().get_info(dataset_name)['name'] - dataset = client.get_dataset(info_data_name) - prefix = dataset.get_object_key_prefix(compressed) - bucket = dataset.get_oss_bucket() - - total_files, total_size = 0, 0 - obj_info_list = [] - download_info_body = [] - - for info in oss2.ObjectIteratorV2(bucket, prefix): - if not info.is_prefix() and not info.key.endswith("/"): - file_name = "/".join(info.key.split("/")[2:]) - f_name = Path(file_name).name - if not sub_dir: - obj_info_list.append(info.key) - total_files = total_files + 1 - total_size = total_size + info.size - download_info_body.append({"name": f_name, "size": info.size}) - elif sub_dir and file_name.startswith(sub_dir): - obj_info_list.append(info.key) - total_files = total_files + 1 - total_size = total_size + info.size - download_info_body.append({"name": f_name, "size": info.size}) - else: - pass - - if len(download_info_body) == 0: - raise OdlDataNotExistsError(error_msg=f"{name} not exists!") - - client.get_api().call_download_log(dataset_name, download_info_body) - click.echo(f"Scan done, total files: {len(obj_info_list)}, total size: {tqdm.format_sizeof(total_size,divisor=1024)}") + data_info = client.get_api().get_info(dataset_name) + + # basic info of dataset + info_dataset_name = data_info['name'] + info_dataset_id = data_info['id'] + + # check the download authorization status. + auth_status = client.get_api().get_auth_status(dataset_id=info_dataset_id) + if auth_status['state'] in STATUS_DICT.keys(): + click.echo(f"{STATUS_DICT[auth_status['state']]}") + sys.exit(1) + + # get risk level + info_dataset_risk = data_info['attrs'].get('riskLevel', 0) + info_dataset_url = data_info['attrs']['publishUrl'] + if info_dataset_risk > 3: + click.echo(f"Direct download for {dataset_name} is currently not available." + f"\nPlease visit the homepage {info_dataset_url} for more information.") + sys.exit(1) - download_data = client.get_api().get_download_record(dataset_name) + + dataset_res_dict = client.get_api().get_dataset_files(dataset_name=info_dataset_name, + prefix = sub_dir) + if not single_file_flag: + total_object = dataset_res_dict['total'] + # obj list constuct + obj_info_list = [] + for info in dataset_res_dict['list']: + curr_dict = {} + if not info['isDir']: + curr_dict['size'] = info['size'] + if single_file_flag: + curr_dict['name'] = info['path'] + elif len(sub_dir.split('/')) > 1: + # curr_dict['name'] = sub_dir + curr_dict['name'] = info['path'] + else: + curr_dict['name'] = info['path'] + obj_info_list.append(curr_dict) + else: + total_object = 1 + obj_info_list = [] + for info in dataset_res_dict['list']: + curr_dict = {} + if info['path'] == str(file_name): + curr_dict['size'] = info['size'] + curr_dict['name'] = info['path'] + obj_info_list.append(curr_dict) + + local_dir = destination + + download_data = client.get_api().get_download_record(info_dataset_name) has_download = download_data['hasDownload'] if not has_download: if click.confirm(f"<>: {service_agreement_url}" f"\n<>: {private_policy_url}" f"\n[Warning]: Before downloading, please agree above content."): - client.get_api().submit_download_record(dataset_name, download_data) + client.get_api().submit_download_record(info_dataset_name, download_data) else: - click.secho('bye~') + click.secho('See you next time~!') sys.exit(1) - - limit_speed_per_thread = get_oss_traffic_limit(int(limit_speed * 1024 * 8 / thread)) - - local_dir = Path.cwd().joinpath(dataset_name) + if click.confirm(f"Download files into local directory: {local_dir} ?", default=True): if not Path(local_dir).exists(): Path(local_dir).mkdir(parents=True) print(f"create local dir: {local_dir}") else: - click.secho('bye~') + click.secho('See you next time~!') sys.exit(1) - pbar = tqdm(total=total_size, unit="B", unit_divisor=1024, unit_scale=True, position=0) - - index = 0 - is_running = True - while is_running: - global key_to_get_size_map - bucket = dataset.refresh_oss_bucket() - error_object_list = get_objects_retry(bucket=bucket, - local_dir=local_dir, - obj_info_list=obj_info_list, - pbar=pbar, - limit_speed_per_thread=limit_speed_per_thread, - thread=thread) - index = index + 1 - time.sleep(1) - if len(error_object_list) > 0: - obj_info_list = error_object_list - is_running = True - continue - else: - is_running = False - break - - pbar.close() - print(f"{dataset_name} ,download completed!") - - -def get_objects_retry(bucket, local_dir, obj_info_list, pbar, limit_speed_per_thread, thread) -> List: - lock = threading.RLock() - error_object_list = [] - with ThreadPoolExecutor(max_workers=thread) as executor: - future_to_obj = {executor.submit( - download_object, bucket, obj, lock, local_dir, pbar, limit_speed_per_thread - ): obj for obj in obj_info_list} - - for future in as_completed(future_to_obj): - obj = future_to_obj[future] - success, _ = future.result() - if not success: - error_object_list.append(obj) - return error_object_list + with tqdm(total = total_object) as pbar: + for idx in range(total_object): + dataset_seg_list = [] + dataset_seg_list.append(obj_info_list[idx]) + download_urls_list = client.get_api().get_dataset_download_urls( + dataset_id=info_dataset_id, + dataset_list=dataset_seg_list) + # print(download_urls_list) + url_download = download_urls_list[0]['url'] + filename = download_urls_list[0]['name'] + # print(url_download, filename) + click.echo(f"Downloading No.{idx+1} of total {total_object} files") + if os.path.exists((os.path.join(destination, info_dataset_name, filename))): + # print(os.path.join(destination, info_dataset_name, filename)) + click.echo('target already exists, jumping to next!') + pbar.update(1) + continue + + downloader.Downloader(url = url_download, + filename= filename, + download_dir = os.path.join(destination, info_dataset_name), + blocks_num= num_workers).start() + pbar.update(1) + + click.echo(f"\nDownload Complete!") \ No newline at end of file diff --git a/opendatalab/cli/info.py b/opendatalab/cli/info.py index afff4d4..67e9870 100644 --- a/opendatalab/cli/info.py +++ b/opendatalab/cli/info.py @@ -11,90 +11,82 @@ from opendatalab.utils import bytes2human -@exception_handler -def implement_info(obj: ContextInfo, dataset: str) -> None: - """ - implement for displaying dataset info - Args: - obj (ContextInfo): context object - dataset (str): dataset name +def _format_types(info_data, type_name): + types_str = "" + if type_name in info_data['attrs'].keys(): + types_list = info_data['attrs'][type_name] + if types_list and len(types_list) > 0: + types_str = ", ".join([x['name']['en'] for x in types_list]) - Returns: + return types_str - """ - client = obj.get_client() - odl_api = client.get_api() - info_data = odl_api.get_info(dataset) - similar_data_list = odl_api.get_similar_dataset(dataset) - data_introd = info_data['introduction'] +def reformat_info_data(info_data): + license_str = _format_types(info_data, 'license') + publisher_str = _format_types(info_data, 'publisher') + media_types_str = _format_types(info_data, 'mediaTypes') + label_types_str = _format_types(info_data, 'labelTypes') + task_types_str = _format_types(info_data, 'taskTypes') + tags_str = _format_types(info_data, 'tags') + + data_introduction = info_data['introduction']['en'] introduction_str = "" - if data_introd and len(data_introd) > 0: - introduction_str = data_introd[:97] + '...' - - license_list = info_data['licenses'] - license_str = "" - if license_list and len(license_list) > 0: - license_str = ", ".join([x['name'] for x in license_list]) - - publisher_list = info_data['publisher'] - publisher_str = "" - if publisher_list and len(publisher_list) > 0: - publisher_str = ", ".join([x['name'] for x in publisher_list]) - - media_types_list = info_data['mediaTypes'] - media_types_str = "" - if media_types_list and len(media_types_list) > 0: - media_types_str = ", ".join([x['name'] for x in media_types_list]) - - label_types_list = info_data['labelTypes'] - label_types_str = "" - if label_types_list and len(label_types_list) > 0: - label_types_str = ", ".join([x['name'] for x in label_types_list]) - - task_types_list = info_data['taskTypes'] - task_types_str = "" - if label_types_list and len(task_types_list) > 0: - task_types_str = ", ".join([x['name'] for x in task_types_list]) - - tags_list = info_data['tags'] - tags_str = "" - if tags_list and len(tags_list) > 0: - tags_str = ", ".join([x['name'] for x in tags_list]) - - citation_data = info_data['citation'] + if data_introduction and len(data_introduction) > 0: + introduction_str = data_introduction[:97] + '...' + + citation_data = info_data['attrs']['citation'] citation_str = "" if citation_data and len(citation_data) > 0: citation_str = citation_data.strip("```").replace('\r', '').replace('\n', '') + similar_data_list = info_data['similar'] similar_ds_str = "" if similar_data_list and len(similar_data_list) > 0: similar_ds_str = ", ".join([x['name'] for x in similar_data_list]) - info_data = { + info_data_result = { 'Name': info_data['name'], - 'File Bytes': str(bytes2human(info_data['fileBytes'])), - 'File Count': str(info_data['fileCount']), + 'File Bytes': str(bytes2human(info_data['attrs']['fileBytes'])), + 'File Count': str(info_data['attrs']['fileCount']), 'Introduction': introduction_str, - 'Issue Time': info_data['publishDate'], + 'Issue Time': info_data['attrs']['publishDate'], 'License': license_str, 'Author': publisher_str, 'Data Type': media_types_str, 'Label Type': label_types_str, 'Task Type': task_types_str, 'Tags': tags_str, - 'HomePage': info_data['publishUrl'], + 'HomePage': info_data['attrs']['publishUrl'], 'Citation': citation_str, 'Similar Datasets': similar_ds_str, } + return info_data_result + + +@exception_handler +def implement_info(obj: ContextInfo, dataset: str) -> None: + """ + implement for displaying dataset info + Args: + obj (ContextInfo): context object + dataset (str): dataset name + + Returns: + + """ + client = obj.get_client() + odl_api = client.get_api() + info_data = odl_api.get_info(dataset) + + info_data_result = reformat_info_data(info_data) console = Console() table = Table(show_header=True, header_style='bold cyan', box=box.ASCII2) table.add_column("Field", width=20, justify='full', overflow='fold') table.add_column("Content", width=120, justify='full', overflow='fold') - for key in info_data.keys(): - val = info_data[key] + for key in info_data_result.keys(): + val = info_data_result[key] val = "" if not val else val table.add_row(key, val, end_section=True) diff --git a/opendatalab/cli/login.py b/opendatalab/cli/login.py index b61c570..1aee09a 100644 --- a/opendatalab/cli/login.py +++ b/opendatalab/cli/login.py @@ -2,6 +2,7 @@ # Copyright 2022 Shanghai AI Lab. Licensed under MIT License. # import sys + from opendatalab.cli.utility import ContextInfo, exception_handler @@ -10,7 +11,6 @@ def implement_login(obj: ContextInfo, username: str, password: str) -> None: try: client = obj.get_client() odl_api = client.get_api() - # config_json = odl_api.login(username=username, password=password) config_json = odl_api.odl_auth(account=username, password=password) obj.update_config(config_json) @@ -19,4 +19,3 @@ def implement_login(obj: ContextInfo, username: str, password: str) -> None: sys.exit(-1) print(f"Login successfully as {username}") - diff --git a/opendatalab/cli/logout.py b/opendatalab/cli/logout.py index b6557bf..ed756d8 100644 --- a/opendatalab/cli/logout.py +++ b/opendatalab/cli/logout.py @@ -7,7 +7,7 @@ @exception_handler def implement_logout(obj: ContextInfo) -> None: - + ##TODO: add /api/users/sync/logout config_content = obj.get_config_content() username = "" if 'user.email' in config_content.keys(): diff --git a/opendatalab/cli/ls.py b/opendatalab/cli/ls.py index a0aa1e8..4c249e1 100644 --- a/opendatalab/cli/ls.py +++ b/opendatalab/cli/ls.py @@ -3,8 +3,8 @@ # import sys -import oss2 from rich import box +from rich import print as rprint from rich.console import Console from rich.table import Table @@ -14,7 +14,7 @@ @exception_handler -def implement_ls(obj: ContextInfo, dataset: str) -> None: +def implement_ls(obj: ContextInfo, dataset: str): """ implementation for show dataset files Args: @@ -33,27 +33,19 @@ def implement_ls(obj: ContextInfo, dataset: str) -> None: sub_dir = "" client = obj.get_client() - info_data_name = client.get_api().get_info(dataset_name)['name'] - dataset_instance = client.get_dataset(dataset_name=info_data_name) - - bucket = dataset_instance.get_oss_bucket() - prefix = dataset_instance.get_object_key_prefix(compressed=True) + info_dataset_name = client.get_api().get_info(dataset_name)['name'] + dataset_instance = client.get_dataset(dataset_name=info_dataset_name) + dataset_res_dict = client.get_api().get_dataset_files(dataset_name=info_dataset_name, prefix = sub_dir) + + # generate output info dict object_info_dict = {} total_files, total_size = 0, 0 - for info in oss2.ObjectIteratorV2(bucket, prefix): - if not info.is_prefix() and not info.key.endswith("/"): - file_name = "/".join(info.key.split("/")[2:]) - if not sub_dir: - object_info_dict[file_name] = bytes2human(info.size) - total_files = total_files + 1 - total_size = total_size + info.size - elif sub_dir and file_name.startswith(sub_dir): - object_info_dict[file_name] = bytes2human(info.size) - total_files = total_files + 1 - total_size = total_size + info.size - else: - pass + total_files = dataset_res_dict['total'] + for info in dataset_res_dict['list']: + object_info_dict[info['path']] = bytes2human(info['size']) + total_size += info['size'] + if len(object_info_dict) == 0: raise OdlAccessDeniedError() @@ -66,8 +58,7 @@ def implement_ls(obj: ContextInfo, dataset: str) -> None: table.add_column("File Name", min_width=20, justify='left') table.add_column("Size", width=12, justify='left') - print(f"total: {total_files}, size: {bytes2human(total_size)}") + print(f"Total file count: {total_files}, Size: {bytes2human(total_size)}") for key, val in sorted_object_info_dict.items(): table.add_row(key, val, end_section=True) - console.print(table) diff --git a/opendatalab/cli/search.py b/opendatalab/cli/search.py index fd22a3a..b9ed42b 100644 --- a/opendatalab/cli/search.py +++ b/opendatalab/cli/search.py @@ -2,9 +2,12 @@ # Copyright 2022 Shanghai AI Lab. Licensed under MIT License. # import re +import time + from rich import box from rich.console import Console from rich.table import Table + from opendatalab.cli.utility import ContextInfo, exception_handler from opendatalab.utils import bytes2human @@ -48,8 +51,11 @@ def implement_search(obj: ContextInfo, keywords: str) -> None: """ client = obj.get_client() odl_api = client.get_api() + import time + time_start =time.time() result_list = odl_api.search_dataset(keywords) - + time_end = time.time() + # print('-------------time_consuming--------', time_end - time_start, 's') console = Console() table = Table(show_header=True, header_style='bold cyan', box=box.ASCII2) table.add_column("Name", min_width=10, justify='left', overflow='fold') @@ -66,18 +72,33 @@ def implement_search(obj: ContextInfo, keywords: str) -> None: for _, res in enumerate(result_list): ds_name = res['name'] ds_name_rich = rich_content_str(keywords=keywords, content=ds_name) - ds_data_types = ','.join([dmt['name'] for dmt in res['mediaTypes']]) - ds_file_byte = bytes2human(res['fileBytes']) - ds_file_count = res['fileCount'] - ds_task_types = ','.join([dtt['name'] for dtt in res['taskTypes']]) - ds_task_types_rich = rich_content_str(keywords=keywords, content=ds_task_types) - ds_label_types = ','.join([dlt['name'] for dlt in res['labelTypes']]) - ds_label_types_rich = rich_content_str(keywords=keywords, content=ds_label_types) ds_view_count = res['viewCount'] - ds_desc = res['introductionText'][:97] + '...' + ds_desc = res['introduction']['en'][:97] + '...' ds_desc_rich = rich_content_str(keywords=keywords, content=ds_desc) + ds_attr_info = res['attrs'] + ds_file_byte = bytes2human(ds_attr_info.get('fileBytes', 0)) + ds_file_count = ds_attr_info.get('fileCount',0) + + ds_data_types = _get_complex_types_str(ds_attr_info, 'mediaTypes') + ds_task_types = _get_complex_types_str(ds_attr_info, 'taskTypes') + ds_label_types = _get_complex_types_str(ds_attr_info, 'labelTypes') + + ds_task_types_rich = rich_content_str(keywords=keywords, content=ds_task_types) + ds_label_types_rich = rich_content_str(keywords=keywords, content=ds_label_types) + table.add_row(ds_name_rich, ds_data_types, str(ds_file_byte), str(ds_file_count), ds_task_types_rich, ds_label_types_rich, str(ds_view_count), ds_desc_rich, end_section=True) - console.print(table) \ No newline at end of file + console.print(table) + + +def _get_complex_types_str(ds_attr_info, type_name): + if not (ds_attr_info or type_name): + return "" + + if type_name in ds_attr_info.keys(): + type_list = ds_attr_info[type_name] + return ','.join([d['name']['en'] for d in type_list]) + else: + return "" diff --git a/opendatalab/cli/upgrade.py b/opendatalab/cli/upgrade.py index d23eb58..49412ed 100644 --- a/opendatalab/cli/upgrade.py +++ b/opendatalab/cli/upgrade.py @@ -1,8 +1,9 @@ # # Copyright 2022 Shanghai AI Lab. Licensed under MIT License. # -import sys import operator +import sys + import click from opendatalab.__version__ import __version__ diff --git a/opendatalab/cli/utility.py b/opendatalab/cli/utility.py index 699d7c7..d671766 100644 --- a/opendatalab/cli/utility.py +++ b/opendatalab/cli/utility.py @@ -4,17 +4,19 @@ # """OpenDataLab CLI utility functions.""" -import sys import json +import sys from functools import wraps from typing import Any, Callable, TypeVar + import click -from opendatalab.__version__ import __version__ +from rich import print as rprint +from opendatalab.__version__ import __version__ from opendatalab.cli.config import config as client_config from opendatalab.client import Client -from opendatalab.utils import UUID from opendatalab.exception import OpenDataLabError +from opendatalab.utils import UUID _Callable = TypeVar("_Callable", bound=Callable[..., None]) @@ -34,6 +36,8 @@ def __init__(self, url: str, token: str): self.check_ret = 0 self.install_version = __version__ self.latest_version = None + self.warning = "[red]WARNING[/red]:This CLI tool is deprecated and will be removed in a future release.\nThe [bold yellow]opendatalab(odl)[/bold yellow] pkg has been deprecated and will no longer be supported in few weeks.\nWe recommend that you switch to the [bold yellow]openxlab[/bold yellow] pkg, which accept same username/password,\nprovides the same functionality and other enhanced AI friendly features.\nMore details please refer to [blue]https://openxlab.org.cn/datasets[/blue]\n" + rprint(self.warning) def get_client(self) -> Client: return Client(self.url, self.token, self.cookie) diff --git a/opendatalab/client/api.py b/opendatalab/client/api.py index 189e5fa..ac968d7 100644 --- a/opendatalab/client/api.py +++ b/opendatalab/client/api.py @@ -19,7 +19,103 @@ def __init__(self, host, token, odl_cookie): self.host = host self.token = token self.odl_cookie = odl_cookie + + def get_dataset_files(self, dataset_name:str, prefix:str): + """ https request retrieve dataset files + Args: + dataset (str): dataset name + + Returns: + result_dict: 2 keys: + dict['list']:contain list of files + dict['total']:files count. + """ + + header_dict = {"X-OPENDATALAB-API-TOKEN": self.token, + "Cookie": f"opendatalab_session={self.odl_cookie}", + "User-Agent": UUID, + "accept" : "application/json" + } + data = {"recursive": True, + "prefix":prefix} + resp = requests.get( + url = f"{self.host}/api/datasets/{dataset_name}/files", + params = data, + headers = header_dict + ) + if resp.status_code != 200: + if resp.status_code == 404: + raise OdlDataNotExistsError() + elif resp.status_code == 401: + raise OdlAuthError() + elif resp.status_code == 403: + raise OdlAccessDeniedError() + elif resp.status_code == 412: + raise OdlAccessCdnError() + elif resp.status_code == 500: + raise OdlAccessDeniedError() + else: + raise RespError(resp_code=resp.status_code, error_msg=resp.reason) + + result_dict = resp.json()['data'] + + return result_dict + + def get_dataset_download_urls(self, dataset_id:int, dataset_list:list): + """get Dataset segments downloadable url + + Args: + dataset (str): dataset name + dataset_list (list): list of dict contain segment size and name + + Returns: + download_url_list: list of dict contain segment name and executable url. + """ + resp = requests.post( + f"{self.host}/api/track/datasets/download/{dataset_id}", + data = json.dumps(dataset_list), + headers={ + "Content-Type": "application/json", + "Cookie": f"opendatalab_session={self.odl_cookie}", + "User-Agent": f"opendatalab-python-sdk/{__version__}", + "accept": "application/json" + } + ) + if resp.status_code != 200: + print(f"{OpenDataLabError(resp.status_code, resp.text)}") + sys.exit(-1) + + download_url_list = resp.json()['data'] + if not download_url_list: + click.secho(f"No datasets matched!", fg='red') + sys.exit(-1) + + return download_url_list + + def get_auth_status(self, dataset_id:int): + """Get Dataset authentication status. + + Args: + dataset_id (int): dataset id + """ + resp = requests.get( + f"{self.host}/api/datasets/{dataset_id}/downloadAuth", + headers= { + "X-OPENDATALAB-API-TOKEN": self.token, + "Cookie": f"opendatalab_session={self.odl_cookie}", + "User-Agent": UUID, + "accept" : "application/json" + } + ) + if resp.status_code != 200: + click.echo(f"{OpenDataLabError(resp.status_code, resp.text)}") + sys.exit(-1) + result_status = resp.json()['data'] + + return result_status + + def get_dataset_sts(self, dataset, expires=900): """Get dataset sts by dataset_name Args: @@ -58,43 +154,22 @@ def get_dataset_sts(self, dataset, expires=900): # print(f"sts api, headers: {resp.headers}, text: {resp.text}") return resp.json()["data"] - @DeprecationWarning - def login(self, username: str, password: str): - data = { - "email": username, - "password": password, - } - data = json.dumps(data) - resp = requests.post( - f"{self.host}/api/users/login", - data=data, - headers={"Content-Type": "application/json"}, - ) - if resp.status_code != 200: - raise OdlAuthError(resp.status_code, resp.text) - - cookies_dict = requests.utils.dict_from_cookiejar(resp.cookies) - - if 'opendatalab_session' in cookies_dict.keys(): - opendatalab_session = cookies_dict['opendatalab_session'] - else: - raise OpenDataLabError(resp.status_code, "No opendatalab_session") - - config_json = { - 'user.email': username, - 'user.token': opendatalab_session, - } - - return config_json - def search_dataset(self, keywords): - resp = requests.get( - f"{self.host}/api/datasets/?pageSize=25&keywords={keywords}", + resp = requests.post( # f"{self.host}/api/datasets/?pageSize=25&keywords={keywords}", + f"{self.host}/api/datasets/list", headers={"X-OPENDATALAB-API-TOKEN": self.token, "Cookie": f"opendatalab_session={self.odl_cookie}", "User-Agent": f"opendatalab-python-sdk/{__version__}", + "Content-Type": "application/json" }, + data=json.dumps({ + "backend": False, + "keywords": keywords, + "pageSize": 25, + "state": ["online"], + }) ) + # print(resp.status_code, resp.url) if resp.status_code != 200: print(f"{OpenDataLabError(resp.status_code, resp.text)}") sys.exit(-1) @@ -113,10 +188,10 @@ def get_similar_dataset(self, dataset): headers={"X-OPENDATALAB-API-TOKEN": self.token, "Cookie": f"opendatalab_session={self.odl_cookie}", "User-Agent": f"opendatalab-python-sdk/{__version__}", + "Content-Type": "application/json" }, ) if resp.status_code != 200: - # print(f"{(resp.status_code, resp.text)}") sys.exit(-1) data = resp.json()['data'] @@ -124,10 +199,11 @@ def get_similar_dataset(self, dataset): def get_info(self, dataset): resp = requests.get( - f"{self.host}/api/datasets/{dataset}", + f"{self.host}/api/datasets/{dataset}?backend=false", headers={"X-OPENDATALAB-API-TOKEN": self.token, "Cookie": f"opendatalab_session={self.odl_cookie}", "User-Agent": f"opendatalab-python-sdk/{__version__}", + "Content-Type": "application/json" }, ) if resp.status_code != 200: @@ -151,6 +227,7 @@ def call_download_log(self, dataset, download_info): headers={"Content-Type": "application/json", "Cookie": f"opendatalab_session={self.odl_cookie}", "User-Agent": f"opendatalab-python-sdk/{__version__}", + "Content-Type": "application/json" }, ) @@ -164,6 +241,7 @@ def get_download_record(self, dataset): headers={"Content-Type": "application/json", "Cookie": f"opendatalab_session={self.odl_cookie}", "User-Agent": f"opendatalab-python-sdk/{__version__}", + "Content-Type": "application/json" }, ) @@ -192,6 +270,7 @@ def submit_download_record(self, dataset, download_data): headers={"Content-Type": "application/json", "Cookie": f"opendatalab_session={self.odl_cookie}", "User-Agent": f"opendatalab-python-sdk/{__version__}", + "Content-Type": "application/json" }, ) @@ -211,7 +290,6 @@ def odl_auth(self, account, password): data=data, headers={"Content-Type": "application/json"}, ) - if resp.status_code != 200: raise OdlAuthError(resp.status_code, resp.text) diff --git a/opendatalab/client/client.py b/opendatalab/client/client.py index cf4a3f3..8b6db9a 100644 --- a/opendatalab/client/client.py +++ b/opendatalab/client/client.py @@ -2,10 +2,10 @@ # # Copyright 2022 Shanghai AI Lab. Licensed under MIT License. # +from opendatalab.__version__ import __url__ from opendatalab.client.api import OpenDataLabAPI from opendatalab.dataset.dataset import Dataset from opendatalab.utils import get_api_token_from_env -from opendatalab.__version__ import __url__ class Client: @@ -31,10 +31,6 @@ def get_dataset(self, dataset_name: str) -> Dataset: f"{self.host}/datasets/{dataset_name}", self.token, self.odl_cookie) return self.dataset_map[dataset_name] - def get(self, dataset_name: int, filepath: str): - dataset = self.get_dataset(dataset_name) - return dataset.get(filepath) - def get_api(self): self.odl_api = OpenDataLabAPI(self.host, self.token, self.odl_cookie) return self.odl_api diff --git a/opendatalab/client/downloader.py b/opendatalab/client/downloader.py new file mode 100644 index 0000000..6370597 --- /dev/null +++ b/opendatalab/client/downloader.py @@ -0,0 +1,371 @@ +# -*- coding: utf-8 -*- +import glob +import hashlib +import os +import sys +import threading +import time + +import requests + + +class Worker: + def __init__(self, name: str, url: str, range_start, range_end, cache_dir, finish_callback): + self.name = name + self.url = url + self.cache_filename = os.path.join(cache_dir, name + ".odl") + self.range_start = range_start # fixed + self.range_end = range_end # fixed + self.range_curser = range_start # curser dynamic + self.finish_callback = finish_callback + self.terminate_flag = False + self.FINISH_TYPE = "" # DONE\HELP\RETIRE + + def __run(self): + chunk_size = 1 * 1024 + header = { + 'Range': f'Bytes={self.range_curser}-{self.range_end}', + } + req = requests.get(self.url, stream=True, headers=header) + + if 200 <= req.status_code <= 299: + with open(self.cache_filename, "wb") as cache: + for chunk in req.iter_content(chunk_size=chunk_size): + if self.terminate_flag: + break + cache.write(chunk) + self.range_curser += len(chunk) + if not self.terminate_flag: + self.FINISH_TYPE = "DONE" + req.close() + self.finish_callback(self) + + def start(self): + threading.Thread(target=self.__run).start() + + def help(self): + self.FINISH_TYPE = "HELP" + self.terminate_flag = True + + def retire(self): + self.FINISH_TYPE = "RETIRE" + self.terminate_flag = True + + def __lt__(self, another): + return self.range_start < another.range_start + + def get_progress(self): + """progress for each worker""" + _progress = { + "curser": self.range_curser, + "start": self.range_start, + "end": self.range_end + } + return _progress + + +class Downloader: + def __init__(self, url: str, filename:str, download_dir: str, blocks_num: int = 8): + assert 0 <= blocks_num <= 32 + self.prefix_flag = False + if len(filename.split('/')) == 1: + self.filename = filename + self.prefix = '' + else: + self.filename = filename.split('/')[-1] + self.prefix_flag = True + self.prefix = os.path.dirname(filename) + self.url = url + self.download_dir = download_dir + + + # self.download_dir = os.path.join(download_dir, f".{os.sep}odl{os.sep}") + self.blocks_num = blocks_num + self.__bad_url_flag = False + self.file_size = self.__get_size() + if self.file_size <= 1: + return + if not self.__bad_url_flag: + # make download dir + if not os.path.exists(self.download_dir): + os.makedirs(self.download_dir) + + # make cache dir + if self.prefix_flag: + self.cache_dir = os.path.join(self.download_dir,self.prefix,'.cache/') + else: + self.cache_dir = os.path.join(self.download_dir,'.cache/') + if not os.path.exists(self.cache_dir): + os.makedirs(self.cache_dir) + # print(self.url, self.file_size) + # slicing + self.start_since = time.time() + # worker container + self.workers = [] + self.LOG = self.__get_log_from_cache() + self.__done = threading.Event() + self.__download_record = [] + threading.Thread(target=self.__supervise).start() + # main + self.__main_thread_done = threading.Event() + # + readable_size = self.__get_readable_size(self.file_size) + pathfilename = os.path.join(self.download_dir, self.prefix,self.filename) + + def __get_size(self): + try: + # print(self.url) + # req = requests.head(self.url) + # print(req.headers) + # content_length = req.headers["Content-Length"] + resp = requests.get(self.url,stream=True) + content_length = resp.headers["Content-Length"] + # print(f"-------------{content_length}--------------") + resp.close() + # print(req.headers) + # print(req.headers["Content-Length"]) + return int(content_length) + except Exception as err: + self.__bad_url_flag = True + self.__whistleblower(f"[Error] {err}") + return 0 + + def __get_readable_size(self, size): + units = ["B", "KB", "MB", "GB", "TB", "PB"] + unit_index = 0 + K = 1024.0 + while size >= K: + size = size / K + unit_index += 1 + return "%.1f %s" % (size, units[unit_index]) + + def __get_cache_filenames(self): + return glob.glob(f"{self.cache_dir}{self.filename}.*.odl") + + def __get_ranges_from_cache(self): + # like ./cache/filename.1120.odl + ranges = [] + for filename in self.__get_cache_filenames(): + size = os.path.getsize(filename) + if size > 0: + cache_start = int(filename.split(".")[-2]) + cache_end = cache_start + size - 1 + ranges.append((cache_start, cache_end)) + ranges.sort(key=lambda x: x[0]) + return ranges + + def __get_log_from_cache(self): + ranges = self.__get_ranges_from_cache() + LOG = [] + if len(ranges) == 0: + LOG.append((0, self.file_size - 1)) + else: + for i, (start, end) in enumerate(ranges): + if i == 0: + if start > 0: + LOG.append((0, start - 1)) + next_start = self.file_size if i == len(ranges) - 1 else ranges[i + 1][0] + if end < next_start - 1: + LOG.append((end + 1, next_start - 1)) + return LOG + + def __increase_ranges_slice(self, ranges: list, minimum_size=1024 * 1024): + assert len(ranges) > 0 + block_size = [end - start + 1 for start, end in ranges] + index_of_max = block_size.index(max(block_size)) + start, end = ranges[index_of_max] + halfsize = block_size[index_of_max] // 2 + if halfsize >= minimum_size: + new_ranges = [x for i, x in enumerate(ranges) if i != index_of_max] + new_ranges.append((start, start + halfsize)) + new_ranges.append((start + halfsize + 1, end)) + else: + new_ranges = ranges + return new_ranges + + def __ask_for_work(self, worker_num: int): + """ask for work, return[work_range],update self.LOG""" + assert worker_num > 0 + task = [] + LOG_num = len(self.LOG) + # no work now, ask for new work + if LOG_num == 0: + self.__share_the_burdern() + return [] + # enough work, consume + if LOG_num >= worker_num: + for _ in range(worker_num): + task.append(self.LOG.pop(0)) + # too much work + else: + slice_num = worker_num - LOG_num + task = self.LOG + self.LOG = [] + for _ in range(slice_num): + task = self.__increase_ranges_slice(task) + task.sort(key=lambda x: x[0]) + return task + + def __share_the_burdern(self, minimum_size=1024 * 1024): + """Find the heavy worker, and introduce helper""" + max_size = 0 + max_size_name = "" + for w in self.workers: + p = w.get_progress() + size = p["end"] - p["curser"] + 1 + if size > max_size: + max_size = size + max_size_name = w.name + if max_size >= minimum_size: + for w in self.workers: + if w.name == max_size_name: + w.help() + break + + def __give_back_work(self, worker: Worker): + """Take unfinished work""" + progress = worker.get_progress() + curser = progress["curser"] + end = progress["end"] + if curser <= end: + self.LOG.append((curser, end)) + self.LOG.sort(key=lambda x: x[0]) + + def __give_me_a_worker(self, start, end): + worker = Worker(name=f"{self.filename}.{start}", + url=self.url, range_start=start, range_end=end, cache_dir=self.cache_dir, + finish_callback=self.__on_worker_finish, + ) + return worker + + def __whip(self, worker: Worker): + """assign new job""" + self.workers.append(worker) + self.workers.sort() + worker.start() + + def __on_worker_finish(self, worker: Worker): + assert worker.FINISH_TYPE != "" + self.workers.remove(worker) + # need helper + if worker.FINISH_TYPE == "HELP": + self.__give_back_work(worker) + self.workaholic(2) + # job done + elif worker.FINISH_TYPE == "DONE": + # get one more job + self.workaholic(1) + elif worker.FINISH_TYPE == "RETIRE": + self.__give_back_work(worker) + # Job Done, Sewing! + if self.workers == [] and self.__get_log_from_cache() == []: + self.__sew() + + def start(self): + # workers assembly + for start, end in self.__ask_for_work(self.blocks_num): + worker = self.__give_me_a_worker(start, end) + self.__whip(worker) + # wait till done + self.__main_thread_done.wait() + + def stop(self): + for w in self.workers: + w.retire() + while len(self.workers) != 0: + time.sleep(0.5) + self.LOG = self.__get_log_from_cache() + + def workaholic(self, n=1): + """ no work no life""" + for s, e in self.__ask_for_work(n): + worker = self.__give_me_a_worker(s, e) + self.__whip(worker) + + def restart(self): + self.stop() + # worker assembly again! + for start, end in self.__ask_for_work(self.blocks_num): + worker = self.__give_me_a_worker(start, end) + self.__whip(worker) + + def __supervise(self): + """worker and download status supervisor""" + REFRESH_INTERVAL = 2 + # serve as a time window-length + LAG_COUNT = 5 + WAIT_TIMES_BEFORE_RESTART = 30 + SPEED_DEGRADATION_PERCENTAGE = 0.3 + self.__download_record = [] + maxspeed = 0 + wait_times = WAIT_TIMES_BEFORE_RESTART + while not self.__done.is_set(): + dwn_size = sum([os.path.getsize(cachefile) for cachefile in self.__get_cache_filenames()]) + self.__download_record.append({"timestamp": time.time(), "size": dwn_size}) + if len(self.__download_record) > LAG_COUNT: + self.__download_record.pop(0) + s = self.__download_record[-1]["size"] - self.__download_record[0]["size"] + t = self.__download_record[-1]["timestamp"] - self.__download_record[0]["timestamp"] + if not t == 0: + EPSILON = 1e-5 + speed = s / t + readable_speed = self.__get_readable_size(speed) + # print(s,t,readable_speed) + percentage = self.__download_record[-1]["size"] / self.file_size * 100 + status_msg = f"\r[Current File Download Info] File Progress: {percentage:.2f} % | Speed: {readable_speed}/s | Number of Workers: {len(self.workers)} | Time Elapsed: {(time.time() - self.start_since):.0f}s | ETA: {((self.file_size- dwn_size)/(speed+EPSILON)):.2f}s" + self.__whistleblower(status_msg) + # speed monitor + maxspeed = max(maxspeed, speed) + # tolerance reached + time_over = wait_times < 0 + # not finished yet + not_finished = not self.__done.is_set() + + # still running fast enough + speed_drops_significantly = (maxspeed - speed + EPSILON) / (maxspeed + EPSILON) > SPEED_DEGRADATION_PERCENTAGE + speed_under_threshold = speed < 1024 * 1024 + scene_1 = speed_drops_significantly and speed_under_threshold + # running slow + scene_2 = speed < 100 * 1024 + if time_over and not_finished and (scene_1 or scene_2): + self.__whistleblower("\r[info] speed degradation, restarting...") + self.restart() + maxspeed = 0 + wait_times = WAIT_TIMES_BEFORE_RESTART + else: + wait_times -= 1 + time.sleep(REFRESH_INTERVAL) + + def __sew(self): + self.__done.set() + chunk_size = 10 * 1024 * 1024 + with open(f"{os.path.join(self.download_dir, self.prefix, self.filename)}", "wb") as f: + for start, _ in self.__get_ranges_from_cache(): + cache_filename = f"{self.cache_dir}{self.filename}.{start}.odl" + with open(cache_filename, "rb") as cache_file: + data = cache_file.read(chunk_size) + while data: + f.write(data) + f.flush() + data = cache_file.read(chunk_size) + self.clear() + self.__whistleblower("\r") + self.__main_thread_done.set() + + def __whistleblower(self, saying: str): + sys.stdout.write(saying) + + def md5(self): + chunk_size = 1024 * 1024 + filename = f"{os.path.join(self.download_dir, self.prefix, self.filename)}" + md5 = hashlib.md5() + with open(filename, "rb") as f: + data = f.read(chunk_size) + while data: + md5.update(data) + data = f.read(chunk_size) + return md5.hexdigest() + + def clear(self): + for filename in self.__get_cache_filenames(): + os.remove(filename) \ No newline at end of file diff --git a/opendatalab/client/uaa.py b/opendatalab/client/uaa.py index be720e5..00f90f8 100644 --- a/opendatalab/client/uaa.py +++ b/opendatalab/client/uaa.py @@ -1,14 +1,14 @@ +import json import sys +import time +from base64 import b64decode, b64encode import click import requests -import json -import time -from Crypto.PublicKey import RSA from Crypto.Cipher import PKCS1_v1_5 -from base64 import b64encode, b64decode +from Crypto.PublicKey import RSA -from opendatalab.__version__ import uaa_url_prefix, odl_clientId +from opendatalab.__version__ import odl_clientId, uaa_url_prefix api_login = "/api/v1/login/byClientSdk" api_public_key = "/api/v1/cipher/getPubKey" @@ -121,6 +121,7 @@ def get_odl_token(account, password): auth_code = get_auth_code(sso_uid=sso_uid) if not auth_code: + print(auth_code) click.secho(f"Error: Auth failure with account: {account}", err=True, fg="red") sys.exit(1) diff --git a/opendatalab/dataset/dataset.py b/opendatalab/dataset/dataset.py index 22c9463..7630577 100644 --- a/opendatalab/dataset/dataset.py +++ b/opendatalab/dataset/dataset.py @@ -6,13 +6,12 @@ import sys import click -import oss2 import requests from requests.adapters import HTTPAdapter from opendatalab.client.api import OpenDataLabAPI from opendatalab.exception import OpenDataLabError -from opendatalab.utils import parse_url, get_api_token_from_env +from opendatalab.utils import get_api_token_from_env, parse_url class Dataset: @@ -25,74 +24,4 @@ def __init__(self, url: str, token: str = "", odl_cookie: str = "") -> None: self.oss_bucket = None self.bucket_name = None - self.oss_path_prefix = "" - self.init_oss_bucket() - - def get(self, filepath: str, compressed: bool = True): - object_key = self.get_object_key_prefix(compressed) + filepath - try: - return self.oss_bucket.get_object(object_key) - except oss2.exceptions.ServerError as e: - if "InvalidAccessKeyId" not in str(e): - raise e - - self.init_oss_bucket() - return self.oss_bucket.get_object(object_key) - - def init_oss_bucket(self, expires=3600): - sts = self.open_data_lab_api.get_dataset_sts(self.dataset_name, expires=expires) - - if sts: - path_info = sts["path"].replace("oss://", "").split("/") - bucket_name = path_info[0] - sts_point, sts_use_cname = self.select_endpoint(sts) - - if sts_point: - auth = oss2.StsAuth(sts["accessKeyId"], sts["accessKeySecret"], sts["securityToken"]) - self.oss_bucket = oss2.Bucket(auth, sts_point, bucket_name, is_cname=sts_use_cname) - self.oss_path_prefix = "/".join(path_info[1:]) - else: - raise OpenDataLabError(1001, "access to bucket error") - - def get_oss_bucket(self) -> oss2.Bucket: - if self.oss_bucket is None: - self.init_oss_bucket() - return self.oss_bucket - - def refresh_oss_bucket(self) -> oss2.Bucket: - self.init_oss_bucket() - return self.get_oss_bucket() - - def get_object_key_prefix(self, compressed: bool = True) -> str: - if compressed: - return f"{self.oss_path_prefix}/raw/" - else: - return f"{self.oss_path_prefix}/" - - @classmethod - def select_endpoint(cls, sts): - s = requests.Session() - sts_endpoints = sts["endpoints"] - path_info = sts["path"].replace("oss://", "").split("/") - bucket_name = path_info[0] - - # use general endpoint - if len(sts_endpoints) > 0: - endpoint = sts_endpoints[-1] - sts_endpoint = endpoint['url'] - sts_use_cname = endpoint['useCname'] - - url_splitter = "://" - url_split_arr = str(sts_endpoint).split(url_splitter) - url_prefix = url_split_arr[0] - url_body = url_split_arr[1] - check_url = url_prefix + url_splitter + bucket_name + "." + url_body + "/check_connected" - s.mount(check_url, HTTPAdapter(max_retries=0)) - - try: - resp = s.get(check_url, timeout=(3, 1)) # 0.5 - if resp.status_code == http.HTTPStatus.OK: - return sts_endpoint, sts_use_cname - except Exception as e: - click.secho(f"ConnectionError occurs, please check network!", fg='red') - sys.exit(-1) \ No newline at end of file + self.oss_path_prefix = "" \ No newline at end of file diff --git a/opendatalab/exception.py b/opendatalab/exception.py index 8c7b434..f064ec0 100644 --- a/opendatalab/exception.py +++ b/opendatalab/exception.py @@ -27,7 +27,7 @@ class RespError(OpenDataLabError): """ STATUS_CODE: int - _INDENT = " " * len(__qualname__) + _INDENT = " " * len(__qualname__) # type: ignore def __init__(self, resp_code: Optional[int] = None, error_msg: str = ""): super().__init__(resp_code, error_msg) diff --git a/setup.cfg b/setup.cfg index 2412b11..25c840d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,12 +4,12 @@ [metadata] name = opendatalab -url = https://github.com/opendatalab/opendatalab-python-sdk/-/tree/main -author = opendatalab -author_email = wangrui@pjlab.org.cn +url = https://github.com/opendatalab/opendatalab-python-sdk +author = OpenDataLab +author_email = OpenDataLab@pjlab.org.cn license = MIT license_file = LICENSE -keywords = opendatalab, dataset +keywords = opendatalab, dataset, test description = OpenDataLab Python SDK long_description = file: README.md long_description_content_type = text/markdown @@ -22,6 +22,7 @@ classifiers = Programming Language :: Python Programming Language :: Python :: 3 Programming Language :: Python :: 3 :: Only + Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 @@ -30,14 +31,15 @@ classifiers = [options] packages = find: -python_requires = >=3.8 +python_requires = >=3.7 install_requires = + pycryptodome click >= 7.0.0 requests >= 2.4.2 - tqdm >= 4.14.0 - oss2 + tqdm colorama rich + openxlab pywin32; platform_system == "Windows" [options.packages.find] diff --git a/setup.py b/setup.py index 345b545..d23e84f 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,8 @@ import os + import setuptools -about = {} +about = {} # type: ignore here = os.path.abspath(os.path.dirname(__file__)) with open( os.path.join(here, "opendatalab", "__version__.py"), "r", encoding="utf-8" @@ -14,6 +15,6 @@ setuptools.setup( version=about["__version__"], project_urls={ - "Bug Tracker": "https://github.com/opendatalab/opendatalab-python-sdk/-/tree/main/-/issues", + "Bug Tracker": "https://github.com/opendatalab/opendatalab-python-sdk/issues", }, ) diff --git a/test-requirements.txt b/test-requirements.txt deleted file mode 100644 index e079f8a..0000000 --- a/test-requirements.txt +++ /dev/null @@ -1 +0,0 @@ -pytest diff --git a/tests/demo.py b/tests/demo.py index 5cd0484..ab14338 100644 --- a/tests/demo.py +++ b/tests/demo.py @@ -18,37 +18,32 @@ odl_api = client.get_api() # 0. login with account - # account = "xxxxx" # your username - # pw = "xxxxx" # your password - # print(f'*****'*8) - # implement_login(ctx, account, pw) + account = "18639553699" # your username + pw = "wxj8023hh!" # your password + print(f'*****'*8) + implement_login(ctx, account, pw) # 1. search demo - res_list = odl_api.search_dataset("coco") + res_list = odl_api.search_dataset("mnist") # for index, res in enumerate(res_list): # print(f"-->index: {index}, result: {res['name']}") - # implement_search("coco") + implement_search(ctx, "coco") print(f'*****'*8) # 2. list demo implement_ls(ctx, 'TAO') print(f'*****' * 8) - # 3. read file online demo - dataset = client.get_dataset('FB15k') - with dataset.get('meta/info.json', compressed=False) as fd: - content = json.load(fd) - print(f"{content}") - print(f'*****'*8) # 4. get dataset info implement_info(ctx, 'FB15k') + implement_info(ctx, 'COCO_1') # 5. download # get all files of dataset # implement_get(ctx, "MNIST", 4, 0) # get partial files of dataset - implement_get(ctx, "GOT-10k/data/test_data.zip", 4, 0) # 139, zip 1.16G GOT-10k + implement_get(ctx, "MNIST") print(f'*****' * 5) diff --git a/tests/uaa_test.py b/tests/uaa_test.py index e827c76..ce416a5 100644 --- a/tests/uaa_test.py +++ b/tests/uaa_test.py @@ -1,14 +1,16 @@ -import requests import json import time -from Crypto.PublicKey import RSA +from base64 import b64decode, b64encode + +import requests from Crypto.Cipher import PKCS1_v1_5 -from base64 import b64encode, b64decode +from Crypto.PublicKey import RSA -odl_dev_clientId = "qja9jy5wnjyqwvylmeqw" +odl_dev_clientId = "ypkl8bwo0eb5ao1b96no" odl_prd_clientId = "kmz3bkwzlaa3wrq8pvwa" -uaa_dev_url_prefix = "https://uaa-dev.openmmlab.com/gw/uaa-be" + +uaa_dev_url_prefix = "https://sso.staging.openxlab.org.cn/gw/uaa-be" uaa_prd_url_prefix = "https://sso.openxlab.org.cn/gw/uaa-be" api_login = "/api/v1/login/byAccount" @@ -39,7 +41,7 @@ def get_public_key(): result = "" if resp.status_code == 200: result = resp.json()['data']['pubKey'] - # print(result) + print(result) return result @@ -75,7 +77,7 @@ def get_account(account ,password): authorization = None if resp.status_code == 200: result = resp.json()['data'] - # print(result) + print(result) if result: authorization = resp.headers['authorization'] @@ -95,7 +97,7 @@ def get_user_info(authorization): if resp.status_code == 200: result = resp.json()['data']['ssoUid'] - # print(result) + print(result) return result @@ -114,14 +116,14 @@ def get_auth_code(ssouid): ) if resp.status_code == 200: result = resp.json()['data'] - # print(result) + print(result) return result def main(): - account = "191637988@qq.com" #"191637988@qq.com" "chenlu@pjlab.org.cn" - pw = "qq11111111" + account = "18639553699" + pw = "wxj8023hh!" authorization = get_account(account=account, password=pw) sso_uid = get_user_info(authorization=authorization) @@ -129,6 +131,5 @@ def main(): if __name__ == "__main__": - main() - - + main() + # print('!!!') \ No newline at end of file