diff --git a/superbench/benchmarks/model_benchmarks/megatron_gpt3.py b/superbench/benchmarks/model_benchmarks/megatron_gpt3.py index 37d27bf1a..7541aee4b 100644 --- a/superbench/benchmarks/model_benchmarks/megatron_gpt3.py +++ b/superbench/benchmarks/model_benchmarks/megatron_gpt3.py @@ -651,13 +651,40 @@ def _generate_dataset(self): if self._args.dataset_url: self._raw_data_path = str(Path(self._args.data_home) / 'data.json') download_file(self._args.dataset_url, self._raw_data_path) + + # Megatron's preprocess_data.py appends '_text_document' to --output-prefix + # when producing the .bin/.idx files. Derive the output-prefix from + # data_prefix (stripping the '_text_document' suffix when present) so that + # the generated files match the existence checks above for any custom + # data_prefix value. + output_prefix_basename = self._args.data_prefix + if output_prefix_basename.endswith('_text_document'): + stripped = output_prefix_basename[:-len('_text_document')] + # Guard against data_prefix == '_text_document' which would + # leave an empty basename and produce a malformed --output-prefix + # ending in '/'. Fall back to the original value in that case. + output_prefix_basename = stripped or output_prefix_basename + output_prefix = os.path.join(self._args.data_home, output_prefix_basename) + + # num_workers=0 is valid for DataLoader (main process loads data), + # but preprocess_data.py requires workers>=1 for multiprocessing.Pool. + preprocess_workers = max(1, self._args.num_workers) + if preprocess_workers != self._args.num_workers: + logger.warning( + 'preprocess_data.py requires --workers >= 1; ' + 'overriding num_workers={} to {} for dataset preprocessing only ' + '(DataLoader still uses num_workers={}).'.format( + self._args.num_workers, preprocess_workers, self._args.num_workers + ) + ) + command = ( 'python3 ' f'{os.path.join(self._args.code_base, "tools/preprocess_data.py")} ' f'--input {self._raw_data_path} ' f'--tokenizer-type {self._args.tokenizer_type} ' - f'--output-prefix {os.path.join(self._args.data_home, "dataset")} ' - f'--workers {str(self._args.num_workers)} ' + f'--output-prefix {output_prefix} ' + f'--workers {preprocess_workers} ' f'--vocab-file {self._vocab_path} ' f'--merge-file {self._merges_path}' ) diff --git a/tests/benchmarks/model_benchmarks/test_megatron_gpt.py b/tests/benchmarks/model_benchmarks/test_megatron_gpt.py index b7c588677..e856a8010 100644 --- a/tests/benchmarks/model_benchmarks/test_megatron_gpt.py +++ b/tests/benchmarks/model_benchmarks/test_megatron_gpt.py @@ -174,6 +174,90 @@ def test_megatron_gpt_dataset(self): ret = benchmark._generate_dataset() assert (ret is True) + @mock.patch('superbench.benchmarks.model_benchmarks.megatron_gpt3.run_command') + @mock.patch('superbench.benchmarks.model_benchmarks.megatron_gpt3.download_file') + def test_megatron_gpt_dataset_generate_command(self, mock_download_file, mock_run_command): + """Verify _generate_dataset clamps --workers to >=1 and derives --output-prefix from data_prefix.""" + (benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA) + assert (benchmark_cls) + os.environ['OMPI_COMM_WORLD_SIZE'] = '1' + os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'] = '1' + os.environ['OMPI_COMM_WORLD_RANK'] = '0' + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12345' + + # Use a real, valid code_base so _preprocess() can validate it (avoid hardcoded /root path). + self.createMockFiles(['pretrain_gpt.py']) + + # Helper: make run_command's side_effect create the expected .bin/.idx files + # so _generate_dataset() (invoked from within _preprocess()) succeeds. + created_files = [] + + def _make_dataset_files(prefix): + def _side_effect(*_args, **_kwargs): + for ext in ('.bin', '.idx'): + p = Path(self._tmp_dir) / f'{prefix}{ext}' + p.touch() + created_files.append(p) + return _side_effect + + self.addCleanup(lambda: [p.unlink() for p in created_files if p.is_file()]) + + def _run_case(extra_params, expected_workers, expected_prefix_basename, expected_data_prefix): + mock_run_command.reset_mock() + mock_run_command.side_effect = _make_dataset_files(expected_data_prefix) + benchmark = benchmark_cls( + self.benchmark_name, + parameters=( + f'--code_base {self._tmp_dir} --data_home {self._tmp_dir} ' + f'--batch_size 2048 --dataset_url http://example.com/data.json ' + f'{extra_params}' + ), + ) + assert benchmark._preprocess() is True + assert mock_run_command.call_count >= 1 + cmd = mock_run_command.call_args_list[0].args[0] + units = normalize_command(cmd) + assert f'--workers {expected_workers}' in units, units + expected_output_prefix = os.path.join(self._tmp_dir, expected_prefix_basename) + assert f'--output-prefix {expected_output_prefix}' in units, units + + # Case 1: num_workers=0 with default data_prefix should produce '--workers 1' (clamped) + # and '--output-prefix /dataset' (default 'dataset_text_document' suffix stripped). + _run_case( + extra_params='--num_workers 0', + expected_workers=1, + expected_prefix_basename='dataset', + expected_data_prefix='dataset_text_document', + ) + + # Case 2: num_workers=4 with custom data_prefix='custom_text_document' should produce + # '--workers 4' and '--output-prefix /custom'. + _run_case( + extra_params='--num_workers 4 --data_prefix custom_text_document', + expected_workers=4, + expected_prefix_basename='custom', + expected_data_prefix='custom_text_document', + ) + + # Case 3: data_prefix without the '_text_document' suffix is used as-is. + _run_case( + extra_params='--num_workers 2 --data_prefix mydata', + expected_workers=2, + expected_prefix_basename='mydata', + expected_data_prefix='mydata', + ) + + # Case 4: edge case - data_prefix == '_text_document' should NOT strip down + # to an empty basename (which would produce '--output-prefix /'). + # Fall back to using '_text_document' as the basename. + _run_case( + extra_params='--num_workers 1 --data_prefix _text_document', + expected_workers=1, + expected_prefix_basename='_text_document', + expected_data_prefix='_text_document', + ) + @mock.patch('superbench.benchmarks.model_benchmarks.MegatronGPT._generate_dataset') def test_megatron_gpt_command(self, mock_generate_dataset): """Test command generation."""