Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 98 additions & 30 deletions dlio_benchmark/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,75 @@
from typing import Optional, Dict

dlp = Profile(MODULE_CONFIG)


class VirtualIndexMap:
"""Memory-efficient sample index map that computes file mappings on demand.

Instead of materializing a Python dict with billions of entries (each ~200
bytes), this class stores only:
- A shuffled permutation array (numpy int64, ~8 bytes/sample)
- The file list reference (small)
- num_samples_per_file (scalar)

For the DLRM workload with 1.74 billion samples this reduces memory from
~350 GB (materialized dict) to ~14 GB (permutation array only).

Provides dict-like __getitem__, __contains__, items() interface for
drop-in compatibility with the existing code paths in reader_handler.py
and indexed_binary_*_reader.py.
"""

def __init__(self, file_list, num_samples_per_file, start_sample, end_sample,
shuffle_seed=None, storage_type=None):
self._num_samples_per_file = num_samples_per_file
self._start = start_sample

# Build the permutation array — this is the only large allocation
self._sample_list = np.arange(start_sample, end_sample + 1)
if shuffle_seed is not None:
np.random.seed(shuffle_seed)
np.random.shuffle(self._sample_list)

# Pre-resolve absolute paths once (only num_files entries)
if storage_type == StorageType.LOCAL_FS:
self._abs_paths = [os.path.abspath(f) for f in file_list]
else:
self._abs_paths = list(file_list)

def _resolve(self, global_sample_index):
"""Compute (filename, sample_index) from a global sample index."""
file_index = int(global_sample_index // self._num_samples_per_file)
sample_index = int(global_sample_index % self._num_samples_per_file)
return (self._abs_paths[file_index], sample_index)

def __getitem__(self, global_sample_index):
return self._resolve(global_sample_index)

def __contains__(self, key):
return self._start <= key < self._start + len(self._sample_list)

def __len__(self):
return len(self._sample_list)

def __iter__(self):
return iter(self._sample_list)

def items(self):
"""Yield (global_sample_index, (filename, sample_index)) pairs.

Used by indexed_binary_reader and indexed_binary_mmap_reader to
pre-load index files. Computes mappings on-the-fly.
"""
for idx in self._sample_list:
yield int(idx), self._resolve(int(idx))

def __repr__(self):
return (f"VirtualIndexMap(samples={len(self._sample_list)}, "
f"files={len(self._abs_paths)}, "
f"samples_per_file={self._num_samples_per_file})")


@dataclass
class ConfigArguments:
__instance = None
Expand Down Expand Up @@ -703,37 +772,36 @@ def build_sample_map_iter(self, file_list, total_samples, epoch_number):

@dlp.log
def get_global_map_index(self, file_list, total_samples, epoch_number):
process_thread_file_map = {}
num_files = len(file_list)
start_sample = 0
end_sample = 0
samples_sum = 0
if num_files > 0:
end_sample = total_samples - 1
samples_per_proc = int(math.ceil(total_samples/self.comm_size))
start_sample = self.my_rank * samples_per_proc
end_sample = (self.my_rank + 1) * samples_per_proc - 1
if end_sample > total_samples - 1:
end_sample = total_samples - 1
self.logger.debug(f"my_rank: {self.my_rank}, start_sample: {start_sample}, end_sample: {end_sample}")
sample_list = np.arange(start_sample, end_sample + 1)
if self.sample_shuffle is not Shuffle.OFF:
if self.seed_change_epoch:
np.random.seed(self.seed + epoch_number)
else:
np.random.seed(self.seed)
np.random.shuffle(sample_list)
for sample_index in range(end_sample - start_sample + 1):
global_sample_index = sample_list[sample_index]
samples_sum += global_sample_index
file_index = int(math.floor(global_sample_index/self.num_samples_per_file))
if self.storage_type == StorageType.LOCAL_FS:
abs_path = os.path.abspath(file_list[file_index])
else:
abs_path = file_list[file_index]
sample_index = global_sample_index % self.num_samples_per_file
process_thread_file_map[global_sample_index] = (abs_path, sample_index)
return process_thread_file_map, samples_sum
if num_files == 0:
return {}, 0

samples_per_proc = int(math.ceil(total_samples / self.comm_size))
start_sample = self.my_rank * samples_per_proc
end_sample = min((self.my_rank + 1) * samples_per_proc - 1, total_samples - 1)
self.logger.debug(f"my_rank: {self.my_rank}, start_sample: {start_sample}, end_sample: {end_sample}")

# Determine shuffle seed (None = no shuffle)
shuffle_seed = None
if self.sample_shuffle is not Shuffle.OFF:
shuffle_seed = (self.seed + epoch_number) if self.seed_change_epoch else self.seed

vmap = VirtualIndexMap(
file_list, self.num_samples_per_file,
start_sample, end_sample,
shuffle_seed=shuffle_seed,
storage_type=self.storage_type,
)

# Compute samples_sum using numpy to avoid Python loop over billions of elements
samples_sum = int(np.sum(vmap._sample_list, dtype=np.int64))

self.logger.info(
f"{utcnow()} VirtualIndexMap: {len(vmap)} samples, "
f"~{len(vmap) * 8 / 1e9:.1f} GB permutation array "
f"(saved ~{len(vmap) * 200 / 1e9:.0f} GB vs materialized dict)"
)
return vmap, samples_sum

@dlp.log
def reconfigure(self, epoch_number):
Expand Down
Loading
Loading