From a2156a8c02a626a0866e1e98dee4eab5bdc4883b Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 24 Dec 2025 02:52:04 -0500 Subject: [PATCH 01/35] move _scan_csv_tsv_gz into ../tmp folder --- pyhealth/datasets/base_dataset.py | 50 ++++++++++++++++++------------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index edbb19e8..8145eb72 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -304,6 +304,12 @@ def __init__( @property def cache_dir(self) -> Path: """Returns the cache directory path. + The cache structure is as follows:: + + tmp/ # Temporary files during processing + global_event_df.parquet/ # Cached global event dataframe + tasks/ # Cached task-specific data, please see set_task method + Returns: Path: The cache directory path. """ @@ -330,12 +336,24 @@ def cache_dir(self) -> Path: self._cache_dir = cache_dir return Path(self._cache_dir) - @property - def temp_dir(self) -> Path: - return self.cache_dir / "temp" + def create_tmpdir(self) -> Path: + """Creates and returns a new temporary directory within the cache. + + Returns: + Path: The path to the new temporary directory. + """ + tmp_dir = self.cache_dir / "tmp" / str(uuid.uuid4()) + tmp_dir.mkdir(parents=True, exist_ok=True) + return tmp_dir + + def clean_tmpdir(self) -> None: + """Cleans up the temporary directory within the cache.""" + tmp_dir = self.cache_dir / "tmp" + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) def _scan_csv_tsv_gz( - self, table_name: str, source_path: str | None = None + self, source_path: str ) -> dd.DataFrame: """Scans a CSV/TSV file (possibly gzipped) and returns a Dask DataFrame. @@ -343,9 +361,7 @@ def _scan_csv_tsv_gz( to Parquet and saves it to the cache. Args: - table_name (str): The name of the table. - source_path (str | None): The source CSV/TSV file path. If None, assumes the - Parquet file already exists in the cache. + source_path (str): The source CSV/TSV file path. Returns: dd.DataFrame: The Dask DataFrame loaded from the cached Parquet file. @@ -356,22 +372,14 @@ def _scan_csv_tsv_gz( ValueError: If the path does not have an expected extension. """ # Ensure the tables cache directory exists - (self.temp_dir / "tables").mkdir(parents=True, exist_ok=True) - ret_path = str(self.temp_dir / "tables" / f"{table_name}.parquet") - - if not path_exists(ret_path): - if source_path is None: - raise FileNotFoundError( - f"Table {table_name} not found in cache and no source_path provided." - ) + ret_path = self.create_tmpdir() / "table.parquet" + if not ret_path.exists(): source_path = _csv_tsv_gz_path(source_path) if is_url(source_path): local_filename = os.path.basename(source_path) - download_dir = self.temp_dir / "downloads" - download_dir.mkdir(parents=True, exist_ok=True) - local_path = download_dir / local_filename + local_path = self.create_tmpdir() / local_filename if not local_path.exists(): logger.info(f"Downloading {source_path} to {local_path}") urlretrieve(source_path, local_path) @@ -495,7 +503,7 @@ def load_table(self, table_name: str) -> dd.DataFrame: csv_path = clean_path(csv_path) logger.info(f"Scanning table: {table_name} from {csv_path}") - df = self._scan_csv_tsv_gz(table_name, csv_path) + df = self._scan_csv_tsv_gz(csv_path) # Convert column names to lowercase before calling preprocess_func df = df.rename(columns=str.lower) @@ -510,11 +518,11 @@ def load_table(self, table_name: str) -> dd.DataFrame: df = preprocess_func(nw.from_native(df)).to_native() # type: ignore # Handle joins - for i, join_cfg in enumerate(table_cfg.join): + for join_cfg in table_cfg.join: other_csv_path = f"{self.root}/{join_cfg.file_path}" other_csv_path = clean_path(other_csv_path) logger.info(f"Joining with table: {other_csv_path}") - join_df = self._scan_csv_tsv_gz(f"{table_name}_join_{i}", other_csv_path) + join_df = self._scan_csv_tsv_gz(other_csv_path) join_df = join_df.rename(columns=str.lower) join_key = join_cfg.on columns = join_cfg.columns From d7b2edc70f0a2fee7d544850e2264f39abc4e0cd Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 24 Dec 2025 02:59:41 -0500 Subject: [PATCH 02/35] clean up global_event_df.parquet if failed, clean up tmp dir --- pyhealth/datasets/base_dataset.py | 66 +++++++++++++++---------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 8145eb72..5e95c3b4 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -427,43 +427,43 @@ def global_event_df(self) -> pl.LazyFrame: Returns: Path: The path to the cached event dataframe. """ - if not multiprocessing.current_process().name == "MainProcess": - logger.warning( - "global_event_df property accessed from a non-main process. This may lead to unexpected behavior.\n" - + "Consider use __name__ == '__main__' guard when using multiprocessing." - ) - return None # type: ignore + self._main_guard(type(self).global_event_df.fget.__name__) # type: ignore if self._global_event_df is None: ret_path = self.cache_dir / "global_event_df.parquet" if not ret_path.exists(): - # Use cache_dir for Dask's scratch space to avoid filling up /tmp or home directory - dask_scratch_dir = self.cache_dir / "dask_scratch" - dask_scratch_dir.mkdir(parents=True, exist_ok=True) - - with DaskCluster( - n_workers=self.num_workers, - threads_per_worker=1, - processes=not in_notebook(), - local_directory=str(dask_scratch_dir), - ) as cluster: - with DaskClient(cluster) as client: - df: dd.DataFrame = self.load_data() - if self.dev: - logger.info("Dev mode enabled: limiting to 1000 patients") - patients = df["patient_id"].unique().head(1000).tolist() - filter = df["patient_id"].isin(patients) - df = df[filter] - - logger.info(f"Caching event dataframe to {ret_path}...") - collection = df.sort_values("patient_id").to_parquet( - ret_path, - write_index=False, - compute=False, - ) - handle = client.compute(collection) - dask_progress(handle) - handle.result() # type: ignore + try: + with DaskCluster( + n_workers=self.num_workers, + threads_per_worker=1, + processes=not in_notebook(), + # Use cache_dir for Dask's scratch space to avoid filling up /tmp or home directory + local_directory=str(self.create_tmpdir()), + ) as cluster: + with DaskClient(cluster) as client: + df: dd.DataFrame = self.load_data() + if self.dev: + logger.info("Dev mode enabled: limiting to 1000 patients") + patients = df["patient_id"].unique().head(1000).tolist() + filter = df["patient_id"].isin(patients) + df = df[filter] + + logger.info(f"Caching event dataframe to {ret_path}...") + collection = df.sort_values("patient_id").to_parquet( + ret_path, + write_index=False, + compute=False, + ) + handle = client.compute(collection) + dask_progress(handle) + handle.result() # type: ignore + except Exception as e: + if ret_path.exists(): + logger.error(f"Error during caching, removing incomplete file {ret_path}") + ret_path.unlink() + raise e + finally: + self.clean_tmpdir() self._global_event_df = ret_path return pl.scan_parquet( From 01553c1b4dadfbc96497a53e65c9f61754fd1d45 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 24 Dec 2025 03:04:59 -0500 Subject: [PATCH 03/35] support context manager for SampleDataset --- pyhealth/datasets/sample_dataset.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index e14e02ca..906b06b8 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from pathlib import Path import pickle +import shutil import tempfile from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type import inspect @@ -355,6 +356,20 @@ def subset(self, indices: Union[Sequence[int], slice]) -> "SampleDataset": new_dataset.reset() return new_dataset + + def close(self) -> None: + """Cleans up any temporary directories used by the dataset.""" + if self.input_dir.path is not None and Path(self.input_dir.path).exists(): + shutil.rmtree(self.input_dir.path) + + # -------------------------------------------------------------- + # Context manager support + # -------------------------------------------------------------- + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() class InMemorySampleDataset(SampleDataset): @@ -464,6 +479,8 @@ def subset(self, indices: Union[Sequence[int], slice]) -> SampleDataset: new_dataset._data = samples return new_dataset + def close(self) -> None: + pass # No temporary directories to clean up for in-memory dataset def create_sample_dataset( samples: List[Dict[str, Any]], From 44add05e7759b93935894a01486ba689611dcafa Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 24 Dec 2025 03:21:41 -0500 Subject: [PATCH 04/35] Fix set_task cache_dir --- pyhealth/datasets/base_dataset.py | 49 +++++++++++++++++++------------ 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 5e95c3b4..608dfb6d 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -306,9 +306,9 @@ def cache_dir(self) -> Path: """Returns the cache directory path. The cache structure is as follows:: - tmp/ # Temporary files during processing + tmp/ # Temporary files during processing global_event_df.parquet/ # Cached global event dataframe - tasks/ # Cached task-specific data, please see set_task method + tasks/ # Cached task-specific data, please see set_task method Returns: Path: The cache directory path. @@ -737,14 +737,20 @@ def set_task( output_processors: Optional[Dict[str, FeatureProcessor]] = None, ) -> SampleDataset: """Processes the base dataset to generate the task-specific sample dataset. + The cache structure is as follows:: + + task_df.parquet/ # Intermediate task dataframe after task transformation + samples_{uuid}/ # Final processed samples after applying processors + schema.pkl # Saved SampleBuilder schema + *.parquet # Processed sample files + samples_{uuid}/ + ... Args: task (Optional[BaseTask]): The task to set. Uses default task if None. - num_workers (int): Number of workers for multi-threading. Default is 1. - This is because the task function is usually CPU-bound. And using - multi-threading may not speed up the task function. - cache_dir (Optional[str]): Directory to cache processed samples. - Default is None (no caching). + num_workers (int): Number of workers for multi-threading. Default is `self.num_workers`. + cache_dir (Optional[str]): Directory to cache samples after task transformation, + but without applying processors. Default is {self.cache_dir}/tasks/{task_name}. cache_format (str): Deprecated. Only "parquet" is supported now. input_processors (Optional[Dict[str, FeatureProcessor]]): Pre-fitted input processors. If provided, these will be used @@ -783,22 +789,23 @@ def set_task( cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) - path = Path(cache_dir) + task_df_path = Path(cache_dir) / "task_df.parquet" + samples_path = Path(cache_dir) / f"samples_{uuid.uuid4()}" # Check if index.json exists to verify cache integrity, this # is the standard file for litdata.StreamingDataset - if not (path / "index.json").exists(): - with tempfile.TemporaryDirectory() as tmp_dir: + if not (task_df_path / "index.json").exists(): + try: self._task_transform( task, - Path(tmp_dir), + task_df_path, num_workers, ) # Build processors and fit on the dataset logger.info(f"Fitting processors on the dataset...") dataset = litdata.StreamingDataset( - tmp_dir, + str(task_df_path), item_loader=ParquetLoader(), transform=lambda x: pickle.loads(x["sample"]), ) @@ -809,12 +816,12 @@ def set_task( output_processors=output_processors, ) builder.fit(dataset) - builder.save(str(path / "schema.pkl")) + builder.save(str(samples_path / "schema.pkl")) # Apply processors and save final samples to cache_dir - logger.info(f"Processing samples and saving to {path}...") + logger.info(f"Processing samples and saving to {samples_path}...") dataset = litdata.StreamingDataset( - tmp_dir, + str(task_df_path), item_loader=ParquetLoader(), ) litdata.optimize( @@ -824,14 +831,20 @@ def set_task( batch_size=1, collate_fn=_uncollate, ), - output_dir=str(path), + output_dir=str(samples_path), chunk_bytes="64MB", num_workers=num_workers, ) - logger.info(f"Cached processed samples to {path}") + logger.info(f"Cached processed samples to {samples_path}") + except Exception as e: + logger.error(f"Error during set_task, cleaning up cache directory: {cache_dir}") + shutil.rmtree(cache_dir) + raise e + finally: + self.clean_tmpdir() return SampleDataset( - path=str(path), + path=str(samples_path), dataset_name=self.dataset_name, task_name=task.task_name, ) From c0cc77d8b475b270dd41060de65f72f10e61c920 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 24 Dec 2025 04:08:39 -0500 Subject: [PATCH 05/35] move function to _event_transform --- pyhealth/datasets/base_dataset.py | 68 ++++++++++++++++--------------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 608dfb6d..9bb88a09 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -420,6 +420,40 @@ def _scan_csv_tsv_gz( ) return df.replace("", pd.NA) # Replace empty strings with NaN + def _event_transform(self, df: dd.DataFrame, output_dir: Path) -> None: + try: + with DaskCluster( + n_workers=self.num_workers, + threads_per_worker=1, + processes=not in_notebook(), + # Use cache_dir for Dask's scratch space to avoid filling up /tmp or home directory + local_directory=str(self.create_tmpdir()), + ) as cluster: + with DaskClient(cluster) as client: + if self.dev: + logger.info("Dev mode enabled: limiting to 1000 patients") + patients = df["patient_id"].unique().head(1000).tolist() + filter = df["patient_id"].isin(patients) + df = df[filter] + + logger.info(f"Caching event dataframe to {output_dir}...") + collection = df.sort_values("patient_id").to_parquet( + output_dir, + write_index=False, + compute=False, + ) + handle = client.compute(collection) + dask_progress(handle) + handle.result() # type: ignore + except Exception as e: + if output_dir.exists(): + logger.error(f"Error during caching, removing incomplete file {output_dir}") + shutil.rmtree(output_dir) + raise e + finally: + self.clean_tmpdir() + pass + @property def global_event_df(self) -> pl.LazyFrame: """Returns the path to the cached event dataframe. @@ -432,38 +466,7 @@ def global_event_df(self) -> pl.LazyFrame: if self._global_event_df is None: ret_path = self.cache_dir / "global_event_df.parquet" if not ret_path.exists(): - try: - with DaskCluster( - n_workers=self.num_workers, - threads_per_worker=1, - processes=not in_notebook(), - # Use cache_dir for Dask's scratch space to avoid filling up /tmp or home directory - local_directory=str(self.create_tmpdir()), - ) as cluster: - with DaskClient(cluster) as client: - df: dd.DataFrame = self.load_data() - if self.dev: - logger.info("Dev mode enabled: limiting to 1000 patients") - patients = df["patient_id"].unique().head(1000).tolist() - filter = df["patient_id"].isin(patients) - df = df[filter] - - logger.info(f"Caching event dataframe to {ret_path}...") - collection = df.sort_values("patient_id").to_parquet( - ret_path, - write_index=False, - compute=False, - ) - handle = client.compute(collection) - dask_progress(handle) - handle.result() # type: ignore - except Exception as e: - if ret_path.exists(): - logger.error(f"Error during caching, removing incomplete file {ret_path}") - ret_path.unlink() - raise e - finally: - self.clean_tmpdir() + self._event_transform(self.load_data(), ret_path) self._global_event_df = ret_path return pl.scan_parquet( @@ -824,6 +827,7 @@ def set_task( str(task_df_path), item_loader=ParquetLoader(), ) + # TODO: use our own implementation to have more control over the processing litdata.optimize( fn=builder.transform, inputs=litdata.StreamingDataLoader( From 8d66343d4dde741567b4051d91713dd855fed69a Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 24 Dec 2025 04:13:21 -0500 Subject: [PATCH 06/35] Fix set_task --- pyhealth/datasets/base_dataset.py | 72 +++++++++++++++---------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 9bb88a09..0aec40d4 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -804,42 +804,6 @@ def set_task( task_df_path, num_workers, ) - - # Build processors and fit on the dataset - logger.info(f"Fitting processors on the dataset...") - dataset = litdata.StreamingDataset( - str(task_df_path), - item_loader=ParquetLoader(), - transform=lambda x: pickle.loads(x["sample"]), - ) - builder = SampleBuilder( - input_schema=task.input_schema, # type: ignore - output_schema=task.output_schema, # type: ignore - input_processors=input_processors, - output_processors=output_processors, - ) - builder.fit(dataset) - builder.save(str(samples_path / "schema.pkl")) - - # Apply processors and save final samples to cache_dir - logger.info(f"Processing samples and saving to {samples_path}...") - dataset = litdata.StreamingDataset( - str(task_df_path), - item_loader=ParquetLoader(), - ) - # TODO: use our own implementation to have more control over the processing - litdata.optimize( - fn=builder.transform, - inputs=litdata.StreamingDataLoader( - dataset, - batch_size=1, - collate_fn=_uncollate, - ), - output_dir=str(samples_path), - chunk_bytes="64MB", - num_workers=num_workers, - ) - logger.info(f"Cached processed samples to {samples_path}") except Exception as e: logger.error(f"Error during set_task, cleaning up cache directory: {cache_dir}") shutil.rmtree(cache_dir) @@ -847,6 +811,42 @@ def set_task( finally: self.clean_tmpdir() + # Build processors and fit on the dataset + logger.info(f"Fitting processors on the dataset...") + dataset = litdata.StreamingDataset( + str(task_df_path), + item_loader=ParquetLoader(), + transform=lambda x: pickle.loads(x["sample"]), + ) + builder = SampleBuilder( + input_schema=task.input_schema, # type: ignore + output_schema=task.output_schema, # type: ignore + input_processors=input_processors, + output_processors=output_processors, + ) + builder.fit(dataset) + builder.save(str(samples_path / "schema.pkl")) + + # Apply processors and save final samples to cache_dir + logger.info(f"Processing samples and saving to {samples_path}...") + dataset = litdata.StreamingDataset( + str(task_df_path), + item_loader=ParquetLoader(), + ) + # TODO: use our own implementation to have more control over the processing + litdata.optimize( + fn=builder.transform, + inputs=litdata.StreamingDataLoader( + dataset, + batch_size=1, + collate_fn=_uncollate, + ), + output_dir=str(samples_path), + chunk_bytes="64MB", + num_workers=num_workers, + ) + logger.info(f"Cached processed samples to {samples_path}") + return SampleDataset( path=str(samples_path), dataset_name=self.dataset_name, From f1c55ae0b1c834849c5520598be291644b0f581d Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 24 Dec 2025 05:29:21 -0500 Subject: [PATCH 07/35] Refactor set_task --- pyhealth/datasets/base_dataset.py | 225 ++++++++++++++++++++++-------- 1 file changed, 165 insertions(+), 60 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 0aec40d4..57a29576 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -11,7 +11,6 @@ import json import uuid import platformdirs -import tempfile import multiprocessing import multiprocessing.queues import shutil @@ -19,6 +18,7 @@ import litdata from litdata.streaming.item_loader import ParquetLoader from litdata.processing.data_processor import in_notebook +from litdata.streaming.writer import BinaryWriter import pyarrow as pa import pyarrow.csv as pv import pyarrow.parquet as pq @@ -30,6 +30,8 @@ from dask.distributed import Client as DaskClient, LocalCluster as DaskCluster, progress as dask_progress import narwhals as nw import itertools +import numpy as np +import more_itertools from ..data import Patient from ..tasks import BaseTask @@ -219,37 +221,109 @@ def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, P class _FakeQueue: def put(self, x): pass - - # Use a batch size 128 can reduce runtime by 30%. - BATCH_SIZE = 128 - + + BATCH_SIZE = 128 # Use a batch size 128 can reduce runtime by 30%. + worker_id, task, patient_ids, global_event_df, output_dir = args + os.environ["DATA_OPTIMIZER_GLOBAL_RANK"] = str(worker_id) # For BinaryWriter to determine the rank. logger.info(f"Worker {args[0]} started processing {len(list(args[2]))} patients. (Polars threads: {pl.thread_pool_size()})") + + writer = BinaryWriter( + cache_dir=str(output_dir), + chunk_size=67_108_864, # 64 MB + ) + progress = _task_transform_queue or _FakeQueue() + + writer_index = 0 + batches = itertools.batched(patient_ids, BATCH_SIZE) + for batch in batches: + complete = 0 + patients = ( + global_event_df.filter(pl.col("patient_id").is_in(batch)) + .collect(engine="streaming") + .partition_by("patient_id", as_dict=True) + ) + for patient_id, patient_df in patients.items(): + patient_id = patient_id[0] # Extract string from single-element list + patient = Patient(patient_id=patient_id, data_source=patient_df) + for sample in task(patient): + writer.add_item(writer_index, {"sample": pickle.dumps(sample)}) + writer_index += 1 + complete += 1 + progress.put(complete) + writer.done() + + logger.info(f"Worker {args[0]} finished processing patients.") - worker_id, task, patient_ids, global_event_df, output_dir = args - queue = _task_transform_queue or _FakeQueue() +_proc_transform_queue: multiprocessing.queues.Queue | None = None - with _ParquetWriter( - output_dir / f"chunk_{worker_id:03d}.parquet", - pa.schema([("sample", pa.binary())]), - ) as writer: - batches = itertools.batched(patient_ids, BATCH_SIZE) +def _proc_transform_init(queue: multiprocessing.queues.Queue) -> None: + """ + Initializer for worker processes to set up a global queue. + + Args: + queue (multiprocessing.queues.Queue): The queue for progress tracking. + """ + global _proc_transform_queue + _proc_transform_queue = queue - for batch in batches: - count = 0 - patients = ( - global_event_df.filter(pl.col("patient_id").is_in(batch)) - .collect(engine="streaming") - .partition_by("patient_id", as_dict=True) - ) - for patient_id, patient_df in patients.items(): - patient_id = patient_id[0] # Extract string from single-element list - patient = Patient(patient_id=patient_id, data_source=patient_df) - for sample in task(patient): - writer.append({"sample": pickle.dumps(sample)}) - count += 1 - queue.put(count) +def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None: + """ + Worker function to apply processors on a chunk of samples. + + Args: + args (tuple): A tuple containing: + worker_id (int): The ID of the worker. + task_df (Path): The path to the task dataframe. + start_idx (int): The start index of samples to process. + end_idx (int): The end index of samples to process. + output_dir (Path): The output directory to save results. + """ + class _FakeQueue: + def put(self, x): + pass + + BATCH_SIZE = 128 + worker_id, task_df, start_idx, end_idx, output_dir = args + os.environ["DATA_OPTIMIZER_GLOBAL_RANK"] = str(worker_id) # For BinaryWriter to determine the rank. + logger.info(f"Worker {args[0]} started processing {end_idx - start_idx} samples. ({start_idx} to {end_idx})") + + writer = BinaryWriter( + cache_dir=str(output_dir), + chunk_size=67_108_864, # 64 MB + ) + progress = _proc_transform_queue or _FakeQueue() + + dataset = litdata.StreamingDataset(str(task_df)) + complete = 0 + with open(f"{output_dir}/schema.pkl", "rb") as f: + metadata = pickle.load(f) + + input_processors = metadata["input_processors"] + output_processors = metadata["output_processors"] + + writer_index = 0 + + for i in range(start_idx, end_idx): + transformed: Dict[str, Any] = {} + for key, value in pickle.loads(dataset[i]["sample"]).items(): + if key in input_processors: + transformed[key] = input_processors[key].process(value) + elif key in output_processors: + transformed[key] = output_processors[key].process(value) + else: + transformed[key] = value + writer.add_item(writer_index, transformed) + writer_index += 1 + complete += 1 + + if complete >= BATCH_SIZE: + progress.put(complete) + complete = 0 + + if complete > 0: + progress.put(complete) + writer.done() - logger.info(f"Worker {args[0]} finished processing patients.") class BaseDataset(ABC): """Abstract base class for all PyHealth datasets. @@ -685,14 +759,14 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> if in_notebook(): logger.info("Detected Jupyter notebook environment, setting num_workers to 1") num_workers = 1 - + num_workers = min(num_workers, len(patient_ids)) # Avoid spawning empty workers + if num_workers == 1: logger.info("Single worker mode, processing sequentially") _task_transform_fn((0, task, patient_ids, global_event_df, output_dir)) litdata.index_parquet_dataset(str(output_dir)) return - - num_workers = min(num_workers, len(patient_ids)) # Avoid spawning empty workers + batch_size = len(patient_ids) // num_workers + 1 # spwan is required for polars in multiprocessing, see https://docs.pola.rs/user-guide/misc/multiprocessing/#summary @@ -716,8 +790,8 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> while not queue.empty(): progress.update(queue.get()) result.get() # ensure exceptions are raised + BinaryWriter(cache_dir=str(output_dir)).merge(num_workers) - litdata.index_parquet_dataset(str(output_dir)) logger.info(f"Task transformation completed and saved to {output_dir}") except Exception as e: logger.error(f"Error during task transformation, cleaning up output directory: {output_dir}") @@ -728,7 +802,57 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> os.environ["POLARS_MAX_THREADS"] = old_polars_max_threads else: os.environ.pop("POLARS_MAX_THREADS", None) - + + def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> None: + self._main_guard(self._proc_transform.__name__) + try: + logger.info(f"Applying processors on data with {num_workers} workers...") + num_samples = len(litdata.StreamingDataset( + str(task_df), + item_loader=ParquetLoader(), + )) + + if in_notebook(): + logger.info("Detected Jupyter notebook environment, setting num_workers to 1") + num_workers = 1 + + num_workers = min(num_workers, num_samples) # Avoid spawning empty workers + if num_workers == 1: + logger.info("Single worker mode, processing sequentially") + _proc_transform_fn((0, task_df, 0, num_samples, output_dir)) + BinaryWriter(cache_dir=str(output_dir)).merge(num_workers) + return + + ctx = multiprocessing.get_context("spawn") + queue = ctx.Queue() + linspace = more_itertools.sliding_window(np.linspace(0, num_samples, num_workers + 1, dtype=int), 2) + args_list = [( + worker_id, + task_df, + start, + end, + output_dir, + ) for worker_id, (start, end) in enumerate(linspace)] + with ctx.Pool(processes=num_workers, initializer=_proc_transform_init, initargs=(queue,)) as pool: + result = pool.map_async(_proc_transform_fn, args_list) # type: ignore + with tqdm(total=num_samples) as progress: + while not result.ready(): + while not queue.empty(): + progress.update(queue.get()) + + # remaining items + while not queue.empty(): + progress.update(queue.get()) + result.get() # ensure exceptions are raised + BinaryWriter(cache_dir=str(output_dir)).merge(num_workers) + + logger.info(f"Processor transformation completed and saved to {output_dir}") + except Exception as e: + logger.error(f"Error during processor transformation.") + shutil.rmtree(output_dir) + raise e + finally: + self.clean_tmpdir() def set_task( self, @@ -798,24 +922,16 @@ def set_task( # Check if index.json exists to verify cache integrity, this # is the standard file for litdata.StreamingDataset if not (task_df_path / "index.json").exists(): - try: - self._task_transform( - task, - task_df_path, - num_workers, - ) - except Exception as e: - logger.error(f"Error during set_task, cleaning up cache directory: {cache_dir}") - shutil.rmtree(cache_dir) - raise e - finally: - self.clean_tmpdir() + self._task_transform( + task, + task_df_path, + num_workers, + ) # Build processors and fit on the dataset logger.info(f"Fitting processors on the dataset...") dataset = litdata.StreamingDataset( str(task_df_path), - item_loader=ParquetLoader(), transform=lambda x: pickle.loads(x["sample"]), ) builder = SampleBuilder( @@ -829,21 +945,10 @@ def set_task( # Apply processors and save final samples to cache_dir logger.info(f"Processing samples and saving to {samples_path}...") - dataset = litdata.StreamingDataset( - str(task_df_path), - item_loader=ParquetLoader(), - ) - # TODO: use our own implementation to have more control over the processing - litdata.optimize( - fn=builder.transform, - inputs=litdata.StreamingDataLoader( - dataset, - batch_size=1, - collate_fn=_uncollate, - ), - output_dir=str(samples_path), - chunk_bytes="64MB", - num_workers=num_workers, + self._proc_transform( + task_df_path, + samples_path, + num_workers, ) logger.info(f"Cached processed samples to {samples_path}") From cc91385066f6a2577df9b52740ebdcac5e8d68d5 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 24 Dec 2025 05:55:03 -0500 Subject: [PATCH 08/35] Fix up --- pyhealth/datasets/base_dataset.py | 28 +++++++++++----------------- pyproject.toml | 1 + 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 57a29576..ca1cd54a 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -227,10 +227,7 @@ def put(self, x): os.environ["DATA_OPTIMIZER_GLOBAL_RANK"] = str(worker_id) # For BinaryWriter to determine the rank. logger.info(f"Worker {args[0]} started processing {len(list(args[2]))} patients. (Polars threads: {pl.thread_pool_size()})") - writer = BinaryWriter( - cache_dir=str(output_dir), - chunk_size=67_108_864, # 64 MB - ) + writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") progress = _task_transform_queue or _FakeQueue() writer_index = 0 @@ -287,10 +284,7 @@ def put(self, x): os.environ["DATA_OPTIMIZER_GLOBAL_RANK"] = str(worker_id) # For BinaryWriter to determine the rank. logger.info(f"Worker {args[0]} started processing {end_idx - start_idx} samples. ({start_idx} to {end_idx})") - writer = BinaryWriter( - cache_dir=str(output_dir), - chunk_size=67_108_864, # 64 MB - ) + writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") progress = _proc_transform_queue or _FakeQueue() dataset = litdata.StreamingDataset(str(task_df)) @@ -820,7 +814,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> if num_workers == 1: logger.info("Single worker mode, processing sequentially") _proc_transform_fn((0, task_df, 0, num_samples, output_dir)) - BinaryWriter(cache_dir=str(output_dir)).merge(num_workers) + BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) return ctx = multiprocessing.get_context("spawn") @@ -844,7 +838,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> while not queue.empty(): progress.update(queue.get()) result.get() # ensure exceptions are raised - BinaryWriter(cache_dir=str(output_dir)).merge(num_workers) + BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) logger.info(f"Processor transformation completed and saved to {output_dir}") except Exception as e: @@ -866,11 +860,11 @@ def set_task( """Processes the base dataset to generate the task-specific sample dataset. The cache structure is as follows:: - task_df.parquet/ # Intermediate task dataframe after task transformation - samples_{uuid}/ # Final processed samples after applying processors - schema.pkl # Saved SampleBuilder schema - *.parquet # Processed sample files - samples_{uuid}/ + task_df.ld/ # Intermediate task dataframe after task transformation + samples_{uuid}.ld/ # Final processed samples after applying processors + schema.pkl # Saved SampleBuilder schema + *.bin # Processed sample files + samples_{uuid}.ld/ ... Args: @@ -916,8 +910,8 @@ def set_task( cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) - task_df_path = Path(cache_dir) / "task_df.parquet" - samples_path = Path(cache_dir) / f"samples_{uuid.uuid4()}" + task_df_path = Path(cache_dir) / "task_df.ld" + samples_path = Path(cache_dir) / f"samples_{uuid.uuid4()}.ld" # Check if index.json exists to verify cache integrity, this # is the standard file for litdata.StreamingDataset diff --git a/pyproject.toml b/pyproject.toml index 81e1a14b..06157b83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "litdata~=0.2.58", "pyarrow~=22.0.0", "narwhals~=2.13.0", + "more-itertools~=10.8.0", ] license = "BSD-3-Clause" license-files = ["LICENSE.md"] From decd35e1534e009941bb190414e2edaec48df4d2 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 24 Dec 2025 06:17:23 -0500 Subject: [PATCH 09/35] Fixup --- pyhealth/datasets/base_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index ca1cd54a..010bf5e1 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -784,7 +784,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> while not queue.empty(): progress.update(queue.get()) result.get() # ensure exceptions are raised - BinaryWriter(cache_dir=str(output_dir)).merge(num_workers) + BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) logger.info(f"Task transformation completed and saved to {output_dir}") except Exception as e: From a456f12294f576a29d2c71a9932ace3ebce8cbe9 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 24 Dec 2025 06:24:17 -0500 Subject: [PATCH 10/35] rename --- pyhealth/datasets/base_dataset.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 010bf5e1..c417cffb 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -230,7 +230,7 @@ def put(self, x): writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") progress = _task_transform_queue or _FakeQueue() - writer_index = 0 + write_index = 0 batches = itertools.batched(patient_ids, BATCH_SIZE) for batch in batches: complete = 0 @@ -243,8 +243,8 @@ def put(self, x): patient_id = patient_id[0] # Extract string from single-element list patient = Patient(patient_id=patient_id, data_source=patient_df) for sample in task(patient): - writer.add_item(writer_index, {"sample": pickle.dumps(sample)}) - writer_index += 1 + writer.add_item(write_index, {"sample": pickle.dumps(sample)}) + write_index += 1 complete += 1 progress.put(complete) writer.done() @@ -295,8 +295,7 @@ def put(self, x): input_processors = metadata["input_processors"] output_processors = metadata["output_processors"] - writer_index = 0 - + write_index = 0 for i in range(start_idx, end_idx): transformed: Dict[str, Any] = {} for key, value in pickle.loads(dataset[i]["sample"]).items(): @@ -306,8 +305,8 @@ def put(self, x): transformed[key] = output_processors[key].process(value) else: transformed[key] = value - writer.add_item(writer_index, transformed) - writer_index += 1 + writer.add_item(write_index, transformed) + write_index += 1 complete += 1 if complete >= BATCH_SIZE: From ef39415b7d7738dc71cfa1ae080fcb2926f35207 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 24 Dec 2025 07:17:30 -0500 Subject: [PATCH 11/35] fix _task_transform_fn --- pyhealth/datasets/base_dataset.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index c417cffb..5dc9dd45 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -206,7 +206,7 @@ def _task_transform_init(queue: multiprocessing.queues.Queue) -> None: global _task_transform_queue _task_transform_queue = queue -def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, Path]) -> None: +def _task_transform_fn(args: tuple[int, int, BaseTask, Iterable[str], pl.LazyFrame, Path]) -> None: """ Worker function to apply task transformation on a chunk of patients. @@ -223,9 +223,10 @@ def put(self, x): pass BATCH_SIZE = 128 # Use a batch size 128 can reduce runtime by 30%. - worker_id, task, patient_ids, global_event_df, output_dir = args + worker_id, num_workers, task, patient_ids, global_event_df, output_dir = args os.environ["DATA_OPTIMIZER_GLOBAL_RANK"] = str(worker_id) # For BinaryWriter to determine the rank. - logger.info(f"Worker {args[0]} started processing {len(list(args[2]))} patients. (Polars threads: {pl.thread_pool_size()})") + os.environ["DATA_OPTIMIZER_NUM_WORKERS"] = str(num_workers) + logger.info(f"Worker {worker_id} started processing {len(list(patient_ids))} patients. (Polars threads: {pl.thread_pool_size()})") writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") progress = _task_transform_queue or _FakeQueue() @@ -244,6 +245,7 @@ def put(self, x): patient = Patient(patient_id=patient_id, data_source=patient_df) for sample in task(patient): writer.add_item(write_index, {"sample": pickle.dumps(sample)}) + logger.error(f"Worker {args[0]}, {writer._min_index}, {writer._max_index}, {writer._chunk_index}, {writer._per_sample_num_bytes}, {writer._per_sample_num_items}, {len(writer._serialized_items)}") write_index += 1 complete += 1 progress.put(complete) @@ -756,7 +758,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> if num_workers == 1: logger.info("Single worker mode, processing sequentially") - _task_transform_fn((0, task, patient_ids, global_event_df, output_dir)) + _task_transform_fn((0, num_workers,task, patient_ids, global_event_df, output_dir)) litdata.index_parquet_dataset(str(output_dir)) return @@ -767,6 +769,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> queue = ctx.Queue() args_list = [( worker_id, + num_workers, task, pids, global_event_df, From a7a06c18f073ec0f9ca0283eb52826c431c07140 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 24 Dec 2025 07:19:15 -0500 Subject: [PATCH 12/35] fixup --- pyhealth/datasets/base_dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 5dc9dd45..fd5ee9f0 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -914,6 +914,9 @@ def set_task( task_df_path = Path(cache_dir) / "task_df.ld" samples_path = Path(cache_dir) / f"samples_{uuid.uuid4()}.ld" + + task_df_path.mkdir(parents=True, exist_ok=True) + samples_path.mkdir(parents=True, exist_ok=True) # Check if index.json exists to verify cache integrity, this # is the standard file for litdata.StreamingDataset From b466fa040a030a2ac1494fc449bf8e5c54faacde Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 24 Dec 2025 07:36:34 -0500 Subject: [PATCH 13/35] Fixup --- pyhealth/datasets/base_dataset.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index fd5ee9f0..9a35a057 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -224,7 +224,7 @@ def put(self, x): BATCH_SIZE = 128 # Use a batch size 128 can reduce runtime by 30%. worker_id, num_workers, task, patient_ids, global_event_df, output_dir = args - os.environ["DATA_OPTIMIZER_GLOBAL_RANK"] = str(worker_id) # For BinaryWriter to determine the rank. + os.environ["DATA_OPTIMIZER_GLOBAL_RANK"] = str(worker_id) os.environ["DATA_OPTIMIZER_NUM_WORKERS"] = str(num_workers) logger.info(f"Worker {worker_id} started processing {len(list(patient_ids))} patients. (Polars threads: {pl.thread_pool_size()})") @@ -265,7 +265,7 @@ def _proc_transform_init(queue: multiprocessing.queues.Queue) -> None: global _proc_transform_queue _proc_transform_queue = queue -def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None: +def _proc_transform_fn(args: tuple[int, int,Path, int, int, Path]) -> None: """ Worker function to apply processors on a chunk of samples. @@ -282,9 +282,10 @@ def put(self, x): pass BATCH_SIZE = 128 - worker_id, task_df, start_idx, end_idx, output_dir = args - os.environ["DATA_OPTIMIZER_GLOBAL_RANK"] = str(worker_id) # For BinaryWriter to determine the rank. - logger.info(f"Worker {args[0]} started processing {end_idx - start_idx} samples. ({start_idx} to {end_idx})") + worker_id, num_workers, task_df, start_idx, end_idx, output_dir = args + os.environ["DATA_OPTIMIZER_GLOBAL_RANK"] = str(worker_id) + os.environ["DATA_OPTIMIZER_NUM_WORKERS"] = str(num_workers) + logger.info(f"Worker {worker_id} started processing {end_idx - start_idx} samples. ({start_idx} to {end_idx})") writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") progress = _proc_transform_queue or _FakeQueue() @@ -803,10 +804,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> self._main_guard(self._proc_transform.__name__) try: logger.info(f"Applying processors on data with {num_workers} workers...") - num_samples = len(litdata.StreamingDataset( - str(task_df), - item_loader=ParquetLoader(), - )) + num_samples = len(litdata.StreamingDataset(str(task_df))) if in_notebook(): logger.info("Detected Jupyter notebook environment, setting num_workers to 1") @@ -815,7 +813,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> num_workers = min(num_workers, num_samples) # Avoid spawning empty workers if num_workers == 1: logger.info("Single worker mode, processing sequentially") - _proc_transform_fn((0, task_df, 0, num_samples, output_dir)) + _proc_transform_fn((0, num_workers, task_df, 0, num_samples, output_dir)) BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) return @@ -824,6 +822,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> linspace = more_itertools.sliding_window(np.linspace(0, num_samples, num_workers + 1, dtype=int), 2) args_list = [( worker_id, + num_workers, task_df, start, end, From d08f66fdfa09a51dd65854f3ae7432e78c1deb62 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 24 Dec 2025 07:38:44 -0500 Subject: [PATCH 14/35] Fixup --- pyhealth/datasets/base_dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 9a35a057..61c7f828 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -245,13 +245,12 @@ def put(self, x): patient = Patient(patient_id=patient_id, data_source=patient_df) for sample in task(patient): writer.add_item(write_index, {"sample": pickle.dumps(sample)}) - logger.error(f"Worker {args[0]}, {writer._min_index}, {writer._max_index}, {writer._chunk_index}, {writer._per_sample_num_bytes}, {writer._per_sample_num_items}, {len(writer._serialized_items)}") write_index += 1 complete += 1 progress.put(complete) writer.done() - logger.info(f"Worker {args[0]} finished processing patients.") + logger.info(f"Worker {worker_id} finished processing patients.") _proc_transform_queue: multiprocessing.queues.Queue | None = None @@ -319,6 +318,8 @@ def put(self, x): if complete > 0: progress.put(complete) writer.done() + + logger.info(f"Worker {worker_id} finished processing samples.") class BaseDataset(ABC): @@ -397,7 +398,7 @@ def cache_dir(self) -> Path: uuid.uuid5(uuid.NAMESPACE_DNS, id_str) ) cache_dir.mkdir(parents=True, exist_ok=True) - print(f"No cache_dir provided. Using default cache dir: {cache_dir}") + logger.info(f"No cache_dir provided. Using default cache dir: {cache_dir}") self._cache_dir = cache_dir else: # Ensure the explicitly provided cache_dir exists From 83a45b7150c365c2fc4ee573768d68c9f299a05e Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Wed, 24 Dec 2025 08:03:51 -0500 Subject: [PATCH 15/35] Fix mimic4 cache dir --- pyhealth/datasets/mimic4.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/pyhealth/datasets/mimic4.py b/pyhealth/datasets/mimic4.py index 39835580..6a01033b 100644 --- a/pyhealth/datasets/mimic4.py +++ b/pyhealth/datasets/mimic4.py @@ -264,12 +264,11 @@ def __init__( logger.info( f"Initializing MIMIC4EHRDataset with tables: {ehr_tables} (dev mode: {dev})" ) - ehr_cache_dir = None if cache_dir is None else f"{cache_dir}/ehr" self.sub_datasets["ehr"] = MIMIC4EHRDataset( root=ehr_root, tables=ehr_tables, config_path=ehr_config_path, - cache_dir=ehr_cache_dir, + cache_dir=str(self.cache_dir), dev=dev, num_workers=num_workers, ) @@ -280,12 +279,11 @@ def __init__( logger.info( f"Initializing MIMIC4NoteDataset with tables: {note_tables} (dev mode: {dev})" ) - note_cache_dir = None if cache_dir is None else f"{cache_dir}/note" self.sub_datasets["note"] = MIMIC4NoteDataset( root=note_root, tables=note_tables, config_path=note_config_path, - cache_dir=note_cache_dir, + cache_dir=str(self.cache_dir), dev=dev, num_workers=num_workers, ) @@ -296,12 +294,11 @@ def __init__( logger.info( f"Initializing MIMIC4CXRDataset with tables: {cxr_tables} (dev mode: {dev})" ) - cxr_cache_dir = None if cache_dir is None else f"{cache_dir}/cxr" self.sub_datasets["cxr"] = MIMIC4CXRDataset( root=cxr_root, tables=cxr_tables, config_path=cxr_config_path, - cache_dir=cxr_cache_dir, + cache_dir=str(self.cache_dir), dev=dev, num_workers=num_workers, ) From 5430c8fc4cbd5082380c72c314196fd948dfa79d Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 06:19:34 -0500 Subject: [PATCH 16/35] update memtest --- examples/memtest.py | 129 ++++++++++++++++++++------------------------ 1 file changed, 58 insertions(+), 71 deletions(-) diff --git a/examples/memtest.py b/examples/memtest.py index 056b21f8..5635c639 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -8,7 +8,6 @@ 4. Training a StageNet model """ -# %% if __name__ == "__main__": from pyhealth.datasets import ( MIMIC4Dataset, @@ -20,7 +19,6 @@ from pyhealth.trainer import Trainer import torch - # %% STEP 1: Load MIMIC-IV base dataset base_dataset = MIMIC4Dataset( ehr_root="/home/logic/physionet.org/files/mimiciv/3.1/", ehr_tables=[ @@ -33,74 +31,63 @@ dev=False, ) - # %% # STEP 2: Apply StageNet mortality prediction task - sample_dataset = base_dataset.set_task( + with base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), num_workers=4, - ) - - print(f"Total samples: {len(sample_dataset)}") - print(f"Input schema: {sample_dataset.input_schema}") - print(f"Output schema: {sample_dataset.output_schema}") - - # %% Inspect a sample - sample = next(iter(sample_dataset)) - print("\nSample structure:") - print(f" Patient ID: {sample['patient_id']}") - print(f"ICD Codes: {sample['icd_codes']}") - print(f" Labs shape: {len(sample['labs'][0])} timesteps") - print(f" Mortality: {sample['mortality']}") - - # %% STEP 3: Split dataset - train_dataset, val_dataset, test_dataset = split_by_patient( - sample_dataset, [0.8, 0.1, 0.1] - ) - - # Create dataloaders - train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) - val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) - test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) - - # %% STEP 4: Initialize StageNet model - model = StageNet( - dataset=sample_dataset, - embedding_dim=128, - chunk_size=128, - levels=3, - dropout=0.3, - ) - - num_params = sum(p.numel() for p in model.parameters()) - print(f"\nModel initialized with {num_params} parameters") - - # %% STEP 5: Train the model - trainer = Trainer( - model=model, - device="cpu", # or "cpu" - metrics=["pr_auc", "roc_auc", "accuracy", "f1"], - ) - - trainer.train( - train_dataloader=train_loader, - val_dataloader=val_loader, - epochs=5, - monitor="roc_auc", - optimizer_params={"lr": 1e-5}, - ) - - # %% STEP 6: Evaluate on test set - results = trainer.evaluate(test_loader) - print("\nTest Results:") - for metric, value in results.items(): - print(f" {metric}: {value:.4f}") - - # %% STEP 7: Inspect model predictions - sample_batch = next(iter(test_loader)) - with torch.no_grad(): - output = model(**sample_batch) - - print("\nSample predictions:") - print(f" Predicted probabilities: {output['y_prob'][:5]}") - print(f" True labels: {output['y_true'][:5]}") - - # %% + ) as sample_dataset: + print(f"Total samples: {len(sample_dataset)}") + print(f"Input schema: {sample_dataset.input_schema}") + print(f"Output schema: {sample_dataset.output_schema}") + + sample = next(iter(sample_dataset)) + print("\nSample structure:") + print(f" Patient ID: {sample['patient_id']}") + print(f"ICD Codes: {sample['icd_codes']}") + print(f" Labs shape: {len(sample['labs'][0])} timesteps") + print(f" Mortality: {sample['mortality']}") + + train_dataset, val_dataset, test_dataset = split_by_patient( + sample_dataset, [0.8, 0.1, 0.1] + ) + + train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True) + val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False) + test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False) + + model = StageNet( + dataset=sample_dataset, + embedding_dim=128, + chunk_size=128, + levels=3, + dropout=0.3, + ) + + num_params = sum(p.numel() for p in model.parameters()) + print(f"\nModel initialized with {num_params} parameters") + + trainer = Trainer( + model=model, + device="cuda:0", # or "cpu" + metrics=["pr_auc", "roc_auc", "accuracy", "f1"], + ) + + trainer.train( + train_dataloader=train_loader, + val_dataloader=val_loader, + epochs=5, + monitor="roc_auc", + optimizer_params={"lr": 1e-5}, + ) + + results = trainer.evaluate(test_loader) + print("\nTest Results:") + for metric, value in results.items(): + print(f" {metric}: {value:.4f}") + + sample_batch = next(iter(test_loader)) + with torch.no_grad(): + output = model(**sample_batch) + + print("\nSample predictions:") + print(f" Predicted probabilities: {output['y_prob'][:5]}") + print(f" True labels: {output['y_true'][:5]}") From ba4c4f2e260a5f79e5bb23d09e137a6c016500cb Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 06:41:03 -0500 Subject: [PATCH 17/35] more workers --- examples/memtest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/memtest.py b/examples/memtest.py index 5635c639..234cf9c7 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -29,11 +29,11 @@ "labevents", ], dev=False, + num_workers=8, ) with base_dataset.set_task( MortalityPredictionStageNetMIMIC4(), - num_workers=4, ) as sample_dataset: print(f"Total samples: {len(sample_dataset)}") print(f"Input schema: {sample_dataset.input_schema}") From 74fce5f121a4f07064918f67cea6325dece2a67a Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 07:08:53 -0500 Subject: [PATCH 18/35] Fix signle thread --- pyhealth/datasets/base_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 61c7f828..5d31a6ea 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -761,7 +761,6 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> if num_workers == 1: logger.info("Single worker mode, processing sequentially") _task_transform_fn((0, num_workers,task, patient_ids, global_event_df, output_dir)) - litdata.index_parquet_dataset(str(output_dir)) return batch_size = len(patient_ids) // num_workers + 1 From 5ad6ca0377da1a4692b3c0ea5b668d58b22bea5c Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 07:09:39 -0500 Subject: [PATCH 19/35] Fix single thread --- pyhealth/datasets/base_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 5d31a6ea..147617d9 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -761,6 +761,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> if num_workers == 1: logger.info("Single worker mode, processing sequentially") _task_transform_fn((0, num_workers,task, patient_ids, global_event_df, output_dir)) + BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) return batch_size = len(patient_ids) // num_workers + 1 From 06048133379504c8f1484fe10309f3edc1cf1b55 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 07:13:29 -0500 Subject: [PATCH 20/35] Fixup --- examples/memtest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/memtest.py b/examples/memtest.py index 234cf9c7..b5e0b604 100644 --- a/examples/memtest.py +++ b/examples/memtest.py @@ -67,7 +67,7 @@ trainer = Trainer( model=model, - device="cuda:0", # or "cpu" + device="cpu", # or "cpu" metrics=["pr_auc", "roc_auc", "accuracy", "f1"], ) From 99da1a9e36b38a2a5ceb1e1aebaa17efe176936e Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 07:17:55 -0500 Subject: [PATCH 21/35] Fix test --- tests/core/test_caching.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index 86508433..638832e7 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -132,8 +132,8 @@ def test_set_task_writes_cache_and_metadata(self): self.assertEqual(self.task.call_count, 2) # Cache artifacts should be present for StreamingDataset - self.assertTrue((cache_dir / "index.json").exists()) - self.assertTrue((cache_dir / "schema.pkl").exists()) + self.assertTrue((cache_dir / "task_df.ld" / "index.json").exists()) + self.assertTrue((cache_dir / "task_df.ld" / "schema.pkl").exists()) # Check processed sample structure and metadata persisted sample = sample_dataset[0] @@ -155,7 +155,7 @@ def test_default_cache_dir_is_used(self): sample_dataset = self.dataset.set_task(self.task) self.assertTrue(task_cache.exists()) - self.assertTrue((task_cache / "index.json").exists()) + self.assertTrue((task_cache / "task_df.ld" / "index.json").exists()) self.assertTrue((self.dataset.cache_dir / "global_event_df.parquet").exists()) self.assertEqual(len(sample_dataset), 4) From 647ce7de349330757ca976a1f943fb015c873bd6 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 07:25:59 -0500 Subject: [PATCH 22/35] Fix up environ. --- pyhealth/datasets/base_dataset.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 147617d9..9f8c1f2a 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -206,7 +206,7 @@ def _task_transform_init(queue: multiprocessing.queues.Queue) -> None: global _task_transform_queue _task_transform_queue = queue -def _task_transform_fn(args: tuple[int, int, BaseTask, Iterable[str], pl.LazyFrame, Path]) -> None: +def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, Path]) -> None: """ Worker function to apply task transformation on a chunk of patients. @@ -223,9 +223,8 @@ def put(self, x): pass BATCH_SIZE = 128 # Use a batch size 128 can reduce runtime by 30%. - worker_id, num_workers, task, patient_ids, global_event_df, output_dir = args + worker_id, task, patient_ids, global_event_df, output_dir = args os.environ["DATA_OPTIMIZER_GLOBAL_RANK"] = str(worker_id) - os.environ["DATA_OPTIMIZER_NUM_WORKERS"] = str(num_workers) logger.info(f"Worker {worker_id} started processing {len(list(patient_ids))} patients. (Polars threads: {pl.thread_pool_size()})") writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") @@ -264,7 +263,7 @@ def _proc_transform_init(queue: multiprocessing.queues.Queue) -> None: global _proc_transform_queue _proc_transform_queue = queue -def _proc_transform_fn(args: tuple[int, int,Path, int, int, Path]) -> None: +def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None: """ Worker function to apply processors on a chunk of samples. @@ -281,9 +280,8 @@ def put(self, x): pass BATCH_SIZE = 128 - worker_id, num_workers, task_df, start_idx, end_idx, output_dir = args + worker_id, task_df, start_idx, end_idx, output_dir = args os.environ["DATA_OPTIMIZER_GLOBAL_RANK"] = str(worker_id) - os.environ["DATA_OPTIMIZER_NUM_WORKERS"] = str(num_workers) logger.info(f"Worker {worker_id} started processing {end_idx - start_idx} samples. ({start_idx} to {end_idx})") writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") @@ -757,10 +755,11 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> logger.info("Detected Jupyter notebook environment, setting num_workers to 1") num_workers = 1 num_workers = min(num_workers, len(patient_ids)) # Avoid spawning empty workers - + + os.environ["DATA_OPTIMIZER_NUM_WORKERS"] = str(num_workers) if num_workers == 1: logger.info("Single worker mode, processing sequentially") - _task_transform_fn((0, num_workers,task, patient_ids, global_event_df, output_dir)) + _task_transform_fn((0, task, patient_ids, global_event_df, output_dir)) BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) return @@ -771,7 +770,6 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> queue = ctx.Queue() args_list = [( worker_id, - num_workers, task, pids, global_event_df, @@ -796,6 +794,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> shutil.rmtree(output_dir) raise e finally: + os.environ.pop("DATA_OPTIMIZER_NUM_WORKERS", None) if old_polars_max_threads is not None: os.environ["POLARS_MAX_THREADS"] = old_polars_max_threads else: @@ -812,9 +811,11 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> num_workers = 1 num_workers = min(num_workers, num_samples) # Avoid spawning empty workers + + os.environ["DATA_OPTIMIZER_NUM_WORKERS"] = str(num_workers) if num_workers == 1: logger.info("Single worker mode, processing sequentially") - _proc_transform_fn((0, num_workers, task_df, 0, num_samples, output_dir)) + _proc_transform_fn((0, task_df, 0, num_samples, output_dir)) BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) return @@ -823,7 +824,6 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> linspace = more_itertools.sliding_window(np.linspace(0, num_samples, num_workers + 1, dtype=int), 2) args_list = [( worker_id, - num_workers, task_df, start, end, @@ -848,6 +848,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> shutil.rmtree(output_dir) raise e finally: + os.environ.pop("DATA_OPTIMIZER_NUM_WORKERS", None) self.clean_tmpdir() def set_task( From 5d22c4d0c41b76e43db56d9b84cd34b63b20786f Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 07:34:29 -0500 Subject: [PATCH 23/35] Fix test --- tests/core/test_caching.py | 60 ++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index 638832e7..e5ec990c 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -121,33 +121,43 @@ def test_set_task_signature(self): def test_set_task_writes_cache_and_metadata(self): """Ensure set_task materializes cache files and schema metadata.""" cache_dir = self._task_cache_dir() - sample_dataset = self.dataset.set_task( + with self.dataset.set_task( self.task, cache_dir=cache_dir, cache_format="parquet" - ) - - self.assertIsInstance(sample_dataset, SampleDataset) - self.assertEqual(sample_dataset.dataset_name, "TestDataset") - self.assertEqual(sample_dataset.task_name, self.task.task_name) - self.assertEqual(len(sample_dataset), 4) - self.assertEqual(self.task.call_count, 2) - - # Cache artifacts should be present for StreamingDataset + ) as sample_dataset: + self.assertIsInstance(sample_dataset, SampleDataset) + self.assertEqual(sample_dataset.dataset_name, "TestDataset") + self.assertEqual(sample_dataset.task_name, self.task.task_name) + self.assertEqual(len(sample_dataset), 4) + self.assertEqual(self.task.call_count, 2) + + # Ensure intermediate cache files are created + self.assertTrue((cache_dir / "task_df.ld" / "index.json").exists()) + + # Cache artifacts should be present for StreamingDataset + assert sample_dataset.input_dir.path is not None + sample_dir = Path(sample_dataset.input_dir.path) + self.assertTrue((sample_dir / "index.json").exists()) + self.assertTrue((sample_dir / "schema.pkl").exists()) + + # Check processed sample structure and metadata persisted + sample = sample_dataset[0] + self.assertIn("test_attribute", sample) + self.assertIn("test_label", sample) + self.assertIn("patient_id", sample) + self.assertIsInstance(sample["test_label"], torch.Tensor) + self.assertIn("test_attribute", sample_dataset.input_processors) + self.assertIn("test_label", sample_dataset.output_processors) + self.assertEqual(set(sample_dataset.patient_to_index), {"1", "2"}) + self.assertTrue( + all(len(indexes) == 2 for indexes in sample_dataset.patient_to_index.values()) + ) + self.assertEqual(sample_dataset.record_to_index, {}) + # Ensure directory is cleaned up after context exit + self.assertFalse((sample_dir / "index.json").exists()) + self.assertFalse((sample_dir / "schema.pkl").exists()) + # Ensure intermediate cache files are still present self.assertTrue((cache_dir / "task_df.ld" / "index.json").exists()) - self.assertTrue((cache_dir / "task_df.ld" / "schema.pkl").exists()) - - # Check processed sample structure and metadata persisted - sample = sample_dataset[0] - self.assertIn("test_attribute", sample) - self.assertIn("test_label", sample) - self.assertIn("patient_id", sample) - self.assertIsInstance(sample["test_label"], torch.Tensor) - self.assertIn("test_attribute", sample_dataset.input_processors) - self.assertIn("test_label", sample_dataset.output_processors) - self.assertEqual(set(sample_dataset.patient_to_index), {"1", "2"}) - self.assertTrue( - all(len(indexes) == 2 for indexes in sample_dataset.patient_to_index.values()) - ) - self.assertEqual(sample_dataset.record_to_index, {}) + def test_default_cache_dir_is_used(self): """When cache_dir is omitted, default cache dir should be used.""" From d02b3ccb89055087e71ef17dfdfb1d856e586b0a Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 07:57:23 -0500 Subject: [PATCH 24/35] better env_var management --- pyhealth/datasets/base_dataset.py | 222 ++++++++++++++---------------- pyhealth/utils.py | 23 ++++ 2 files changed, 129 insertions(+), 116 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 9f8c1f2a..60ca2610 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -38,6 +38,7 @@ from ..processors.base_processor import FeatureProcessor from .configs import load_yaml_config from .sample_dataset import SampleDataset, SampleBuilder +from ..utils import set_env # Set logging level for distributed to ERROR to reduce verbosity logging.getLogger("distributed").setLevel(logging.ERROR) @@ -224,30 +225,30 @@ def put(self, x): BATCH_SIZE = 128 # Use a batch size 128 can reduce runtime by 30%. worker_id, task, patient_ids, global_event_df, output_dir = args - os.environ["DATA_OPTIMIZER_GLOBAL_RANK"] = str(worker_id) logger.info(f"Worker {worker_id} started processing {len(list(patient_ids))} patients. (Polars threads: {pl.thread_pool_size()})") - writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") - progress = _task_transform_queue or _FakeQueue() - - write_index = 0 - batches = itertools.batched(patient_ids, BATCH_SIZE) - for batch in batches: - complete = 0 - patients = ( - global_event_df.filter(pl.col("patient_id").is_in(batch)) - .collect(engine="streaming") - .partition_by("patient_id", as_dict=True) - ) - for patient_id, patient_df in patients.items(): - patient_id = patient_id[0] # Extract string from single-element list - patient = Patient(patient_id=patient_id, data_source=patient_df) - for sample in task(patient): - writer.add_item(write_index, {"sample": pickle.dumps(sample)}) - write_index += 1 - complete += 1 - progress.put(complete) - writer.done() + with set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)): + writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") + progress = _task_transform_queue or _FakeQueue() + + write_index = 0 + batches = itertools.batched(patient_ids, BATCH_SIZE) + for batch in batches: + complete = 0 + patients = ( + global_event_df.filter(pl.col("patient_id").is_in(batch)) + .collect(engine="streaming") + .partition_by("patient_id", as_dict=True) + ) + for patient_id, patient_df in patients.items(): + patient_id = patient_id[0] # Extract string from single-element list + patient = Patient(patient_id=patient_id, data_source=patient_df) + for sample in task(patient): + writer.add_item(write_index, {"sample": pickle.dumps(sample)}) + write_index += 1 + complete += 1 + progress.put(complete) + writer.done() logger.info(f"Worker {worker_id} finished processing patients.") @@ -733,122 +734,111 @@ def default_task(self) -> Optional[BaseTask]: def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> None: self._main_guard(self._task_transform.__name__) + logger.info(f"Applying task transformations on data with {num_workers} workers...") + global_event_df = task.pre_filter(self.global_event_df) + patient_ids = ( + global_event_df.select("patient_id") + .unique() + .collect(engine="streaming") + .to_series() + # .sort can reduce runtime by 5%. + .sort() + ) + + if in_notebook(): + logger.info("Detected Jupyter notebook environment, setting num_workers to 1") + num_workers = 1 + num_workers = min(num_workers, len(patient_ids)) # Avoid spawning empty workers + # This ensures worker's polars threads are limited to avoid oversubscription, # which can lead to additional 75% speedup when num_workers is large. - old_polars_max_threads = os.environ.get("POLARS_MAX_THREADS") threads_per_worker = max(1, (os.cpu_count() or 1) // num_workers) - os.environ["POLARS_MAX_THREADS"] = str(threads_per_worker) try: - logger.info(f"Applying task transformations on data with {num_workers} workers...") - global_event_df = task.pre_filter(self.global_event_df) - patient_ids = ( - global_event_df.select("patient_id") - .unique() - .collect(engine="streaming") - .to_series() - # .sort can reduce runtime by 5%. - .sort() - ) - - if in_notebook(): - logger.info("Detected Jupyter notebook environment, setting num_workers to 1") - num_workers = 1 - num_workers = min(num_workers, len(patient_ids)) # Avoid spawning empty workers - - os.environ["DATA_OPTIMIZER_NUM_WORKERS"] = str(num_workers) - if num_workers == 1: - logger.info("Single worker mode, processing sequentially") - _task_transform_fn((0, task, patient_ids, global_event_df, output_dir)) - BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) - return - - batch_size = len(patient_ids) // num_workers + 1 - - # spwan is required for polars in multiprocessing, see https://docs.pola.rs/user-guide/misc/multiprocessing/#summary - ctx = multiprocessing.get_context("spawn") - queue = ctx.Queue() - args_list = [( - worker_id, - task, - pids, - global_event_df, - output_dir, - ) for worker_id, pids in enumerate(itertools.batched(patient_ids, batch_size))] - with ctx.Pool(processes=num_workers, initializer=_task_transform_init, initargs=(queue,)) as pool: - result = pool.map_async(_task_transform_fn, args_list) # type: ignore - with tqdm(total=len(patient_ids)) as progress: - while not result.ready(): + with set_env(POLARS_MAX_THREADS=str(threads_per_worker), DATA_OPTIMIZER_NUM_WORKERS=str(num_workers)): + if num_workers == 1: + logger.info("Single worker mode, processing sequentially") + _task_transform_fn((0, task, patient_ids, global_event_df, output_dir)) + BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) + return + + # spwan is required for polars in multiprocessing, see https://docs.pola.rs/user-guide/misc/multiprocessing/#summary + ctx = multiprocessing.get_context("spawn") + queue = ctx.Queue() + args_list = [( + worker_id, + task, + pids, + global_event_df, + output_dir, + ) for worker_id, pids in enumerate(itertools.batched(patient_ids, len(patient_ids) // num_workers + 1))] + with ctx.Pool(processes=num_workers, initializer=_task_transform_init, initargs=(queue,)) as pool: + result = pool.map_async(_task_transform_fn, args_list) # type: ignore + with tqdm(total=len(patient_ids)) as progress: + while not result.ready(): + while not queue.empty(): + progress.update(queue.get()) + + # remaining items while not queue.empty(): progress.update(queue.get()) - - # remaining items - while not queue.empty(): - progress.update(queue.get()) - result.get() # ensure exceptions are raised - BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) - - logger.info(f"Task transformation completed and saved to {output_dir}") + result.get() # ensure exceptions are raised + BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) + + logger.info(f"Task transformation completed and saved to {output_dir}") except Exception as e: logger.error(f"Error during task transformation, cleaning up output directory: {output_dir}") shutil.rmtree(output_dir) raise e - finally: - os.environ.pop("DATA_OPTIMIZER_NUM_WORKERS", None) - if old_polars_max_threads is not None: - os.environ["POLARS_MAX_THREADS"] = old_polars_max_threads - else: - os.environ.pop("POLARS_MAX_THREADS", None) def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> None: self._main_guard(self._proc_transform.__name__) - try: - logger.info(f"Applying processors on data with {num_workers} workers...") - num_samples = len(litdata.StreamingDataset(str(task_df))) - - if in_notebook(): - logger.info("Detected Jupyter notebook environment, setting num_workers to 1") - num_workers = 1 - - num_workers = min(num_workers, num_samples) # Avoid spawning empty workers - - os.environ["DATA_OPTIMIZER_NUM_WORKERS"] = str(num_workers) - if num_workers == 1: - logger.info("Single worker mode, processing sequentially") - _proc_transform_fn((0, task_df, 0, num_samples, output_dir)) - BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) - return + + logger.info(f"Applying processors on data with {num_workers} workers...") + num_samples = len(litdata.StreamingDataset(str(task_df))) - ctx = multiprocessing.get_context("spawn") - queue = ctx.Queue() - linspace = more_itertools.sliding_window(np.linspace(0, num_samples, num_workers + 1, dtype=int), 2) - args_list = [( - worker_id, - task_df, - start, - end, - output_dir, - ) for worker_id, (start, end) in enumerate(linspace)] - with ctx.Pool(processes=num_workers, initializer=_proc_transform_init, initargs=(queue,)) as pool: - result = pool.map_async(_proc_transform_fn, args_list) # type: ignore - with tqdm(total=num_samples) as progress: - while not result.ready(): + if in_notebook(): + logger.info("Detected Jupyter notebook environment, setting num_workers to 1") + num_workers = 1 + + num_workers = min(num_workers, num_samples) # Avoid spawning empty workers + try: + with set_env(POLARS_MAX_THREADS=str(num_workers)): + if num_workers == 1: + logger.info("Single worker mode, processing sequentially") + _proc_transform_fn((0, task_df, 0, num_samples, output_dir)) + BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) + return + + ctx = multiprocessing.get_context("spawn") + queue = ctx.Queue() + linspace = more_itertools.sliding_window(np.linspace(0, num_samples, num_workers + 1, dtype=int), 2) + args_list = [( + worker_id, + task_df, + start, + end, + output_dir, + ) for worker_id, (start, end) in enumerate(linspace)] + with ctx.Pool(processes=num_workers, initializer=_proc_transform_init, initargs=(queue,)) as pool: + result = pool.map_async(_proc_transform_fn, args_list) # type: ignore + with tqdm(total=num_samples) as progress: + while not result.ready(): + while not queue.empty(): + progress.update(queue.get()) + + # remaining items while not queue.empty(): progress.update(queue.get()) - - # remaining items - while not queue.empty(): - progress.update(queue.get()) - result.get() # ensure exceptions are raised - BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) - - logger.info(f"Processor transformation completed and saved to {output_dir}") + result.get() # ensure exceptions are raised + BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) + + logger.info(f"Processor transformation completed and saved to {output_dir}") except Exception as e: logger.error(f"Error during processor transformation.") shutil.rmtree(output_dir) raise e finally: - os.environ.pop("DATA_OPTIMIZER_NUM_WORKERS", None) self.clean_tmpdir() def set_task( diff --git a/pyhealth/utils.py b/pyhealth/utils.py index 2a57a0f9..b4af8980 100644 --- a/pyhealth/utils.py +++ b/pyhealth/utils.py @@ -2,6 +2,7 @@ import os import pickle import random +import contextlib import numpy as np import torch @@ -43,3 +44,25 @@ def save_json(data, filename): with open(filename, "w") as f: json.dump(data, f) +@contextlib.contextmanager +def set_env(**environ): + """ + Temporarily set the process environment variables. + + >>> with set_env(PLUGINS_DIR='test/plugins'): + ... "PLUGINS_DIR" in os.environ + True + + >>> "PLUGINS_DIR" in os.environ + False + + :type environ: dict[str, unicode] + :param environ: Environment variables to set + """ + old_environ = dict(os.environ) + os.environ.update(environ) + try: + yield + finally: + os.environ.clear() + os.environ.update(old_environ) \ No newline at end of file From 078707d669e16384711b61d847d1e70557e6eb63 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 08:04:58 -0500 Subject: [PATCH 25/35] Fixup --- pyhealth/datasets/base_dataset.py | 66 +++++++++++++++---------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 60ca2610..5b282c72 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -282,42 +282,42 @@ def put(self, x): BATCH_SIZE = 128 worker_id, task_df, start_idx, end_idx, output_dir = args - os.environ["DATA_OPTIMIZER_GLOBAL_RANK"] = str(worker_id) logger.info(f"Worker {worker_id} started processing {end_idx - start_idx} samples. ({start_idx} to {end_idx})") - writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") - progress = _proc_transform_queue or _FakeQueue() + with set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)): + writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") + progress = _proc_transform_queue or _FakeQueue() - dataset = litdata.StreamingDataset(str(task_df)) - complete = 0 - with open(f"{output_dir}/schema.pkl", "rb") as f: - metadata = pickle.load(f) + dataset = litdata.StreamingDataset(str(task_df)) + complete = 0 + with open(f"{output_dir}/schema.pkl", "rb") as f: + metadata = pickle.load(f) + + input_processors = metadata["input_processors"] + output_processors = metadata["output_processors"] + + write_index = 0 + for i in range(start_idx, end_idx): + transformed: Dict[str, Any] = {} + for key, value in pickle.loads(dataset[i]["sample"]).items(): + if key in input_processors: + transformed[key] = input_processors[key].process(value) + elif key in output_processors: + transformed[key] = output_processors[key].process(value) + else: + transformed[key] = value + writer.add_item(write_index, transformed) + write_index += 1 + complete += 1 + + if complete >= BATCH_SIZE: + progress.put(complete) + complete = 0 + + if complete > 0: + progress.put(complete) + writer.done() - input_processors = metadata["input_processors"] - output_processors = metadata["output_processors"] - - write_index = 0 - for i in range(start_idx, end_idx): - transformed: Dict[str, Any] = {} - for key, value in pickle.loads(dataset[i]["sample"]).items(): - if key in input_processors: - transformed[key] = input_processors[key].process(value) - elif key in output_processors: - transformed[key] = output_processors[key].process(value) - else: - transformed[key] = value - writer.add_item(write_index, transformed) - write_index += 1 - complete += 1 - - if complete >= BATCH_SIZE: - progress.put(complete) - complete = 0 - - if complete > 0: - progress.put(complete) - writer.done() - logger.info(f"Worker {worker_id} finished processing samples.") @@ -803,7 +803,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> num_workers = min(num_workers, num_samples) # Avoid spawning empty workers try: - with set_env(POLARS_MAX_THREADS=str(num_workers)): + with set_env(DATA_OPTIMIZER_NUM_WORKERS=str(num_workers)): if num_workers == 1: logger.info("Single worker mode, processing sequentially") _proc_transform_fn((0, task_df, 0, num_samples, output_dir)) From dc36543d4ea13bec5e3997e3974049b32f3329cd Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 09:25:54 -0500 Subject: [PATCH 26/35] correct result scope --- pyhealth/datasets/base_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 5b282c72..2db395db 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -782,7 +782,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> # remaining items while not queue.empty(): progress.update(queue.get()) - result.get() # ensure exceptions are raised + result.get() # ensure exceptions are raised BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) logger.info(f"Task transformation completed and saved to {output_dir}") @@ -830,7 +830,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> # remaining items while not queue.empty(): progress.update(queue.get()) - result.get() # ensure exceptions are raised + result.get() # ensure exceptions are raised BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) logger.info(f"Processor transformation completed and saved to {output_dir}") From 01199ba42091cca41f5399b5b8c4794150e073c7 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 09:28:00 -0500 Subject: [PATCH 27/35] fix litdata A newer version of litdata is available (0.2.59) --- pyhealth/datasets/base_dataset.py | 3 ++- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 2db395db..ac63f4a6 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -43,7 +43,8 @@ # Set logging level for distributed to ERROR to reduce verbosity logging.getLogger("distributed").setLevel(logging.ERROR) logger = logging.getLogger(__name__) - +# Remove LitData version check to avoid unnecessary warnings +os.environ["LITDATA_DISABLE_VERSION_CHECK"] = "1" def is_url(path: str) -> bool: """URL detection.""" diff --git a/pyproject.toml b/pyproject.toml index 06157b83..f91b712a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "pandarallel~=1.6.5", "pydantic~=2.11.7", "dask[complete]~=2025.11.0", - "litdata~=0.2.58", + "litdata~=0.2.59", "pyarrow~=22.0.0", "narwhals~=2.13.0", "more-itertools~=10.8.0", From 95d8b470a237c05910036bb2835eb3823e8530ca Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 09:31:50 -0500 Subject: [PATCH 28/35] Fix incorrect tmpdir cleanup --- pyhealth/datasets/base_dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index ac63f4a6..9c71be90 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -491,8 +491,9 @@ def _scan_csv_tsv_gz( ) return df.replace("", pd.NA) # Replace empty strings with NaN - def _event_transform(self, df: dd.DataFrame, output_dir: Path) -> None: + def _event_transform(self, output_dir: Path) -> None: try: + df = self.load_data() with DaskCluster( n_workers=self.num_workers, threads_per_worker=1, @@ -537,7 +538,7 @@ def global_event_df(self) -> pl.LazyFrame: if self._global_event_df is None: ret_path = self.cache_dir / "global_event_df.parquet" if not ret_path.exists(): - self._event_transform(self.load_data(), ret_path) + self._event_transform(ret_path) self._global_event_df = ret_path return pl.scan_parquet( From 8b33eed5d41791063f75ef4b952de8c18f847403 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 09:37:14 -0500 Subject: [PATCH 29/35] Add TODO --- pyhealth/datasets/base_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 9c71be90..60652177 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -777,6 +777,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> with ctx.Pool(processes=num_workers, initializer=_task_transform_init, initargs=(queue,)) as pool: result = pool.map_async(_task_transform_fn, args_list) # type: ignore with tqdm(total=len(patient_ids)) as progress: + # TODO: this is busy-waiting, can we do better? while not result.ready(): while not queue.empty(): progress.update(queue.get()) From d6f0cf5594e2b6dd029753b448d1284faca2b711 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 11:58:42 -0500 Subject: [PATCH 30/35] Clear unused code --- pyhealth/datasets/base_dataset.py | 81 ------------------------------- 1 file changed, 81 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 60652177..fedc455f 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -115,87 +115,6 @@ def _csv_tsv_gz_path(path: str) -> str: raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") - -def _uncollate(x: list[Any]) -> Any: - return x[0] if isinstance(x, list) and len(x) == 1 else x - - -class _ParquetWriter: - """ - Stream-write rows into a Parquet file in chunked (row-group) fashion. - - Usage: - writer = StreamingParquetWriter(Path("out.parquet"), schema, chunk_size=10000) - writer.append({"id": 1, "val": 3.14}) - writer.append({"id": 2, "val": 1.23}) - writer.close() - """ - - def __init__(self, path: Path | str, schema: pa.Schema, chunk_size: int = 8_192): - """ - Args: - path: output Parquet file path - schema: pyarrow.Schema (required) - chunk_size: flush buffer every N rows - """ - self.path = Path(path) - self.schema = schema - self.chunk_size = chunk_size - - if self.schema is None: - raise ValueError( - "schema must be provided — no automatic inference allowed." - ) - - self._writer: pq.ParquetWriter | None = None - self._buffer: list[dict] = [] - self._closed = False - - # -------------------------------------------------------------- - # Public API - # -------------------------------------------------------------- - def append(self, row: dict) -> None: - """Append a single row (a Python dict).""" - if self._closed: - raise RuntimeError("Cannot append to a closed StreamingParquetWriter") - - self._buffer.append(row) - if len(self._buffer) >= self.chunk_size: - self.flush() - - def flush(self) -> None: - """Flush buffered rows into a Parquet row-group.""" - if not self._buffer: - return - - # Convert list[dict] → Arrow RecordBatch - batch = pa.RecordBatch.from_pylist(self._buffer, schema=self.schema) - - # Lazy-initialize writer - if self._writer is None: - self._writer = pq.ParquetWriter(self.path, self.schema) - - self._writer.write_batch(batch) - self._buffer.clear() - - def close(self) -> None: - """Flush and close the Parquet writer.""" - if self._closed: - return - self.flush() - if self._writer is not None: - self._writer.close() - self._closed = True - - # -------------------------------------------------------------- - # Context manager support - # -------------------------------------------------------------- - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - self.close() - _task_transform_queue: multiprocessing.queues.Queue | None = None def _task_transform_init(queue: multiprocessing.queues.Queue) -> None: From 043905947c824935728ed9c7fb50da635bc71115 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 11:59:29 -0500 Subject: [PATCH 31/35] rename queue to progress to better reflect it's usage. --- pyhealth/datasets/base_dataset.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index fedc455f..45812cb5 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -115,7 +115,7 @@ def _csv_tsv_gz_path(path: str) -> str: raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") -_task_transform_queue: multiprocessing.queues.Queue | None = None +_task_transform_progress: multiprocessing.queues.Queue | None = None def _task_transform_init(queue: multiprocessing.queues.Queue) -> None: """ @@ -124,8 +124,8 @@ def _task_transform_init(queue: multiprocessing.queues.Queue) -> None: Args: queue (multiprocessing.queues.Queue): The queue for progress tracking. """ - global _task_transform_queue - _task_transform_queue = queue + global _task_transform_progress + _task_transform_progress = queue def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, Path]) -> None: """ @@ -149,7 +149,7 @@ def put(self, x): with set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)): writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") - progress = _task_transform_queue or _FakeQueue() + progress = _task_transform_progress or _FakeQueue() write_index = 0 batches = itertools.batched(patient_ids, BATCH_SIZE) @@ -172,7 +172,7 @@ def put(self, x): logger.info(f"Worker {worker_id} finished processing patients.") -_proc_transform_queue: multiprocessing.queues.Queue | None = None +_proc_transform_progress: multiprocessing.queues.Queue | None = None def _proc_transform_init(queue: multiprocessing.queues.Queue) -> None: """ @@ -181,8 +181,8 @@ def _proc_transform_init(queue: multiprocessing.queues.Queue) -> None: Args: queue (multiprocessing.queues.Queue): The queue for progress tracking. """ - global _proc_transform_queue - _proc_transform_queue = queue + global _proc_transform_progress + _proc_transform_progress = queue def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None: """ @@ -206,7 +206,7 @@ def put(self, x): with set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)): writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") - progress = _proc_transform_queue or _FakeQueue() + progress = _proc_transform_progress or _FakeQueue() dataset = litdata.StreamingDataset(str(task_df)) complete = 0 From 5ba4c1dcf1e52766fe02d7a36c57e1363044a87e Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 12:07:46 -0500 Subject: [PATCH 32/35] Fix tqdm for single worker --- pyhealth/datasets/base_dataset.py | 39 ++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 45812cb5..d7cc1d03 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -115,6 +115,24 @@ def _csv_tsv_gz_path(path: str) -> str: raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") +class _TqdmQueue: + def __init__(self, total: int, **kwargs) -> None: + self.total = total + self.kwargs = kwargs + self.progress = None + + def __enter__(self): + self.progress = tqdm(total=self.total, **self.kwargs) + return self + + def __exit__(self, *args, **kwargs): + if self.progress: + self.progress.close() + + def put(self, n: int) -> None: + if self.progress: + self.progress.update(n) + _task_transform_progress: multiprocessing.queues.Queue | None = None def _task_transform_init(queue: multiprocessing.queues.Queue) -> None: @@ -139,17 +157,14 @@ def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, P global_event_df (pl.LazyFrame): The global event dataframe. output_dir (Path): The output directory to save results. """ - class _FakeQueue: - def put(self, x): - pass - BATCH_SIZE = 128 # Use a batch size 128 can reduce runtime by 30%. worker_id, task, patient_ids, global_event_df, output_dir = args - logger.info(f"Worker {worker_id} started processing {len(list(patient_ids))} patients. (Polars threads: {pl.thread_pool_size()})") + total_patients = len(list(patient_ids)) + logger.info(f"Worker {worker_id} started processing {total_patients} patients. (Polars threads: {pl.thread_pool_size()})") with set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)): writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") - progress = _task_transform_progress or _FakeQueue() + progress = _task_transform_progress or _TqdmQueue(total=total_patients) write_index = 0 batches = itertools.batched(patient_ids, BATCH_SIZE) @@ -171,7 +186,7 @@ def put(self, x): writer.done() logger.info(f"Worker {worker_id} finished processing patients.") - + _proc_transform_progress: multiprocessing.queues.Queue | None = None def _proc_transform_init(queue: multiprocessing.queues.Queue) -> None: @@ -196,17 +211,15 @@ def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None: end_idx (int): The end index of samples to process. output_dir (Path): The output directory to save results. """ - class _FakeQueue: - def put(self, x): - pass - BATCH_SIZE = 128 worker_id, task_df, start_idx, end_idx, output_dir = args - logger.info(f"Worker {worker_id} started processing {end_idx - start_idx} samples. ({start_idx} to {end_idx})") + total_samples = end_idx - start_idx + logger.info(f"Worker {worker_id} started processing {total_samples} samples. ({start_idx} to {end_idx})") with set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)): writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") - progress = _proc_transform_progress or _FakeQueue() + # Use TqdmQueue for single worker to show progress bar + progress = _proc_transform_progress or _TqdmQueue(total=total_samples) dataset = litdata.StreamingDataset(str(task_df)) complete = 0 From acd6308d08728d735b59b7f1a25e492f50eef269 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 12:09:17 -0500 Subject: [PATCH 33/35] Fix busy waiting --- pyhealth/datasets/base_dataset.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index d7cc1d03..44cdc064 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -709,10 +709,11 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> with ctx.Pool(processes=num_workers, initializer=_task_transform_init, initargs=(queue,)) as pool: result = pool.map_async(_task_transform_fn, args_list) # type: ignore with tqdm(total=len(patient_ids)) as progress: - # TODO: this is busy-waiting, can we do better? while not result.ready(): - while not queue.empty(): - progress.update(queue.get()) + try: + progress.update(queue.get(timeout=1)) + except: + pass # remaining items while not queue.empty(): @@ -759,8 +760,10 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> result = pool.map_async(_proc_transform_fn, args_list) # type: ignore with tqdm(total=num_samples) as progress: while not result.ready(): - while not queue.empty(): - progress.update(queue.get()) + try: + progress.update(queue.get(timeout=1)) + except: + pass # remaining items while not queue.empty(): From 71e3f54ce0d0d444b8f036cec50fc0866cd21f25 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 12:17:23 -0500 Subject: [PATCH 34/35] delete outdated test --- tests/core/test_streaming_parquet_writer.py | 55 --------------------- 1 file changed, 55 deletions(-) delete mode 100644 tests/core/test_streaming_parquet_writer.py diff --git a/tests/core/test_streaming_parquet_writer.py b/tests/core/test_streaming_parquet_writer.py deleted file mode 100644 index c7c212bd..00000000 --- a/tests/core/test_streaming_parquet_writer.py +++ /dev/null @@ -1,55 +0,0 @@ -import tempfile -from pathlib import Path -import unittest - -import pyarrow as pa -import pyarrow.parquet as pq - -from pyhealth.datasets.base_dataset import _ParquetWriter -from tests.base import BaseTestCase - - -class TestStreamingParquetWriter(BaseTestCase): - def setUp(self): - self.tmpdir = tempfile.TemporaryDirectory() - self.schema = pa.schema( - [ - ("id", pa.int64()), - ("value", pa.string()), - ] - ) - self.output_path = Path(self.tmpdir.name) / "stream.parquet" - - def tearDown(self): - self.tmpdir.cleanup() - - def test_append_flush_close_and_context_manager(self): - rows = [ - {"id": 1, "value": "a"}, - {"id": 2, "value": "b"}, - {"id": 3, "value": "c"}, - {"id": 4, "value": "d"}, - ] - - with _ParquetWriter( - self.output_path, self.schema, chunk_size=2 - ) as writer: - # First two appends trigger an automatic flush due to chunk_size=2. - writer.append(rows[0]) - writer.append(rows[1]) - - # Flush again after adding a third row to ensure flushing appends - # rather than overwriting previous row groups. - writer.append(rows[2]) - writer.flush() - - # Leave data in the buffer to verify close() flushes it. - writer.append(rows[3]) - - # Context manager should have closed and flushed remaining buffered rows. - self.assertTrue(self.output_path.exists()) - - written_rows = pq.read_table(self.output_path).to_pylist() - - # Every append should be present as a distinct row in order. - self.assertEqual(written_rows, rows) From 8e180672a852e03d756d8d537bccfc069eebdfbe Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Thu, 25 Dec 2025 13:17:30 -0500 Subject: [PATCH 35/35] Fix signle thread context --- pyhealth/datasets/base_dataset.py | 49 +++++++++++++++++++------------ 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 44cdc064..aca46ab9 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -115,23 +115,33 @@ def _csv_tsv_gz_path(path: str) -> str: raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") -class _TqdmQueue: - def __init__(self, total: int, **kwargs) -> None: +class _ProgressContext: + def __init__(self, queue: multiprocessing.queues.Queue | None, total: int, **kwargs): + """ + :param queue: An existing queue (e.g., from multiprocessing). If provided, + this class acts as a passthrough. + :param total: Total items for the progress bar (only used if queue is None). + :param kwargs: Extra arguments for tqdm (e.g., desc="Processing"). + """ + self.queue = queue self.total = total self.kwargs = kwargs self.progress = None - + + def put(self, n): + if self.progress: + self.progress.update(n) + def __enter__(self): + if self.queue: + return self.queue + self.progress = tqdm(total=self.total, **self.kwargs) return self - - def __exit__(self, *args, **kwargs): + + def __exit__(self, exc_type, exc_val, exc_tb): if self.progress: self.progress.close() - - def put(self, n: int) -> None: - if self.progress: - self.progress.update(n) _task_transform_progress: multiprocessing.queues.Queue | None = None @@ -161,11 +171,13 @@ def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, P worker_id, task, patient_ids, global_event_df, output_dir = args total_patients = len(list(patient_ids)) logger.info(f"Worker {worker_id} started processing {total_patients} patients. (Polars threads: {pl.thread_pool_size()})") - - with set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)): + + with ( + set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)), + _ProgressContext(_task_transform_progress, total=total_patients) as progress + ): writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") - progress = _task_transform_progress or _TqdmQueue(total=total_patients) - + write_index = 0 batches = itertools.batched(patient_ids, BATCH_SIZE) for batch in batches: @@ -215,12 +227,13 @@ def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None: worker_id, task_df, start_idx, end_idx, output_dir = args total_samples = end_idx - start_idx logger.info(f"Worker {worker_id} started processing {total_samples} samples. ({start_idx} to {end_idx})") - - with set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)): - writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") - # Use TqdmQueue for single worker to show progress bar - progress = _proc_transform_progress or _TqdmQueue(total=total_samples) + with ( + set_env(DATA_OPTIMIZER_GLOBAL_RANK=str(worker_id)), + _ProgressContext(_proc_transform_progress, total=total_samples) as progress + ): + writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") + dataset = litdata.StreamingDataset(str(task_df)) complete = 0 with open(f"{output_dir}/schema.pkl", "rb") as f: