Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions superbench/benchmarks/model_benchmarks/megatron_gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,13 +651,40 @@ def _generate_dataset(self):
if self._args.dataset_url:
self._raw_data_path = str(Path(self._args.data_home) / 'data.json')
download_file(self._args.dataset_url, self._raw_data_path)

# Megatron's preprocess_data.py appends '_text_document' to --output-prefix
# when producing the .bin/.idx files. Derive the output-prefix from
# data_prefix (stripping the '_text_document' suffix when present) so that
# the generated files match the existence checks above for any custom
# data_prefix value.
output_prefix_basename = self._args.data_prefix
if output_prefix_basename.endswith('_text_document'):
stripped = output_prefix_basename[:-len('_text_document')]
# Guard against data_prefix == '_text_document' which would
# leave an empty basename and produce a malformed --output-prefix
# ending in '/'. Fall back to the original value in that case.
output_prefix_basename = stripped or output_prefix_basename
Comment on lines +656 to +666
output_prefix = os.path.join(self._args.data_home, output_prefix_basename)

# num_workers=0 is valid for DataLoader (main process loads data),
# but preprocess_data.py requires workers>=1 for multiprocessing.Pool.
preprocess_workers = max(1, self._args.num_workers)
if preprocess_workers != self._args.num_workers:
logger.warning(
'preprocess_data.py requires --workers >= 1; '
'overriding num_workers={} to {} for dataset preprocessing only '
'(DataLoader still uses num_workers={}).'.format(
self._args.num_workers, preprocess_workers, self._args.num_workers
)
)

command = (
'python3 '
f'{os.path.join(self._args.code_base, "tools/preprocess_data.py")} '
f'--input {self._raw_data_path} '
f'--tokenizer-type {self._args.tokenizer_type} '
f'--output-prefix {os.path.join(self._args.data_home, "dataset")} '
f'--workers {str(self._args.num_workers)} '
f'--output-prefix {output_prefix} '
f'--workers {preprocess_workers} '
f'--vocab-file {self._vocab_path} '
Comment thread
polarG marked this conversation as resolved.
f'--merge-file {self._merges_path}'
)
Expand Down
84 changes: 84 additions & 0 deletions tests/benchmarks/model_benchmarks/test_megatron_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,90 @@ def test_megatron_gpt_dataset(self):
ret = benchmark._generate_dataset()
assert (ret is True)

@mock.patch('superbench.benchmarks.model_benchmarks.megatron_gpt3.run_command')
@mock.patch('superbench.benchmarks.model_benchmarks.megatron_gpt3.download_file')
def test_megatron_gpt_dataset_generate_command(self, mock_download_file, mock_run_command):
"""Verify _generate_dataset clamps --workers to >=1 and derives --output-prefix from data_prefix."""
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA)
assert (benchmark_cls)
os.environ['OMPI_COMM_WORLD_SIZE'] = '1'
os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'] = '1'
os.environ['OMPI_COMM_WORLD_RANK'] = '0'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'

# Use a real, valid code_base so _preprocess() can validate it (avoid hardcoded /root path).
self.createMockFiles(['pretrain_gpt.py'])

# Helper: make run_command's side_effect create the expected .bin/.idx files
# so _generate_dataset() (invoked from within _preprocess()) succeeds.
created_files = []

def _make_dataset_files(prefix):
def _side_effect(*_args, **_kwargs):
for ext in ('.bin', '.idx'):
p = Path(self._tmp_dir) / f'{prefix}{ext}'
p.touch()
created_files.append(p)
return _side_effect

self.addCleanup(lambda: [p.unlink() for p in created_files if p.is_file()])

def _run_case(extra_params, expected_workers, expected_prefix_basename, expected_data_prefix):
mock_run_command.reset_mock()
mock_run_command.side_effect = _make_dataset_files(expected_data_prefix)
benchmark = benchmark_cls(
self.benchmark_name,
parameters=(
f'--code_base {self._tmp_dir} --data_home {self._tmp_dir} '
f'--batch_size 2048 --dataset_url http://example.com/data.json '
f'{extra_params}'
),
)
assert benchmark._preprocess() is True
assert mock_run_command.call_count >= 1
cmd = mock_run_command.call_args_list[0].args[0]
units = normalize_command(cmd)
assert f'--workers {expected_workers}' in units, units
expected_output_prefix = os.path.join(self._tmp_dir, expected_prefix_basename)
assert f'--output-prefix {expected_output_prefix}' in units, units

# Case 1: num_workers=0 with default data_prefix should produce '--workers 1' (clamped)
# and '--output-prefix <data_home>/dataset' (default 'dataset_text_document' suffix stripped).
_run_case(
extra_params='--num_workers 0',
expected_workers=1,
expected_prefix_basename='dataset',
expected_data_prefix='dataset_text_document',
)

# Case 2: num_workers=4 with custom data_prefix='custom_text_document' should produce
# '--workers 4' and '--output-prefix <data_home>/custom'.
_run_case(
extra_params='--num_workers 4 --data_prefix custom_text_document',
expected_workers=4,
expected_prefix_basename='custom',
expected_data_prefix='custom_text_document',
)

# Case 3: data_prefix without the '_text_document' suffix is used as-is.
_run_case(
extra_params='--num_workers 2 --data_prefix mydata',
expected_workers=2,
expected_prefix_basename='mydata',
expected_data_prefix='mydata',
)

# Case 4: edge case - data_prefix == '_text_document' should NOT strip down
# to an empty basename (which would produce '--output-prefix <data_home>/').
# Fall back to using '_text_document' as the basename.
_run_case(
extra_params='--num_workers 1 --data_prefix _text_document',
expected_workers=1,
expected_prefix_basename='_text_document',
expected_data_prefix='_text_document',
)

Comment on lines +243 to +260
@mock.patch('superbench.benchmarks.model_benchmarks.MegatronGPT._generate_dataset')
def test_megatron_gpt_command(self, mock_generate_dataset):
"""Test command generation."""
Expand Down
Loading