From 04e6e41ff3d6e0bb8845b9596b4dc5c71a11296e Mon Sep 17 00:00:00 2001 From: mchochowski Date: Fri, 13 Feb 2026 02:54:21 -0800 Subject: [PATCH 1/2] gpt-oss 20b support Signed-off-by: mchochowski --- .../gptoss-20b.yaml | 110 ++++ .../gptoss-20b_remove_experts_memory.yaml | 22 + .../pruning/ffn_pruning.yaml | 21 + .../pruning/pruning_defaults.yaml | 34 ++ .../validate_model_defaults.yaml | 18 + .../validate_solutions_defaults.yaml | 11 + .../anymodel/converter/converter.py | 17 +- .../gpt_oss_20b/gpt_oss_20b_converter.py | 1 + .../gpt_oss_20b_model_descriptor.py | 25 +- .../gpt_oss_20b/gpt_oss_pruned_to_mxfp4.py | 506 ++++++++++++++++++ 10 files changed, 759 insertions(+), 6 deletions(-) create mode 100644 examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml create mode 100644 examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml create mode 100644 examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml create mode 100644 examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml create mode 100644 examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml create mode 100644 examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml create mode 100644 modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_pruned_to_mxfp4.py diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml new file mode 100644 index 000000000..7de281e78 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml @@ -0,0 +1,110 @@ +defaults: + - pruning: ffn_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: llama +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +build_replacement_library: + add_ffn_no_ops: true + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + # puzzle_profile: + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_memory: 45_000 + num_params: 3_000_000_000 + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false # To enable validation of the model solution set `skip_validation` as False + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} + diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml new file mode 100644 index 000000000..979803939 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b_remove_experts_memory.yaml @@ -0,0 +1,22 @@ +defaults: + - gptoss-20b + - _self_ + +# Input Hugging Face model to compress +input_hf_model_path: /workspace/hf_models/openai/gpt-oss-20b + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for compression outputs +puzzle_dir: /workspace/puzzle_dir + +# MIP memory constraint (in MiB) +mip: + human_constraints: + target_memory: 45_000 # 45 GiB + +# FFN intermediate sizes to search over (heterogeneous architecture) +# teacher_intermediate_size is 8192, so we use proportionally smaller values +pruning: + intermediate_size_list: [2048, 4096, 6144] diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml new file mode 100644 index 000000000..e9e15db32 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/ffn_pruning.yaml @@ -0,0 +1,21 @@ +defaults: + - pruning_defaults + +eval_samples: 2500 #10 +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.gpt_oss_20b.gpt_oss_20b_model_descriptor.GptOss20bExpertRemovalLayerDescriptor + target_name: "mlp.router" + +hook_class: ${get_object:utils.activation_hooks.hooks.RankedChoiceVotingHook} +activation_hooks_kwargs: # Additional kwargs to pass to the hook init + +num_experts_to_keep_list: [24, 16, 8] # num_experts in teacher is 128 +mlp_init_mode: "ExpertRemoval" +mlp_init_config_yaml: + expert_scores_key: "expert_ranks" + layer_prefix_template: "model.layers.{layer_idx}.mlp.router" + diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml new file mode 100644 index 000000000..cec781465 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/pruning/pruning_defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 10_000 +micro_batch_size: 1 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_outpt_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} \ No newline at end of file diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml new file mode 100644 index 000000000..b80faea5f --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml @@ -0,0 +1,18 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} + diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml new file mode 100644 index 000000000..ab8c89218 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml @@ -0,0 +1,11 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false + diff --git a/modelopt/torch/puzzletron/anymodel/converter/converter.py b/modelopt/torch/puzzletron/anymodel/converter/converter.py index 67ed74ed9..e241e72b6 100644 --- a/modelopt/torch/puzzletron/anymodel/converter/converter.py +++ b/modelopt/torch/puzzletron/anymodel/converter/converter.py @@ -27,6 +27,7 @@ from safetensors.torch import load_file, save_file from tqdm import tqdm from transformers import PretrainedConfig +from transformers.integrations.mxfp4 import convert_moe_packed_tensors from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig @@ -61,8 +62,9 @@ def _get_weight_map(input_dir: Path) -> Dict[str, str]: f"Neither {index_path} nor {single_file_path} found. Cannot determine model format." ) - @staticmethod + @classmethod def convert_model_weights( + cls, input_dir: Path, output_dir: Path, descriptor: ModelDescriptor, num_hidden_layers: int ): """Convert model weights to subblock format.""" @@ -95,7 +97,18 @@ def convert_model_weights( data = load_file(os.path.join(input_dir, file)) for name in param_names: if param_to_file[name] == file and name in data: - tensors[name] = data[name] + converted_name = cls.convert_weight_name(name) + # Convert MoE packed tensors if quantized is mxfp4 //gpt-oss-20b + if getattr(cls, 'quantized', None) == 'mxfp4': + if name.endswith("_blocks"): + converted_name = converted_name.replace("_blocks", "") + tensors[converted_name] = convert_moe_packed_tensors(data[converted_name+"_blocks"], data[converted_name+"_scales"]) + elif name.endswith("_scales"): + continue + else: + tensors[converted_name] = data[name] + else: + tensors[converted_name] = data[name] # Save this subblock print(f"\n✅ Group: {subblock} ({len(tensors)} layers)") diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_20b_converter.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_20b_converter.py index b7e83dcec..d35c004c1 100644 --- a/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_20b_converter.py +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_20b_converter.py @@ -36,6 +36,7 @@ class GptOss20bConverter(Converter): GPT-OSS-20B is a pure MoE model with 32 experts per layer and 4 active experts. All layers use MoE FFN (no standard dense FFN layers). """ + quantized = 'mxfp4' @staticmethod def create_block_configs_from_main_config(config: PretrainedConfig) -> List[BlockConfig]: diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_20b_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_20b_model_descriptor.py index fd5edc063..644da802c 100644 --- a/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_20b_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_20b_model_descriptor.py @@ -50,6 +50,13 @@ class GptOss20bModelDescriptor(ModelDescriptor): _DECODER_LAYER_CLS: Type[nn.Module] = None + @classmethod + def create_dummy_block(cls, original_layer: GptOssDecoderLayer, block_index: int) -> nn.Module: + dummy_block = DummyBlock(block_index=block_index) + # Required by `GptOssModel.forward`. + dummy_block.attention_type = original_layer.attention_type + return dummy_block + @staticmethod def decoder_layer_cls(): """Get the decoder layer class for GPT-OSS models. @@ -132,7 +139,7 @@ def build_ffn_predicates() -> Dict[str, re.Pattern]: r"(post_attention_layernorm\.weight" r"|mlp\.router\.weight" r"|mlp\.router\.bias" - r"|mlp\.experts\.((\d+\.)?(gate_up_proj|down_proj)(\.(weight|bias|blocks|scales))?|gate_up_proj_(bias|blocks|scales)|down_proj_(bias|blocks|scales)))$" + r"|mlp\.experts\.(gate_up_proj|down_proj)(_(bias|blocks|scales))?)$" ) for layer_idx in range(num_layers) } @@ -190,12 +197,15 @@ class GptOss20bExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): target_name: str = "mlp" moe_prefix_name: str = "model.layers.{layer_idx}.mlp" - expert_prefix_name: str = "experts.{expert_idx}" + expert_prefix_name: str = "experts" # Router has both weight and bias router_weights: List[str] = field(default_factory=lambda: ["router.weight"]) router_biases: List[str] = field(default_factory=lambda: ["router.bias"]) + # Fused format: experts stored as single tensors + is_fused_experts: bool = True + # Fused format: single tensors containing all experts (test models) fused_expert_weights: List[str] = field( default_factory=lambda: [ @@ -212,5 +222,12 @@ class GptOss20bExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): default_factory=lambda: ["gate_up_proj_bias", "down_proj_bias"] ) - # Fused format: experts stored as single tensors - is_fused_experts: bool = True + def get_modules_names_to_hook(self, model) -> List[Tuple[int, str]]: + target_class_name = "GptOssTopKRouter" + + module_names_to_hook = [] + for module_name, module in model.named_modules(): + if module_name.endswith(self.target_name) and module.__class__.__name__ == target_class_name: + module_names_to_hook.append((self.block_idx_from_module_name(module_name), module_name)) + return module_names_to_hook + diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_pruned_to_mxfp4.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_pruned_to_mxfp4.py new file mode 100644 index 000000000..8e993573d --- /dev/null +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss_20b/gpt_oss_pruned_to_mxfp4.py @@ -0,0 +1,506 @@ +#!/usr/bin/env python3 +""" +Create a HuggingFace checkpoint with MXFP4 MoE weights from the original gpt-oss-120b model. + +This script: +1. Copies non-MoE weights from the student model (trained attention, embeddings, etc.) +2. Extracts MoE expert weights from the original gpt-oss-120b in MXFP4 format +3. Either loads experts_to_keep.json or deduces expert mappings by comparing weights +4. Outputs a new checkpoint in decihf format with PACKED MXFP4 expert weights +""" + +import argparse +import json +import os +import shutil +from typing import Dict, List, Any, Tuple, Optional + +import torch +from safetensors import safe_open +from safetensors.torch import save_file +from tqdm import tqdm + +from transformers.integrations.mxfp4 import convert_moe_packed_tensors + + +def deduce_experts_for_layer( + layer: int, + original_path: str, + original_index: Dict, + student_path: str, +) -> Tuple[List[int], int, int]: + """ + Deduce which original experts match the student experts by comparing weights. + + Compares dequantized MXFP4 weights from the original model against the student + model's BF16 weights using L2 distance. Finds the best 1-to-1 matching. + + Args: + layer: Layer index + original_path: Path to original model + original_index: Original model's safetensors index + student_path: Path to student model + num_student_experts: Number of experts in student model (if None, auto-detect) + + Returns: + Tuple of (expert_indices, num_student_experts, num_original_experts) + """ + # Load original tensors + orig_tensors = load_layer_tensors(original_path, layer, original_index) + mlp1_blocks = orig_tensors[f'model.layers.{layer}.mlp.experts.gate_up_proj_blocks'] + mlp1_scales = orig_tensors[f'model.layers.{layer}.mlp.experts.gate_up_proj_scales'] + mlp2_blocks = orig_tensors[f'model.layers.{layer}.mlp.experts.down_proj_blocks'] + mlp2_scales = orig_tensors[f'model.layers.{layer}.mlp.experts.down_proj_scales'] + + num_original_experts = mlp1_blocks.shape[0] + + # Load student tensors + student_subblocks = os.path.join(student_path, 'subblocks_safetensors') + student_ffn = os.path.join(student_subblocks, f'block_{layer}_ffn.safetensors') + if not os.path.exists(student_ffn): + print(f"FFN file not found at {student_ffn} - fallback to no_op") + return [], 0, num_original_experts + + student_experts = {} + with safe_open(student_ffn, framework='pt') as f: + for key in f.keys(): + if 'experts' in key or 'router' in key: + student_experts[key] = f.get_tensor(key) + + # Auto-detect number of student experts + num_student_experts = student_experts[f'model.layers.{layer}.mlp.experts.gate_up_proj'].size(0) + print(f" Layer {layer}: Comparing {num_student_experts} student experts against {num_original_experts} original experts") + + # Pre-dequantize all original experts once (optimization) + print(f" Pre-dequantizing {num_original_experts} original experts...") + deqexpert_mlp1 = convert_moe_packed_tensors(mlp1_blocks, mlp1_scales).cpu() + deqexpert_mlp2 = convert_moe_packed_tensors(mlp2_blocks, mlp2_scales).cpu() + original_experts_dequant = [] + for orig_idx in range(num_original_experts): + + original_experts_dequant.append({ + 'up': deqexpert_mlp1[orig_idx], + 'down': deqexpert_mlp2[orig_idx] + }) + + # For each student expert, find best matching original expert + experts_to_keep = [] + used_original_indices = set() + + # Number of values to use for quick comparison (tune this) + quick_compare_size = 8 + # Number of candidates to keep for full comparison + top_k_candidates = min(10, num_original_experts) + + for student_idx in range(num_student_experts): + # Get student expert weights + prefix = f'model.layers.{layer}.mlp' + student_up = student_experts.get(f'{prefix}.experts.gate_up_proj')[student_idx] + student_down = student_experts.get(f'{prefix}.experts.down_proj')[student_idx] + + # if student_gate is None or student_up is None or student_down is None: + if student_up is None or student_down is None: + raise ValueError(f"Missing student expert weights for layer {layer} expert {student_idx}") + + # Step 1: Quick filtering using first N values + candidate_scores = [] + for orig_idx in range(num_original_experts): + if orig_idx in used_original_indices: + continue + + orig_expert = original_experts_dequant[orig_idx] + + up_quick = (orig_expert['up'].flatten()[:quick_compare_size] - + student_up.float().flatten()[:quick_compare_size]).pow(2).mean().sqrt() + down_quick = (orig_expert['down'].flatten()[:quick_compare_size] - + student_down.float().flatten()[:quick_compare_size]).pow(2).mean().sqrt() + + quick_score = (up_quick + down_quick) / 2.0 + candidate_scores.append((orig_idx, quick_score.item())) + + # Step 2: Get top-k candidates based on quick comparison + candidate_scores.sort(key=lambda x: x[1]) + top_candidates = [idx for idx, _ in candidate_scores[:top_k_candidates]] + + # Step 3: Full comparison only on top candidates + best_match_idx = None + best_match_score = float('inf') + + for orig_idx in top_candidates: + orig_expert = original_experts_dequant[orig_idx] + + # Full comparison across all values + up_diff = (orig_expert['up'] - student_up.float()).pow(2).mean().sqrt() + down_diff = (orig_expert['down'] - student_down.float()).pow(2).mean().sqrt() + + score = (up_diff + down_diff) / 2.0 + + if score < best_match_score: + best_match_score = score + best_match_idx = orig_idx + + if best_match_idx is None: + raise ValueError(f"Could not find match for student expert {student_idx} in layer {layer}") + + experts_to_keep.append(best_match_idx) + used_original_indices.add(best_match_idx) + print(f" Student expert {student_idx} -> Original expert {best_match_idx} (RMSE: {best_match_score:.6f})") + + return experts_to_keep, num_student_experts, num_original_experts + + +def load_original_index(path: str) -> Dict[str, Any]: + """Load the original model's safetensors index.""" + with open(path, 'r') as f: + return json.load(f) + + +def load_layer_tensors(original_path: str, layer: int, index: Dict) -> Dict[str, torch.Tensor]: + """Load all MoE-related tensors for a layer, potentially from multiple files.""" + keys_to_load = [ + f'model.layers.{layer}.mlp.experts.gate_up_proj_blocks', + f'model.layers.{layer}.mlp.experts.gate_up_proj_scales', + f'model.layers.{layer}.mlp.experts.gate_up_proj_bias', + f'model.layers.{layer}.mlp.experts.down_proj_blocks', + f'model.layers.{layer}.mlp.experts.down_proj_scales', + f'model.layers.{layer}.mlp.experts.down_proj_bias', + f'model.layers.{layer}.mlp.router.weight', # Router weight + f'model.layers.{layer}.mlp.router.bias', # Router bias + ] + + # Group by file + file_to_keys = {} + for key in keys_to_load: + if key in index['weight_map']: + filename = index['weight_map'][key] + if filename not in file_to_keys: + file_to_keys[filename] = [] + file_to_keys[filename].append(key) + + # Load from each file + tensors = {} + for filename, keys in file_to_keys.items(): + filepath = os.path.join(original_path, filename) + with safe_open(filepath, framework='pt') as f: + for key in keys: + tensors[key] = f.get_tensor(key) + + return tensors + + +def copy_non_moe_weights( + student_path: str, + output_path: str, + num_layers: int +) -> Dict[str, str]: + """ + Copy non-MoE weights from student model. + Returns weight_map for the new index. + """ + weight_map = {} + subblocks_dir = os.path.join(output_path, 'subblocks_safetensors') + os.makedirs(subblocks_dir, exist_ok=True) + + student_subblocks = os.path.join(student_path, 'subblocks_safetensors') + + # Copy embeddings + src_emb = os.path.join(student_subblocks, 'embeddings.safetensors') + dst_emb = os.path.join(subblocks_dir, 'embeddings.safetensors') + shutil.copy2(src_emb, dst_emb) + with safe_open(src_emb, framework='pt') as f: + for key in f.keys(): + weight_map[key] = 'subblocks_safetensors/embeddings.safetensors' + + # Copy lm_head + src_head = os.path.join(student_subblocks, 'lm_head.safetensors') + dst_head = os.path.join(subblocks_dir, 'lm_head.safetensors') + shutil.copy2(src_head, dst_head) + with safe_open(src_head, framework='pt') as f: + for key in f.keys(): + weight_map[key] = 'subblocks_safetensors/lm_head.safetensors' + + # Copy attention blocks + for layer in range(num_layers): + src_attn = os.path.join(student_subblocks, f'block_{layer}_attention.safetensors') + dst_attn = os.path.join(subblocks_dir, f'block_{layer}_attention.safetensors') + shutil.copy2(src_attn, dst_attn) + with safe_open(src_attn, framework='pt') as f: + for key in f.keys(): + weight_map[key] = f'subblocks_safetensors/block_{layer}_attention.safetensors' + + return weight_map + + + + +def process_single_layer( + layer: int, + original_path: str, + original_index: Dict, + student_path: str, + output_path: str, + experts_to_keep: List[int], +) -> Tuple[Dict[str, str], List[str]]: + """ + Process a single layer - loads tensors from potentially multiple files. + Returns (weight_map, verification_errors). + """ + weight_map = {} + verification_errors = [] + subblocks_dir = os.path.join(output_path, 'subblocks_safetensors') + student_subblocks = os.path.join(student_path, 'subblocks_safetensors') + + # Load all tensors for this layer (may come from multiple files) + orig_tensors = load_layer_tensors(original_path, layer, original_index) + + # Load student FFN file + student_ffn = os.path.join(student_subblocks, f'block_{layer}_ffn.safetensors') + + tensors_to_save = {} + student_tensors = {} + + with safe_open(student_ffn, framework='pt') as f: + for key in f.keys(): + tensor = f.get_tensor(key) + if 'experts' not in key and 'router' not in key: + # Copy norm weights + tensors_to_save[key] = tensor + + # Get router from original model, sliced to kept experts + orig_router_weight = orig_tensors[f'model.layers.{layer}.mlp.router.weight'] + orig_router_bias = orig_tensors[f'model.layers.{layer}.mlp.router.bias'] + + kept_indices_tensor = torch.tensor(experts_to_keep, dtype=torch.long) + sliced_router_weight = orig_router_weight[kept_indices_tensor] + sliced_router_bias = orig_router_bias[kept_indices_tensor] + + tensors_to_save[f'model.layers.{layer}.mlp.router.weight'] = sliced_router_weight + tensors_to_save[f'model.layers.{layer}.mlp.router.bias'] = sliced_router_bias + + # Get MoE tensors + mlp1_blocks = orig_tensors[f'model.layers.{layer}.mlp.experts.gate_up_proj_blocks'] + mlp1_scales = orig_tensors[f'model.layers.{layer}.mlp.experts.gate_up_proj_scales'] + mlp2_blocks = orig_tensors[f'model.layers.{layer}.mlp.experts.down_proj_blocks'] + mlp2_scales = orig_tensors[f'model.layers.{layer}.mlp.experts.down_proj_scales'] + mlp1_bias = orig_tensors[f'model.layers.{layer}.mlp.experts.gate_up_proj_bias'] + mlp2_bias = orig_tensors[f'model.layers.{layer}.mlp.experts.down_proj_bias'] + + tensors_to_save[f'model.layers.{layer}.mlp.experts.gate_up_proj_blocks'] = mlp1_blocks[kept_indices_tensor] + tensors_to_save[f'model.layers.{layer}.mlp.experts.gate_up_proj_scales'] = mlp1_scales[kept_indices_tensor] + tensors_to_save[f'model.layers.{layer}.mlp.experts.gate_up_proj_bias'] = mlp1_bias[kept_indices_tensor] + + tensors_to_save[f'model.layers.{layer}.mlp.experts.down_proj_blocks'] = mlp2_blocks[kept_indices_tensor] + tensors_to_save[f'model.layers.{layer}.mlp.experts.down_proj_scales'] = mlp2_scales[kept_indices_tensor] + tensors_to_save[f'model.layers.{layer}.mlp.experts.down_proj_bias'] = mlp2_bias[kept_indices_tensor] + + # Save the FFN file + output_file = os.path.join(subblocks_dir, f'block_{layer}_ffn.safetensors') + save_file(tensors_to_save, output_file) + + # Build weight map + for key in tensors_to_save.keys(): + weight_map[key] = f'subblocks_safetensors/block_{layer}_ffn.safetensors' + + return weight_map, verification_errors + + +def copy_config_files(student_path: str, output_path: str): + """Copy configuration files from student model and update config.json.""" + files_to_copy = [ + 'tokenizer.json', + 'tokenizer_config.json', + 'special_tokens_map.json', + 'chat_template.jinja', + ] + + # Also copy transformers compatibility files + if os.path.exists(student_path): + for f in os.listdir(student_path): + if f.startswith('transformers_'): + files_to_copy.append(f) + + for filename in files_to_copy: + src = os.path.join(student_path, filename) + dst = os.path.join(output_path, filename) + + # Try student path first + if os.path.exists(src): + try: + shutil.copy2(src, dst) + continue + except PermissionError: + pass + + # If we get here, file doesn't exist or permission denied + if not os.path.exists(dst): + print(f" Warning: Could not copy {filename}") + + # Update config.json for DeciGptOssForCausalLM with MXFP4 + src_config = os.path.join(student_path, 'config.json') + if not os.path.exists(src_config): + raise FileNotFoundError(f"config.json not found at {src_config}") + + with open(src_config, 'r') as f: + config = json.load(f) + + # Set architecture to DeciGptOssForCausalLM for MXFP4 support + config['architectures'] = ['DeciGptOssForCausalLM'] + + # Add quantization_config so vllm calls _load_weights_mxfp4 + config['quantization_config'] = { + "quant_method": "mxfp4", + "modules_to_not_convert": [ + "model.layers.*.self_attn", + "model.layers.*.mlp.router", + "model.embed_tokens", + "lm_head", + ] + } + + dst_config = os.path.join(output_path, 'config.json') + with open(dst_config, 'w') as f: + json.dump(config, f, indent=2) + + +def main(): + parser = argparse.ArgumentParser(description='Create MXFP4 checkpoint from student model') + parser.add_argument( + '--student-path', + type=str, + required=True, + help='Path to student model checkpoint' + ) + parser.add_argument( + '--original-path', + type=str, + required=True, + help='Path to original gpt-oss-120b model with MXFP4 weights' + ) + parser.add_argument( + '--output-path', + type=str, + required=True, + help='Output path for the new checkpoint' + ) + parser.add_argument( + '--num-layers', + type=int, + default=36, + help='Number of transformer layers' + ) + args = parser.parse_args() + + print(f"Creating MXFP4 checkpoint...") + print(f" Student model: {args.student_path}") + print(f" Original model: {args.original_path}") + print(f" Output: {args.output_path}") + + + # Load original model index + original_index = load_original_index( + os.path.join(args.original_path, 'model.safetensors.index.json') + ) + + print("\nDeducing expert mappings by comparing weights...") + experts_to_keep = [] + layer_statistics = [] # Store (num_student, num_original) for each layer + + for layer in range(args.num_layers): + layer_experts, num_student, num_original = deduce_experts_for_layer( + layer, + args.original_path, + original_index, + args.student_path, + ) + experts_to_keep.append(layer_experts) + layer_statistics.append((num_student, num_original)) + + # Print statistics + print(f"\n{'='*70}") + print("EXPERT DEDUCTION STATISTICS") + print(f"{'='*70}") + print(f"{'Layer':<8} {'Student Experts':<18} {'Original Experts':<18} {'Kept %':<10}") + print(f"{'-'*70}") + + total_student = 0 + total_original = 0 + for layer, (num_student, num_original) in enumerate(layer_statistics): + percentage = (num_student / num_original * 100) if num_original > 0 else 0 + print(f"{layer:<8} {num_student:<18} {num_original:<18} {percentage:<10.2f}") + total_student += num_student + total_original += num_original + + print(f"{'-'*70}") + avg_percentage = (total_student / total_original * 100) if total_original > 0 else 0 + print(f"{'TOTAL':<8} {total_student:<18} {total_original:<18} {avg_percentage:<10.2f}") + print(f"{'='*70}") + print(f"\n Deduced experts_to_keep mapping for {len(experts_to_keep)} layers") + + # Create output directory + os.makedirs(args.output_path, exist_ok=True) + os.makedirs(os.path.join(args.output_path, 'subblocks_safetensors'), exist_ok=True) + + # Copy config files + print("Copying configuration files...") + copy_config_files(args.student_path, args.output_path) + + # Save experts_to_keep.json + experts_to_keep_output = os.path.join(args.output_path, 'experts_to_keep.json') + with open(experts_to_keep_output, 'w') as f: + json.dump(experts_to_keep, f, indent=2) + print(f" Saved experts_to_keep mapping to {experts_to_keep_output}") + + # Copy non-MoE weights (embeddings, attention, lm_head) + print("Copying non-MoE weights...") + weight_map = copy_non_moe_weights( + args.student_path, + args.output_path, + args.num_layers + ) + + # Load weights per layer (handles multi-file loading) + print(f"Processing {args.num_layers} layers...") + + all_verification_errors = [] + + # Process each layer + for layer in tqdm(range(args.num_layers), desc="Processing layers"): + if len(experts_to_keep[layer]) == 0: + print(f"Layer {layer} has no experts to keep - ffn->no_op") + continue + layer_weight_map, layer_errors = process_single_layer( + layer, + args.original_path, + original_index, + args.student_path, + args.output_path, + experts_to_keep[layer], + ) + weight_map.update(layer_weight_map) + all_verification_errors.extend(layer_errors) + + # Calculate total size + total_size = 0 + subblocks_dir = os.path.join(args.output_path, 'subblocks_safetensors') + for filename in os.listdir(subblocks_dir): + filepath = os.path.join(subblocks_dir, filename) + total_size += os.path.getsize(filepath) + + # Create model.safetensors.index.json + index = { + 'metadata': { + 'total_size': total_size + }, + 'weight_map': weight_map + } + + index_path = os.path.join(args.output_path, 'model.safetensors.index.json') + with open(index_path, 'w') as f: + json.dump(index, f, indent=2) + + print(f"\nCheckpoint created successfully at: {args.output_path}") + print(f"Total size: {total_size / 1e9:.2f} GB") + + +if __name__ == '__main__': + main() From 5ad56c5a05c12d4b3e7184fa864f98ac53b64d91 Mon Sep 17 00:00:00 2001 From: mchochowski Date: Fri, 13 Feb 2026 03:17:51 -0800 Subject: [PATCH 2/2] added paragraph in readme Signed-off-by: mchochowski --- examples/puzzletron/README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/puzzletron/README.md b/examples/puzzletron/README.md index 619fb619b..40d6bcfe0 100644 --- a/examples/puzzletron/README.md +++ b/examples/puzzletron/README.md @@ -285,3 +285,10 @@ python -m nemo_export/convert_nemo_to_hf --input-ckpt-path path/to/nemo-model -- ## Advanced Usage Modify `llama-3_1-8B_pruneffn_memory.yaml` file for advanced compression scenarios. + +## GptOss - 20b + +With this release Puzzle algorithm supports only experts removal for Gpt-Oss-20b. This model comes as a quantized checkpoint i.e. MoE experts matrices are quantized with mxfp4 format. In the prunning steps puzzle utilizes decompressed model (back to bf16) for statistics and scores computation. This means, during the conversion to puzzle format we decompress the model and store it as a bf16. Once the pruning is done i.e. experts to be removed are identified and the process is finished, user may want to get back the mxfp4 format of the checkpoint. To do so, there is an additional script, that takes the original and the pruned checkpoint and outputs pruned checkpoint in mxfp4 format. +```bash +python gpt_oss_pack_mxfp4_vllm.py --student-path /workspaces/any_model_gpt_oss_20b/mip/puzzle_solutions/stats_num_params_18014757184/solutions--checkpoints/solution_0/ --original-path /workspaces/source_model_checkpoints/openai_gpt-oss-20b/ --output-path /workspaces/any_model_gpt_oss_20b/mip/puzzle_solutions/stats_num_params_18014757184/solutions--checkpoints/mxfp4-ckpt/ --deduce-experts --num-layers 24 +```