diff --git a/dlio_benchmark/utils/config.py b/dlio_benchmark/utils/config.py index 2f8bb8fb..ceae850a 100644 --- a/dlio_benchmark/utils/config.py +++ b/dlio_benchmark/utils/config.py @@ -182,6 +182,12 @@ class ConfigArguments: record_element_bytes: int = 4 record_element_dtype: ClassVar[np.dtype] = np.dtype("uint8") + ## dataset: parquet-only + parquet_columns: ClassVar[List[Dict[str, Any]]] = [] + parquet_row_group_size: int = 1024 + parquet_partition_by: Optional[str] = None + parquet_generation_batch_size: int = 0 + ## dataset: hdf5-only num_dset_per_record: int = 1 chunk_dims: ClassVar[List[int]] = [] @@ -1123,6 +1129,19 @@ def LoadConfig(args, config): if 'record_dims' in config['dataset']: args.record_dims = list(config['dataset']['record_dims']) + # parquet only config + if 'parquet' in config['dataset']: + pq_cfg = config['dataset']['parquet'] + if 'columns' in pq_cfg: + cols = pq_cfg['columns'] + args.parquet_columns = [dict(c) if hasattr(c, 'items') else c for c in cols] + if 'row_group_size' in pq_cfg: + args.parquet_row_group_size = int(pq_cfg['row_group_size']) + if 'partition_by' in pq_cfg: + args.parquet_partition_by = str(pq_cfg['partition_by']) + if 'generation_batch_size' in pq_cfg: + args.parquet_generation_batch_size = int(pq_cfg['generation_batch_size']) + # hdf5 only config if 'hdf5' in config['dataset']: if 'chunk_dims' in config['dataset']['hdf5']: diff --git a/dlio_benchmark/utils/statscounter.py b/dlio_benchmark/utils/statscounter.py index e085541a..7caef6f8 100644 --- a/dlio_benchmark/utils/statscounter.py +++ b/dlio_benchmark/utils/statscounter.py @@ -50,7 +50,6 @@ def __init__(self): self.my_rank = self.args.my_rank self.comm_size = self.args.comm_size self.output_folder = self.args.output_folder - self.record_size = self.args.record_length self.batch_size = self.args.batch_size self.batch_size_eval = self.args.batch_size_eval self.checkpoint_size = 0.0 @@ -121,7 +120,7 @@ def __init__(self): self.eval_au = [] self.train_throughput = [] self.eval_throughput = [] - data_per_node = self.MPI.npernode()*self.args.num_samples_per_file * self.args.num_files_train//self.MPI.size()*self.args.record_length + data_per_node = self.MPI.npernode()*self.args.num_samples_per_file * self.args.num_files_train//self.MPI.size()*self.record_size self.summary['data_size_per_host_GB'] = data_per_node/1024./1024./1024. if self.MPI.rank() == 0 and self.args.do_train: self.logger.info(f"Total amount of data each host will consume is {data_per_node/1024./1024./1024} GiB; each host has {self.summary['host_memory_GB']} GiB memory") @@ -137,6 +136,27 @@ def __init__(self): potential_caching.append(1) self.summary['potential_caching'] = potential_caching + @property + def record_size(self): + """Return the effective per-sample size in bytes. + + Uses parquet column specs when available, otherwise falls back to + the legacy record_length field. + """ + parquet_cols = getattr(self.args, 'parquet_columns', []) + if parquet_cols: + _DTYPE_BYTES = { + 'float64': 8, 'int64': 8, 'uint64': 8, + 'float32': 4, 'int32': 4, 'uint32': 4, + 'float16': 2, 'int16': 2, 'uint16': 2, + 'uint8': 1, 'int8': 1, 'bool': 1, + } + return sum( + int(c.get('size', 1)) * _DTYPE_BYTES.get(c.get('dtype', 'float32'), 4) + for c in parquet_cols + ) + return self.args.record_length + def start_run(self): self.start_run_timestamp = time() def end_run(self):