diff --git a/examples/memtest.py b/examples/memtest.py index 056b21f8..b5e0b604 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -8,7 +8,6 @@ 4. Training a StageNet model """ -# %% if __name__ == "__main__": from pyhealth.datasets import ( MIMIC4Dataset, @@ -20,7 +19,6 @@ from pyhealth.trainer import Trainer import torch - # %% STEP 1: Load MIMIC-IV base dataset base_dataset = MIMIC4Dataset( ehr_root="/home/logic/physionet.org/files/mimiciv/3.1/", ehr_tables=[ @@ -31,76 +29,65 @@ "labevents", ], dev=False, + num_workers=8, ) - # %% # STEP 2: Apply StageNet mortality prediction task - sample_dataset = base_dataset.set_task( + with base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), - num_workers=4, - ) - - print(f"Total samples: {len(sample_dataset)}") - print(f"Input schema: {sample_dataset.input_schema}") - print(f"Output schema: {sample_dataset.output_schema}") - - # %% Inspect a sample - sample = next(iter(sample_dataset)) - print("\nSample structure:") - print(f" Patient ID: {sample['patient_id']}") - print(f"ICD Codes: {sample['icd_codes']}") - print(f" Labs shape: {len(sample['labs'][0])} timesteps") - print(f" Mortality: {sample['mortality']}") - - # %% STEP 3: Split dataset - train_dataset, val_dataset, test_dataset = split_by_patient( - sample_dataset, [0.8, 0.1, 0.1] - ) - - # Create dataloaders - train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) - val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) - test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) - - # %% STEP 4: Initialize StageNet model - model = StageNet( - dataset=sample_dataset, - embedding_dim=128, - chunk_size=128, - levels=3, - dropout=0.3, - ) - - num_params = sum(p.numel() for p in model.parameters()) - print(f"\nModel initialized with {num_params} parameters") - - # %% STEP 5: Train the model - trainer = Trainer( - model=model, - device="cpu", # or "cpu" - metrics=["pr_auc", "roc_auc", "accuracy", "f1"], - ) - - trainer.train( - train_dataloader=train_loader, - val_dataloader=val_loader, - epochs=5, - monitor="roc_auc", - optimizer_params={"lr": 1e-5}, - ) - - # %% STEP 6: Evaluate on test set - results = trainer.evaluate(test_loader) - print("\nTest Results:") - for metric, value in results.items(): - print(f" {metric}: {value:.4f}") - - # %% STEP 7: Inspect model predictions - sample_batch = next(iter(test_loader)) - with torch.no_grad(): - output = model(**sample_batch) - - print("\nSample predictions:") - print(f" Predicted probabilities: {output['y_prob'][:5]}") - print(f" True labels: {output['y_true'][:5]}") - - # %% + ) as sample_dataset: + print(f"Total samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + sample = next(iter(sample_dataset)) + print("\nSample structure:") + print(f" Patient ID: {sample['patient_id']}") + print(f"ICD Codes: {sample['icd_codes']}") + print(f" Labs shape: {len(sample['labs'][0])} timesteps") + print(f" Mortality: {sample['mortality']}") + + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + + train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) + + model = StageNet( + dataset=sample_dataset, + embedding_dim=128, + chunk_size=128, + levels=3, + dropout=0.3, + ) + + num_params = sum(p.numel() for p in model.parameters()) + print(f"\nModel initialized with {num_params} parameters") + + trainer = Trainer( + model=model, + device="cpu", # or "cpu" + metrics=["pr_auc", "roc_auc", "accuracy", "f1"], + ) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=5, + monitor="roc_auc", + optimizer_params={"lr": 1e-5}, + ) + + results = trainer.evaluate(test_loader) + print("\nTest Results:") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + sample_batch = next(iter(test_loader)) + with torch.no_grad(): + output = model(**sample_batch) + + print("\nSample predictions:") + print(f" Predicted probabilities: {output['y_prob'][:5]}") + print(f" True labels: {output['y_true'][:5]}") diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index edbb19e8..aca46ab9 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -11,7 +11,6 @@ import json import uuid import platformdirs -import tempfile import multiprocessing import multiprocessing.queues import shutil @@ -19,6 +18,7 @@ import litdata from litdata.streaming.item_loader import ParquetLoader from litdata.processing.data_processor import in_notebook +from litdata.streaming.writer import BinaryWriter import pyarrow as pa import pyarrow.csv as pv import pyarrow.parquet as pq @@ -30,17 +30,21 @@ from dask.distributed import Client as DaskClient, LocalCluster as DaskCluster, progress as dask_progress import narwhals as nw import itertools +import numpy as np +import more_itertools from ..data import Patient from ..tasks import BaseTask from ..processors.base_processor import FeatureProcessor from .configs import load_yaml_config from .sample_dataset import SampleDataset, SampleBuilder +from ..utils import set_env # Set logging level for distributed to ERROR to reduce verbosity logging.getLogger("distributed").setLevel(logging.ERROR) logger = logging.getLogger(__name__) - +# Remove LitData version check to avoid unnecessary warnings +os.environ["LITDATA_DISABLE_VERSION_CHECK"] = "1" def is_url(path: str) -> bool: """URL detection.""" @@ -111,88 +115,35 @@ def _csv_tsv_gz_path(path: str) -> str: raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") - -def _uncollate(x: list[Any]) -> Any: - return x[0] if isinstance(x, list) and len(x) == 1 else x - - -class _ParquetWriter: - """ - Stream-write rows into a Parquet file in chunked (row-group) fashion. - - Usage: - writer = StreamingParquetWriter(Path("out.parquet"), schema, chunk_size=10000) - writer.append({"id": 1, "val": 3.14}) - writer.append({"id": 2, "val": 1.23}) - writer.close() - """ - - def __init__(self, path: Path | str, schema: pa.Schema, chunk_size: int = 8_192): +class _ProgressContext: + def __init__(self, queue: multiprocessing.queues.Queue | None, total: int, **kwargs): """ - Args: - path: output Parquet file path - schema: pyarrow.Schema (required) - chunk_size: flush buffer every N rows + :param queue: An existing queue (e.g., from multiprocessing). If provided, + this class acts as a passthrough. + :param total: Total items for the progress bar (only used if queue is None). + :param kwargs: Extra arguments for tqdm (e.g., desc="Processing"). """ - self.path = Path(path) - self.schema = schema - self.chunk_size = chunk_size + self.queue = queue + self.total = total + self.kwargs = kwargs + self.progress = None - if self.schema is None: - raise ValueError( - "schema must be provided — no automatic inference allowed." - ) + def put(self, n): + if self.progress: + self.progress.update(n) - self._writer: pq.ParquetWriter | None = None - self._buffer: list[dict] = [] - self._closed = False - - # -------------------------------------------------------------- - # Public API - # -------------------------------------------------------------- - def append(self, row: dict) -> None: - """Append a single row (a Python dict).""" - if self._closed: - raise RuntimeError("Cannot append to a closed StreamingParquetWriter") - - self._buffer.append(row) - if len(self._buffer) >= self.chunk_size: - self.flush() - - def flush(self) -> None: - """Flush buffered rows into a Parquet row-group.""" - if not self._buffer: - return - - # Convert list[dict] → Arrow RecordBatch - batch = pa.RecordBatch.from_pylist(self._buffer, schema=self.schema) - - # Lazy-initialize writer - if self._writer is None: - self._writer = pq.ParquetWriter(self.path, self.schema) - - self._writer.write_batch(batch) - self._buffer.clear() - - def close(self) -> None: - """Flush and close the Parquet writer.""" - if self._closed: - return - self.flush() - if self._writer is not None: - self._writer.close() - self._closed = True - - # -------------------------------------------------------------- - # Context manager support - # -------------------------------------------------------------- def __enter__(self): + if self.queue: + return self.queue + + self.progress = tqdm(total=self.total, **self.kwargs) return self - def __exit__(self, exc_type, exc, tb): - self.close() + def __exit__(self, exc_type, exc_val, exc_tb): + if self.progress: + self.progress.close() -_task_transform_queue: multiprocessing.queues.Queue | None = None +_task_transform_progress: multiprocessing.queues.Queue | None = None def _task_transform_init(queue: multiprocessing.queues.Queue) -> None: """ @@ -201,8 +152,8 @@ def _task_transform_init(queue: multiprocessing.queues.Queue) -> None: Args: queue (multiprocessing.queues.Queue): The queue for progress tracking. """ - global _task_transform_queue - _task_transform_queue = queue + global _task_transform_progress + _task_transform_progress = queue def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, Path]) -> None: """ @@ -216,26 +167,21 @@ def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, P global_event_df (pl.LazyFrame): The global event dataframe. output_dir (Path): The output directory to save results. """ - class _FakeQueue: - def put(self, x): - pass - - # Use a batch size 128 can reduce runtime by 30%. - BATCH_SIZE = 128 - - logger.info(f"Worker {args[0]} started processing {len(list(args[2]))} patients. (Polars threads: {pl.thread_pool_size()})") - + BATCH_SIZE = 128 # Use a batch size 128 can reduce runtime by 30%. worker_id, task, patient_ids, global_event_df, output_dir = args - queue = _task_transform_queue or _FakeQueue() - - with _ParquetWriter( - output_dir / f"chunk_{worker_id:03d}.parquet", - pa.schema([("sample", pa.binary())]), - ) as writer: - batches = itertools.batched(patient_ids, BATCH_SIZE) + total_patients = len(list(patient_ids)) + logger.info(f"Worker {worker_id} started processing {total_patients} patients. (Polars threads: {pl.thread_pool_size()})") + with ( + set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)), + _ProgressContext(_task_transform_progress, total=total_patients) as progress + ): + writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") + + write_index = 0 + batches = itertools.batched(patient_ids, BATCH_SIZE) for batch in batches: - count = 0 + complete = 0 patients = ( global_event_df.filter(pl.col("patient_id").is_in(batch)) .collect(engine="streaming") @@ -245,11 +191,81 @@ def put(self, x): patient_id = patient_id[0] # Extract string from single-element list patient = Patient(patient_id=patient_id, data_source=patient_df) for sample in task(patient): - writer.append({"sample": pickle.dumps(sample)}) - count += 1 - queue.put(count) + writer.add_item(write_index, {"sample": pickle.dumps(sample)}) + write_index += 1 + complete += 1 + progress.put(complete) + writer.done() + + logger.info(f"Worker {worker_id} finished processing patients.") + +_proc_transform_progress: multiprocessing.queues.Queue | None = None + +def _proc_transform_init(queue: multiprocessing.queues.Queue) -> None: + """ + Initializer for worker processes to set up a global queue. + + Args: + queue (multiprocessing.queues.Queue): The queue for progress tracking. + """ + global _proc_transform_progress + _proc_transform_progress = queue + +def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None: + """ + Worker function to apply processors on a chunk of samples. + + Args: + args (tuple): A tuple containing: + worker_id (int): The ID of the worker. + task_df (Path): The path to the task dataframe. + start_idx (int): The start index of samples to process. + end_idx (int): The end index of samples to process. + output_dir (Path): The output directory to save results. + """ + BATCH_SIZE = 128 + worker_id, task_df, start_idx, end_idx, output_dir = args + total_samples = end_idx - start_idx + logger.info(f"Worker {worker_id} started processing {total_samples} samples. ({start_idx} to {end_idx})") + + with ( + set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)), + _ProgressContext(_proc_transform_progress, total=total_samples) as progress + ): + writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") + + dataset = litdata.StreamingDataset(str(task_df)) + complete = 0 + with open(f"{output_dir}/schema.pkl", "rb") as f: + metadata = pickle.load(f) + + input_processors = metadata["input_processors"] + output_processors = metadata["output_processors"] + + write_index = 0 + for i in range(start_idx, end_idx): + transformed: Dict[str, Any] = {} + for key, value in pickle.loads(dataset[i]["sample"]).items(): + if key in input_processors: + transformed[key] = input_processors[key].process(value) + elif key in output_processors: + transformed[key] = output_processors[key].process(value) + else: + transformed[key] = value + writer.add_item(write_index, transformed) + write_index += 1 + complete += 1 + + if complete >= BATCH_SIZE: + progress.put(complete) + complete = 0 + + if complete > 0: + progress.put(complete) + writer.done() + + logger.info(f"Worker {worker_id} finished processing samples.") - logger.info(f"Worker {args[0]} finished processing patients.") class BaseDataset(ABC): """Abstract base class for all PyHealth datasets. @@ -304,6 +320,12 @@ def __init__( @property def cache_dir(self) -> Path: """Returns the cache directory path. + The cache structure is as follows:: + + tmp/ # Temporary files during processing + global_event_df.parquet/ # Cached global event dataframe + tasks/ # Cached task-specific data, please see set_task method + Returns: Path: The cache directory path. """ @@ -321,7 +343,7 @@ def cache_dir(self) -> Path: uuid.uuid5(uuid.NAMESPACE_DNS, id_str) ) cache_dir.mkdir(parents=True, exist_ok=True) - print(f"No cache_dir provided. Using default cache dir: {cache_dir}") + logger.info(f"No cache_dir provided. Using default cache dir: {cache_dir}") self._cache_dir = cache_dir else: # Ensure the explicitly provided cache_dir exists @@ -330,12 +352,24 @@ def cache_dir(self) -> Path: self._cache_dir = cache_dir return Path(self._cache_dir) - @property - def temp_dir(self) -> Path: - return self.cache_dir / "temp" + def create_tmpdir(self) -> Path: + """Creates and returns a new temporary directory within the cache. + + Returns: + Path: The path to the new temporary directory. + """ + tmp_dir = self.cache_dir / "tmp" / str(uuid.uuid4()) + tmp_dir.mkdir(parents=True, exist_ok=True) + return tmp_dir + + def clean_tmpdir(self) -> None: + """Cleans up the temporary directory within the cache.""" + tmp_dir = self.cache_dir / "tmp" + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) def _scan_csv_tsv_gz( - self, table_name: str, source_path: str | None = None + self, source_path: str ) -> dd.DataFrame: """Scans a CSV/TSV file (possibly gzipped) and returns a Dask DataFrame. @@ -343,9 +377,7 @@ def _scan_csv_tsv_gz( to Parquet and saves it to the cache. Args: - table_name (str): The name of the table. - source_path (str | None): The source CSV/TSV file path. If None, assumes the - Parquet file already exists in the cache. + source_path (str): The source CSV/TSV file path. Returns: dd.DataFrame: The Dask DataFrame loaded from the cached Parquet file. @@ -356,22 +388,14 @@ def _scan_csv_tsv_gz( ValueError: If the path does not have an expected extension. """ # Ensure the tables cache directory exists - (self.temp_dir / "tables").mkdir(parents=True, exist_ok=True) - ret_path = str(self.temp_dir / "tables" / f"{table_name}.parquet") - - if not path_exists(ret_path): - if source_path is None: - raise FileNotFoundError( - f"Table {table_name} not found in cache and no source_path provided." - ) + ret_path = self.create_tmpdir() / "table.parquet" + if not ret_path.exists(): source_path = _csv_tsv_gz_path(source_path) if is_url(source_path): local_filename = os.path.basename(source_path) - download_dir = self.temp_dir / "downloads" - download_dir.mkdir(parents=True, exist_ok=True) - local_path = download_dir / local_filename + local_path = self.create_tmpdir() / local_filename if not local_path.exists(): logger.info(f"Downloading {source_path} to {local_path}") urlretrieve(source_path, local_path) @@ -412,6 +436,41 @@ def _scan_csv_tsv_gz( ) return df.replace("", pd.NA) # Replace empty strings with NaN + def _event_transform(self, output_dir: Path) -> None: + try: + df = self.load_data() + with DaskCluster( + n_workers=self.num_workers, + threads_per_worker=1, + processes=not in_notebook(), + # Use cache_dir for Dask's scratch space to avoid filling up /tmp or home directory + local_directory=str(self.create_tmpdir()), + ) as cluster: + with DaskClient(cluster) as client: + if self.dev: + logger.info("Dev mode enabled: limiting to 1000 patients") + patients = df["patient_id"].unique().head(1000).tolist() + filter = df["patient_id"].isin(patients) + df = df[filter] + + logger.info(f"Caching event dataframe to {output_dir}...") + collection = df.sort_values("patient_id").to_parquet( + output_dir, + write_index=False, + compute=False, + ) + handle = client.compute(collection) + dask_progress(handle) + handle.result() # type: ignore + except Exception as e: + if output_dir.exists(): + logger.error(f"Error during caching, removing incomplete file {output_dir}") + shutil.rmtree(output_dir) + raise e + finally: + self.clean_tmpdir() + pass + @property def global_event_df(self) -> pl.LazyFrame: """Returns the path to the cached event dataframe. @@ -419,43 +478,12 @@ def global_event_df(self) -> pl.LazyFrame: Returns: Path: The path to the cached event dataframe. """ - if not multiprocessing.current_process().name == "MainProcess": - logger.warning( - "global_event_df property accessed from a non-main process. This may lead to unexpected behavior.\n" - + "Consider use __name__ == '__main__' guard when using multiprocessing." - ) - return None # type: ignore + self._main_guard(type(self).global_event_df.fget.__name__) # type: ignore if self._global_event_df is None: ret_path = self.cache_dir / "global_event_df.parquet" if not ret_path.exists(): - # Use cache_dir for Dask's scratch space to avoid filling up /tmp or home directory - dask_scratch_dir = self.cache_dir / "dask_scratch" - dask_scratch_dir.mkdir(parents=True, exist_ok=True) - - with DaskCluster( - n_workers=self.num_workers, - threads_per_worker=1, - processes=not in_notebook(), - local_directory=str(dask_scratch_dir), - ) as cluster: - with DaskClient(cluster) as client: - df: dd.DataFrame = self.load_data() - if self.dev: - logger.info("Dev mode enabled: limiting to 1000 patients") - patients = df["patient_id"].unique().head(1000).tolist() - filter = df["patient_id"].isin(patients) - df = df[filter] - - logger.info(f"Caching event dataframe to {ret_path}...") - collection = df.sort_values("patient_id").to_parquet( - ret_path, - write_index=False, - compute=False, - ) - handle = client.compute(collection) - dask_progress(handle) - handle.result() # type: ignore + self._event_transform(ret_path) self._global_event_df = ret_path return pl.scan_parquet( @@ -495,7 +523,7 @@ def load_table(self, table_name: str) -> dd.DataFrame: csv_path = clean_path(csv_path) logger.info(f"Scanning table: {table_name} from {csv_path}") - df = self._scan_csv_tsv_gz(table_name, csv_path) + df = self._scan_csv_tsv_gz(csv_path) # Convert column names to lowercase before calling preprocess_func df = df.rename(columns=str.lower) @@ -510,11 +538,11 @@ def load_table(self, table_name: str) -> dd.DataFrame: df = preprocess_func(nw.from_native(df)).to_native() # type: ignore # Handle joins - for i, join_cfg in enumerate(table_cfg.join): + for join_cfg in table_cfg.join: other_csv_path = f"{self.root}/{join_cfg.file_path}" other_csv_path = clean_path(other_csv_path) logger.info(f"Joining with table: {other_csv_path}") - join_df = self._scan_csv_tsv_gz(f"{table_name}_join_{i}", other_csv_path) + join_df = self._scan_csv_tsv_gz(other_csv_path) join_df = join_df.rename(columns=str.lower) join_key = join_cfg.on columns = join_cfg.columns @@ -653,71 +681,116 @@ def default_task(self) -> Optional[BaseTask]: def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> None: self._main_guard(self._task_transform.__name__) + logger.info(f"Applying task transformations on data with {num_workers} workers...") + global_event_df = task.pre_filter(self.global_event_df) + patient_ids = ( + global_event_df.select("patient_id") + .unique() + .collect(engine="streaming") + .to_series() + # .sort can reduce runtime by 5%. + .sort() + ) + + if in_notebook(): + logger.info("Detected Jupyter notebook environment, setting num_workers to 1") + num_workers = 1 + num_workers = min(num_workers, len(patient_ids)) # Avoid spawning empty workers + # This ensures worker's polars threads are limited to avoid oversubscription, # which can lead to additional 75% speedup when num_workers is large. - old_polars_max_threads = os.environ.get("POLARS_MAX_THREADS") threads_per_worker = max(1, (os.cpu_count() or 1) // num_workers) - os.environ["POLARS_MAX_THREADS"] = str(threads_per_worker) try: - logger.info(f"Applying task transformations on data with {num_workers} workers...") - global_event_df = task.pre_filter(self.global_event_df) - patient_ids = ( - global_event_df.select("patient_id") - .unique() - .collect(engine="streaming") - .to_series() - # .sort can reduce runtime by 5%. - .sort() - ) - - if in_notebook(): - logger.info("Detected Jupyter notebook environment, setting num_workers to 1") - num_workers = 1 + with set_env(POLARS_MAX_THREADS=str(threads_per_worker), DATA_OPTIMIZER_NUM_WORKERS=str(num_workers)): + if num_workers == 1: + logger.info("Single worker mode, processing sequentially") + _task_transform_fn((0, task, patient_ids, global_event_df, output_dir)) + BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) + return + + # spwan is required for polars in multiprocessing, see https://docs.pola.rs/user-guide/misc/multiprocessing/#summary + ctx = multiprocessing.get_context("spawn") + queue = ctx.Queue() + args_list = [( + worker_id, + task, + pids, + global_event_df, + output_dir, + ) for worker_id, pids in enumerate(itertools.batched(patient_ids, len(patient_ids) // num_workers + 1))] + with ctx.Pool(processes=num_workers, initializer=_task_transform_init, initargs=(queue,)) as pool: + result = pool.map_async(_task_transform_fn, args_list) # type: ignore + with tqdm(total=len(patient_ids)) as progress: + while not result.ready(): + try: + progress.update(queue.get(timeout=1)) + except: + pass + + # remaining items + while not queue.empty(): + progress.update(queue.get()) + result.get() # ensure exceptions are raised + BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) + + logger.info(f"Task transformation completed and saved to {output_dir}") + except Exception as e: + logger.error(f"Error during task transformation, cleaning up output directory: {output_dir}") + shutil.rmtree(output_dir) + raise e - if num_workers == 1: - logger.info("Single worker mode, processing sequentially") - _task_transform_fn((0, task, patient_ids, global_event_df, output_dir)) - litdata.index_parquet_dataset(str(output_dir)) - return - - num_workers = min(num_workers, len(patient_ids)) # Avoid spawning empty workers - batch_size = len(patient_ids) // num_workers + 1 + def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> None: + self._main_guard(self._proc_transform.__name__) + + logger.info(f"Applying processors on data with {num_workers} workers...") + num_samples = len(litdata.StreamingDataset(str(task_df))) - # spwan is required for polars in multiprocessing, see https://docs.pola.rs/user-guide/misc/multiprocessing/#summary - ctx = multiprocessing.get_context("spawn") - queue = ctx.Queue() - args_list = [( - worker_id, - task, - pids, - global_event_df, - output_dir, - ) for worker_id, pids in enumerate(itertools.batched(patient_ids, batch_size))] - with ctx.Pool(processes=num_workers, initializer=_task_transform_init, initargs=(queue,)) as pool: - result = pool.map_async(_task_transform_fn, args_list) # type: ignore - with tqdm(total=len(patient_ids)) as progress: - while not result.ready(): + if in_notebook(): + logger.info("Detected Jupyter notebook environment, setting num_workers to 1") + num_workers = 1 + + num_workers = min(num_workers, num_samples) # Avoid spawning empty workers + try: + with set_env(DATA_OPTIMIZER_NUM_WORKERS=str(num_workers)): + if num_workers == 1: + logger.info("Single worker mode, processing sequentially") + _proc_transform_fn((0, task_df, 0, num_samples, output_dir)) + BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) + return + + ctx = multiprocessing.get_context("spawn") + queue = ctx.Queue() + linspace = more_itertools.sliding_window(np.linspace(0, num_samples, num_workers + 1, dtype=int), 2) + args_list = [( + worker_id, + task_df, + start, + end, + output_dir, + ) for worker_id, (start, end) in enumerate(linspace)] + with ctx.Pool(processes=num_workers, initializer=_proc_transform_init, initargs=(queue,)) as pool: + result = pool.map_async(_proc_transform_fn, args_list) # type: ignore + with tqdm(total=num_samples) as progress: + while not result.ready(): + try: + progress.update(queue.get(timeout=1)) + except: + pass + + # remaining items while not queue.empty(): progress.update(queue.get()) - - # remaining items - while not queue.empty(): - progress.update(queue.get()) - result.get() # ensure exceptions are raised - - litdata.index_parquet_dataset(str(output_dir)) - logger.info(f"Task transformation completed and saved to {output_dir}") + result.get() # ensure exceptions are raised + BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) + + logger.info(f"Processor transformation completed and saved to {output_dir}") except Exception as e: - logger.error(f"Error during task transformation, cleaning up output directory: {output_dir}") + logger.error(f"Error during processor transformation.") shutil.rmtree(output_dir) raise e finally: - if old_polars_max_threads is not None: - os.environ["POLARS_MAX_THREADS"] = old_polars_max_threads - else: - os.environ.pop("POLARS_MAX_THREADS", None) - + self.clean_tmpdir() def set_task( self, @@ -729,14 +802,20 @@ def set_task( output_processors: Optional[Dict[str, FeatureProcessor]] = None, ) -> SampleDataset: """Processes the base dataset to generate the task-specific sample dataset. + The cache structure is as follows:: + + task_df.ld/ # Intermediate task dataframe after task transformation + samples_{uuid}.ld/ # Final processed samples after applying processors + schema.pkl # Saved SampleBuilder schema + *.bin # Processed sample files + samples_{uuid}.ld/ + ... Args: task (Optional[BaseTask]): The task to set. Uses default task if None. - num_workers (int): Number of workers for multi-threading. Default is 1. - This is because the task function is usually CPU-bound. And using - multi-threading may not speed up the task function. - cache_dir (Optional[str]): Directory to cache processed samples. - Default is None (no caching). + num_workers (int): Number of workers for multi-threading. Default is `self.num_workers`. + cache_dir (Optional[str]): Directory to cache samples after task transformation, + but without applying processors. Default is {self.cache_dir}/tasks/{task_name}. cache_format (str): Deprecated. Only "parquet" is supported now. input_processors (Optional[Dict[str, FeatureProcessor]]): Pre-fitted input processors. If provided, these will be used @@ -775,55 +854,47 @@ def set_task( cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) - path = Path(cache_dir) + task_df_path = Path(cache_dir) / "task_df.ld" + samples_path = Path(cache_dir) / f"samples_{uuid.uuid4()}.ld" + + task_df_path.mkdir(parents=True, exist_ok=True) + samples_path.mkdir(parents=True, exist_ok=True) # Check if index.json exists to verify cache integrity, this # is the standard file for litdata.StreamingDataset - if not (path / "index.json").exists(): - with tempfile.TemporaryDirectory() as tmp_dir: - self._task_transform( - task, - Path(tmp_dir), - num_workers, - ) + if not (task_df_path / "index.json").exists(): + self._task_transform( + task, + task_df_path, + num_workers, + ) - # Build processors and fit on the dataset - logger.info(f"Fitting processors on the dataset...") - dataset = litdata.StreamingDataset( - tmp_dir, - item_loader=ParquetLoader(), - transform=lambda x: pickle.loads(x["sample"]), - ) - builder = SampleBuilder( - input_schema=task.input_schema, # type: ignore - output_schema=task.output_schema, # type: ignore - input_processors=input_processors, - output_processors=output_processors, - ) - builder.fit(dataset) - builder.save(str(path / "schema.pkl")) - - # Apply processors and save final samples to cache_dir - logger.info(f"Processing samples and saving to {path}...") - dataset = litdata.StreamingDataset( - tmp_dir, - item_loader=ParquetLoader(), - ) - litdata.optimize( - fn=builder.transform, - inputs=litdata.StreamingDataLoader( - dataset, - batch_size=1, - collate_fn=_uncollate, - ), - output_dir=str(path), - chunk_bytes="64MB", - num_workers=num_workers, - ) - logger.info(f"Cached processed samples to {path}") + # Build processors and fit on the dataset + logger.info(f"Fitting processors on the dataset...") + dataset = litdata.StreamingDataset( + str(task_df_path), + transform=lambda x: pickle.loads(x["sample"]), + ) + builder = SampleBuilder( + input_schema=task.input_schema, # type: ignore + output_schema=task.output_schema, # type: ignore + input_processors=input_processors, + output_processors=output_processors, + ) + builder.fit(dataset) + builder.save(str(samples_path / "schema.pkl")) + + # Apply processors and save final samples to cache_dir + logger.info(f"Processing samples and saving to {samples_path}...") + self._proc_transform( + task_df_path, + samples_path, + num_workers, + ) + logger.info(f"Cached processed samples to {samples_path}") return SampleDataset( - path=str(path), + path=str(samples_path), dataset_name=self.dataset_name, task_name=task.task_name, ) diff --git a/pyhealth/datasets/mimic4.py b/pyhealth/datasets/mimic4.py index 39835580..6a01033b 100644 --- a/pyhealth/datasets/mimic4.py +++ b/pyhealth/datasets/mimic4.py @@ -264,12 +264,11 @@ def __init__( logger.info( f"Initializing MIMIC4EHRDataset with tables: {ehr_tables} (dev mode: {dev})" ) - ehr_cache_dir = None if cache_dir is None else f"{cache_dir}/ehr" self.sub_datasets["ehr"] = MIMIC4EHRDataset( root=ehr_root, tables=ehr_tables, config_path=ehr_config_path, - cache_dir=ehr_cache_dir, + cache_dir=str(self.cache_dir), dev=dev, num_workers=num_workers, ) @@ -280,12 +279,11 @@ def __init__( logger.info( f"Initializing MIMIC4NoteDataset with tables: {note_tables} (dev mode: {dev})" ) - note_cache_dir = None if cache_dir is None else f"{cache_dir}/note" self.sub_datasets["note"] = MIMIC4NoteDataset( root=note_root, tables=note_tables, config_path=note_config_path, - cache_dir=note_cache_dir, + cache_dir=str(self.cache_dir), dev=dev, num_workers=num_workers, ) @@ -296,12 +294,11 @@ def __init__( logger.info( f"Initializing MIMIC4CXRDataset with tables: {cxr_tables} (dev mode: {dev})" ) - cxr_cache_dir = None if cache_dir is None else f"{cache_dir}/cxr" self.sub_datasets["cxr"] = MIMIC4CXRDataset( root=cxr_root, tables=cxr_tables, config_path=cxr_config_path, - cache_dir=cxr_cache_dir, + cache_dir=str(self.cache_dir), dev=dev, num_workers=num_workers, ) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index e14e02ca..906b06b8 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from pathlib import Path import pickle +import shutil import tempfile from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type import inspect @@ -355,6 +356,20 @@ def subset(self, indices: Union[Sequence[int], slice]) -> "SampleDataset": new_dataset.reset() return new_dataset + + def close(self) -> None: + """Cleans up any temporary directories used by the dataset.""" + if self.input_dir.path is not None and Path(self.input_dir.path).exists(): + shutil.rmtree(self.input_dir.path) + + # -------------------------------------------------------------- + # Context manager support + # -------------------------------------------------------------- + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() class InMemorySampleDataset(SampleDataset): @@ -464,6 +479,8 @@ def subset(self, indices: Union[Sequence[int], slice]) -> SampleDataset: new_dataset._data = samples return new_dataset + def close(self) -> None: + pass # No temporary directories to clean up for in-memory dataset def create_sample_dataset( samples: List[Dict[str, Any]], diff --git a/pyhealth/utils.py b/pyhealth/utils.py index 2a57a0f9..b4af8980 100644 --- a/pyhealth/utils.py +++ b/pyhealth/utils.py @@ -2,6 +2,7 @@ import os import pickle import random +import contextlib import numpy as np import torch @@ -43,3 +44,25 @@ def save_json(data, filename): with open(filename, "w") as f: json.dump(data, f) +@contextlib.contextmanager +def set_env(**environ): + """ + Temporarily set the process environment variables. + + >>> with set_env(PLUGINS_DIR='test/plugins'): + ... "PLUGINS_DIR" in os.environ + True + + >>> "PLUGINS_DIR" in os.environ + False + + :type environ: dict[str, unicode] + :param environ: Environment variables to set + """ + old_environ = dict(os.environ) + os.environ.update(environ) + try: + yield + finally: + os.environ.clear() + os.environ.update(old_environ) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 81e1a14b..f91b712a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,9 +40,10 @@ dependencies = [ "pandarallel~=1.6.5", "pydantic~=2.11.7", "dask[complete]~=2025.11.0", - "litdata~=0.2.58", + "litdata~=0.2.59", "pyarrow~=22.0.0", "narwhals~=2.13.0", + "more-itertools~=10.8.0", ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index 86508433..e5ec990c 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -121,33 +121,43 @@ def test_set_task_signature(self): def test_set_task_writes_cache_and_metadata(self): """Ensure set_task materializes cache files and schema metadata.""" cache_dir = self._task_cache_dir() - sample_dataset = self.dataset.set_task( + with self.dataset.set_task( self.task, cache_dir=cache_dir, cache_format="parquet" - ) - - self.assertIsInstance(sample_dataset, SampleDataset) - self.assertEqual(sample_dataset.dataset_name, "TestDataset") - self.assertEqual(sample_dataset.task_name, self.task.task_name) - self.assertEqual(len(sample_dataset), 4) - self.assertEqual(self.task.call_count, 2) + ) as sample_dataset: + self.assertIsInstance(sample_dataset, SampleDataset) + self.assertEqual(sample_dataset.dataset_name, "TestDataset") + self.assertEqual(sample_dataset.task_name, self.task.task_name) + self.assertEqual(len(sample_dataset), 4) + self.assertEqual(self.task.call_count, 2) + + # Ensure intermediate cache files are created + self.assertTrue((cache_dir / "task_df.ld" / "index.json").exists()) + + # Cache artifacts should be present for StreamingDataset + assert sample_dataset.input_dir.path is not None + sample_dir = Path(sample_dataset.input_dir.path) + self.assertTrue((sample_dir / "index.json").exists()) + self.assertTrue((sample_dir / "schema.pkl").exists()) + + # Check processed sample structure and metadata persisted + sample = sample_dataset[0] + self.assertIn("test_attribute", sample) + self.assertIn("test_label", sample) + self.assertIn("patient_id", sample) + self.assertIsInstance(sample["test_label"], torch.Tensor) + self.assertIn("test_attribute", sample_dataset.input_processors) + self.assertIn("test_label", sample_dataset.output_processors) + self.assertEqual(set(sample_dataset.patient_to_index), {"1", "2"}) + self.assertTrue( + all(len(indexes) == 2 for indexes in sample_dataset.patient_to_index.values()) + ) + self.assertEqual(sample_dataset.record_to_index, {}) + # Ensure directory is cleaned up after context exit + self.assertFalse((sample_dir / "index.json").exists()) + self.assertFalse((sample_dir / "schema.pkl").exists()) + # Ensure intermediate cache files are still present + self.assertTrue((cache_dir / "task_df.ld" / "index.json").exists()) - # Cache artifacts should be present for StreamingDataset - self.assertTrue((cache_dir / "index.json").exists()) - self.assertTrue((cache_dir / "schema.pkl").exists()) - - # Check processed sample structure and metadata persisted - sample = sample_dataset[0] - self.assertIn("test_attribute", sample) - self.assertIn("test_label", sample) - self.assertIn("patient_id", sample) - self.assertIsInstance(sample["test_label"], torch.Tensor) - self.assertIn("test_attribute", sample_dataset.input_processors) - self.assertIn("test_label", sample_dataset.output_processors) - self.assertEqual(set(sample_dataset.patient_to_index), {"1", "2"}) - self.assertTrue( - all(len(indexes) == 2 for indexes in sample_dataset.patient_to_index.values()) - ) - self.assertEqual(sample_dataset.record_to_index, {}) def test_default_cache_dir_is_used(self): """When cache_dir is omitted, default cache dir should be used.""" @@ -155,7 +165,7 @@ def test_default_cache_dir_is_used(self): sample_dataset = self.dataset.set_task(self.task) self.assertTrue(task_cache.exists()) - self.assertTrue((task_cache / "index.json").exists()) + self.assertTrue((task_cache / "task_df.ld" / "index.json").exists()) self.assertTrue((self.dataset.cache_dir / "global_event_df.parquet").exists()) self.assertEqual(len(sample_dataset), 4) diff --git a/tests/core/test_streaming_parquet_writer.py b/tests/core/test_streaming_parquet_writer.py deleted file mode 100644 index c7c212bd..00000000 --- a/tests/core/test_streaming_parquet_writer.py +++ /dev/null @@ -1,55 +0,0 @@ -import tempfile -from pathlib import Path -import unittest - -import pyarrow as pa -import pyarrow.parquet as pq - -from pyhealth.datasets.base_dataset import _ParquetWriter -from tests.base import BaseTestCase - - -class TestStreamingParquetWriter(BaseTestCase): - def setUp(self): - self.tmpdir = tempfile.TemporaryDirectory() - self.schema = pa.schema( - [ - ("id", pa.int64()), - ("value", pa.string()), - ] - ) - self.output_path = Path(self.tmpdir.name) / "stream.parquet" - - def tearDown(self): - self.tmpdir.cleanup() - - def test_append_flush_close_and_context_manager(self): - rows = [ - {"id": 1, "value": "a"}, - {"id": 2, "value": "b"}, - {"id": 3, "value": "c"}, - {"id": 4, "value": "d"}, - ] - - with _ParquetWriter( - self.output_path, self.schema, chunk_size=2 - ) as writer: - # First two appends trigger an automatic flush due to chunk_size=2. - writer.append(rows[0]) - writer.append(rows[1]) - - # Flush again after adding a third row to ensure flushing appends - # rather than overwriting previous row groups. - writer.append(rows[2]) - writer.flush() - - # Leave data in the buffer to verify close() flushes it. - writer.append(rows[3]) - - # Context manager should have closed and flushed remaining buffered rows. - self.assertTrue(self.output_path.exists()) - - written_rows = pq.read_table(self.output_path).to_pylist() - - # Every append should be present as a distinct row in order. - self.assertEqual(written_rows, rows)