From fb600a435fd2534342722219ad762910f0639229 Mon Sep 17 00:00:00 2001 From: willg-nv Date: Thu, 12 Feb 2026 19:22:19 +0800 Subject: [PATCH 01/11] Integrate Automated QDQ placement tool - part 2.2 (#845) ## What does this PR do? This PR implements RegionSearch class. RegionSearch could help partition big ONNX model into small region. QDQ autouning will be performed on the regions. **Overview:** ? ## Usage ```python # Add a code snippet demonstrating how to use this ``` ## Testing ## Before your PR is "*Ready for review*" - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: No, document updates is in Part 4. - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: CHANGELOG will be updated when all changes are ready. ## Additional Information ## Summary by CodeRabbit ## Release Notes **Refactor** * Improved ONNX quantization backend with new optimization framework and extensive test coverage to enhance internal graph processing capabilities. --------- Signed-off-by: Will Guo Signed-off-by: Hung-Yueh --- .../quantization/autotune/region_search.py | 1083 +++++++++++++++++ .../autotune/test_region_search.py | 345 ++++++ 2 files changed, 1428 insertions(+) create mode 100644 modelopt/onnx/quantization/autotune/region_search.py create mode 100644 tests/unit/onnx/quantization/autotune/test_region_search.py diff --git a/modelopt/onnx/quantization/autotune/region_search.py b/modelopt/onnx/quantization/autotune/region_search.py new file mode 100644 index 000000000..02f8282a0 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/region_search.py @@ -0,0 +1,1083 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hierarchical region discovery and partitioning for ONNX graphs.""" + +import sys +from collections import defaultdict, deque + +import onnx_graphsurgeon as gs + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.autotune.common import Region, RegionType +from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + +DEFAULT_MAX_STEPS = 10 +DEFAULT_MAX_NODES_TO_SHOW = 20 +MAX_PROBE_STEPS_AFTER_CONVERGE = 3 + + +class RegionSearchBase: + """Base class for region search algorithms providing common graph analysis utilities. + + This class serves as a foundation for region-based graph analysis algorithms by + providing essential data structures and methods for: + - Graph traversal and reachability analysis + - Divergence/convergence pattern detection + - Region boundary computation + - Tensor flow tracking + + For large graphs, initialization may take significant time but enables + efficient queries during region formation. + """ + + def __init__( + self, + graph: gs.Graph, + root: Region | None = None, + max_steps: int = DEFAULT_MAX_STEPS, + tensor_users_map: dict[str, list[int]] | None = None, + forward_reachable_nodes_map: dict[int, dict[int, int]] | None = None, + ): + """Initialize the base region search with graph analysis. + + Performs pre-computation of essential data structures for efficient + region analysis: + 1. Creates or validates root region containing all nodes + 2. Builds tensor-to-users mapping for divergence detection + 3. Pre-computes forward reachability for convergence detection + """ + self.graph = graph + if tensor_users_map is None: + tensor_users_map = get_tensor_consumer_node_indices(self.graph) + self.tensor_users_map = tensor_users_map + if root is None: + root = self._build_root_region() + self.root = root + if forward_reachable_nodes_map is None: + forward_reachable_nodes_map = self._build_forward_reachable_nodes_map( + max_steps=max_steps + ) + self.forward_reachable_nodes_map = forward_reachable_nodes_map + + def _build_root_region(self) -> Region: + """Create a root region containing all nodes in the graph. + + The root region serves as the universal search space for region + formation algorithms. It represents the entire computation graph + as a single region before any partitioning. + + Returns: + Region of type ROOT containing all graph nodes. + """ + root = Region(region_id=0, level=0, region_type=RegionType.ROOT) + root.nodes.update(range(len(self.graph.nodes))) + self.compute_region_boundaries(root) + return root + + def _is_tensor_divergent(self, tensor_name: str) -> bool: + """Check if a tensor is consumed by multiple nodes (divergent). + + A divergent tensor indicates branching in the computation graph, + where one operation's output feeds into multiple downstream operations. + + Args: + tensor_name: Name of the tensor to check + + Returns: + True if tensor has more than one consumer, False otherwise + """ + return len(self.tensor_users_map.get(tensor_name, [])) > 1 + + def _is_node_divergent(self, node_idx: int) -> bool: + """Check if a node has outputs that branch to multiple consumers. + + A divergent node is one that produces outputs consumed by multiple + downstream nodes, creating branches in the computation graph. These + nodes are important boundaries for region formation. + + Args: + node_idx: Index of the node to check + + Returns: + True if the node has at least one output consumed by multiple nodes, + False otherwise or if node is not in root region. + """ + if node_idx not in self.root.get_nodes(): + logger.debug(f"Node {node_idx} not in root region") + return False + + node = self.graph.nodes[node_idx] + divergent_outputs = [ + out.name for out in node.outputs if self._is_tensor_divergent(out.name) + ] + is_divergent = len(divergent_outputs) > 0 + + if is_divergent: + logger.debug( + f"Divergent node {node_idx} ({node.op}): {len(divergent_outputs)} branches" + ) + + return is_divergent + + def _compute_forward_reachable_nodes( + self, start_node_idx: int, max_steps: int + ) -> dict[int, int]: + """Compute all nodes reachable forward from a starting node with distances. + + Uses breadth-first search (BFS) to find all nodes reachable by following + forward edges (data flow direction) from the start node, up to a maximum + distance. Records the shortest-path distance to each reachable node. + + Args: + start_node_idx: Index of node to start search from + max_steps: Maximum forward distance to explore + + Returns: + Dictionary mapping reachable node indices to their distances from start. + Includes start_node_idx mapped to distance 0. + """ + reachable: dict[int, int] = {start_node_idx: 0} + queue: deque[tuple[int, int]] = deque([(start_node_idx, 0)]) + while queue: + current_node_idx, distance = queue.popleft() + if distance >= max_steps: + continue + for output in self.graph.nodes[current_node_idx].outputs: + for next_node_idx in self.tensor_users_map.get(output.name, ()): + if next_node_idx not in reachable: + reachable[next_node_idx] = distance + 1 + queue.append((next_node_idx, distance + 1)) + return reachable + + def _build_forward_reachable_nodes_map(self, max_steps: int) -> dict[int, dict[int, int]]: + """Pre-compute forward reachability for all nodes in the graph. + + This is a key optimization that enables efficient convergence detection. + By pre-computing forward reachability once, we can quickly answer queries + like "Can node A reach node B?" and "What is the distance from A to B?" + + Args: + max_steps: Maximum forward distance to pre-compute for each node. + Limits both time and space complexity. + + Returns: + Nested dictionary where outer key is start node index, inner key is + reachable node index, and value is shortest-path distance. + """ + logger.debug(f"Building forward reachability map (max_steps={max_steps})...") + forward_reachable_nodes_map: dict[int, dict[int, int]] = {} + for node_idx in self.root.get_nodes(): + forward_reachable_nodes_map[node_idx] = self._compute_forward_reachable_nodes( + node_idx, max_steps + ) + + total_reachable = sum(len(reachable) for reachable in forward_reachable_nodes_map.values()) + avg_reachable = total_reachable / len(self.root.get_nodes()) if self.root.get_nodes() else 0 + logger.debug(f"Reachability map complete: avg {avg_reachable:.1f} reachable nodes per node") + return forward_reachable_nodes_map + + def _find_common_reachable_nodes( + self, node_idx: int, branches: list[int] + ) -> tuple[list[dict], set[int]]: + """Find common reachable nodes from all branches (potential convergence points). + + Used as STEP 1 of convergence detection in _find_converge_nodes. + + Args: + node_idx: Index of the divergent node (excluded from common_nodes). + branches: List of branch head node indices. + + Returns: + (branch_reachable, common_nodes) + """ + branch_reachable = [self.forward_reachable_nodes_map.get(b, {}) for b in branches] + + if not branch_reachable: + logger.debug(" No reachable nodes from branches") + return [], set() + + common_nodes = set.intersection(*[set(r.keys()) for r in branch_reachable]) + logger.debug(f" {len(common_nodes)} common nodes found") + common_nodes.discard(node_idx) + + if not common_nodes: + logger.debug(" No valid convergence candidates") + return [], set() + + return branch_reachable, common_nodes + + def _evaluate_convergence_candidate( + self, + candidate_idx: int, + reachable_from_start: dict, + branch_reachable: list, + ) -> tuple[bool, int]: + r"""Check if a candidate convergence node forms a valid region and return its max distance. + + A valid region has no \"escaping\" edges: no node inside the region may reach a node + outside the region before reaching the candidate convergence point. + + Args: + candidate_idx: Candidate convergence node index. + reachable_from_start: Forward reachability from the divergent node. + branch_reachable: Per-branch reachability dicts (for max distance). + + Returns: + (is_valid, max_distance). max_distance is only meaningful when is_valid is True. + """ + region_nodes: set[int] = set(reachable_from_start.keys()) + reachable_from_candidate = self.forward_reachable_nodes_map.get(candidate_idx, {}) + region_nodes = region_nodes - set(reachable_from_candidate.keys()) + + for rnode_index in region_nodes: + reachable_from_rnode = self.forward_reachable_nodes_map.get(rnode_index, {}) + rnode_to_candidate_distance = reachable_from_rnode.get(candidate_idx, float("inf")) + for test_node_idx in reachable_from_rnode: + if test_node_idx in region_nodes: + continue + rnode_to_test_distance = reachable_from_rnode.get(test_node_idx, float("inf")) + if any( + d == float("inf") for d in (rnode_to_test_distance, rnode_to_candidate_distance) + ): + return False, 0 + + max_distance = max(reachable[candidate_idx] for reachable in branch_reachable) + return True, max_distance + + def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]: + """Find convergence point and intermediate nodes for a divergent node. + + Given a divergent node (where computation branches), this method finds: + 1. The convergence node: Where the branches rejoin + 2. All nodes between divergence and convergence + + Args: + node_idx: Index of the divergent node to find convergence for + + Returns: + Tuple containing: + - Convergence node index (None if no convergence found) + - Set of nodes between divergence and convergence + """ + node = self.graph.nodes[node_idx] + logger.debug(f"Finding convergence for node {node_idx} ({node.op})") + + branches: list[int] = [] + for output in node.outputs: + branches.extend(self.tensor_users_map.get(output.name, [])) + + branches = list(dict.fromkeys(branches)) + + logger.debug(f" {len(branches)} unique branches found") + + if len(branches) <= 1: + logger.debug(" Insufficient branches for convergence") + return None, set() + + branch_reachable, common_nodes = self._find_common_reachable_nodes(node_idx, branches) + if not branch_reachable or not common_nodes: + return None, set() + + # Select Best Convergence Node with Region Validity Check + converge_node_idx: int | None = None + min_max_distance = float("inf") + + reachable_from_start = self.forward_reachable_nodes_map.get(node_idx, {}) + + for candidate_idx in common_nodes: + valid, max_distance = self._evaluate_convergence_candidate( + candidate_idx, reachable_from_start, branch_reachable + ) + if not valid: + continue + if max_distance < min_max_distance: + min_max_distance = max_distance + converge_node_idx = candidate_idx + + # If no valid convergence found, this divergence has no convergence + if converge_node_idx is None: + logger.debug(" No valid convergence found") + return None, set() + + converge_node = self.graph.nodes[converge_node_idx] + logger.debug( + f" Convergence at node {converge_node_idx} ({converge_node.op}), distance {min_max_distance}" + ) + + # Compute All Nodes Between Divergence and Convergence + visited_nodes: set[int] = set() + for candidate_idx in reachable_from_start: + if candidate_idx == converge_node_idx: + continue + reachable_from_candidate = self.forward_reachable_nodes_map.get(candidate_idx, {}) + if converge_node_idx in reachable_from_candidate: + visited_nodes.add(candidate_idx) + logger.debug(f" {len(visited_nodes)} nodes between divergence and convergence") + return converge_node_idx, visited_nodes + + def _max_distance_to_nodes(self, src_idx: int, dst_indices: set[int]) -> int: + """Compute maximum distance from a source node to a set of destination nodes. + + Uses pre-computed forward reachability to efficiently find the maximum + shortest-path distance from src_idx to any node in dst_indices. + + Args: + src_idx: Index of the source node + dst_indices: Set of destination node indices + + Returns: + Maximum distance from src_idx to any node in dst_indices + """ + max_distance = 0 + for dst_idx in dst_indices: + reachable = self.forward_reachable_nodes_map.get(src_idx, {}) + if dst_idx in reachable: + max_distance = max(max_distance, reachable[dst_idx]) + + logger.debug( + f"Max distance from node {src_idx}: {max_distance} steps to {len(dst_indices)} nodes" + ) + return max_distance + + def compute_region_boundaries(self, region: Region, include_constant: bool = False) -> None: + """Compute input and output tensor boundaries for a region. + + Args: + region: The region to compute boundaries for + include_constant: Whether to include constant tensors in inputs + """ + node_indices = region.get_region_nodes_and_descendants() + + consumed_tensors: set[str] = set() + produced_tensors: set[str] = set() + region_outputs: set[str] = set() + + for node_idx in node_indices: + if node_idx >= len(self.graph.nodes): + continue + node = self.graph.nodes[node_idx] + + # Collect consumed tensors (potential inputs) + for input_tensor in node.inputs: + if isinstance(input_tensor, gs.Constant) and not include_constant: + continue + consumed_tensors.add(input_tensor.name) + + # Collect produced tensors and determine outputs + for output_tensor in node.outputs: + tensor_name = output_tensor.name + produced_tensors.add(tensor_name) + + consumer_indices = self.tensor_users_map.get(tensor_name, []) + has_external_consumer = any(idx not in node_indices for idx in consumer_indices) + is_graph_output = output_tensor in self.graph.outputs + + if has_external_consumer or is_graph_output or not consumer_indices: + region_outputs.add(tensor_name) + + # Region inputs = consumed tensors not produced internally + region.inputs = sorted(consumed_tensors - produced_tensors) + region.outputs = sorted(region_outputs) + + logger.debug( + f"Computed boundaries: {len(region.inputs)} inputs, {len(region.outputs)} outputs" + ) + + def print_tree( + self, + region: Region | None = None, + indent: int = 0, + max_items: int = DEFAULT_MAX_NODES_TO_SHOW, + file=None, + ) -> None: + """Print hierarchical region tree in human-readable text format.""" + region = region or self.root + file = file or sys.stdout + p = " " * indent + + def truncated(items, fmt=str): + """Yield formatted items, truncating with count if needed.""" + items = list(items) + yield from (fmt(x) for x in items[:max_items]) + if len(items) > max_items: + yield f"... and {len(items) - max_items} more" + + direct_nodes = region.get_nodes() + children = region.get_children() + # Header + counts + print( + f"{p}├─ Region {region.id} (Level {region.level}, Type: {region.type.value})", file=file + ) + print(f"{p}│ ├─ Direct nodes: {len(direct_nodes)}", file=file) + print(f"{p}│ ├─ Total nodes: {len(region.get_region_nodes_and_descendants())}", file=file) + print(f"{p}│ ├─ Children: {len(children)}", file=file) + # I/O + for label, items in [("Inputs", region.inputs), ("Outputs", region.outputs)]: + print(f"{p}│ ├─ {label}: {len(items)}", file=file) + for line in truncated(items): + print(f"{p}│ │ - {line}", file=file) + # Direct nodes + if direct_nodes: + print(f"{p}│\n{p}│ Nodes in this region:", file=file) + + def node_fmt(i: int) -> str: + return f"Node {i}: {self.graph.nodes[i].op} ({self.graph.nodes[i].name})" + + for line in truncated(sorted(direct_nodes), node_fmt): + print(f"{p}│ - {line}", file=file) + # Children + if children: + print(f"{p}│\n{p}│ Child regions:", file=file) + for child in children: + print(f"{p}│", file=file) + self.print_tree(child, indent + 1, max_items, file) + + +class RegionPartitioner(RegionSearchBase): + """Bottom-up graph partitioner that creates initial regions based on divergence patterns. + + This class implements Phase 1 of the combined region search strategy. It performs + a systematic traversal of the computation graph from inputs to outputs, identifying + natural boundaries for region formation based on computation flow patterns. + + **Core Strategy:** + Partitions the graph by analyzing three types of computational patterns: + + 1. **Divergent Nodes with Convergence:** + - Nodes whose outputs branch to multiple paths (divergence) + - Paths that eventually rejoin at a common node (convergence) + - Creates a single region encompassing divergence + branches + convergence + - Example: A → (B,C) → D creates region containing {A, B, C, D} + + 2. **Divergent Nodes without Convergence:** + - Nodes whose outputs branch but never rejoin + - Creates a single-node "orphan" region for the divergent node + - Example: A → (B,C) with no convergence creates region {A} + + 3. **Linear Sequences:** + - Chains of non-divergent nodes (simple sequential computation) + - Groups entire sequence into one region + - Example: A → B → C → D creates region {A, B, C, D} + """ + + def __init__( + self, + graph: gs.Graph, + tensor_users_map: dict[str, list[int]] | None = None, + forward_reachable_nodes_map: dict[int, dict[int, int]] | None = None, + ): + """Initialize the partitioner with a computation graph. + + Sets up necessary data structures and inherits graph analysis utilities + from RegionSearchBase (tensor users map, reachability, etc.). + + Args: + graph: The ONNX computation graph (onnx_graphsurgeon.Graph) + tensor_users_map: Mapping from tensor names to consuming node indices + forward_reachable_nodes_map: Pre-computed forward reachability for all nodes + """ + super().__init__( + graph, + root=None, + tensor_users_map=tensor_users_map, + forward_reachable_nodes_map=forward_reachable_nodes_map, + ) + self.regions: list[Region] = [] + self.current_region: Region | None = None + self.current_region_id: int = 0 + self.visited_nodes: set[int] = set() + + def _append_node_to_region(self, node_idx: int): + """Add a node to the current region, creating a new region if needed. + + This is the primary method for building regions incrementally. If no + region is currently active, creates a new LEAF region. Then adds the + specified node to that region. + + Args: + node_idx: Index of the node to add to the current region + + Returns: + None + """ + node = self.graph.nodes[node_idx] + if self.current_region is None: + self.current_region = Region( + region_id=self.current_region_id, level=0, region_type=RegionType.LEAF + ) + logger.debug(f"Started region {self.current_region_id}") + self.current_region_id += 1 + + self.current_region.nodes.add(node_idx) + logger.debug( + f" Added node {node_idx} ({node.op}), region size: {len(self.current_region.nodes)}" + ) + + def _commit_region(self): + """Finalize and store the current region being built. + + Completes region construction by: + 1. Computing input/output tensor boundaries + 2. Adding region to the completed regions list + 3. Resetting current_region to None for next region + + **Post-Conditions:** + - current_region is added to regions list + - current_region is reset to None + - Region has computed input/output tensor lists + + Side Effects: + - Appends current_region to self.regions + - Sets current_region to None + - Logs region commit with size info + """ + if self.current_region is not None: + region_size = len(self.current_region.nodes) + region_id = self.current_region.id + + self.compute_region_boundaries(self.current_region) + + self.regions.append(self.current_region) + logger.debug( + f"Committed region {region_id}: {region_size} nodes (total: {len(self.regions)})" + ) + self.current_region = None + else: + logger.debug("No region to commit") + + def _build_sequence_from_node(self, node_idx: int, max_nodes: int = -1): + """Build a region from a linear sequence of nodes. + + Starting from a non-divergent node, follows the forward chain of nodes, + adding each non-divergent node to the current region. Stops when hitting: + - A divergent node (branches to multiple paths) + - A node already visited + - End of graph + + Args: + node_idx: Index of the starting node + max_nodes: Maximum number of nodes to add to the region (-1 for no limit) + + Returns: + None + """ + logger.debug(f"Building sequence from node {node_idx} ({self.graph.nodes[node_idx].op})") + + queue: deque[int] = deque([node_idx]) + nodes_added = 0 + + while queue: + current_idx = queue.popleft() + node = self.graph.nodes[current_idx] + + self._append_node_to_region(current_idx) + self.visited_nodes.add(current_idx) + nodes_added += 1 + + if self._is_node_divergent(current_idx): + logger.debug(f" Stopped at divergent node {current_idx} ({node.op})") + else: + # Queue successors for non-divergent nodes + for output in node.outputs: + if output.name in self.tensor_users_map: + queue.extend(self.tensor_users_map[output.name]) + + if 0 < max_nodes <= nodes_added: + logger.debug(" Max nodes reached") + break + + logger.debug(f"Sequence complete: {nodes_added} nodes") + + def _build_small_converged_region( + self, start_node_idx: int, converge_node_idx: int, visited_nodes: set[int] + ): + r"""Create a region encompassing divergence, branches, and convergence. + + Builds a single region containing: + - The divergent node (where branches split) + - All nodes in the branches + - The convergence node (where branches rejoin) + + This creates a "diamond" or "funnel" shaped region that captures + parallel computation paths and their merge point. + + **Structure:** + ``` + start (divergent) + / \ + path1 path2 (visited_nodes) + \\ / + convergence + ``` + """ + visited_nodes.remove(start_node_idx) + for node_idx in sorted(visited_nodes): + self._append_node_to_region(node_idx) + self.visited_nodes.add(node_idx) + if not self._is_node_divergent(converge_node_idx): + self._append_node_to_region(converge_node_idx) + self.visited_nodes.add(converge_node_idx) + self._build_sequence_from_node(converge_node_idx, max_nodes=MAX_PROBE_STEPS_AFTER_CONVERGE) + + def _build_region_from_node(self, node_idx: int): + """Process a single node and create appropriate region(s) based on its pattern. + + This is the core dispatch method that determines how to handle each node based on whether + it's divergent (branches) or sequential. + + - Pattern 1: Divergent with Convergence (Ideal Case) + - Pattern 2: Divergent without Convergence (Boundary Case) + - Pattern 3: Sequential Chain (Common Case) + + Args: + node_idx: Index of node to process + + Side Effects: + - Marks processed nodes as visited + - Creates and commits region(s) via helper methods + - May recursively process successor nodes (in sequence building) + """ + node = self.graph.nodes[node_idx] + + # Skip nodes already assigned to regions + if node_idx in self.visited_nodes: + logger.debug(f"Skipping node {node_idx} ({node.op}): already visited") + return + + logger.debug(f"Processing node {node_idx} ({node.op})") + + # Pattern 1 & 2: Handle divergent nodes + if self._is_node_divergent(node_idx): + logger.debug(" Divergent node, searching for convergence") + # Attempt to find where branches rejoin + converge_node_idx, visited_nodes = self._find_converge_nodes(node_idx) + # Check if convergence creates a reasonable-sized region + max_distance = self._max_distance_to_nodes(node_idx, visited_nodes) + # Pattern 1: Convergence found and region size is acceptable + if converge_node_idx is not None and max_distance < DEFAULT_MAX_STEPS: + converge_node = self.graph.nodes[converge_node_idx] + logger.debug( + f" Creating converged region: {len(visited_nodes)} nodes, " + f"convergence at {converge_node_idx} ({converge_node.op}), distance {max_distance}" + ) + # Create region containing: divergence + all branches + convergence + self._build_small_converged_region(node_idx, converge_node_idx, visited_nodes) + self._commit_region() + # Pattern 2: No convergence or region would be too large + else: + logger.debug(" Creating orphan region for divergent node") + # Create single-node region for this divergent node + # Its successors will be processed separately + self._append_node_to_region(node_idx) + self.visited_nodes.add(node_idx) + self._commit_region() + else: + # Pattern 3: Handle non-divergent (sequential) nodes + logger.debug(" Non-divergent node, building sequence") + # Build region by following the linear chain forward + self._build_sequence_from_node(node_idx) + self._commit_region() + + def partition_graph(self): + """Partition the entire graph into non-overlapping LEAF regions. + + This is the main entry point for bottom-up graph partitioning. Performs + a single pass over all nodes in graph order, creating regions based on + divergence/convergence patterns and sequential chains. + + Returns: + List of non-overlapping LEAF regions created from the graph. + + """ + logger.info(f"Partitioning graph ({len(self.graph.nodes)} nodes)") + logger.debug( + f"Initial state: {len(self.visited_nodes)} visited, {len(self.regions)} regions" + ) + + for node_idx in range(len(self.graph.nodes)): + self._build_region_from_node(node_idx) + + coverage_pct = ( + 100 * len(self.visited_nodes) / len(self.graph.nodes) if self.graph.nodes else 0 + ) + logger.info( + f"Partitioning complete: {len(self.regions)} regions, " + f"{len(self.visited_nodes)}/{len(self.graph.nodes)} nodes ({coverage_pct:.1f}%)" + ) + + if self.regions: + region_sizes = [len(r.nodes) for r in self.regions] + avg_size = sum(region_sizes) / len(region_sizes) + min_size = min(region_sizes) + max_size = max(region_sizes) + logger.debug(f"Region sizes: min={min_size}, max={max_size}, avg={avg_size:.1f}") + + return self.regions + + +class TopDownRegionBuilder(RegionSearchBase): + """Top-down region refiner that creates hierarchical structure from initial regions. + + This class implements Phase 2 of the combined region search strategy. It takes + a region created by RegionPartitioner and refines it by: + 1. Identifying and merging converged sub-patterns + 2. Splitting long sequences into optimal sub-regions + 3. Creating a hierarchical COMPOSITE region structure + """ + + def __init__( + self, + graph: gs.Graph, + root: Region, + next_region_id: int = 0, + maximum_sequence_region_size: int = 10, + tensor_users_map: dict[str, list[int]] | None = None, + forward_reachable_nodes_map: dict[int, dict[int, int]] | None = None, + ): + """Initialize the refiner with a region to refine. + + Args: + graph: The ONNX graph (onnx_graphsurgeon.Graph) + root: The region to refine (typically from RegionPartitioner) + next_region_id: Starting ID for new regions created during refinement + maximum_sequence_region_size: Maximum nodes per sequence region during merging (default: 10) + """ + super().__init__( + graph, + root=root, + tensor_users_map=tensor_users_map, + forward_reachable_nodes_map=forward_reachable_nodes_map, + ) + self.regions: list[Region] = [] + self.next_region_id = next_region_id + self.maximum_sequence_region_size = maximum_sequence_region_size + self.boundary_op_types = { + "Conv", + "ConvTranspose", + "Gemm", + "MatMul", + "AveragePool", + "MaxPool", + "GlobalAveragePool", + "GlobalMaxPool", + "Resize", + } + + def _create_leaf_region(self, node_indices: set[int]) -> Region: + """Create a new LEAF region containing specified nodes. + + Args: + node_indices: Set of node indices to add to the region + + Returns: + New LEAF region containing the specified nodes + """ + region = Region( + region_id=self.next_region_id, level=self.root.level + 1, region_type=RegionType.LEAF + ) + self.next_region_id += 1 + for node_idx in node_indices: + region.nodes.add(node_idx) + self.compute_region_boundaries(region) + return region + + def _build_region_usage_map(self, regions: list[Region]) -> dict[str, list[Region]]: + """Build mapping from tensor names to regions that consume them. + + Similar to tensor_users_map but at the region level instead of node level. + This enables efficient traversal of region dependencies for merging decisions. + + Args: + regions: List of regions to build the usage map for + + Returns: + Mapping from tensor names to regions that consume them + """ + region_usage_map: dict[str, list[Region]] = defaultdict(list) + for region in regions: + for input_tensor in region.inputs: + region_usage_map[input_tensor].append(region) + return region_usage_map + + def _split_sequence_regions(self, root: Region) -> list[Region]: + """Split a region into smaller sub-regions by merging producer-consumer chains. + + Takes a region and creates optimal sub-regions by: + 1. Initially splitting into individual single-node regions + 2. Traversing in data flow order (following tensor dependencies) + 3. Merging adjacent regions that form simple producer-consumer chains + 4. Respecting boundary operations and size limits + + Args: + root: The region to split + + Returns: + List of smaller sub-regions + """ + result_regions: list[Region] = [] + removed_regions: set[int] = set() + + # PHASE 1: Split into Single-Node Regions + for node_idx in root.get_nodes(): + region = Region( + region_id=self.next_region_id, level=root.level + 1, region_type=RegionType.LEAF + ) + region.nodes.add(node_idx) + self.compute_region_boundaries(region) + result_regions.append(region) + self.next_region_id += 1 + + region_usage_map = self._build_region_usage_map(result_regions) + + # PHASE 2: Merge Regions in Data Flow Order + queue = deque(root.inputs) + + while len(queue) > 0: + tensor_name = queue.popleft() + # Skip tensors not produced by any region in our scope + if tensor_name not in region_usage_map: + continue + # Process each region consuming this tensor (potential merge targets) + consumers = region_usage_map[tensor_name] + for consumer in consumers: + # Skip regions already merged into others + if consumer.id in removed_regions: + continue + # Merging criteria: ALL outputs go to same single region + common_use_region = None + can_merge = True + # Check all outputs of the consumer region + for output_tensor in consumer.outputs: + queue.append(output_tensor) + if output_tensor not in region_usage_map: + can_merge = False + break + use_regions = region_usage_map[output_tensor] + if len(use_regions) != 1: + can_merge = False + break + if common_use_region is None: + common_use_region = use_regions[0] + elif common_use_region != use_regions[0]: + can_merge = False + break + # No valid downstream region to merge with + if common_use_region is None or common_use_region.id in removed_regions: + can_merge = False + continue + # Constraint 1: Limit the number of boundary operations after merge + nodes_after_merge = set() + nodes_after_merge.update(consumer.get_nodes()) + nodes_after_merge.update(common_use_region.get_nodes()) + node_ops = [self.graph.nodes[idx].op for idx in nodes_after_merge] + boundary_op_count = sum( + [1 if op in self.boundary_op_types else 0 for op in node_ops] + ) + if boundary_op_count > 3: + can_merge = False + continue + # Constraint 2: Size limits to avoid overly large regions + # Keep regions manageable for optimization passes + if ( + len(consumer.nodes) >= self.maximum_sequence_region_size + or len(common_use_region.nodes) >= self.maximum_sequence_region_size + ): + # One or both regions too large - don't merge + can_merge = False + continue + # All criteria met: merge consumer into its downstream region + if can_merge: + common_use_region.merge(consumer) + removed_regions.add(consumer.id) + # Remove regions that were merged into others + result_regions = [region for region in result_regions if region.id not in removed_regions] + # Recompute boundaries for all remaining regions + for region in result_regions: + self.compute_region_boundaries(region) + + return result_regions + + def _merge_converged_regions(self, root: Region): + """Identify and merge convergence patterns within a region. + + Traverses the region to find divergent nodes and their convergence points, + creating sub-regions that capture divergence→branches→convergence patterns. + Nodes not part of any convergence pattern are left for sequence splitting. + + Args: + root: The region to merge + + Returns: + List of merged regions + """ + result_regions: list[Region] = [] + removed_nodes: set[int] = set() + queue = deque(root.inputs) + while len(queue) > 0: + tensor_name = queue.popleft() + if tensor_name not in self.tensor_users_map: + continue + consumer_nodes = self.tensor_users_map[tensor_name] + for node_idx in consumer_nodes: + # stop at boundary nodes + if node_idx not in root.get_nodes(): + continue + consumer = self.graph.nodes[node_idx] + for output_tensor in consumer.outputs: + if output_tensor.name not in self.tensor_users_map: + continue + queue.append(output_tensor.name) + # if the node is already in a region, skip + if node_idx in removed_nodes: + continue + if not self._is_node_divergent(node_idx): + continue + converge_node_idx, visited_nodes = self._find_converge_nodes(node_idx) + visited_nodes = visited_nodes.intersection(root.get_region_nodes_and_descendants()) + # if no convergence found, skip + if converge_node_idx is None: + continue + # group converged nodes into a region + if converge_node_idx in root.get_nodes(): + converged_region = self._create_leaf_region(visited_nodes) + result_regions.append(converged_region) + removed_nodes.update(visited_nodes) + continue + # create a leaf region for the remaining nodes + remaining_nodes = set(root.get_nodes()) - removed_nodes + if len(remaining_nodes) > 0: + result_regions.append(self._create_leaf_region(remaining_nodes)) + # compute region boundaries for all regions + for region in result_regions: + self.compute_region_boundaries(region) + return result_regions + + def build_composite_region(self) -> Region: + """Refine a flat region into a hierarchical COMPOSITE region.""" + # merge converged regions into composite regions + regions = self._merge_converged_regions(self.root) + # split sequence regions into smaller regions + result_regions: list[Region] = [] + for region in regions: + result_regions.extend(self._split_sequence_regions(region)) + for region in result_regions: + self.compute_region_boundaries(region, include_constant=True) + regions = result_regions + # merge all regions into a single composite region + if len(regions) > 1: + composite = Region( + region_id=self.next_region_id, + level=self.root.level, + region_type=RegionType.COMPOSITE, + ) + self.next_region_id += 1 + regions = sorted( + regions, key=lambda x: RegionPattern.from_region(x, self.graph).signature + ) + for region in regions: + composite.add_child(region) + self.compute_region_boundaries(composite) + regions = [composite] + self.regions = regions + return self.regions[0] + + +class CombinedRegionSearch(RegionSearchBase): + """Two-phase region search combining bottom-up partitioning with top-down refinement. + + This class implements a sophisticated region discovery algorithm that combines two + complementary strategies to create well-formed, hierarchical regions from an ONNX + computation graph. + + """ + + def __init__( + self, + graph: gs.Graph, + maximum_sequence_region_size: int = 10, + minimum_topdown_search_size: int = 10, + ): + """Initialize CombinedRegionSearch for a given ONNX graph.""" + super().__init__(graph) + self.regions: list[Region] = [] + self.minimum_topdown_search_size = minimum_topdown_search_size + self.maximum_sequence_region_size = maximum_sequence_region_size + + def search_regions(self) -> list[Region]: + """Execute two-phase region search to partition the graph into hierarchical regions. + + 1. Bottom-up partitioning + 2. Top-down refinement + + Args: + None + + Returns: + List of hierarchical regions created from the graph + """ + logger.info("Phase 1: Bottom-up partitioning") + logger.debug("Initializing RegionPartitioner") + region_partitioner = RegionPartitioner(self.graph) + + # Execute the bottom-up partitioning algorithm. + self.regions = region_partitioner.partition_graph() + + coverage_pct = ( + 100 * len(region_partitioner.visited_nodes) / len(self.graph.nodes) + if self.graph.nodes + else 0 + ) + logger.info( + f"Phase 1 complete: {len(self.regions)} regions, " + f"{len(region_partitioner.visited_nodes)}/{len(self.graph.nodes)} nodes ({coverage_pct:.1f}%)" + ) + logger.debug("Proceeding to Phase 2: Top-down refinement") + + logger.info("Phase 2: Top-down refinement") + next_region_id = region_partitioner.current_region_id + + refined_count = 0 + for idx, region in enumerate(self.regions): + node_count = len(region.get_region_nodes_and_descendants()) + if node_count < self.minimum_topdown_search_size: + logger.debug(f"Skipping region {idx}: {node_count} nodes (below minimum)") + continue + + logger.debug(f"Refining region {idx}: {node_count} nodes") + region_builder = TopDownRegionBuilder( + self.graph, + region, + next_region_id=next_region_id, + maximum_sequence_region_size=self.maximum_sequence_region_size, + tensor_users_map=region_partitioner.tensor_users_map, + forward_reachable_nodes_map=region_partitioner.forward_reachable_nodes_map, + ) + + self.regions[idx] = region_builder.build_composite_region() + node_count_after = len(self.regions[idx].get_region_nodes_and_descendants()) + if node_count != node_count_after: + logger.warning( + f"Node count mismatch in region {idx}: {node_count} → {node_count_after}" + ) + + region_partitioner.compute_region_boundaries(self.regions[idx]) + next_region_id = region_builder.next_region_id + refined_count += 1 + + logger.info(f"Phase 2 complete: refined {refined_count}/{len(self.regions)} regions") + + return self.regions diff --git a/tests/unit/onnx/quantization/autotune/test_region_search.py b/tests/unit/onnx/quantization/autotune/test_region_search.py new file mode 100644 index 000000000..e2fb179fd --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_region_search.py @@ -0,0 +1,345 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for region search algorithms. + +Tests CombinedRegionSearch, RegionPartitioner, and TopDownRegionBuilder. +Note: Comprehensive integration tests with real ONNX graphs should be in separate integration test files. +""" + +import io + +import onnx +import onnx_graphsurgeon as gs +import pytest +from onnx import helper + +from modelopt.onnx.quantization.autotune.common import Region, RegionType +from modelopt.onnx.quantization.autotune.region_search import ( + CombinedRegionSearch, + RegionPartitioner, + TopDownRegionBuilder, +) + + +@pytest.fixture +def simple_linear_graph(): + """ + Create a simple linear graph: Input -> Conv -> Relu -> Output. + + This is the simplest possible graph for testing region discovery. + """ + # Input + input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) + + # Output + output_tensor = helper.make_tensor_value_info( + "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] + ) + + # Conv node + conv_node = helper.make_node( + "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" + ) + + # Relu node + relu_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["output"], name="relu") + + # Create graph + graph = helper.make_graph( + [conv_node, relu_node], + "simple_linear", + [input_tensor], + [output_tensor], + initializer=[ + helper.make_tensor( + "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) + ) + ], + ) + + # Create model + model = helper.make_model(graph, producer_name="test") + + # Convert to GraphSurgeon + return gs.import_onnx(model) + + +@pytest.fixture +def divergent_graph(): + """ + Create a graph with divergence: Input -> Conv -> [Relu1, Relu2] -> Add -> Output. + + Tests divergence/convergence pattern detection. + """ + input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) + output_tensor = helper.make_tensor_value_info( + "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] + ) + + conv_node = helper.make_node( + "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" + ) + relu1_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["relu1_out"], name="relu1") + relu2_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["relu2_out"], name="relu2") + add_node = helper.make_node( + "Add", inputs=["relu1_out", "relu2_out"], outputs=["output"], name="add" + ) + + graph = helper.make_graph( + [conv_node, relu1_node, relu2_node, add_node], + "divergent", + [input_tensor], + [output_tensor], + initializer=[ + helper.make_tensor( + "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) + ) + ], + ) + + model = helper.make_model(graph, producer_name="test") + return gs.import_onnx(model) + + +class TestRegionPartitioner: + """Test RegionPartitioner basic functionality.""" + + def test_partition_linear_graph(self, simple_linear_graph): + """Test partitioning a simple linear graph.""" + partitioner = RegionPartitioner(simple_linear_graph) + + regions = partitioner.partition_graph() + + # Should create at least one region + assert len(regions) > 0 + + # Check that regions cover most nodes (ONNX GS may add Constant nodes that aren't partitioned) + total_nodes = sum(len(r.get_region_nodes_and_descendants()) for r in regions) + assert total_nodes > 0 + assert total_nodes <= len(simple_linear_graph.nodes) + + def test_partition_divergent_graph(self, divergent_graph): + """Test partitioning a divergent graph.""" + partitioner = RegionPartitioner(divergent_graph) + + regions = partitioner.partition_graph() + + # Should create regions covering all nodes + assert len(regions) > 0 + + # Check that regions cover most nodes (ONNX GS may add Constant nodes that aren't partitioned) + total_nodes = sum(len(r.get_region_nodes_and_descendants()) for r in regions) + assert total_nodes > 0 + assert total_nodes <= len(divergent_graph.nodes) + + +class TestTopDownRegionBuilder: + """Test TopDownRegionBuilder basic functionality.""" + + def test_build_composite_region(self, simple_linear_graph): + """Test building a composite region.""" + # First partition to get initial regions + partitioner = RegionPartitioner(simple_linear_graph) + initial_regions = partitioner.partition_graph() + + if len(initial_regions) > 0: + # Use first region as root for top-down building + root_region = initial_regions[0] + + builder = TopDownRegionBuilder(simple_linear_graph, root_region, next_region_id=100) + + # Build composite region (may return LEAF or COMPOSITE depending on structure) + composite = builder.build_composite_region() + + assert composite is not None + # Region type depends on whether refinement created internal structure + # For simple linear graphs, may stay as LEAF + assert composite.type in [RegionType.LEAF, RegionType.COMPOSITE] + else: + pytest.skip("No initial regions to refine") + + +class TestCombinedRegionSearch: + """Test CombinedRegionSearch two-phase algorithm.""" + + def test_search_linear_graph(self, simple_linear_graph): + """Test searching regions in a simple linear graph.""" + search = CombinedRegionSearch(simple_linear_graph) + + regions = search.search_regions() + + # Should create regions + assert len(regions) > 0 + + # Check that regions cover most nodes (ONNX GS may add Constant nodes that aren't partitioned) + total_nodes = sum(len(r.get_region_nodes_and_descendants()) for r in regions) + assert total_nodes > 0 + assert total_nodes <= len(simple_linear_graph.nodes) + + # Each region should have valid inputs/outputs + for region in regions: + assert region.inputs is not None + assert region.outputs is not None + + def test_search_divergent_graph(self, divergent_graph): + """Test searching regions in a divergent graph.""" + search = CombinedRegionSearch(divergent_graph) + + regions = search.search_regions() + + # Should create regions + assert len(regions) > 0 + + # Check that regions cover most nodes (ONNX GS may add Constant nodes that aren't partitioned) + total_nodes = sum(len(r.get_region_nodes_and_descendants()) for r in regions) + assert total_nodes > 0 + assert total_nodes <= len(divergent_graph.nodes) + + def test_region_hierarchy(self, simple_linear_graph): + """Test that regions have proper hierarchical structure.""" + search = CombinedRegionSearch(simple_linear_graph) + + regions = search.search_regions() + + # Check that regions have children (hierarchical structure) + for region in regions: + if region.type == RegionType.COMPOSITE: + assert len(region.get_children()) > 0 + + # Verify parent-child relationships + for child in region.get_children(): + assert child.parent == region + + def test_parameters(self, simple_linear_graph): + """Test CombinedRegionSearch with custom parameters.""" + # Test with different parameter values + search = CombinedRegionSearch( + simple_linear_graph, + maximum_sequence_region_size=5, + minimum_topdown_search_size=5, + ) + + regions = search.search_regions() + + assert len(regions) > 0 + + +class TestPrintTree: + """Test print_tree functionality.""" + + def test_print_tree_output_content(self, simple_linear_graph): + """Test that print_tree output contains region, node, and I/O information.""" + search = CombinedRegionSearch(simple_linear_graph) + search.search_regions() + + output = io.StringIO() + search.print_tree(file=output) + result = output.getvalue() + + # Region information + assert "Region" in result + assert "Level" in result + assert "Type:" in result + + # Node counts + assert "Direct nodes:" in result + assert "Total nodes:" in result + assert "Children:" in result + + # I/O information + assert "Inputs:" in result + assert "Outputs:" in result + + def test_print_tree_divergent_graph(self, divergent_graph): + """Test print_tree on a divergent graph with more complex structure.""" + search = CombinedRegionSearch(divergent_graph) + search.search_regions() + + output = io.StringIO() + search.print_tree(file=output) + + result = output.getvalue() + + # Should produce valid output + assert "Region" in result + assert len(result) > 0 + + def test_print_tree_max_nodes_to_show(self, simple_linear_graph): + """Test print_tree with custom max_nodes_to_show parameter.""" + search = CombinedRegionSearch(simple_linear_graph) + search.search_regions() + + # Test with different max_nodes_to_show values + output1 = io.StringIO() + search.print_tree(max_items=1, file=output1) + + output2 = io.StringIO() + search.print_tree(max_items=10, file=output2) + + # Both should produce output + assert len(output1.getvalue()) > 0 + assert len(output2.getvalue()) > 0 + + def test_print_tree_specific_region(self, simple_linear_graph): + """Test print_tree with a specific region instead of root.""" + search = CombinedRegionSearch(simple_linear_graph) + regions = search.search_regions() + + if len(regions) > 0: + # Print a specific region + output = io.StringIO() + search.print_tree(region=regions[0], file=output) + + result = output.getvalue() + assert "Region" in result + assert f"Region {regions[0].id}" in result + + def test_print_tree_partitioner(self, simple_linear_graph): + """Test print_tree on RegionPartitioner.""" + partitioner = RegionPartitioner(simple_linear_graph) + partitioner.partition_graph() + + output = io.StringIO() + partitioner.print_tree(file=output) + + result = output.getvalue() + assert "Region" in result + assert len(result) > 0 + + def test_print_tree_top_down_builder(self, simple_linear_graph): + """Test print_tree on TopDownRegionBuilder.""" + # Create a root region with all nodes + root = Region(region_id=0, level=0, region_type=RegionType.LEAF) + root.nodes.update(range(len(simple_linear_graph.nodes))) + + builder = TopDownRegionBuilder(simple_linear_graph, root) + # Compute region I/O boundaries before building + builder.compute_region_boundaries(root) + builder.build_composite_region() + + output = io.StringIO() + builder.print_tree(file=output) + + result = output.getvalue() + print("\n" + "=" * 60) + print("Region Tree Structure:") + print("=" * 60) + print(result) + print("=" * 60) + + assert "Region" in result + assert len(result) > 0 From 55b7962b5e5cc798f5a2f1439d95b70b8a1e8e93 Mon Sep 17 00:00:00 2001 From: willg-nv Date: Thu, 12 Feb 2026 21:22:14 +0800 Subject: [PATCH 02/11] Integrate Automated QDQ placement tool - part 2.3 (#846) ## What does this PR do? This PR implement RegionInspect tool. This tool could be used to visualize the regions parititioned by RegionSearch classes. This tool could be used to analyze if the partitioned regions match the fusion patterns. **Overview:** ? ## Usage ```python # Add a code snippet demonstrating how to use this ``` ## Testing ## Before your PR is "*Ready for review*" - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: No, document update is in Part 4. - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: No, CHANGELOG will be updated when all changes are ready. ## Additional Information ## Summary by CodeRabbit ## New Features * Added a region inspection tool for ONNX models. Analyzes model structure and generates detailed reports including region statistics, hierarchical relationships, node coverage metrics, and size distribution analysis. Available through a command-line interface with configurable parameters. --------- Signed-off-by: Will Guo Co-authored-by: Ajinkya Rasane <131806219+ajrasane@users.noreply.github.com> Signed-off-by: Hung-Yueh --- .../quantization/autotune/region_inspect.py | 203 ++++++++++ .../autotune/test_region_inspect.py | 367 ++++++++++++++++++ 2 files changed, 570 insertions(+) create mode 100644 modelopt/onnx/quantization/autotune/region_inspect.py create mode 100644 tests/unit/onnx/quantization/autotune/test_region_inspect.py diff --git a/modelopt/onnx/quantization/autotune/region_inspect.py b/modelopt/onnx/quantization/autotune/region_inspect.py new file mode 100644 index 000000000..beb60268d --- /dev/null +++ b/modelopt/onnx/quantization/autotune/region_inspect.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Region search inspection tool for ONNX models.""" + +import argparse +import logging +import sys +from collections import Counter + +import onnx +import onnx_graphsurgeon as gs + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.autotune.common import Region, RegionType +from modelopt.onnx.quantization.autotune.insertion_points import has_quantizable_operations +from modelopt.onnx.quantization.autotune.region_search import ( + DEFAULT_MAX_STEPS, + CombinedRegionSearch, +) + + +def inspect_region_search( + onnx_path: str, + max_sequence_size: int = 10, + include_all_regions: bool = False, +) -> list[Region]: + """Inspect region search results for an ONNX model. + + This function loads an ONNX model, runs CombinedRegionSearch (which performs + both bottom-up partitioning and top-down refinement internally), and prints + detailed information about the discovered regions including their hierarchical + structure. + + **What it does:** + 1. Loads ONNX model and converts to GraphSurgeon format + 2. Creates CombinedRegionSearch instance with specified parameters + 3. Runs two-phase search (partitioning + refinement) via search_regions() + 4. Displays detailed region structure and statistics + 5. Returns the final list of refined regions + + **Output Sections:** + - Initialization: Shows search parameters + - Two-Phase Search: Runs automatically via CombinedRegionSearch.search_regions() + - Detailed Structure: Shows each region's hierarchy and properties + - Summary Statistics: Shows region counts and node coverage + + Args: + onnx_path: Path to the ONNX model file + max_sequence_size: Maximum size for sequence regions during refinement (default: 10) + include_all_regions: Include all regions, even those without major quantizable + operations (Conv, MatMul, etc.). Default: False (skips such regions) + + Returns: + List of discovered and refined regions (LEAF and COMPOSITE) + """ + # Load ONNX model + logger.info(f"Loading model: {onnx_path}") + onnx_model = onnx.load(onnx_path) + # Convert to onnx_graphsurgeon Graph + graph = gs.import_onnx(onnx_model) + graph.cleanup().toposort() + logger.info( + f"Loaded graph: {len(graph.nodes)} nodes, {len(graph.inputs)} inputs, {len(graph.outputs)} outputs" + ) + # Initialize CombinedRegionSearch (contains RegionPartitioner internally) + logger.debug( + f"Search parameters: max_steps={DEFAULT_MAX_STEPS}, max_sequence_size={max_sequence_size}" + ) + + combined_search = CombinedRegionSearch(graph, maximum_sequence_region_size=max_sequence_size) + + # Run complete two-phase region search + logger.info("Running region search") + regions = combined_search.search_regions() + # Show detailed region structure + logger.info("Analyzing region structure") + all_regions = [] + for i, region in enumerate(regions): + region.children = [ + c + for c in region.get_children() + if include_all_regions or has_quantizable_operations(c, graph) + ] + if not include_all_regions and not has_quantizable_operations(region, graph): + logger.debug(f"Filtered out region {i} (no quantizable operations)") + continue + logger.debug( + f"Region {i}: {region.type.value}, {len(region.get_region_nodes_and_descendants())} nodes, " + f"{len(region.inputs)} inputs, {len(region.outputs)} outputs" + ) + all_regions.append(region) + if region.type == RegionType.COMPOSITE: + logger.debug(f" {len(region.get_children())} child regions") + all_regions.extend(region.get_children()) + combined_search.print_tree(region, indent=2) + + # Summary statistics + type_counts = Counter(r.type for r in all_regions) + leaf_regions, composite_regions = ( + type_counts[RegionType.LEAF], + type_counts[RegionType.COMPOSITE], + ) + + all_nodes = {n for r in all_regions for n in r.get_region_nodes_and_descendants()} + total_nodes = len(all_nodes) + coverage_pct = 100 * total_nodes / len(graph.nodes) if graph.nodes else 0 + + logger.info( + f"Summary: {len(all_regions)} regions ({leaf_regions} LEAF, {composite_regions} COMPOSITE), " + f"{total_nodes}/{len(graph.nodes)} nodes ({coverage_pct:.1f}%)" + ) + + # Print histogram of region sizes + region_sizes = [ + len(r.get_region_nodes_and_descendants()) for r in all_regions if r.type == RegionType.LEAF + ] + + if region_sizes: + min_size = min(region_sizes) + max_size = max(region_sizes) + avg_size = sum(region_sizes) / len(region_sizes) + + logger.info(f"LEAF region sizes: min={min_size}, max={max_size}, avg={avg_size:.1f}") + size_counts = Counter(region_sizes) + logger.debug("Size distribution:") + for size in sorted(size_counts.keys()): + count = size_counts[size] + bar = "█" * min(count, 50) + logger.debug(f" {size:4d} nodes: {bar} ({count} regions)") + + return all_regions + + +def main(): + """Command-line entry point for region search inspection.""" + parser = argparse.ArgumentParser( + prog="modelopt.onnx.quantization.autotune.region_inspect", + description="Inspect region search results for ONNX models", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic inspection + python -m modelopt.onnx.quantization.autotune.region_inspect --model model.onnx + + # Verbose mode for debug logging + python -m modelopt.onnx.quantization.autotune.region_inspect \\ + --model model.onnx --verbose + + # Custom maximum sequence size + python -m modelopt.onnx.quantization.autotune.region_inspect \\ + --model model.onnx --max-sequence-size 20 + """, + ) + + parser.add_argument("--model", "-m", type=str, required=True, help="Path to ONNX model file") + parser.add_argument( + "--max-sequence-size", + type=int, + default=10, + help="Maximum size for sequence regions during refinement (default: 10)", + ) + parser.add_argument( + "--include-all-regions", + action="store_true", + help="Include all regions, even those without major quantizable operations. " + "Default: False (skips such regions)", + ) + parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug logging") + + args = parser.parse_args() + + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s") + logger.setLevel(log_level) + + try: + regions = inspect_region_search( + onnx_path=args.model, + max_sequence_size=args.max_sequence_size, + include_all_regions=args.include_all_regions, + ) + logger.info(f"✓ Inspection complete: {len(regions)} regions discovered") + return 0 + except Exception as e: + logger.error(f"Inspection failed: {e}", exc_info=args.verbose) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/unit/onnx/quantization/autotune/test_region_inspect.py b/tests/unit/onnx/quantization/autotune/test_region_inspect.py new file mode 100644 index 000000000..a932fa3c2 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_region_inspect.py @@ -0,0 +1,367 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for region_inspect module.""" + +import os +from unittest.mock import Mock, patch + +import numpy as np +import onnx +import pytest +from onnx import TensorProto, helper, numpy_helper + + +def create_simple_onnx_model(): + """Create a simple ONNX model for testing. + + Creates a model with: Input -> Conv -> Relu -> MatMul -> Output + """ + # Create input + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 224, 224]) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 1000]) + + # Create weights for Conv + conv_weight = np.random.randn(64, 3, 7, 7).astype(np.float32) + conv_weight_tensor = numpy_helper.from_array(conv_weight, "conv_weight") + + # Create weights for MatMul + matmul_weight = np.random.randn(64, 1000).astype(np.float32) + matmul_weight_tensor = numpy_helper.from_array(matmul_weight, "matmul_weight") + + # Create nodes + conv_node = helper.make_node( + "Conv", + inputs=["input", "conv_weight"], + outputs=["conv_output"], + kernel_shape=[7, 7], + strides=[2, 2], + pads=[3, 3, 3, 3], + ) + + relu_node = helper.make_node( + "Relu", + inputs=["conv_output"], + outputs=["relu_output"], + ) + + flatten_node = helper.make_node( + "Flatten", + inputs=["relu_output"], + outputs=["flatten_output"], + axis=1, + ) + + matmul_node = helper.make_node( + "MatMul", + inputs=["flatten_output", "matmul_weight"], + outputs=["output"], + ) + + # Create graph + graph = helper.make_graph( + [conv_node, relu_node, flatten_node, matmul_node], + "test_model", + [input_tensor], + [output_tensor], + [conv_weight_tensor, matmul_weight_tensor], + ) + + # Create model + model = helper.make_model(graph, producer_name="test") + model.opset_import[0].version = 13 + + return model + + +@pytest.fixture +def simple_onnx_model(): + """Fixture that provides a simple ONNX model.""" + return create_simple_onnx_model() + + +@pytest.fixture +def onnx_model_file(tmp_path, simple_onnx_model): + """Fixture that provides a path to a saved ONNX model.""" + model_path = os.path.join(tmp_path, "test_model.onnx") + onnx.save(simple_onnx_model, model_path) + return model_path + + +class TestRegionInspectImports: + """Test that the region_inspect module can be imported.""" + + def test_module_imports(self): + """Test that the module imports without errors when dependencies exist.""" + # This test will skip if the required dependencies don't exist + try: + from modelopt.onnx.quantization.autotune import region_inspect + + assert hasattr(region_inspect, "inspect_region_search") + assert hasattr(region_inspect, "main") + except ImportError as e: + pytest.skip(f"Required dependencies not available: {e}") + + +class TestRegionInspectWithMocks: + """Test region_inspect functionality with mocked dependencies.""" + + @patch("modelopt.onnx.quantization.autotune.region_inspect.CombinedRegionSearch") + @patch("modelopt.onnx.quantization.autotune.region_inspect.has_quantizable_operations") + def test_inspect_region_search_basic( + self, mock_has_quantizable, mock_combined_search, onnx_model_file + ): + """Test basic functionality of inspect_region_search with mocked dependencies.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import inspect_region_search + except ImportError: + pytest.skip("Required dependencies not available") + + # Setup mocks + mock_region = Mock() + mock_region.type = Mock(value="LEAF") + mock_region.inputs = ["input1"] + mock_region.outputs = ["output1"] + mock_region.children = [] + mock_region.get_region_nodes_and_descendants.return_value = [Mock(), Mock()] + mock_region.get_children.return_value = [] + + mock_search_instance = Mock() + mock_search_instance.search_regions.return_value = [mock_region] + mock_search_instance.print_tree = Mock() + mock_combined_search.return_value = mock_search_instance + + mock_has_quantizable.return_value = True + + # Call the function + result = inspect_region_search( + onnx_path=onnx_model_file, max_sequence_size=10, include_all_regions=False + ) + + # Verify the function was called correctly + assert mock_combined_search.called + assert mock_search_instance.search_regions.called + assert isinstance(result, list) + + @patch("modelopt.onnx.quantization.autotune.region_inspect.CombinedRegionSearch") + @patch("modelopt.onnx.quantization.autotune.region_inspect.has_quantizable_operations") + def test_inspect_region_search_with_custom_params( + self, mock_has_quantizable, mock_combined_search, onnx_model_file + ): + """Test inspect_region_search with custom parameters.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import inspect_region_search + except ImportError: + pytest.skip("Required dependencies not available") + + # Setup mocks + mock_region = Mock() + mock_region.type = Mock(value="COMPOSITE") + mock_region.inputs = ["input1"] + mock_region.outputs = ["output1"] + mock_region.children = [] + mock_region.get_region_nodes_and_descendants.return_value = [Mock()] + mock_region.get_children.return_value = [] + + mock_search_instance = Mock() + mock_search_instance.search_regions.return_value = [mock_region] + mock_search_instance.print_tree = Mock() + mock_combined_search.return_value = mock_search_instance + + mock_has_quantizable.return_value = True + + # Call with custom parameters + result = inspect_region_search( + onnx_path=onnx_model_file, max_sequence_size=20, include_all_regions=True + ) + + # Verify custom parameters were used + assert mock_combined_search.called + call_kwargs = mock_combined_search.call_args[1] + assert call_kwargs.get("maximum_sequence_region_size") == 20 + assert isinstance(result, list) + + @patch("modelopt.onnx.quantization.autotune.region_inspect.CombinedRegionSearch") + @patch("modelopt.onnx.quantization.autotune.region_inspect.has_quantizable_operations") + def test_inspect_region_search_filtering( + self, mock_has_quantizable, mock_combined_search, onnx_model_file + ): + """Test that regions without quantizable operations are filtered out.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import inspect_region_search + except ImportError: + pytest.skip("Required dependencies not available") + + # Setup mocks - one region with quantizable ops, one without + mock_region_quantizable = Mock() + mock_region_quantizable.type = Mock(value="LEAF") + mock_region_quantizable.inputs = ["input1"] + mock_region_quantizable.outputs = ["output1"] + mock_region_quantizable.get_region_nodes_and_descendants.return_value = [Mock()] + mock_region_quantizable.get_children.return_value = [] + + mock_region_non_quantizable = Mock() + mock_region_non_quantizable.type = Mock(value="LEAF") + mock_region_non_quantizable.inputs = ["input2"] + mock_region_non_quantizable.outputs = ["output2"] + mock_region_non_quantizable.get_region_nodes_and_descendants.return_value = [Mock()] + mock_region_non_quantizable.get_children.return_value = [] + + mock_search_instance = Mock() + mock_search_instance.search_regions.return_value = [ + mock_region_quantizable, + mock_region_non_quantizable, + ] + mock_search_instance.print_tree = Mock() + mock_combined_search.return_value = mock_search_instance + + # First region has quantizable ops, second doesn't + mock_has_quantizable.side_effect = [True, False] + + # Call with filtering enabled + result = inspect_region_search( + onnx_path=onnx_model_file, max_sequence_size=10, include_all_regions=False + ) + + # Should only return the quantizable region + assert len(result) == 1 + + +class TestRegionInspectMain: + """Test the main CLI entry point.""" + + @patch("modelopt.onnx.quantization.autotune.region_inspect.inspect_region_search") + def test_main_success(self, mock_inspect, onnx_model_file): + """Test main function with successful execution.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import main + except ImportError: + pytest.skip("Required dependencies not available") + + mock_inspect.return_value = [Mock(), Mock()] + + with patch("sys.argv", ["region_inspect", "--model", onnx_model_file]): + exit_code = main() + assert exit_code == 0 + assert mock_inspect.called + + @patch("modelopt.onnx.quantization.autotune.region_inspect.inspect_region_search") + def test_main_with_verbose(self, mock_inspect, onnx_model_file): + """Test main function with verbose flag.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import main + except ImportError: + pytest.skip("Required dependencies not available") + + mock_inspect.return_value = [Mock()] + + with patch("sys.argv", ["region_inspect", "--model", onnx_model_file, "--verbose"]): + exit_code = main() + assert exit_code == 0 + + @patch("modelopt.onnx.quantization.autotune.region_inspect.inspect_region_search") + def test_main_with_custom_max_sequence_size(self, mock_inspect, onnx_model_file): + """Test main function with custom max_sequence_size.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import main + except ImportError: + pytest.skip("Required dependencies not available") + + mock_inspect.return_value = [Mock()] + + with patch( + "sys.argv", ["region_inspect", "--model", onnx_model_file, "--max-sequence-size", "20"] + ): + exit_code = main() + assert exit_code == 0 + # Verify max_sequence_size parameter was passed + call_kwargs = mock_inspect.call_args[1] + assert call_kwargs.get("max_sequence_size") == 20 + + @patch("modelopt.onnx.quantization.autotune.region_inspect.inspect_region_search") + def test_main_with_include_all_regions(self, mock_inspect, onnx_model_file): + """Test main function with include_all_regions flag.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import main + except ImportError: + pytest.skip("Required dependencies not available") + + mock_inspect.return_value = [Mock()] + + with patch( + "sys.argv", ["region_inspect", "--model", onnx_model_file, "--include-all-regions"] + ): + exit_code = main() + assert exit_code == 0 + # Verify include_all_regions parameter was passed + call_kwargs = mock_inspect.call_args[1] + assert call_kwargs.get("include_all_regions") is True + + @patch("modelopt.onnx.quantization.autotune.region_inspect.inspect_region_search") + def test_main_failure(self, mock_inspect, onnx_model_file): + """Test main function with execution failure.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import main + except ImportError: + pytest.skip("Required dependencies not available") + + mock_inspect.side_effect = Exception("Test error") + + with patch("sys.argv", ["region_inspect", "--model", onnx_model_file]): + exit_code = main() + assert exit_code == 1 + + +class TestRegionInspectModelLoading: + """Test model loading functionality.""" + + @patch("modelopt.onnx.quantization.autotune.region_inspect.CombinedRegionSearch") + @patch("modelopt.onnx.quantization.autotune.region_inspect.has_quantizable_operations") + def test_loads_valid_onnx_model( + self, mock_has_quantizable, mock_combined_search, onnx_model_file + ): + """Test that a valid ONNX model can be loaded.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import inspect_region_search + except ImportError: + pytest.skip("Required dependencies not available") + + # Setup minimal mocks + mock_region = Mock() + mock_region.type = Mock(value="LEAF") + mock_region.inputs = [] + mock_region.outputs = [] + mock_region.get_region_nodes_and_descendants.return_value = [] + mock_region.get_children.return_value = [] + + mock_search_instance = Mock() + mock_search_instance.search_regions.return_value = [mock_region] + mock_search_instance.print_tree = Mock() + mock_combined_search.return_value = mock_search_instance + mock_has_quantizable.return_value = False + + # Should not raise an exception + result = inspect_region_search(onnx_model_file) + assert isinstance(result, list) + + def test_fails_on_nonexistent_file(self): + """Test that loading a non-existent file raises an error.""" + try: + from modelopt.onnx.quantization.autotune.region_inspect import inspect_region_search + except ImportError: + pytest.skip("Required dependencies not available") + + with pytest.raises(Exception): # Could be FileNotFoundError or other + inspect_region_search("/nonexistent/path/to/model.onnx") From 3486ca991b73a9a34b88b68001ad22a22cac658a Mon Sep 17 00:00:00 2001 From: "Chenhan D. Yu" <5185878+ChenhanYu@users.noreply.github.com> Date: Thu, 12 Feb 2026 15:51:10 -0800 Subject: [PATCH 03/11] Chenhany/megatron export per layer (#881) ## What does this PR do? **Type of change:** ? Bug fix **Overview:** ? 1. Fixing megatron ignore module has additional `.` in the suffix 2. Change megatron export to safe per layer as a safetensor (avoid ghost safetensors) ## Usage ```python # Add a code snippet demonstrating how to use this ``` ## Testing ## Before your PR is "*Ready for review*" - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No ## Additional Information ## Summary by CodeRabbit ## Release Notes * **New Features** * Export workflow now supports additional model components (EAGLE/Medusa modules) * Per-layer model state organization for improved checkpoint management * **Bug Fixes** * More robust Hugging Face configuration, tokenizer, and image processor preservation * Enhanced multimodal component extraction and loading * **Refactor** * Optimized model export process with improved per-layer safetensors handling Signed-off-by: Chenhan Yu Signed-off-by: Hung-Yueh --- .../export/plugins/hf_checkpoint_utils.py | 123 ++++++++ modelopt/torch/export/plugins/mcore_custom.py | 53 ++++ .../export/plugins/vllm_fakequant_megatron.py | 4 +- .../torch/export/unified_export_megatron.py | 266 +++++++----------- 4 files changed, 279 insertions(+), 167 deletions(-) create mode 100644 modelopt/torch/export/plugins/hf_checkpoint_utils.py diff --git a/modelopt/torch/export/plugins/hf_checkpoint_utils.py b/modelopt/torch/export/plugins/hf_checkpoint_utils.py new file mode 100644 index 000000000..e89900cbb --- /dev/null +++ b/modelopt/torch/export/plugins/hf_checkpoint_utils.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hugging Face checkpoint utility.""" + +import json +import os +import shutil +from pathlib import Path + +import torch +from safetensors.torch import safe_open +from tqdm import tqdm + + +def copy_remote_code( + pretrained_model_path: str | os.PathLike, + save_directory: str | os.PathLike, +): + """Copy remote code from pretrained model to save directory. + + For models that keep configuration and modeling files as part of the checkpoint, + we need to copy them to the export directory for seamless integration with inference + frameworks. + + Args: + pretrained_model_path: Path to the pretrained model. + save_directory: Path to the save directory. + + Raises: + ValueError: If the pretrained model path is not a directory. + """ + hf_checkpoint_path = Path(pretrained_model_path) + save_dir = Path(save_directory) + + if not hf_checkpoint_path.is_dir(): + raise ValueError( + f"Invalid pretrained model path: {pretrained_model_path}. It should be a directory." + ) + + for py_file in hf_checkpoint_path.glob("*.py"): + if py_file.is_file(): + shutil.copy(py_file, save_dir / py_file.name) + + +def load_multimodal_components( + pretrained_model_path: str | os.PathLike, +) -> dict[str, torch.Tensor]: + """Load multimodal components from safetensors file. + + Args: + pretrained_model_path: Path to the pretrained model. + + Returns: + A dictionary of multimodal components. + """ + hf_checkpoint_path = Path(pretrained_model_path) + if not hf_checkpoint_path.is_dir(): + raise ValueError( + f"Invalid pretrained model path: {pretrained_model_path}. It should be a directory." + ) + + safetensors_file = Path(hf_checkpoint_path) / "model.safetensors" + safetensors_index_file = Path(hf_checkpoint_path) / "model.safetensors.index.json" + + multimodal_state_dict = {} + + if safetensors_file.is_file(): + print(f"Loading multimodal components from single file: {safetensors_file}") + with safe_open(safetensors_file, framework="pt") as f: + multimodal_keys = [ + key + for key in f.keys() # noqa: SIM118 + if key.startswith(("multi_modal_projector", "vision_model")) + ] + for key in tqdm(multimodal_keys, desc="Loading multimodal tensors"): + multimodal_state_dict[key] = f.get_tensor(key) + + elif safetensors_index_file.is_file(): + print(f"Loading multimodal components from sharded model: {hf_checkpoint_path}") + with open(safetensors_index_file) as f: + safetensors_index = json.load(f) + + # For multimodal models, vision_model and multi_modal_projector are in the first shard + all_shard_files = sorted(set(safetensors_index["weight_map"].values())) + first_shard_file = all_shard_files[0] # e.g., "model-00001-of-00050.safetensors" + + # Load multimodal components from the first shard file + safetensors_filepath = Path(hf_checkpoint_path) / first_shard_file + print(f"Loading multimodal components from {first_shard_file}") + + with safe_open(safetensors_filepath, framework="pt") as f: + shard_keys = list(f.keys()) + multimodal_keys_in_shard = [ + k for k in shard_keys if k.startswith(("multi_modal_projector", "vision_model")) + ] + + if multimodal_keys_in_shard: + print( + f"Found {len(multimodal_keys_in_shard)} multimodal tensors in {first_shard_file}" + ) + for key in tqdm(multimodal_keys_in_shard, desc="Loading multimodal tensors"): + multimodal_state_dict[key] = f.get_tensor(key) + else: + print(f"No multimodal components found in {first_shard_file}") + + else: + print(f"Warning: No safetensors files found in {hf_checkpoint_path}") + + print(f"Successfully loaded {len(multimodal_state_dict)} multimodal tensors") + return multimodal_state_dict diff --git a/modelopt/torch/export/plugins/mcore_custom.py b/modelopt/torch/export/plugins/mcore_custom.py index c269cef1d..90c523d84 100644 --- a/modelopt/torch/export/plugins/mcore_custom.py +++ b/modelopt/torch/export/plugins/mcore_custom.py @@ -274,6 +274,59 @@ def save_safetensors(state_dict, save_directory: str | os.PathLike): json.dump(safetensor_index, f, indent=4) +def save_safetensors_by_layer_index( + layer_state_dicts: dict[int, dict[str, torch.Tensor]], + total_layers: int, + save_directory: str | os.PathLike, + name_template: str = "model-{:05d}-of-{:05d}", +): + """Save safetensors by layer index. + + Args: + layer_state_dicts: A dictionary of layer state dictionaries. + total_layers: Total number of layers. + save_directory: Path to the save directory. + name_template: Template for the filename. + """ + for layer_index, layer_state_dict in layer_state_dicts.items(): + filename = name_template.format(layer_index, total_layers) + meta_filename = filename + ".json" + ckpt_filename = filename + ".safetensors" + + weight_map = {} + layer_total_size = 0 + for key, val in layer_state_dict.items(): + tensor_size = val.numel() * val.element_size() + layer_total_size += tensor_size + weight_map[key] = ckpt_filename + + with open(save_directory + "/" + meta_filename, "w") as f: + json.dump( + {"metadata": {"total_size": layer_total_size}, "weight_map": weight_map}, + f, + indent=4, + ) + save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"}) + + # [TODO]: this global barrier needs to be replaced with something safer + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + safetensor_index = { + "metadata": {"total_size": 0}, + "weight_map": {}, + } + for layer_index in range(total_layers): + meta_filename = name_template.format(layer_index + 1, total_layers) + ".json" + with open(save_directory + "/" + meta_filename) as f: + shard = json.load(f) + safetensor_index["metadata"]["total_size"] += shard["metadata"]["total_size"] + safetensor_index["weight_map"].update(shard["weight_map"]) + + with open(save_directory + "/model.safetensors.index.json", "w") as f: + json.dump(safetensor_index, f, indent=4) + + def _get_safetensors_file(pretrained_model_path: str | Path, key: str) -> Path | None: """Given a tensor key return the safetensors file that contains this tensor if exists. diff --git a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py index 95b194c3f..3f69271b0 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_megatron.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_megatron.py @@ -72,7 +72,7 @@ class VllmFqGPTModelExporter(GPTModelExporter): def save_pretrained( self, save_directory: str | os.PathLike, - pretrained_model_name_or_path: str | os.PathLike | None = None, + pretrained_model_name_or_path: str | os.PathLike, ): os.makedirs(save_directory, exist_ok=True) gather_mcore_vllm_fq_quantized_state_dict(self.model, self.state_dict, save_directory) @@ -91,7 +91,7 @@ def _get_quantization_format(self, module: torch.nn.Module): def export_mcore_gpt_to_hf_vllm_fq( model: torch.nn.Module, - pretrained_model_name_or_path: str | os.PathLike | None = None, + pretrained_model_name_or_path: str | os.PathLike, export_extra_modules: bool = False, dtype: torch.dtype = torch.bfloat16, export_dir: Path | str = tempfile.gettempdir(), diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 8a6d76b34..0567d0d1f 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -20,7 +20,6 @@ import json import os -import shutil import tempfile from collections import OrderedDict from pathlib import Path @@ -28,9 +27,8 @@ import torch import torch.distributed -from huggingface_hub import hf_hub_download, snapshot_download -from safetensors.torch import safe_open, save_file -from tqdm import tqdm +from huggingface_hub import hf_hub_download +from safetensors.torch import save_file from modelopt import __version__ from modelopt.torch.utils import import_plugin @@ -45,8 +43,13 @@ QUANTIZATION_NONE, QUANTIZATION_NVFP4, ) +from .plugins.hf_checkpoint_utils import copy_remote_code, load_multimodal_components from .plugins.mcore_common import all_mcore_hf_export_mapping -from .plugins.mcore_custom import CustomModuleMapping, get_safetensor, save_safetensors +from .plugins.mcore_custom import ( + CustomModuleMapping, + get_safetensor, + save_safetensors_by_layer_index, +) from .plugins.megatron_importer import GPTModelImporter from .quant_utils import ( get_activation_scaling_factor, @@ -119,6 +122,7 @@ def __init__( raise ValueError("Input to GPTModelExport must be a megatron.core.models.GPTModel!") self._state_dict = OrderedDict() + self._layer_state_dicts = OrderedDict() self._hf_pretrained_model_name = pretrained_model_name_or_path self._hf_config = transformers.AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code @@ -221,10 +225,29 @@ def __init__( self._hf_extra_config.update(eagle_config_update) + def save_pretrained_extra_modules( + self, + save_directory: str | os.PathLike, + ): + """Save a EAGLE or Medusa checkpoints which can be deployed by vLLM and TensorRT-LLM.""" + # We use the last PP rank to write the config because + # medusa_heads and eagle_module only exist in the last stage. + pp_rank = get_pipeline_model_parallel_rank() + pp_size = get_pipeline_model_parallel_world_size() + is_last_stage_main_rank = pp_rank == pp_size - 1 + + state_dict = self.extra_state_dict + + if is_last_stage_main_rank and self._hf_extra_config is not None: + self._hf_extra_config.save_pretrained(save_directory) + save_file(state_dict, save_directory + "/model.safetensors", metadata={"format": "pt"}) + + torch.distributed.barrier() + def save_pretrained( self, save_directory: str | os.PathLike, - pretrained_model_name_or_path: str | os.PathLike | None = None, + pretrained_model_name_or_path: str | os.PathLike, ): """Save a unified checkpoint which can be deployed by vLLM and TensorRT-LLM. @@ -242,7 +265,7 @@ def save_pretrained( is_last_stage_main_rank = pp_rank == pp_size - 1 # Main export process - state_dict = self.extra_state_dict if self.export_extra_modules else self.state_dict + layer_state_dicts = self.layer_state_dicts quantization_format = self._get_quantization_format(self.model) quantization = None @@ -259,39 +282,36 @@ def save_pretrained( # We use the last PP rank and the 1st EP rank to write the config because # medusa_heads and eagle_module only exist in the last stage. if is_last_stage_main_rank: - if self.export_extra_modules and self._hf_extra_config is not None: - self._hf_extra_config.save_pretrained(save_directory) - else: - self._hf_config.save_pretrained(save_directory) - try: - generation_config = transformers.GenerationConfig.from_pretrained( - self._hf_pretrained_model_name - ) - generation_config.save_pretrained(save_directory) - except OSError: - pass - try: - tokenizer = transformers.AutoTokenizer.from_pretrained( - self._hf_pretrained_model_name - ) - tokenizer.save_pretrained(save_directory) - except OSError: - pass - except TypeError: - pass - try: - # Load and save preprocessor config from the original model - processor = AutoProcessor.from_pretrained( - self._hf_pretrained_model_name, trust_remote_code=self.trust_remote_code - ) - if hasattr(processor, "image_processor"): - processor.image_processor.save_pretrained(save_directory) - except (OSError, ValueError, ImportError): - pass + self._hf_config.save_pretrained(save_directory) + try: + generation_config = transformers.GenerationConfig.from_pretrained( + self._hf_pretrained_model_name + ) + generation_config.save_pretrained(save_directory) + except OSError: + pass + try: + tokenizer = transformers.AutoTokenizer.from_pretrained( + self._hf_pretrained_model_name + ) + tokenizer.save_pretrained(save_directory) + except OSError: + pass + except TypeError: + pass + try: + # Load and save preprocessor config from the original model + processor = AutoProcessor.from_pretrained( + self._hf_pretrained_model_name, trust_remote_code=self.trust_remote_code + ) + if hasattr(processor, "image_processor"): + processor.image_processor.save_pretrained(save_directory) + except (OSError, ValueError, ImportError): + pass mtp_state_dict = self._get_mtp_state_dict() if len(mtp_state_dict) > 0: - state_dict.update(mtp_state_dict) + layer_state_dicts[self.model.config.num_layers].update(mtp_state_dict) print(f"Successfully loaded {len(mtp_state_dict)} MTP tensors") combined_exclude_modules = self._gather_exclude_modules() @@ -314,121 +334,18 @@ def save_pretrained( with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(self._hf_quant_config, f, indent=4) - if ( - is_first_stage_main_rank - and self.is_multimodal - and pretrained_model_name_or_path is not None - ): - hf_checkpoint_path = Path(pretrained_model_name_or_path) - if not hf_checkpoint_path.is_dir(): - hf_checkpoint_path = tempfile.gettempdir() + "/" + pretrained_model_name_or_path - if not Path(hf_checkpoint_path).exists(): - snapshot_download( - repo_id=pretrained_model_name_or_path, - local_dir=hf_checkpoint_path, - ) - - safetensors_file = Path(hf_checkpoint_path) / "model.safetensors" - safetensors_index_file = Path(hf_checkpoint_path) / "model.safetensors.index.json" - - multimodal_state_dict = {} - - if safetensors_file.is_file(): - print(f"Loading multimodal components from single file: {safetensors_file}") - with safe_open(safetensors_file, framework="pt") as f: - multimodal_keys = [ - key - for key in f.keys() # noqa: SIM118 - if key.startswith(("multi_modal_projector", "vision_model")) - ] - for key in tqdm(multimodal_keys, desc="Loading multimodal tensors"): - multimodal_state_dict[key] = f.get_tensor(key) - - elif safetensors_index_file.is_file(): - print(f"Loading multimodal components from sharded model: {hf_checkpoint_path}") - with open(safetensors_index_file) as f: - safetensors_index = json.load(f) - - # For multimodal models, vision_model and multi_modal_projector are in the first shard - all_shard_files = sorted(set(safetensors_index["weight_map"].values())) - first_shard_file = all_shard_files[0] # e.g., "model-00001-of-00050.safetensors" - - # Load multimodal components from the first shard file - safetensors_filepath = Path(hf_checkpoint_path) / first_shard_file - print(f"Loading multimodal components from {first_shard_file}") - - with safe_open(safetensors_filepath, framework="pt") as f: - shard_keys = list(f.keys()) - multimodal_keys_in_shard = [ - k - for k in shard_keys - if k.startswith(("multi_modal_projector", "vision_model")) - ] - - if multimodal_keys_in_shard: - print( - f"Found {len(multimodal_keys_in_shard)} multimodal tensors in {first_shard_file}" - ) - for key in tqdm( - multimodal_keys_in_shard, desc="Loading multimodal tensors" - ): - multimodal_state_dict[key] = f.get_tensor(key) - else: - print(f"No multimodal components found in {first_shard_file}") - - else: - print(f"Warning: No safetensors files found in {hf_checkpoint_path}") - - print(f"Successfully loaded {len(multimodal_state_dict)} multimodal tensors") - # Add multimodal components to state_dict - state_dict.update(multimodal_state_dict) + # Add multimodal components to state_dict. Since only support decoder model quantization, + # no changes will be made to the multimodal components. We copy the multimodal components + # from the pretrained model directly to the state_dict to avoid implementing the export logic. + if is_first_stage_main_rank and self.is_multimodal: + multimodal_state_dict = load_multimodal_components(pretrained_model_name_or_path) + layer_state_dicts[0].update(multimodal_state_dict) # Barrier to ensure the export_dir has been created. torch.distributed.barrier() - if self.export_extra_modules: - if is_last_stage_main_rank: - save_file( - state_dict, save_directory + "/model.safetensors", metadata={"format": "pt"} - ) - torch.distributed.barrier() - return - - if ( - is_last_stage_main_rank - and self._hf_config is not None - and pretrained_model_name_or_path is not None - ): - # For models that keep configuration and modeling files as part of the checkpoint, - # we need to copy them to the export directory for seamless integration with inference - # frameworks. - hf_checkpoint_path = Path(pretrained_model_name_or_path) - model_type = getattr(self._hf_config, "model_type", None) - - if hf_checkpoint_path.is_dir(): - # Local directory - files should be there - config_file = hf_checkpoint_path / f"configuration_{model_type}.py" - modeling_file = hf_checkpoint_path / f"modeling_{model_type}.py" - else: - # Remote model ID - download from HuggingFace Hub (cached automatically) - try: - config_file = hf_hub_download( - repo_id=pretrained_model_name_or_path, - filename=f"configuration_{model_type}.py", - ) - except Exception: - config_file = "" - try: - modeling_file = hf_hub_download( - repo_id=pretrained_model_name_or_path, filename=f"modeling_{model_type}.py" - ) - except Exception: - modeling_file = "" - - if config_file and os.path.exists(config_file): - shutil.copy(config_file, f"{save_directory}/configuration_{model_type}.py") - if modeling_file and os.path.exists(modeling_file): - shutil.copy(modeling_file, f"{save_directory}/modeling_{model_type}.py") + if is_last_stage_main_rank and self._hf_config is not None: + copy_remote_code(pretrained_model_name_or_path, save_directory) # Newer versions of VLLM expect config.json with hf_quant_config config_json_file = save_directory + "/config.json" @@ -440,7 +357,13 @@ def save_pretrained( with open(config_json_file, "w") as f: json.dump(config_dict, f, indent=4) - save_safetensors(state_dict, save_directory) + # save_safetensors(state_dict, save_directory) + save_safetensors_by_layer_index( + layer_state_dicts=layer_state_dicts, + total_layers=self.model.config.num_layers, + save_directory=save_directory, + name_template="model-{:05d}-of-{:05d}", + ) @property def state_dict(self): @@ -449,6 +372,12 @@ def state_dict(self): self._get_state_dict() return self._state_dict + @property + def layer_state_dicts(self): + if len(self._layer_state_dicts) == 0: + self._get_state_dict() + return self._layer_state_dicts + @property def extra_state_dict(self): if len(self._state_dict) == 0: @@ -463,17 +392,6 @@ def _get_state_dict(self): if hasattr(model, "embedding"): self.rules["word_embeddings"](model.embedding.word_embeddings) - # Final layernorm - if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: - self.rules["final_layernorm"](model.decoder.final_layernorm) - - if hasattr(model.decoder, "final_norm") and model.decoder.final_norm: - self.rules["final_norm"](model.decoder.final_norm) - - # Output layer - if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: - self.rules["output_layer"](model.output_layer) - # Decoder layers for layer in model.decoder.layers: layer_id = layer.layer_number - 1 @@ -484,7 +402,20 @@ def _get_state_dict(self): else: raise ValueError("Only TransformerLayer or MambaLayer are supported.") - # TODO export MTP layer in the future + self._layer_state_dicts[layer.layer_number] = self._state_dict + if layer.layer_number != self.model.config.num_layers: + self._state_dict = OrderedDict() + + # Final layernorm + if hasattr(model.decoder, "final_layernorm") and model.decoder.final_layernorm: + self.rules["final_layernorm"](model.decoder.final_layernorm) + + if hasattr(model.decoder, "final_norm") and model.decoder.final_norm: + self.rules["final_norm"](model.decoder.final_norm) + + # Output layer + if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights: + self.rules["output_layer"](model.output_layer) def _get_transformer_layer_state_dict(self, layer, layer_id): if not isinstance(layer.input_layernorm, IdentityOp): @@ -761,8 +692,10 @@ def _get_quantized_state( """ name_to_value = {} qformat: str = self._get_quantization_format(module) - if qformat is None and "norm" not in prefix: # Add exclude layers for hf_quant_config - self.exclude_modules.append(prefix) + if qformat is None and "norm" not in prefix: + # Add exclude layers for hf_quant_config. Note that if the prefix is not an empty + # string then it usually ends with "." which needs to be removed. + self.exclude_modules.append(prefix.removesuffix(".")) block_size = get_weight_block_size(module) if hasattr(module, "weight") and module.weight is not None and module.weight.numel() > 0: @@ -1254,7 +1187,7 @@ def _gather_exclude_modules(self): def export_mcore_gpt_to_hf( model: torch.nn.Module, - pretrained_model_name_or_path: str | os.PathLike | None = None, + pretrained_model_name_or_path: str | os.PathLike, export_extra_modules: bool = False, dtype: torch.dtype = torch.bfloat16, export_dir: Path | str = tempfile.gettempdir(), @@ -1282,7 +1215,10 @@ def export_mcore_gpt_to_hf( trust_remote_code=trust_remote_code, moe_router_dtype=moe_router_dtype, ) - exporter.save_pretrained(export_dir, pretrained_model_name_or_path) + if exporter.export_extra_modules: + exporter.save_pretrained_extra_modules(export_dir) + else: + exporter.save_pretrained(export_dir, pretrained_model_name_or_path) def import_mcore_gpt_from_hf( From 2d67c0a2237e6aaa8d99700d4e44b0448860d178 Mon Sep 17 00:00:00 2001 From: Zhiyu Date: Thu, 12 Feb 2026 18:18:08 -0800 Subject: [PATCH 04/11] Add Nemotron parse PTQ support (#786) ## What does this PR do? **Type of change:** New model support **Overview:** Add PTQ support for https://huggingface.co/nvidia/NVIDIA-Nemotron-Parse-v1.1 ## Usage ```python python3 hf_ptq.py --pyt_ckpt_path /home/omniml_data_3/models/NVIDIA-Nemotron-Parse-v1.1 --qformat fp8 --export_path /home/omniml_data_3/zhiyuc/checkpoints/NVIDIA-Nemotron-Parse-v1.1-FP8 --trust_remote_code --kv_cache_qformat none --attn_implementation eager ``` By default, image-text data will be used in calibration for VLMs. ## Testing ## Before your PR is "*Ready for review*" - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Not yet ## Additional Information ## Summary by CodeRabbit * **New Features** * Added support for Nemotron-Parse multimodal models, including proper device mapping, processor loading, and generation handling. * **Improvements** * Enhanced quantization robustness with safer handling of quantization attributes and fallback logic. * Improved model loading with better device placement and encoder buffer management for vision-language models. --------- Signed-off-by: Zhiyu Cheng Signed-off-by: Hung-Yueh --- CHANGELOG.rst | 1 + examples/llm_ptq/example_utils.py | 47 ++++++++++---- examples/llm_ptq/hf_ptq.py | 73 +++++++++++++--------- examples/llm_ptq/vlm_utils.py | 67 +++++++++++++------- modelopt/torch/export/model_utils.py | 14 ++++- modelopt/torch/export/unified_export_hf.py | 16 ++--- 6 files changed, 145 insertions(+), 73 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9d7500e58..bbbe6ab9e 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -21,6 +21,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add LTX-2 and Wan2.2 (T2V) support in the diffusers quantization workflow. - Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is. - Add support for image-text data calibration in PTQ for Nemotron VL models. +- Add PTQ support for Nemotron Parse. 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 93687a8d0..71755a02f 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -31,6 +31,7 @@ from safetensors.torch import load_file from transformers import ( AutoConfig, + AutoModel, AutoModelForCausalLM, AutoProcessor, AutoTokenizer, @@ -75,19 +76,18 @@ def run_nemotron_vl_preview( "eos_token_id": tokenizer.eos_token_id, } - # Try text-only generation + # Try text-only generation (may fail for encoder-decoder models like Nemotron-Parse) text_response = run_text_only_generation( full_model, tokenizer, question, generation_config, pyt_ckpt_path ) + generated_ids = None if text_response is not None: print(f"✅ Text-only generation successful: {text_response[:100]}...") generated_ids = text_response elif allow_fallback: print("Text-only generation failed, falling back to standard generate...") generated_ids = full_model.generate(input_ids, max_new_tokens=100) - else: - generated_ids = None # Run additional VL test with images print(f"Running additional VL test with images ({stage_name})...") @@ -106,6 +106,10 @@ def _is_multimodal_config(config): or ( hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer") ) # Image embedding layers + or getattr(config, "is_encoder_decoder", False) # Encoder-decoder VL models + or any( # Architecture-based detection for custom VL models (e.g., Nemotron-Parse) + "conditionalgeneration" in arch.lower() for arch in getattr(config, "architectures", []) + ) ) @@ -158,9 +162,20 @@ def calibrate_loop(_model): ) allowed_keys = set(forward_params.keys()) + # Check if model is encoder-decoder (needs decoder_input_ids instead of input_ids) + is_enc_dec = getattr(full_model.config, "is_encoder_decoder", False) + full_model.eval() with torch.no_grad(): for batch in calib_dataloader: + # For encoder-decoder models, rename input_ids → decoder_input_ids + # and disable KV caching to avoid tuple index errors in decoder layers + if is_enc_dec and "input_ids" in batch and "pixel_values" in batch: + batch["decoder_input_ids"] = batch.pop("input_ids") + if "attention_mask" in batch: + batch["decoder_attention_mask"] = batch.pop("attention_mask") + batch["use_cache"] = False + # Filter batch to only include parameters the model accepts if accepts_kwargs: call_kwargs = batch @@ -172,10 +187,8 @@ def calibrate_loop(_model): # Use safe_nemotron_vl_forward for Nemotron Nano VL (embedding-injection style) # For other VLMs (like Nemotron-Parse), use standard forward if hasattr(full_model, "img_context_token_id"): - # Nemotron Nano VL style safe_nemotron_vl_forward(full_model, call_kwargs) else: - # Standard encoder-decoder or other VLM architectures full_model(**call_kwargs) return calibrate_loop @@ -312,8 +325,15 @@ def get_processor( ) return MllamaImageProcessor(processor, device) - - return None + else: + # Try to load AutoProcessor for other VL models (e.g., Nemotron-Parse) + try: + processor = AutoProcessor.from_pretrained(ckpt_path, **model_kwargs) + print(f"Loaded AutoProcessor for model type: {model_type}") + return processor + except Exception as e: + print(f"Could not load processor for {model_type}: {e}") + return None def load_mtp_weights( @@ -447,6 +467,7 @@ def get_model( # Load config once and handle VL model detection try: hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs) + if is_nemotron_vl(hf_config): print( "Detected Nemotron VL model from config. " @@ -466,8 +487,6 @@ def get_model( model_kwargs.setdefault("torch_dtype", "auto") if "vila" in ckpt_path.lower(): - from transformers import AutoModel - hf_vila = AutoModel.from_pretrained( ckpt_path, device_map=device_map, @@ -510,13 +529,17 @@ def get_model( if not hasattr(transformers, architecture): warnings.warn( f"Architecture {architecture} not found in transformers: {transformers.__version__}. " - "Falling back to AutoModelForCausalLM." + "Falling back to AutoModelForCausalLM (or AutoModel for non-causal architectures)." ) assert trust_remote_code, ( "Please set trust_remote_code to True if you want to use this architecture" ) - auto_model_module = AutoModelForCausalLM + # Use AutoModelForCausalLM for causal LMs, AutoModel for encoder-decoder models + if getattr(hf_config, "is_encoder_decoder", False): + auto_model_module = AutoModel + else: + auto_model_module = AutoModelForCausalLM from_config = auto_model_module.from_config else: auto_model_module = getattr(transformers, architecture) @@ -527,7 +550,7 @@ def get_model( # unless specified by the hf_config. torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) model_kwargs2 = model_kwargs.copy() - if auto_model_module != AutoModelForCausalLM: + if auto_model_module not in [AutoModelForCausalLM, AutoModel]: model_kwargs2.pop("trust_remote_code", None) model_kwargs2["torch_dtype"] = torch_dtype model_kwargs2.pop("max_memory", None) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d9a6ca893..de434e1cf 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -361,6 +361,12 @@ def load_model(args: argparse.Namespace): default_pad_token = None is_nemotron_vl_model = is_nemotron_vl(full_model) + + # Default to image-text calibration for VLM models + if is_nemotron_vl_model and not args.calib_with_images: + print("Nemotron VL model detected. Enabling image-text calibration by default.") + args.calib_with_images = True + if model_type == "mllama": processor = get_processor( args.pyt_ckpt_path, @@ -499,9 +505,12 @@ def mono_quantize( print("Disabling quantization for vision components in Nemotron VL model") quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} quant_cfg["quant_cfg"]["*image*"] = {"enable": False} - # Also disable radio model components specifically + # Also disable radio model components specifically (for Nemotron-Parse) quant_cfg["quant_cfg"]["*radio*"] = {"enable": False} quant_cfg["quant_cfg"]["*visual*"] = {"enable": False} + quant_cfg["quant_cfg"]["*encoder*"] = {"enable": False} # Disable encoder + quant_cfg["quant_cfg"]["*model_encoder*"] = {"enable": False} # Nemotron-Parse specific + print("Quantization will only be applied to the decoder (text generation) component") if not model_is_already_quantized or calibration_only: if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": @@ -686,7 +695,7 @@ def pre_quantize( preview_input_ids, args.pyt_ckpt_path, "before quantization", - allow_fallback=True, + allow_fallback=False, ) else: # Standard generation for non-Nemotron VL models @@ -800,36 +809,42 @@ def quantize_main( device: torch.device, ): if args.batch_size == 0: - # Calibration/sparsification will actually take much more memory than regular inference - # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio - # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference. - sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1 - # Whisper model expects mel-spectrogram input features of length 3000 - # Whisper model needs input of shape (batch_size, num_mel_bins, 3000) - # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float - # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size() - if model_type == "whisper": - max_sample_length = 3000 - num_mel_bins = language_model.config.num_mel_bins - sample_input_single_batch = ( - torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to( - language_model.device - ) - * 100 - ) + # For VL models with image-text calibration, skip automatic batch size detection + # since get_max_batch_size can't handle multimodal inputs + if args.calib_with_images: + print("Image-text calibration enabled. Using default batch_size=1 for calibration.") + args.batch_size = 1 else: - sample_input_single_batch = None + # Calibration/sparsification will actually take much more memory than regular inference + # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio + # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference. + sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1 + # Whisper model expects mel-spectrogram input features of length 3000 + # Whisper model needs input of shape (batch_size, num_mel_bins, 3000) + # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float + # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size() + if model_type == "whisper": + max_sample_length = 3000 + num_mel_bins = language_model.config.num_mel_bins + sample_input_single_batch = ( + torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to( + language_model.device + ) + * 100 + ) + else: + sample_input_single_batch = None - run_auto_quant = args.auto_quantize_bits is not None + run_auto_quant = args.auto_quantize_bits is not None - args.batch_size = get_max_batch_size( - language_model, - max_sample_length=args.calib_seq, - sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0, - sample_input_single_batch=sample_input_single_batch, - enable_grad=run_auto_quant, - ) - args.batch_size = min(args.batch_size, sum(args.calib_size)) + args.batch_size = get_max_batch_size( + language_model, + max_sample_length=args.calib_seq, + sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0, + sample_input_single_batch=sample_input_single_batch, + enable_grad=run_auto_quant, + ) + args.batch_size = min(args.batch_size, sum(args.calib_size)) print(f"Use calib batch_size {args.batch_size}") diff --git a/examples/llm_ptq/vlm_utils.py b/examples/llm_ptq/vlm_utils.py index 6c9d921b8..9919e405b 100644 --- a/examples/llm_ptq/vlm_utils.py +++ b/examples/llm_ptq/vlm_utils.py @@ -105,27 +105,31 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name): else: processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) - messages = [ - {"role": "system", "content": "/no_think"}, - { - "role": "user", - "content": [ - { - "type": "image", - "image": "", - }, - { - "type": "text", - "text": question, - }, - ], - }, - ] - - # Apply chat template - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) + # Use chat template if available, otherwise fall back to default task prompt + if hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None: + messages = [ + {"role": "system", "content": "/no_think"}, + { + "role": "user", + "content": [ + { + "type": "image", + "image": "", + }, + { + "type": "text", + "text": question, + }, + ], + }, + ] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + else: + # For models without chat templates (e.g., encoder-decoder VL models), + # use the tokenizer's bos/eos tokens as a minimal prompt + prompt = (tokenizer.bos_token or "") + question # Process inputs using the processor with single image inputs = processor( @@ -139,6 +143,12 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name): inputs = inputs.to(model_device) print(f" Moved inputs to {model_device}") + # Verify we have pixel_values for the vision encoder + if not hasattr(inputs, "pixel_values") or inputs.pixel_values is None: + raise ValueError( + "Processor did not generate pixel_values. Check processor configuration." + ) + # Generate response using model.generate generated_ids = model.generate( pixel_values=inputs.pixel_values, @@ -148,12 +158,23 @@ def run_vl_preview_generation(model, tokenizer, model_path, stage_name): ) # Decode the response (trim input tokens like in the working example) + if generated_ids is None: + raise ValueError("Model generate returned None") + generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] - output_text = processor.batch_decode( - generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + # Use processor.batch_decode if available, otherwise fall back to tokenizer + decoder = processor if hasattr(processor, "batch_decode") else tokenizer + output_text = decoder.batch_decode( + generated_ids_trimmed, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, ) + + if output_text is None or len(output_text) == 0: + raise ValueError("Decoding returned empty output") + response = output_text[0] print(f"✅ VL generation {stage_name} successful!") diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py index 5a24429ad..6cb5be9a5 100755 --- a/modelopt/torch/export/model_utils.py +++ b/modelopt/torch/export/model_utils.py @@ -85,6 +85,7 @@ def is_multimodal_model(model): - Vision LoRA configurations - Audio processing capabilities - Image embedding layers + - Nemotron-Parse conditional generation models Args: model: The HuggingFace model instance to check @@ -103,6 +104,10 @@ def is_multimodal_model(model): """ config = model.config + # Check for Nemotron-Parse encoder-decoder architecture + architectures = getattr(config, "architectures", []) + is_nemotron_parse = any("nemotronparse" in arch.lower() for arch in architectures) + return ( hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL) or hasattr(model, "language_model") # Language model attribute (e.g., LLaVA) @@ -112,6 +117,7 @@ def is_multimodal_model(model): or ( hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer") ) # Image embedding layers + or is_nemotron_parse # Nemotron-Parse conditional generation model ) @@ -141,5 +147,11 @@ def get_language_model_from_vl(model) -> list[nn.Module] | None: if hasattr(model, "language_model"): return [model, model.language_model] - # Pattern 3: No language_model found + # Pattern 3: For encoder-decoder VL models (e.g., Nemotron-Parse), the decoder is the language model. + # Only match if the model is detected as multimodal to avoid matching non-VLM encoder-decoder + # models like T5, Bart, Whisper which also have .decoder. + if hasattr(model, "decoder") and is_multimodal_model(model): + return [model, model.decoder] + + # Pattern 4: No language_model found return None diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 5703f4515..b6b92f6ff 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -316,27 +316,27 @@ def llm_dummy_forward(): [1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype ).to(model.device) - if getattr(model.config, "is_encoder_decoder", False): - # For encoder-decoder models, we need to pass both the encoder and decoder input ids - model(fake_input, decoder_input_ids=decoder_fake_input) - elif is_vl_model and "nemotron" in model_type: - # For Nemotron VL models, try to run optimization on just the language model part + if is_vl_model and "nemotron" in model_type: + # For Nemotron VL models, run optimization on just the language model/decoder. + # This avoids needing pixel_values for the vision encoder. language_model_lineage = get_language_model_from_vl(model) if language_model_lineage is not None: - # Run optimization on just the language model with the same input format as regular LLMs - # Use the same fake_input tensor that regular LLMs use language_model = language_model_lineage[-1] print( f"Running optimization on language model with fake_input shape: {fake_input.shape}" ) - language_model(fake_input) + # Pass use_cache=False to avoid KV cache issues in encoder-decoder models + language_model(fake_input, use_cache=False) else: raise ValueError( f"Cannot extract language_model from Nemotron VL model (type: {model_type}). " "This is required for requantization/resmoothing optimization. " "Please ensure the model architecture is supported or file an issue." ) + elif getattr(model.config, "is_encoder_decoder", False): + # For other encoder-decoder models (non-VL), pass both encoder and decoder input ids + model(fake_input, decoder_input_ids=decoder_fake_input) else: model(fake_input) From 2e54d9d38a050fe2e3bccc2824fb11b994b88e5c Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Fri, 13 Feb 2026 11:58:56 -0800 Subject: [PATCH 05/11] MBridge pruning minor fix for saving pruned NemotronH (#887) ## What does this PR do? **Type of change:** Bug fix ## Testing Nemotron Nano v2 pruned can be saved ## Summary by CodeRabbit * **Bug Fixes** * Fixed Hugging Face model loading to properly respect the `trust_remote_code` parameter during model instantiation. * **Improvements** * Enhanced distributed training logging with rank-0 aware warning and logging mechanisms for cleaner, non-redundant output in multi-GPU and multi-node scenarios. Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Signed-off-by: Hung-Yueh --- examples/megatron_bridge/distill.py | 2 ++ examples/megatron_bridge/prune_minitron.py | 4 +++- modelopt/torch/opt/searcher.py | 8 +++----- modelopt/torch/utils/plugins/mbridge.py | 2 ++ 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/megatron_bridge/distill.py b/examples/megatron_bridge/distill.py index c21bf7312..31f1cfc71 100644 --- a/examples/megatron_bridge/distill.py +++ b/examples/megatron_bridge/distill.py @@ -198,6 +198,8 @@ def _build_model_provider(hf_path): manual_gc=True, manual_gc_interval=100, ), + # TODO: Replace validation args in train with validation config in nemo:26.04 + # validation=ValidationConfig(eval_interval=args.eval_interval, eval_iters=args.eval_iters), optimizer=optimizer_config, scheduler=scheduler_config, ddp=DistributedDataParallelConfig( diff --git a/examples/megatron_bridge/prune_minitron.py b/examples/megatron_bridge/prune_minitron.py index 44eac3a31..c4da627f1 100644 --- a/examples/megatron_bridge/prune_minitron.py +++ b/examples/megatron_bridge/prune_minitron.py @@ -380,7 +380,9 @@ def score_func_mmlu(m): AutoModelForCausalLM.from_config( hf_cfg, trust_remote_code=args.trust_remote_code ).save_pretrained(args.output_hf_path, trust_remote_code=args.trust_remote_code) - pruned_bridge = AutoBridge.from_hf_pretrained(args.output_hf_path) + pruned_bridge = AutoBridge.from_hf_pretrained( + args.output_hf_path, trust_remote_code=args.trust_remote_code + ) pruned_bridge.save_hf_weights(model, args.output_hf_path) print_rank_0(f"Saved pruned model to {args.output_hf_path} in HF checkpoint format") diff --git a/modelopt/torch/opt/searcher.py b/modelopt/torch/opt/searcher.py index 9e73b143c..ab3930c20 100644 --- a/modelopt/torch/opt/searcher.py +++ b/modelopt/torch/opt/searcher.py @@ -27,7 +27,6 @@ from collections.abc import Callable from contextlib import nullcontext from typing import Any, final -from warnings import warn import numpy as np import pulp @@ -35,7 +34,7 @@ import torch.nn as nn from modelopt.torch.utils import distributed as dist -from modelopt.torch.utils import no_stdout, print_rank_0, run_forward_loop +from modelopt.torch.utils import no_stdout, print_rank_0, run_forward_loop, warn_rank_0 LimitsTuple = tuple[float, float] ConstraintsDict = dict[str, str | float | dict | None] @@ -244,12 +243,11 @@ def load_search_checkpoint(self) -> bool: if checkpoint is None: return False if not os.path.exists(checkpoint): - if dist.is_master(): - warn(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.") + warn_rank_0(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.") return False # iterate through state dict and load keys - print(f"Loading searcher state from {checkpoint}...") + print_rank_0(f"Loading searcher state from {checkpoint}...") # Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input state_dict = torch.load(checkpoint, weights_only=False) assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!" diff --git a/modelopt/torch/utils/plugins/mbridge.py b/modelopt/torch/utils/plugins/mbridge.py index ed2551ee1..94cdf87cf 100644 --- a/modelopt/torch/utils/plugins/mbridge.py +++ b/modelopt/torch/utils/plugins/mbridge.py @@ -191,6 +191,8 @@ def get_hf_mbridge_calibration_loop( eval_iters=num_iters, skip_train=True, ), + # TODO: Replace validation args in train with validation config in nemo:26.04 + # validation=ValidationConfig(eval_iters=num_iters, eval_interval=1, skip_train=True), dataset=_get_dataset_cfg( dataset_name, num_samples, From eb699d8f838fa3a04a4bd94eb689a7654f18d2f7 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Date: Sat, 14 Feb 2026 01:54:53 +0000 Subject: [PATCH 06/11] add qwen3vl Signed-off-by: Hung-Yueh --- modelopt/torch/export/plugins/mcore_common.py | 6 + .../torch/export/plugins/mcore_qwen3vl.py | 120 ++++++++++++++++++ 2 files changed, 126 insertions(+) create mode 100644 modelopt/torch/export/plugins/mcore_qwen3vl.py diff --git a/modelopt/torch/export/plugins/mcore_common.py b/modelopt/torch/export/plugins/mcore_common.py index d5bab9b4e..660e4eac9 100644 --- a/modelopt/torch/export/plugins/mcore_common.py +++ b/modelopt/torch/export/plugins/mcore_common.py @@ -39,6 +39,10 @@ qwen25_causal_lm_export, qwen25_causal_lm_import, ) +from .mcore_qwen3vl import ( + qwen3vl_causal_lm_export, + qwen3vl_causal_lm_import, +) all_mcore_hf_export_mapping: dict[str, Any] = { "DeepseekV2ForCausalLM": deepseek_causal_lm_export, @@ -54,6 +58,7 @@ "Qwen3MoeForCausalLM": qwen3_causal_lm_export, "Qwen2ForCausalLM": qwen25_causal_lm_export, "GptOssForCausalLM": gptoss_causal_lm_export, + "Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_export, } all_mcore_hf_import_mapping: dict[str, Any] = { @@ -66,4 +71,5 @@ "Qwen3MoeForCausalLM": qwen3_causal_lm_import, "Qwen2ForCausalLM": qwen25_causal_lm_import, "GptOssForCausalLM": gptoss_causal_lm_import, + "Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_import, } diff --git a/modelopt/torch/export/plugins/mcore_qwen3vl.py b/modelopt/torch/export/plugins/mcore_qwen3vl.py new file mode 100644 index 000000000..40eb99adb --- /dev/null +++ b/modelopt/torch/export/plugins/mcore_qwen3vl.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom mapping from Qwen3-VL Hugging Face models to Megatron Core models. + +Qwen3-VL model structure differs from Qwen3: +- Language model weights are under `model.language_model.` prefix +- Visual encoder weights are under `model.visual.` prefix + +This module handles the language model conversion for PTQ/QAT workflows. +Visual components are typically kept in full precision. + +HuggingFace Qwen3-VL-8B structure: +- model.language_model.embed_tokens.weight +- model.language_model.layers.{L}.input_layernorm.weight +- model.language_model.layers.{L}.self_attn.q_proj.weight +- model.language_model.layers.{L}.self_attn.k_proj.weight +- model.language_model.layers.{L}.self_attn.v_proj.weight +- model.language_model.layers.{L}.self_attn.q_norm.weight +- model.language_model.layers.{L}.self_attn.k_norm.weight +- model.language_model.layers.{L}.self_attn.o_proj.weight +- model.language_model.layers.{L}.post_attention_layernorm.weight +- model.language_model.layers.{L}.mlp.gate_proj.weight +- model.language_model.layers.{L}.mlp.up_proj.weight +- model.language_model.layers.{L}.mlp.down_proj.weight +- model.language_model.norm.weight +- lm_head.weight +""" + +from .mcore_custom import ( + COL_ETP, + COL_TP, + REPLICATE, + ROW_ETP, + ROW_TP, + CustomModuleMapping, + GatedMLPMerging, + GatedMLPSlicing, + NameRemapping, + QKVMerging, + QKVSlicing, +) + +# Import rules: HuggingFace -> Megatron Core +qwen3vl_causal_lm_import: dict[str, CustomModuleMapping] = { + # Embeddings - note the language_model prefix + "word_embeddings": NameRemapping("model.language_model.embed_tokens.", COL_TP), + # Final layer norm + "final_layernorm": NameRemapping("model.language_model.norm.", REPLICATE), + # Output layer (lm_head is at root level, not under language_model) + "output_layer": NameRemapping("lm_head.", COL_TP), + # Attention - input layernorm + "input_layernorm": NameRemapping("model.language_model.layers.{}.input_layernorm.", REPLICATE), + # Attention - QKV projection (merged) + "linear_qkv": QKVMerging("model.language_model.layers.{}.self_attn.", COL_TP), + # Attention - output projection + "linear_proj": NameRemapping("model.language_model.layers.{}.self_attn.o_proj.", ROW_TP), + # Attention - Q/K layer norms (Qwen3 uses RMSNorm on Q and K) + "q_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.q_norm.", REPLICATE), + "k_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.k_norm.", REPLICATE), + # MLP - pre-MLP layernorm (post_attention_layernorm in HF) + "pre_mlp_layernorm": NameRemapping( + "model.language_model.layers.{}.post_attention_layernorm.", REPLICATE + ), + # MLP - gate_proj + up_proj merged into linear_fc1 + "linear_fc1": GatedMLPMerging("model.language_model.layers.{}.mlp.", COL_TP), + # MLP - down_proj as linear_fc2 + "linear_fc2": NameRemapping("model.language_model.layers.{}.mlp.down_proj.", ROW_TP), + # MoE support (for Qwen3-VL MoE variants like 30B-A3B) + "router": NameRemapping("model.language_model.layers.{}.mlp.gate.", REPLICATE), + "local_experts.linear_fc1": GatedMLPMerging( + "model.language_model.layers.{}.mlp.experts.{}.", COL_ETP + ), + "local_experts.linear_fc2": NameRemapping( + "model.language_model.layers.{}.mlp.experts.{}.down_proj.", ROW_ETP + ), +} + +# Export rules: Megatron Core -> HuggingFace +qwen3vl_causal_lm_export: dict[str, CustomModuleMapping] = { + # Embeddings + "word_embeddings": NameRemapping("model.language_model.embed_tokens."), + # Final layer norm + "final_layernorm": NameRemapping("model.language_model.norm."), + # Output layer + "output_layer": NameRemapping("lm_head."), + # Attention - input layernorm + "input_layernorm": NameRemapping("model.language_model.layers.{}.input_layernorm."), + # Attention - QKV projection (sliced back to separate q/k/v) + "linear_qkv": QKVSlicing("model.language_model.layers.{}.self_attn."), + # Attention - output projection + "linear_proj": NameRemapping("model.language_model.layers.{}.self_attn.o_proj."), + # Attention - Q/K layer norms + "q_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.q_norm."), + "k_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.k_norm."), + # MLP - pre-MLP layernorm + "pre_mlp_layernorm": NameRemapping("model.language_model.layers.{}.post_attention_layernorm."), + # MLP - linear_fc1 sliced back to gate_proj + up_proj + "linear_fc1": GatedMLPSlicing("model.language_model.layers.{}.mlp."), + # MLP - down_proj + "linear_fc2": NameRemapping("model.language_model.layers.{}.mlp.down_proj."), + # MoE support + "router": NameRemapping("model.language_model.layers.{}.mlp.gate."), + "local_experts.linear_fc1": GatedMLPSlicing("model.language_model.layers.{}.mlp.experts.{}."), + "local_experts.linear_fc2": NameRemapping( + "model.language_model.layers.{}.mlp.experts.{}.down_proj." + ), +} \ No newline at end of file From e8b37e35f42ac9a342dcc96e77bbec1f0ff43bd5 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Fri, 13 Feb 2026 13:12:39 -0800 Subject: [PATCH 07/11] Separate CI job for Megatron GPU tests (#888) ## What does this PR do? [Short term]: Megatron based tests take a long time often resulting in CICD timeout. Splitting megatron tests into a dedicated CICD job for faster overall CI/CD run [Mid/Long term]: Run all megatron gpu tests using `torchrun` instead of `pytest` so all dist processes are already created and all individual tests no longer need to setup and destroy their processes which adds a lot of overhead per test ## Testing - [x] 1-GPU CI/CD passing (on this PR) - [x] 2-GPU CI/CD passing (on nightly run - manually triggered): https://github.com/NVIDIA/Model-Optimizer/actions/runs/22000517688 --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Signed-off-by: Hung-Yueh --- .github/workflows/example_tests.yml | 37 +++++-------------- .github/workflows/gpu_tests.yml | 22 +++++++++-- pyproject.toml | 4 +- setup.py | 1 + tests/gpu_megatron/_extensions | 1 + tests/gpu_megatron/torch/conftest.py | 1 + .../distill/plugins/test_distill_megatron.py | 0 .../export/test_unified_export_megatron.py | 0 .../test_vllm_fakequant_megatron_export.py | 0 .../test_megatron_gpt_dynamic_modules.py | 0 .../test_megatron_mamba_dynamic_modules.py | 0 .../opt/plugins/test_megatron_chaining.py | 0 .../torch/peft/plugins}/test_megatron_peft.py | 0 .../test_mcore_gpt_minitron_pruning.py | 0 .../test_mcore_mamba_minitron_pruning.py | 0 .../torch/quantization/plugins/test_apex.py | 0 .../quantization/plugins/test_megatron.py | 0 .../plugins/test_transformer_engine.py | 0 .../plugins/test_megatron_sparsity.py | 0 .../test_speculative_megatron_modules.py | 0 .../utils/plugins/test_utils_megatron.py | 0 tox.ini | 18 +++++---- 22 files changed, 45 insertions(+), 39 deletions(-) create mode 120000 tests/gpu_megatron/_extensions create mode 120000 tests/gpu_megatron/torch/conftest.py rename tests/{gpu => gpu_megatron}/torch/distill/plugins/test_distill_megatron.py (100%) rename tests/{gpu => gpu_megatron}/torch/export/test_unified_export_megatron.py (100%) rename tests/{gpu => gpu_megatron}/torch/export/test_vllm_fakequant_megatron_export.py (100%) rename tests/{gpu => gpu_megatron}/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py (100%) rename tests/{gpu => gpu_megatron}/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py (100%) rename tests/{gpu => gpu_megatron}/torch/opt/plugins/test_megatron_chaining.py (100%) rename tests/{gpu/torch/peft => gpu_megatron/torch/peft/plugins}/test_megatron_peft.py (100%) rename tests/{gpu => gpu_megatron}/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (100%) rename tests/{gpu => gpu_megatron}/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (100%) rename tests/{gpu => gpu_megatron}/torch/quantization/plugins/test_apex.py (100%) rename tests/{gpu => gpu_megatron}/torch/quantization/plugins/test_megatron.py (100%) rename tests/{gpu => gpu_megatron}/torch/quantization/plugins/test_transformer_engine.py (100%) rename tests/{gpu => gpu_megatron}/torch/sparsity/weight_sparsity/plugins/test_megatron_sparsity.py (100%) rename tests/{gpu => gpu_megatron}/torch/speculative/plugins/test_speculative_megatron_modules.py (100%) rename tests/{gpu => gpu_megatron}/torch/utils/plugins/test_utils_megatron.py (100%) diff --git a/.github/workflows/example_tests.yml b/.github/workflows/example_tests.yml index 8442125f3..c1dab5dab 100644 --- a/.github/workflows/example_tests.yml +++ b/.github/workflows/example_tests.yml @@ -56,7 +56,7 @@ jobs: match_pattern: "^DCO$|^linux$" # Wait for DCO and Unit tests / linux to pass delay: 300s - ##### PyTorch Example Tests ##### + ##### PyTorch Example Tests (speculative_decoding requires 26.01 image) ##### torch-pr: needs: [check-file-changes, wait-checks] if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true' @@ -64,10 +64,13 @@ jobs: fail-fast: false matrix: example: [llm_distill, llm_qat, llm_sparsity] + include: + - example: speculative_decoding + docker_image: "nvcr.io/nvidia/pytorch:26.01-py3" uses: ./.github/workflows/_example_tests_runner.yml secrets: inherit with: - docker_image: "nvcr.io/nvidia/pytorch:25.06-py3" + docker_image: ${{ matrix.docker_image || 'nvcr.io/nvidia/pytorch:25.06-py3' }} example: ${{ matrix.example }} pip_install_extras: "[hf,dev-test]" runner: linux-amd64-gpu-l4-latest-1 @@ -78,36 +81,17 @@ jobs: fail-fast: false matrix: example: [llm_distill, llm_qat, llm_sparsity] + include: + - example: speculative_decoding + docker_image: "nvcr.io/nvidia/pytorch:26.01-py3" uses: ./.github/workflows/_example_tests_runner.yml secrets: inherit with: - docker_image: "nvcr.io/nvidia/pytorch:25.06-py3" + docker_image: ${{ matrix.docker_image || 'nvcr.io/nvidia/pytorch:25.06-py3' }} example: ${{ matrix.example }} pip_install_extras: "[hf,dev-test]" runner: linux-amd64-gpu-h100-latest-2 - ##### Speculative Decoding Example Tests (requires 26.01 image) ##### - speculative-decoding-pr: - needs: [check-file-changes, wait-checks] - if: startsWith(github.ref, 'refs/heads/pull-request/') && needs.check-file-changes.outputs.any_changed == 'true' - uses: ./.github/workflows/_example_tests_runner.yml - secrets: inherit - with: - docker_image: "nvcr.io/nvidia/pytorch:26.01-py3" - example: speculative_decoding - pip_install_extras: "[hf,dev-test]" - runner: linux-amd64-gpu-l4-latest-1 - - speculative-decoding-non-pr: - if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }} - uses: ./.github/workflows/_example_tests_runner.yml - secrets: inherit - with: - docker_image: "nvcr.io/nvidia/pytorch:26.01-py3" - example: speculative_decoding - pip_install_extras: "[hf,dev-test]" - runner: linux-amd64-gpu-h100-latest-2 - ##### TensorRT-LLM Example Tests ##### trtllm-pr: needs: [check-file-changes, wait-checks] @@ -172,7 +156,7 @@ jobs: example-pr-required-check: # Run even if example tests are skipped if: ${{ startsWith(github.ref, 'refs/heads/pull-request/') && always() }} - needs: [check-file-changes, torch-pr, speculative-decoding-pr, trtllm-pr, onnx-pr] + needs: [check-file-changes, torch-pr, trtllm-pr, onnx-pr] runs-on: ubuntu-latest steps: - name: Required GPU tests did not succeed @@ -180,7 +164,6 @@ jobs: needs.check-file-changes.result != 'success' || (needs.check-file-changes.outputs.any_changed == 'true' && ( needs.torch-pr.result != 'success' || - needs.speculative-decoding-pr.result != 'success' || needs.trtllm-pr.result != 'success' || needs.onnx-pr.result != 'success' )) diff --git a/.github/workflows/gpu_tests.yml b/.github/workflows/gpu_tests.yml index cb4686815..3e55682cd 100644 --- a/.github/workflows/gpu_tests.yml +++ b/.github/workflows/gpu_tests.yml @@ -59,8 +59,16 @@ jobs: gpu-tests-pr: needs: [check-file-changes, wait-checks] if: needs.check-file-changes.outputs.any_changed == 'true' + strategy: + fail-fast: false + matrix: + include: + - example: py312-cuda12-gpu + timeout: 90 + - example: py312-cuda12-gpu-megatron + timeout: 120 runs-on: linux-amd64-gpu-l4-latest-1 - timeout-minutes: 120 + timeout-minutes: ${{ matrix.timeout }} container: &gpu_container image: nvcr.io/nvidia/pytorch:25.06-py3 env: @@ -74,11 +82,19 @@ jobs: run: | echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/include:/usr/lib/x86_64-linux-gnu" >> $GITHUB_ENV - name: Run gpu tests - run: pip install tox-current-env && tox -e py312-cuda12-gpu --current-env + run: pip install tox-current-env && tox -e ${{ matrix.example }} --current-env gpu-tests-non-pr: if: ${{ !startsWith(github.ref, 'refs/heads/pull-request/') }} + strategy: + fail-fast: false + matrix: + include: + - example: py312-cuda12-gpu + timeout: 90 + - example: py312-cuda12-gpu-megatron + timeout: 120 runs-on: linux-amd64-gpu-h100-latest-2 - timeout-minutes: 150 + timeout-minutes: ${{ matrix.timeout }} container: *gpu_container steps: *gpu_steps gpu-pr-required-check: diff --git a/pyproject.toml b/pyproject.toml index 176866d41..bffa547b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,8 +132,8 @@ disable_error_code = ["attr-defined"] [tool.pytest.ini_options] # Default additional options # Show a short test summary info for all except passed tests with -ra flag -# print execution time for 20 slowest tests and generate coverage reports -addopts = "-v -ra --instafail --cov-report=term-missing --cov-report=html --cov-report=xml:coverage.xml --cov-config=pyproject.toml --durations=20 --strict-markers" +# print execution time for 50 slowest tests and generate coverage reports +addopts = "-v -ra --instafail --cov-report=term-missing --cov-report=html --cov-report=xml:coverage.xml --cov-config=pyproject.toml --durations=50 --strict-markers" pythonpath = ["tests/"] markers = [ "manual: Only run when --run-manual is given", diff --git a/setup.py b/setup.py index 242505302..8f5578e89 100644 --- a/setup.py +++ b/setup.py @@ -77,6 +77,7 @@ "pytest-cov", "pytest-instafail", "pytest-timeout", + "sentencepiece", # For test_unified_export_megatron.py, test_vllm_fakequant_megatron_export.py "timm", "torchprofile>=0.0.4", # For computing flops of CV models "torchvision", diff --git a/tests/gpu_megatron/_extensions b/tests/gpu_megatron/_extensions new file mode 120000 index 000000000..dc4ffce33 --- /dev/null +++ b/tests/gpu_megatron/_extensions @@ -0,0 +1 @@ +../gpu/_extensions/ \ No newline at end of file diff --git a/tests/gpu_megatron/torch/conftest.py b/tests/gpu_megatron/torch/conftest.py new file mode 120000 index 000000000..40eda16c0 --- /dev/null +++ b/tests/gpu_megatron/torch/conftest.py @@ -0,0 +1 @@ +../../gpu/torch/conftest.py \ No newline at end of file diff --git a/tests/gpu/torch/distill/plugins/test_distill_megatron.py b/tests/gpu_megatron/torch/distill/plugins/test_distill_megatron.py similarity index 100% rename from tests/gpu/torch/distill/plugins/test_distill_megatron.py rename to tests/gpu_megatron/torch/distill/plugins/test_distill_megatron.py diff --git a/tests/gpu/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py similarity index 100% rename from tests/gpu/torch/export/test_unified_export_megatron.py rename to tests/gpu_megatron/torch/export/test_unified_export_megatron.py diff --git a/tests/gpu/torch/export/test_vllm_fakequant_megatron_export.py b/tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py similarity index 100% rename from tests/gpu/torch/export/test_vllm_fakequant_megatron_export.py rename to tests/gpu_megatron/torch/export/test_vllm_fakequant_megatron_export.py diff --git a/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py b/tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py similarity index 100% rename from tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py rename to tests/gpu_megatron/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py diff --git a/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py b/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py similarity index 100% rename from tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py rename to tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py diff --git a/tests/gpu/torch/opt/plugins/test_megatron_chaining.py b/tests/gpu_megatron/torch/opt/plugins/test_megatron_chaining.py similarity index 100% rename from tests/gpu/torch/opt/plugins/test_megatron_chaining.py rename to tests/gpu_megatron/torch/opt/plugins/test_megatron_chaining.py diff --git a/tests/gpu/torch/peft/test_megatron_peft.py b/tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py similarity index 100% rename from tests/gpu/torch/peft/test_megatron_peft.py rename to tests/gpu_megatron/torch/peft/plugins/test_megatron_peft.py diff --git a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py similarity index 100% rename from tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py rename to tests/gpu_megatron/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py diff --git a/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py b/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py similarity index 100% rename from tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py rename to tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py diff --git a/tests/gpu/torch/quantization/plugins/test_apex.py b/tests/gpu_megatron/torch/quantization/plugins/test_apex.py similarity index 100% rename from tests/gpu/torch/quantization/plugins/test_apex.py rename to tests/gpu_megatron/torch/quantization/plugins/test_apex.py diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py similarity index 100% rename from tests/gpu/torch/quantization/plugins/test_megatron.py rename to tests/gpu_megatron/torch/quantization/plugins/test_megatron.py diff --git a/tests/gpu/torch/quantization/plugins/test_transformer_engine.py b/tests/gpu_megatron/torch/quantization/plugins/test_transformer_engine.py similarity index 100% rename from tests/gpu/torch/quantization/plugins/test_transformer_engine.py rename to tests/gpu_megatron/torch/quantization/plugins/test_transformer_engine.py diff --git a/tests/gpu/torch/sparsity/weight_sparsity/plugins/test_megatron_sparsity.py b/tests/gpu_megatron/torch/sparsity/weight_sparsity/plugins/test_megatron_sparsity.py similarity index 100% rename from tests/gpu/torch/sparsity/weight_sparsity/plugins/test_megatron_sparsity.py rename to tests/gpu_megatron/torch/sparsity/weight_sparsity/plugins/test_megatron_sparsity.py diff --git a/tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py b/tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py similarity index 100% rename from tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py rename to tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py diff --git a/tests/gpu/torch/utils/plugins/test_utils_megatron.py b/tests/gpu_megatron/torch/utils/plugins/test_utils_megatron.py similarity index 100% rename from tests/gpu/torch/utils/plugins/test_utils_megatron.py rename to tests/gpu_megatron/torch/utils/plugins/test_utils_megatron.py diff --git a/tox.ini b/tox.ini index ee7acf029..ae296e5bd 100644 --- a/tox.ini +++ b/tox.ini @@ -60,23 +60,27 @@ commands = [testenv:{py310,py311,py312}-cuda12-gpu] commands_pre = # Install deps here so that it gets installed even in --current-env - pip install -U megatron-core pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git + pip install -e .[all,dev-test] +commands = + # Coverage fails with "Can't combine line data with arc data" error so not using "--cov" + python -m pytest tests/gpu + +[testenv:{py310,py311,py312}-cuda12-gpu-megatron] +commands_pre = + # Install deps here so that it gets installed even in --current-env + pip install -U megatron-core + # Skip triton because pytorch-triton is installed in the NGC PyTorch containers pip install pip-mark-installed pip-mark-installed triton pip install --no-build-isolation git+https://github.com/state-spaces/mamba.git - # Install Eagle-3 test dependencies - pip install tiktoken blobfile sentencepiece - - # NOTE: User is expected to have correct torch-cuda version pre-installed if using --current-env - # to avoid possible CUDA version mismatch pip install -e .[all,dev-test] commands = # Coverage fails with "Can't combine line data with arc data" error so not using "--cov" - python -m pytest tests/gpu + python -m pytest tests/gpu_megatron ############################################# # Code quality checks on all files or on diff From 049fd6f3688488e4f7ac00dcf2c64282c8ba3fad Mon Sep 17 00:00:00 2001 From: Hung-Yueh Date: Sat, 14 Feb 2026 03:03:14 +0000 Subject: [PATCH 08/11] add test Signed-off-by: Hung-Yueh --- tests/unit/torch/export/test_mcore_qwen3vl.py | 306 ++++++++++++++++++ 1 file changed, 306 insertions(+) create mode 100644 tests/unit/torch/export/test_mcore_qwen3vl.py diff --git a/tests/unit/torch/export/test_mcore_qwen3vl.py b/tests/unit/torch/export/test_mcore_qwen3vl.py new file mode 100644 index 000000000..3f57cb9c4 --- /dev/null +++ b/tests/unit/torch/export/test_mcore_qwen3vl.py @@ -0,0 +1,306 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Qwen3-VL Megatron Core export/import plugin.""" + +import pytest + +from modelopt.torch.export.plugins.mcore_custom import ( + COL_TP, + REPLICATE, + ROW_TP, + GatedMLPMerging, + GatedMLPSlicing, + NameRemapping, + QKVMerging, + QKVSlicing, +) +from modelopt.torch.export.plugins.mcore_qwen3vl import ( + qwen3vl_causal_lm_export, + qwen3vl_causal_lm_import, +) + + +# All mcore keys that a dense (non-MoE) Qwen3-VL model should have +DENSE_MCORE_KEYS = { + "word_embeddings", + "final_layernorm", + "output_layer", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", +} + +# Additional MoE keys +MOE_MCORE_KEYS = { + "router", + "local_experts.linear_fc1", + "local_experts.linear_fc2", +} + + +class TestQwen3VLRegistration: + """Test that Qwen3-VL is registered in the global mapping.""" + + def test_registered_in_export_mapping(self): + from modelopt.torch.export.plugins.mcore_common import ( + all_mcore_hf_export_mapping, + ) + + assert "Qwen3VLForConditionalGeneration" in all_mcore_hf_export_mapping + assert ( + all_mcore_hf_export_mapping["Qwen3VLForConditionalGeneration"] + is qwen3vl_causal_lm_export + ) + + def test_registered_in_import_mapping(self): + from modelopt.torch.export.plugins.mcore_common import ( + all_mcore_hf_import_mapping, + ) + + assert "Qwen3VLForConditionalGeneration" in all_mcore_hf_import_mapping + assert ( + all_mcore_hf_import_mapping["Qwen3VLForConditionalGeneration"] + is qwen3vl_causal_lm_import + ) + + +class TestQwen3VLImportMapping: + """Test the HuggingFace -> Megatron Core import mapping.""" + + def test_has_all_dense_keys(self): + assert DENSE_MCORE_KEYS.issubset(qwen3vl_causal_lm_import.keys()) + + def test_has_all_moe_keys(self): + assert MOE_MCORE_KEYS.issubset(qwen3vl_causal_lm_import.keys()) + + def test_language_model_prefix(self): + """Qwen3-VL uses model.language_model. prefix (not model.).""" + prefix_keys = [ + "word_embeddings", + "final_layernorm", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", + ] + for key in prefix_keys: + mapping = qwen3vl_causal_lm_import[key] + assert "model.language_model." in mapping.target_name_or_prefix, ( + f"{key}: expected 'model.language_model.' prefix, " + f"got '{mapping.target_name_or_prefix}'" + ) + + def test_output_layer_at_root(self): + """lm_head is at root level, not under language_model.""" + mapping = qwen3vl_causal_lm_import["output_layer"] + assert mapping.target_name_or_prefix == "lm_head." + + def test_qkv_uses_merging(self): + assert isinstance(qwen3vl_causal_lm_import["linear_qkv"], QKVMerging) + + def test_mlp_uses_gated_merging(self): + assert isinstance( + qwen3vl_causal_lm_import["linear_fc1"], GatedMLPMerging + ) + + @pytest.mark.parametrize( + "key", + [ + "input_layernorm", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "final_layernorm", + ], + ) + def test_layernorms_are_replicated(self, key): + """Layernorms should use REPLICATE (empty func_kwargs).""" + mapping = qwen3vl_causal_lm_import[key] + assert isinstance(mapping, NameRemapping) + assert mapping.func_kwargs == REPLICATE + + @pytest.mark.parametrize( + "key,expected_kwargs", + [ + ("word_embeddings", COL_TP), + ("output_layer", COL_TP), + ("linear_proj", ROW_TP), + ], + ) + def test_tp_sharding(self, key, expected_kwargs): + mapping = qwen3vl_causal_lm_import[key] + assert mapping.func_kwargs == expected_kwargs + + +class TestQwen3VLExportMapping: + """Test the Megatron Core -> HuggingFace export mapping.""" + + def test_has_all_dense_keys(self): + assert DENSE_MCORE_KEYS.issubset(qwen3vl_causal_lm_export.keys()) + + def test_has_all_moe_keys(self): + assert MOE_MCORE_KEYS.issubset(qwen3vl_causal_lm_export.keys()) + + def test_language_model_prefix(self): + """Export paths should also use model.language_model. prefix.""" + prefix_keys = [ + "word_embeddings", + "final_layernorm", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", + ] + for key in prefix_keys: + mapping = qwen3vl_causal_lm_export[key] + assert "model.language_model." in mapping.target_name_or_prefix, ( + f"{key}: expected 'model.language_model.' prefix, " + f"got '{mapping.target_name_or_prefix}'" + ) + + def test_output_layer_at_root(self): + mapping = qwen3vl_causal_lm_export["output_layer"] + assert mapping.target_name_or_prefix == "lm_head." + + def test_qkv_uses_slicing(self): + assert isinstance(qwen3vl_causal_lm_export["linear_qkv"], QKVSlicing) + + def test_mlp_uses_gated_slicing(self): + assert isinstance( + qwen3vl_causal_lm_export["linear_fc1"], GatedMLPSlicing + ) + + def test_export_has_no_parallel_config(self): + """Export mappings should not have parallel configs.""" + for key in ["word_embeddings", "final_layernorm", "output_layer", + "input_layernorm", "linear_proj"]: + mapping = qwen3vl_causal_lm_export[key] + assert "parallel_config" not in mapping.func_kwargs + + +class TestQwen3VLImportExportSymmetry: + """Test that import and export mappings are consistent.""" + + def test_same_mcore_keys(self): + assert set(qwen3vl_causal_lm_import.keys()) == set( + qwen3vl_causal_lm_export.keys() + ) + + @pytest.mark.parametrize( + "key", + [ + "word_embeddings", + "final_layernorm", + "output_layer", + "input_layernorm", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc2", + "router", + ], + ) + def test_matching_hf_prefixes(self, key): + """Import and export should map to the same HF prefix.""" + imp = qwen3vl_causal_lm_import[key] + exp = qwen3vl_causal_lm_export[key] + assert imp.target_name_or_prefix == exp.target_name_or_prefix, ( + f"{key}: import prefix '{imp.target_name_or_prefix}' != " + f"export prefix '{exp.target_name_or_prefix}'" + ) + + def test_qkv_matching_prefix(self): + imp = qwen3vl_causal_lm_import["linear_qkv"] + exp = qwen3vl_causal_lm_export["linear_qkv"] + assert imp.target_name_or_prefix == exp.target_name_or_prefix + + def test_mlp_fc1_matching_prefix(self): + imp = qwen3vl_causal_lm_import["linear_fc1"] + exp = qwen3vl_causal_lm_export["linear_fc1"] + assert imp.target_name_or_prefix == exp.target_name_or_prefix + + +class TestQwen3VLvsQwen3Difference: + """Test that Qwen3-VL differs from Qwen3 only in the language_model prefix.""" + + def test_same_keys_as_qwen3(self): + from modelopt.torch.export.plugins.mcore_qwen import ( + qwen3_causal_lm_export, + qwen3_causal_lm_import, + ) + + assert set(qwen3vl_causal_lm_import.keys()) == set( + qwen3_causal_lm_import.keys() + ) + assert set(qwen3vl_causal_lm_export.keys()) == set( + qwen3_causal_lm_export.keys() + ) + + @pytest.mark.parametrize( + "key", + [ + "word_embeddings", + "final_layernorm", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", + "router", + "local_experts.linear_fc1", + "local_experts.linear_fc2", + ], + ) + def test_vl_adds_language_model_prefix(self, key): + """Qwen3-VL should have 'language_model.' inserted after 'model.'.""" + from modelopt.torch.export.plugins.mcore_qwen import ( + qwen3_causal_lm_import, + ) + + qwen3_prefix = qwen3_causal_lm_import[key].target_name_or_prefix + qwen3vl_prefix = qwen3vl_causal_lm_import[key].target_name_or_prefix + expected = qwen3_prefix.replace("model.", "model.language_model.", 1) + assert qwen3vl_prefix == expected, ( + f"{key}: expected '{expected}', got '{qwen3vl_prefix}'" + ) + + def test_output_layer_same(self): + """lm_head is at root level for both Qwen3 and Qwen3-VL.""" + from modelopt.torch.export.plugins.mcore_qwen import ( + qwen3_causal_lm_import, + ) + + assert ( + qwen3vl_causal_lm_import["output_layer"].target_name_or_prefix + == qwen3_causal_lm_import["output_layer"].target_name_or_prefix + ) From f50329a19c572d4c5ce2615dd769bd11ffb683fb Mon Sep 17 00:00:00 2001 From: yueshen2016 <39203804+yueshen2016@users.noreply.github.com> Date: Fri, 13 Feb 2026 16:34:52 -0800 Subject: [PATCH 09/11] [OMNIML-3232] Support full TE spec for NemotronH HF-to-Megatron import (#884) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What does this PR do? **Type of change:** new feature **Overview:** Enable full TE spec support for NemotronH (Mamba hybrid) models during HF-to-Megatron weight import via `import_mcore_gpt_from_hf`. Previously, importing HF weights into a Megatron model built with the full TE spec (`TELayerNormColumnParallelLinear`, `TEGroupedMLP`, etc.) failed for NemotronH models due to two issues: 1. **Grouped expert prefix bug**: The `experts.linear_fc1/fc2` import rules had a hard-coded `mtp.layers.{}` prefix, which was only correct for MTP layers. When regular decoder MoE layers use `TEGroupedMLP` (via the full TE spec), the importer generated incorrect HF keys (e.g., `mtp.layers.27.mixer.experts.0.up_proj.weight` instead of `backbone.layers.27.mixer.experts.0.up_proj.weight`). 2. **Fused layer norm loading**: In the full TE spec, layer norms are fused into `TELayerNormColumnParallelLinear` modules as `layer_norm_weight`. The importer's `_name_remapping` would crash trying to load `layer_norm_weight` from a non-existent HF path (e.g., `backbone.layers.X.mixer.in_proj.layer_norm_weight`), when the actual HF norm weight lives at `backbone.layers.X.norm.weight`. ### Changes **`mcore_nemotron.py`**: - Fixed grouped expert prefix from `mtp.layers.{}` to `backbone.layers.{}`. The `_grouped_mlp_merging` function already handles `backbone` → `mtp` replacement when `is_mtp=True`, so both decoder and MTP layers work correctly. - Added `mapping={"layer_norm_weight": None}` to `in_proj` and `linear_fc1` rules to skip `layer_norm_weight` during `_name_remapping` (loaded separately via `fused_norm`). - Added `fused_norm` rule (`NameRemapping("backbone.layers.{}.norm.weight")`) to load HF norm weights into fused TE modules. **`megatron_importer.py`**: - Added `source_key is None` check in `_name_remapping` to skip keys mapped to `None` in the mapping dict (keeps existing value instead of crashing on missing HF key). - Added fused norm loading in `_import_mamba_layer`: after loading `in_proj`, loads `layer_norm_weight` from HF via `fused_norm` rule when `layer.norm` is `IdentityOp`. - Added fused norm loading in `_import_transformer_layer`: loads `layer_norm_weight` into `linear_qkv` (when `input_layernorm` is `IdentityOp`) and into `linear_fc1` (when `pre_mlp_layernorm` is `IdentityOp`). ## Usage The full TE spec is enabled via the `--full-te-spec` flag on the Megatron-LM side (separate PR). On the ModelOpt side, no user-facing changes are needed -- the import rules automatically handle both local spec and full TE spec models. ```bash # Convert HF checkpoint to Megatron with full TE spec (megatron-lm side) unset MLM_MODEL_CKPT && export MLM_MODEL_SAVE=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16_mlm && export HF_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 export PP=2 export MLM_EXTRA_ARGS="--full-te-spec" bash convert.sh nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 # Quantize the converted checkpoint (megatron-lm side) export MLM_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16_mlm export MLM_MODEL_SAVE=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-fp8_mlm export HF_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 export PP=2 && export TP=4 && export EP=4 && export ETP=1 export MLM_EXTRA_ARGS="--full-te-spec" bash quantize.sh nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 FP8_DEFAULT_CFG # Generate export PP=2 && export TP=4 && export EP=4 && export ETP=1 export MLM_EXTRA_ARGS="--full-te-spec" export MLM_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-fp8_mlm && ./generate.sh nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 # MMLU export PP=2 && export TP=4 && export EP=4 && export ETP=1 export MLM_EXTRA_ARGS="--full-te-spec" export MLM_MODEL_CKPT=/models/NVIDIA-Nemotron-3-Nano-30B-A3B-fp8_mlm && export MLM_EXTRA_ARGS="--fraction 0.05 --disable-tqdm" && ./mmlu.sh nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 ``` ## Testing - Tested end-to-end: HF → Megatron conversion → FP8 quantization → inference (generate) → MMLU evaluation with Nemotron-3-Nano-30B-A3B-BF16. - Verified the resulting model structure matches Megatron-Bridge's TE spec output (TELayerNormColumnParallelLinear, TEGroupedMLP, IdentityOp norms, etc.). - Verified quantized model produces coherent text generation outputs. - Verified backward compatibility: all changes are no-ops for existing local-spec pipelines (guarded by `IdentityOp` checks, `hasattr` checks, and `"fused_norm" in self.rules` checks). ## Before your PR is "*Ready for review*" - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes -- all changes are guarded by conditions that only activate for full TE spec models. Local spec models follow the exact same code paths as before. - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: No ## Additional Information Companion megatron-lm changes (separate PR): - `megatron/core/post_training/modelopt/mamba/model_specs.py`: Added `use_full_te_spec` parameter to return canonical `mamba_stack_spec` from `mamba_layer_specs.py`. - `megatron/post_training/model_builder.py`: Passes `use_full_te_spec=args.full_te_spec` to `get_mamba_stack_modelopt_spec`. - `megatron/post_training/arguments.py`: Added `--full-te-spec` CLI flag. - `examples/post_training/modelopt/convert_model.py`: Skip `moe_grouped_gemm=False` override when `--full-te-spec` is set. ## Summary by CodeRabbit * **New Features** * Added support for loading fused normalization weights during model import. * **Bug Fixes** * Improved weight mapping logic to correctly skip redundant layer norm weights in specialized model architectures. * **Refactor** * Reorganized expert model parallel configuration paths for better compatibility with mixed parallel processing settings. Signed-off-by: James Shen Signed-off-by: Hung-Yueh --- .../torch/export/plugins/mcore_nemotron.py | 21 +++++++--- .../torch/export/plugins/megatron_importer.py | 39 +++++++++++++++++++ 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 92611f54b..6883c51c9 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -58,7 +58,11 @@ "D": NameRemapping("backbone.layers.{}.mixer.D", REPLICATE), "dt_bias": NameRemapping("backbone.layers.{}.mixer.dt_bias", REPLICATE), "conv1d": NameRemapping("backbone.layers.{}.mixer.conv1d.", REPLICATE), - "in_proj": NameRemapping("backbone.layers.{}.mixer.in_proj.", COL_TP), + # mapping layer_norm_weight to None tells _name_remapping to skip it; + # the fused layer_norm_weight is loaded separately via the "fused_norm" rule. + "in_proj": NameRemapping( + "backbone.layers.{}.mixer.in_proj.", COL_TP | {"mapping": {"layer_norm_weight": None}} + ), "out_proj": NameRemapping("backbone.layers.{}.mixer.out_proj.", ROW_TP), # Attention "input_layernorm": NameRemapping("backbone.layers.{}.norm.", REPLICATE), @@ -66,8 +70,13 @@ "linear_proj": NameRemapping("backbone.layers.{}.mixer.o_proj.", ROW_TP), # MLP "pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm.", REPLICATE), - "linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj.", COL_TP), + "linear_fc1": NameRemapping( + "backbone.layers.{}.mixer.up_proj.", COL_TP | {"mapping": {"layer_norm_weight": None}} + ), "linear_fc2": NameRemapping("backbone.layers.{}.mixer.down_proj.", ROW_TP), + # Fused layer norm: loads the HF norm weight into fused TELayerNormColumnParallelLinear + # modules (in_proj, linear_qkv, linear_fc1) when using TE spec. + "fused_norm": NameRemapping("backbone.layers.{}.norm.weight"), # MoE "router": NameRemapping( "backbone.layers.{}.mixer.gate.", {"mapping": {"expert_bias": "e_score_correction_bias"}} @@ -92,12 +101,14 @@ "mtp.hnorm": NameRemapping("mtp.layers.{}.hnorm.", {"is_mtp": True}), "mtp.eh_proj": NameRemapping("mtp.layers.{}.eh_proj.", {"is_mtp": True}), "mtp.final_layernorm": NameRemapping("mtp.layers.{}.final_layernorm.", {"is_mtp": True}), - # Grouped local experts in MTP + # Grouped local experts (used for TEGroupedMLP in both decoder and MTP layers). + # The prefix uses "backbone" for regular decoder layers; when called from MTP + # context (is_mtp=True), _grouped_mlp_merging replaces "backbone" with "mtp". "experts.linear_fc1": GroupedMLPMerging( - "mtp.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP | {"is_mtp": True} + "backbone.layers.{}.mixer.experts.{{}}.up_proj", COL_ETP ), "experts.linear_fc2": GroupedMLPMerging( - "mtp.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP | {"is_mtp": True} + "backbone.layers.{}.mixer.experts.{{}}.down_proj", ROW_ETP ), } diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index b4c1ec694..a156f2cd8 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -200,6 +200,12 @@ def _name_remapping( state_dict[key] = val else: source_key = mapping.get(key, key) + # A mapping value of None means "skip this key" (keep existing value). + # This is used for fused TE modules where layer_norm_weight is loaded + # separately from a different HF path. + if source_key is None: + state_dict[key] = val + continue # For bias tensors in ROW_TP layers, don't use parallel config to avoid sharding # since bias should always be replicated, not sharded if ( @@ -537,6 +543,15 @@ def _import_mamba_layer(self, layer, layer_id, layer_pbar): self.rules["in_proj"](layer.mixer.in_proj, layer_id) self.rules["out_proj"](layer.mixer.out_proj, layer_id) + # TE spec: layer norm is fused into in_proj (TELayerNormColumnParallelLinear). + # Load the fused layer_norm_weight from the HF norm path. + if ( + isinstance(layer.norm, IdentityOp) + and hasattr(layer.mixer.in_proj, "layer_norm_weight") + and "fused_norm" in self.rules + ): + self.rules["fused_norm"](layer.mixer.in_proj.layer_norm_weight, layer_id) + def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = False): if not isinstance(layer.input_layernorm, IdentityOp): self.rules["input_layernorm"](layer.input_layernorm, layer_id, is_mtp=is_mtp) @@ -578,6 +593,18 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = attention.core_attention.softmax_offset, layer_id, is_mtp=is_mtp ) + # TE spec: input_layernorm is fused into linear_qkv (TELayerNormColumnParallelLinear). + # Load the fused layer_norm_weight from the HF norm path. + if ( + isinstance(layer.input_layernorm, IdentityOp) + and hasattr(attention, "linear_qkv") + and hasattr(attention.linear_qkv, "layer_norm_weight") + and "fused_norm" in self.rules + ): + self.rules["fused_norm"]( + attention.linear_qkv.layer_norm_weight, layer_id, is_mtp=is_mtp + ) + if not isinstance(layer.pre_mlp_layernorm, IdentityOp): self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id, is_mtp=is_mtp) @@ -671,6 +698,18 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool = self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id, is_mtp=is_mtp) self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id, is_mtp=is_mtp) + # TE spec: pre_mlp_layernorm is fused into linear_fc1 + # (TELayerNormColumnParallelLinear). + # Load the fused layer_norm_weight from the HF norm path. + if ( + isinstance(layer.pre_mlp_layernorm, IdentityOp) + and hasattr(layer.mlp.linear_fc1, "layer_norm_weight") + and "fused_norm" in self.rules + ): + self.rules["fused_norm"]( + layer.mlp.linear_fc1.layer_norm_weight, layer_id, is_mtp=is_mtp + ) + def _import_state_dict(self): model = self.model layer_pbar = tqdm(model.decoder.layers, disable=self.disable_tqdm) From 70cdfb46bbc4a62b738e6f3c0c61dc8cae78ab25 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Date: Sat, 14 Feb 2026 03:13:08 +0000 Subject: [PATCH 10/11] update changelog and doc Signed-off-by: Hung-Yueh --- CHANGELOG.rst | 1 + docs/source/deployment/3_unified_hf.rst | 1 + 2 files changed, 2 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9d7500e58..2e5bcb98b 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -21,6 +21,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add LTX-2 and Wan2.2 (T2V) support in the diffusers quantization workflow. - Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is. - Add support for image-text data calibration in PTQ for Nemotron VL models. +- Add Megatron Core export/import mapping for Qwen3-VL (``Qwen3VLForConditionalGeneration``) vision-language models. The mapping handles the ``model.language_model.`` weight prefix used by Qwen3-VL and supports both dense and MoE variants. 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ diff --git a/docs/source/deployment/3_unified_hf.rst b/docs/source/deployment/3_unified_hf.rst index 9124164b5..6664f987f 100644 --- a/docs/source/deployment/3_unified_hf.rst +++ b/docs/source/deployment/3_unified_hf.rst @@ -61,6 +61,7 @@ Models: * Llama 4, 3.x (FP8, NVFP4) * Qwen 3, 2.5 (FP8, NVFP4) * Qwen 3 MoE (FP8, NVFP4) + * Qwen 3-VL (FP8, NVFP4) * Deepseek R1/V3 (NVFP4) * Mixtral 8x7B (FP8, NVFP4) * Medusa (FP8) From 6d9773b9ba772fc5d5b8b0928b7abeb19960405a Mon Sep 17 00:00:00 2001 From: mxinO <164952785+mxinO@users.noreply.github.com> Date: Sun, 15 Feb 2026 05:28:21 +0800 Subject: [PATCH 11/11] [OMNIML-3505] LTX-2 Distillation Trainer (#892) ## What does this PR do? **Type of change:** new example **Overview:** Adding LTX-2 distillation trainer. ## Usage ```bash accelerate launch \ --config_file configs/accelerate/fsdp.yaml \ --num_processes 8 \ distillation_trainer.py --config configs/distillation_example.yaml ``` See readme for more details. ## Testing Run training with single/multiple nodes. ## Before your PR is "*Ready for review*" - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: NA - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes ## Additional Information ## Summary by CodeRabbit ## New Features * Added distillation training support for LTX-2 models with quantization integration. * Introduced comprehensive documentation and example configurations for distillation workflows. * Includes multi-GPU and multi-node training setup with distributed training support and customizable configuration templates. --------- Signed-off-by: Meng Xin Signed-off-by: Hung-Yueh --- CHANGELOG.rst | 1 + examples/diffusers/distillation/README.md | 153 ++ .../distillation/configs/accelerate/fsdp.yaml | 45 + .../configs/distillation_example.yaml | 142 ++ .../distillation/distillation_trainer.py | 1832 +++++++++++++++++ .../diffusers/distillation/requirements.txt | 4 + 6 files changed, 2177 insertions(+) create mode 100644 examples/diffusers/distillation/README.md create mode 100644 examples/diffusers/distillation/configs/accelerate/fsdp.yaml create mode 100644 examples/diffusers/distillation/configs/distillation_example.yaml create mode 100644 examples/diffusers/distillation/distillation_trainer.py create mode 100644 examples/diffusers/distillation/requirements.txt diff --git a/CHANGELOG.rst b/CHANGELOG.rst index bbbe6ab9e..7dec05338 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -22,6 +22,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is. - Add support for image-text data calibration in PTQ for Nemotron VL models. - Add PTQ support for Nemotron Parse. +- Add distillation support for LTX-2. See `examples/diffusers/distillation/README.md `_ for more details. 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ diff --git a/examples/diffusers/distillation/README.md b/examples/diffusers/distillation/README.md new file mode 100644 index 000000000..ce57c6036 --- /dev/null +++ b/examples/diffusers/distillation/README.md @@ -0,0 +1,153 @@ +# LTX-2 Distillation Training with ModelOpt + +Knowledge distillation for LTX-2 DiT models using NVIDIA ModelOpt. A frozen **teacher** guides a trainable **student** through a combined loss: + +```text +L_total = α × L_task + (1-α) × L_distill +``` + +Currently supported: + +- **Quantization-Aware Distillation (QAD)** — student uses ModelOpt fake quantization + +Planned: + +- **Sparsity-Aware Distillation (SAD)** — student uses ModelOpt sparsity + +## Installation + +```bash +# From the distillation example directory +cd examples/diffusers/distillation + +# Install Model-Optimizer (from repo root) +pip install -e ../../.. + +# Install all dependencies (ltx-trainer, ltx-core, ltx-pipelines, omegaconf) +pip install -r requirements.txt +``` + +## Quick Start + +### 1. Prepare Your Dataset + +Use the ltx-trainer preprocessing to extract latents and text embeddings: + +```bash +python -m ltx_trainer.preprocess \ + --input_dir /path/to/videos \ + --output_dir /path/to/preprocessed \ + --model_path /path/to/ltx2/checkpoint.safetensors +``` + +### 2. Configure + +Copy and edit the example config: + +```bash +cp configs/distillation_example.yaml configs/my_experiment.yaml +``` + +Key settings to update: + +```yaml +model: + model_path: "/path/to/ltx2/checkpoint.safetensors" + text_encoder_path: "/path/to/gemma/model" + +data: + preprocessed_data_root: "/path/to/preprocessed/data" + +distillation: + distillation_alpha: 0.5 # 1.0 = pure task loss, 0.0 = pure distillation + quant_cfg: "FP8_DEFAULT_CFG" # or INT8_DEFAULT_CFG, NVFP4_DEFAULT_CFG, null + +# IMPORTANT: disable ltx-trainer's built-in quantization +acceleration: + quantization: null +``` + +### 3. Run Training + +#### Single GPU + +```bash +python distillation_trainer.py --config configs/my_experiment.yaml +``` + +#### Multi-GPU (Single Node) with Accelerate + +```bash +accelerate launch \ + --config_file configs/accelerate/fsdp.yaml \ + --num_processes 8 \ + distillation_trainer.py --config configs/my_experiment.yaml +``` + +#### Multi-node Training with Accelerate + +To launch on multiple nodes, make sure to set the following environment variables on each node: + +- `NUM_NODES`: Total number of nodes +- `GPUS_PER_NODE`: Number of GPUs per node +- `NODE_RANK`: Unique rank/index of this node (0-based) +- `MASTER_ADDR`: IP address of the master node (rank 0) +- `MASTER_PORT`: Communication port (e.g., 29500) + +Then run this (on every node): + +```bash +accelerate launch \ + --config_file configs/accelerate/fsdp.yaml \ + --num_machines $NUM_NODES \ + --num_processes $((NUM_NODES * GPUS_PER_NODE)) \ + --machine_rank $NODE_RANK \ + --main_process_ip $MASTER_ADDR \ + --main_process_port $MASTER_PORT \ + distillation_trainer.py --config configs/my_experiment.yaml +``` + +**Config overrides** can be passed via CLI using dotted notation: + +```bash +accelerate launch ... distillation_trainer.py \ + --config configs/my_experiment.yaml \ + ++distillation.distillation_alpha=0.6 \ + ++distillation.quant_cfg=INT8_DEFAULT_CFG \ + ++optimization.learning_rate=1e-5 +``` + +## Configuration Reference + +### Calibration + +Before training begins, calibration runs full denoising inference to collect activation statistics for accurate quantizer scales. This is cached as a step-0 checkpoint and reused on subsequent runs. + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `calibration_prompts_file` | null | Text file with one prompt per line. Use the HuggingFace dataset 'Gustavosta/Stable-Diffusion-Prompts' if null. | +| `calibration_size` | 128 | Number of prompts (each runs a full denoising loop) | +| `calibration_n_steps` | 30 | Denoising steps per prompt | +| `calibration_guidance_scale` | 4.0 | CFG scale (should match inference-time) | + +### Checkpoint Resume + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `resume_from_checkpoint` | null | `"latest"` to auto-detect, or explicit path | +| `must_save_by` | null | Minutes after which to save and exit (for Slurm time limits) | +| `restore_quantized_checkpoint` | null | Restore a pre-quantized model (skips calibration) | +| `save_quantized_checkpoint` | null | Path to save the final quantized model | + +### Custom Quantization Configs + +To define custom quantization configs, add entries to `CUSTOM_QUANT_CONFIGS` in `distillation_trainer.py`: + +```python +CUSTOM_QUANT_CONFIGS["MY_FP8_CFG"] = { + "quant_cfg": mtq.FP8_DEFAULT_CFG["quant_cfg"], + "algorithm": "max", +} +``` + +Then reference it in your YAML: `quant_cfg: MY_FP8_CFG`. diff --git a/examples/diffusers/distillation/configs/accelerate/fsdp.yaml b/examples/diffusers/distillation/configs/accelerate/fsdp.yaml new file mode 100644 index 000000000..35e3edf77 --- /dev/null +++ b/examples/diffusers/distillation/configs/accelerate/fsdp.yaml @@ -0,0 +1,45 @@ +# FSDP Configuration +# +# FULL_SHARD across all GPUs for maximum memory efficiency. +# For multi-node training with `accelerate launch`. +# +# Usage: +# accelerate launch \ +# --config_file configs/accelerate/fsdp.yaml \ +# --num_processes 16 \ +# --num_machines 2 \ +# --machine_rank $MACHINE_RANK \ +# --main_process_ip $MASTER_IP \ +# --main_process_port 29500 \ +# distillation_trainer.py --config configs/distillation_example.yaml + +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false + +fsdp_config: + # FULL_SHARD: Shard optimizer states, gradients, and parameters across ALL GPUs + # This provides maximum memory efficiency for large models like LTX-2 19B + # Parameters are fully sharded across all nodes (not replicated) + fsdp_sharding_strategy: FULL_SHARD + + # Enable activation checkpointing to reduce memory during backward pass + # Critical for 19B model training + fsdp_activation_checkpointing: true + + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: false + fsdp_offload_params: false + fsdp_reshard_after_forward: true + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sync_module_states: true + fsdp_transformer_layer_cls_to_wrap: BasicAVTransformerBlock + fsdp_use_orig_params: true + fsdp_version: 1 + +# Note: num_machines and num_processes are overridden by accelerate launch command-line args +# These are just defaults for local testing +num_machines: 1 +num_processes: 8 diff --git a/examples/diffusers/distillation/configs/distillation_example.yaml b/examples/diffusers/distillation/configs/distillation_example.yaml new file mode 100644 index 000000000..6f7778a35 --- /dev/null +++ b/examples/diffusers/distillation/configs/distillation_example.yaml @@ -0,0 +1,142 @@ +# LTX-2 Distillation Training Configuration with ModelOpt + +# Model Configuration +model: + # Path to the LTX-2 checkpoint (used for both teacher and student) + model_path: "/path/to/ltx2/checkpoint.safetensors" + + # Path to Gemma text encoder (required for LTX-2) + text_encoder_path: "/path/to/gemma/model" + + # Training mode: "lora" is not supported yet + training_mode: "full" + +# Distillation Configuration +distillation: + # Path to teacher model (if different from model.model_path) + # Set to null to use the same checkpoint as student (loaded without quantization) + teacher_model_path: + + # Weight for task loss: L_total = α * L_task + (1-α) * L_distill + # α = 1.0: pure task loss (no distillation) + # α = 0.0: pure distillation loss + distillation_alpha: 0.0 + + # Type of distillation loss + # "mse": Mean squared error (recommended - transformer outputs are continuous velocity predictions) + # "cosine": Cosine similarity loss (matches direction only, ignores magnitude) + distillation_loss_type: "mse" + + # Data type for teacher model (bfloat16 recommended for memory efficiency) + teacher_dtype: "bfloat16" + + # ModelOpt Quantization Settings + # Name of the mtq config, e.g. FP8_DEFAULT_CFG, INT8_DEFAULT_CFG, NVFP4_DEFAULT_CFG. + # Custom configs defined in CUSTOM_QUANT_CONFIGS (distillation_trainer.py) are also supported. + quant_cfg: + + # Full-inference calibration settings (matching PTQ workflow). + # Each prompt runs a complete denoising loop through the DiT, covering all noise levels. + # Path to a text file with one prompt per line. If null, uses the default + # HuggingFace dataset 'Gustavosta/Stable-Diffusion-Prompts' (same as PTQ). + calibration_prompts_file: + # Total number of calibration prompts (set to 0 to skip calibration) + calibration_size: 128 + # Number of denoising steps per prompt (matches PTQ --n-steps) + calibration_n_steps: 30 + # CFG guidance scale during calibration (4.0 = PTQ default, calls transformer + # twice per step for positive + negative prompt; 1.0 = no CFG, saves memory) + calibration_guidance_scale: 4.0 + + # Path to restore a previously quantized model (from mto.save) + restore_quantized_checkpoint: + + # Path to save the final quantized model checkpoint + save_quantized_checkpoint: + + # Resume from a full training state checkpoint (saves model + optimizer + RNG + step) + # Set to "latest" to auto-find the most recent checkpoint in output_dir/checkpoints/ + # Or set to an explicit path like "/path/to/checkpoints/step_001000" + resume_from_checkpoint: latest + + # Time-limit-aware saving for Slurm jobs. + # Minutes after which training must save a checkpoint and exit gracefully. + # Set slightly below your Slurm --time limit (e.g. time=30min -> must_save_by: 25). + # Timer starts when train() is called (after model loading/calibration). + must_save_by: + + # Debug/Test: Use mock data instead of real preprocessed data + # Useful for testing the training pipeline without preparing a dataset + use_mock_data: false + mock_data_samples: 100 + +# Training Strategy +training_strategy: + name: "text_to_video" + first_frame_conditioning_p: 0.1 + with_audio: false + +# Optimization Configuration +optimization: + learning_rate: 2.0e-6 + steps: 10000 + batch_size: 1 + gradient_accumulation_steps: 4 + max_grad_norm: 1.0 + optimizer_type: "adamw" # # Use "adamw8bit" for memory efficiency + scheduler_type: "cosine" + enable_gradient_checkpointing: true # Essential for memory savings + +# Acceleration Configuration +acceleration: + mixed_precision_mode: "bf16" + + # NOTE: Set to null - we use ModelOpt quantization instead of ltx-trainer's quanto + quantization: + + # 8-bit text encoder for memory savings + load_text_encoder_in_8bit: false + +# Data Configuration +data: + # Path to preprocessed training data (created by process_dataset.py) + preprocessed_data_root: "/path/to/preprocessed/data" + num_dataloader_workers: 2 + +# Validation Configuration +validation: + prompts: + - "A beautiful sunset over the ocean with gentle waves" + - "A cat playing with a ball of yarn in a cozy living room" + negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted" + video_dims: [512, 320, 33] # [width, height, frames] + frame_rate: 25.0 + inference_steps: 30 + interval: 500 # Validate every 500 steps + guidance_scale: 4.0 + seed: 42 + +# Checkpointing Configuration +checkpoints: + interval: 1000 # Save checkpoint every 1000 steps + keep_last_n: 3 # Keep only last 3 checkpoints + precision: "bfloat16" + +# Weights & Biases Logging +wandb: + enabled: true + project: "ltx2-distillation" + entity: # Your W&B username or team + tags: + - "distillation" + - "modelopt" + log_validation_videos: true + +# Flow Matching Configuration +flow_matching: + timestep_sampling_mode: "shifted_logit_normal" + timestep_sampling_params: {} + +# General Settings +seed: 42 +output_dir: "./outputs/distillation_experiment" diff --git a/examples/diffusers/distillation/distillation_trainer.py b/examples/diffusers/distillation/distillation_trainer.py new file mode 100644 index 000000000..d98278b9a --- /dev/null +++ b/examples/diffusers/distillation/distillation_trainer.py @@ -0,0 +1,1832 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Distillation Trainer for LTX-2 DiT Model with ModelOpt Quantization + +This module implements sparsity/quantization-aware distillation training where: +- Teacher: Original unsparsified/unquantized model (inference only) +- Student: Quantized model using ModelOpt's fake quantization (trainable) + +The distillation loss combines: +- L_task: Standard flow matching MSE loss (student_pred vs velocity_target) +- L_distill: Distillation MSE loss (student_pred vs teacher_pred) + +Usage: + python distillation_trainer.py --config configs/distillation_example.yaml +""" + +from __future__ import annotations + +import argparse +import gc +import json +import os +import time +from pathlib import Path +from typing import Literal + +import torch +import torch.distributed as dist +from ltx_trainer import logger +from ltx_trainer.config import ConfigBaseModel, LtxTrainerConfig +from ltx_trainer.model_loader import load_transformer +from ltx_trainer.trainer import IS_MAIN_PROCESS, LtxvTrainer +from omegaconf import OmegaConf +from pydantic import Field +from torch import Tensor + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq + +# Custom quantization configs. Checked before mtq built-in configs. +# Add your own configs here; they take precedence over mtq.* attributes. +CUSTOM_QUANT_CONFIGS: dict[str, dict] = { + # Example: override NVFP4 with a different algorithm + # "MY_NVFP4_CFG": { + # "quant_cfg": mtq.NVFP4_DEFAULT_CFG["quant_cfg"], + # "algorithm": "max", + # }, +} + + +# IS_MAIN_PROCESS (from ltx_trainer) checks LOCAL_RANK == 0, which is True on +# every node in multi-node training. For file writes on a shared filesystem +# (Lustre) we need a global-rank-0 check so that only a single process writes. +def is_global_rank0() -> bool: + """Check if this is global rank 0. Safe to call before or after dist init.""" + if dist.is_initialized(): + return dist.get_rank() == 0 + return os.environ.get("RANK", "0") == "0" + + +def get_quant_config(quant_cfg_name: str) -> dict: + """ + Resolve a quantization config by name. + + Lookup order: + 1. CUSTOM_QUANT_CONFIGS (user-defined overrides in this file) + 2. mtq. (built-in ModelOpt configs, e.g. FP8_DEFAULT_CFG, INT8_DEFAULT_CFG) + + Args: + quant_cfg_name: Name of the quantization config, e.g. "FP8_DEFAULT_CFG". + + Returns: + A copy of the quantization config dict. + """ + # Check custom configs first + if quant_cfg_name in CUSTOM_QUANT_CONFIGS: + logger.info(f"Using custom quant config: {quant_cfg_name}") + return CUSTOM_QUANT_CONFIGS[quant_cfg_name].copy() + + # Fall back to mtq built-in configs + cfg = getattr(mtq, quant_cfg_name, None) + if cfg is None: + available_custom = list(CUSTOM_QUANT_CONFIGS.keys()) + available_mtq = [ + attr for attr in dir(mtq) if attr.endswith("_CFG") and not attr.startswith("_") + ] + raise ValueError( + f"Unknown quant_cfg: '{quant_cfg_name}'. " + f"Available custom: {available_custom}. " + f"Available mtq built-in: {available_mtq}" + ) + logger.info(f"Using mtq built-in quant config: {quant_cfg_name}") + return cfg.copy() + + +class MockDataset(torch.utils.data.Dataset): + """ + Mock dataset that produces random data matching the expected training format. + + This is useful for testing the training pipeline without preparing real data. + The output format matches what PrecomputedDataset produces, with keys: + - "latents": video latent tensors and metadata + - "conditions": text embeddings and attention masks + + Note: prompt_embed_dim should be 3840 (the connector's inner_dim = 30 heads * 128 dim), + NOT 4096 (Gemma's raw hidden size). The PrecomputedDataset stores embeddings that have + already been projected through the feature_extractor_linear layer. + """ + + def __init__( + self, + width: int = 512, + height: int = 320, + num_frames: int = 33, + dataset_length: int = 100, + latent_dim: int = 128, + latent_spatial_compression_ratio: int = 32, + latent_temporal_compression_ratio: int = 8, + prompt_embed_dim: int = 3840, # Connector inner_dim, not Gemma's 4096 + prompt_sequence_length: int = 256, + fps: int = 25, + dtype: torch.dtype = torch.bfloat16, # Must match model dtype + ): + """ + Initialize mock dataset. + + Args: + width: Video width in pixels (must be divisible by 32) + height: Video height in pixels (must be divisible by 32) + num_frames: Number of video frames (should be 8k+1 for proper compression) + dataset_length: Number of samples in the dataset + latent_dim: Latent channel dimension (128 for LTX-2) + latent_spatial_compression_ratio: Spatial compression ratio (32 for LTX-2) + latent_temporal_compression_ratio: Temporal compression ratio (8 for LTX-2) + prompt_embed_dim: Text embedding dimension after projection (3840 for LTX-2, + which is connector's inner_dim = 30 heads * 128 dim_head) + prompt_sequence_length: Max text sequence length + fps: Frames per second + dtype: Data type for floating point tensors (must match model dtype, default bfloat16) + """ + self.width = width + self.height = height + self.num_frames = num_frames + self.dataset_length = dataset_length + self.latent_dim = latent_dim + self.num_latent_frames = (num_frames - 1) // latent_temporal_compression_ratio + 1 + self.latent_height = height // latent_spatial_compression_ratio + self.latent_width = width // latent_spatial_compression_ratio + self.prompt_embed_dim = prompt_embed_dim + self.prompt_sequence_length = prompt_sequence_length + self.fps = fps + self.dtype = dtype + + def __len__(self) -> int: + return self.dataset_length + + def __getitem__(self, idx: int) -> dict: + """ + Get a mock sample. + + Returns format expected by training strategy: + - latents: dict with "latents" tensor [C, F, H, W] and metadata + - conditions: dict with "prompt_embeds" and "prompt_attention_mask" + """ + return { + # Video latents (key: "latents" to match PrecomputedDataset) + "latents": { + "latents": torch.randn( + self.latent_dim, + self.num_latent_frames, + self.latent_height, + self.latent_width, + dtype=self.dtype, # Must match model dtype + ), + "num_frames": torch.tensor(self.num_latent_frames), + "height": torch.tensor(self.latent_height), + "width": torch.tensor(self.latent_width), + "fps": torch.tensor(self.fps), + }, + # Text conditions (key: "conditions" to match PrecomputedDataset) + "conditions": { + "prompt_embeds": torch.randn( + self.prompt_sequence_length, + self.prompt_embed_dim, + dtype=self.dtype, # Must match model dtype + ), + # Attention mask must be numeric (not bool) for _run_connectors + # Using int8 to save memory (1 byte vs 8 bytes for long) + "prompt_attention_mask": torch.ones( + self.prompt_sequence_length, + dtype=torch.int8, + ), + }, + "idx": idx, + } + + +class DistillationConfig(ConfigBaseModel): + """Configuration for distillation-specific parameters.""" + + teacher_model_path: str | Path | None = Field( + default=None, + description="Path to the teacher model checkpoint. If None, uses the same as model.model_path " + "(teacher is loaded without quantization).", + ) + + distillation_alpha: float = Field( + default=0.5, + description="Weight for the task loss. Distillation loss weight is (1 - alpha). " + "alpha=1.0 means no distillation (pure task loss), " + "alpha=0.0 means pure distillation loss.", + ge=0.0, + le=1.0, + ) + + distillation_loss_type: Literal["mse", "cosine"] = Field( + default="mse", + description="Type of distillation loss. 'mse' is recommended since transformer outputs " + "are continuous velocity predictions in latent space (not probabilities). " + "'cosine' matches direction only, ignoring magnitude.", + ) + + teacher_dtype: Literal["bfloat16", "float16", "float32"] = Field( + default="bfloat16", + description="Data type for teacher model. BFloat16 is recommended for memory efficiency.", + ) + + # ModelOpt Quantization Settings + quant_cfg: str | None = Field( + default=None, + description="Name of the ModelOpt quantization config to apply to the student model. " + "Looked up first in CUSTOM_QUANT_CONFIGS (distillation_trainer.py), then as mtq.. " + "Examples: 'FP8_DEFAULT_CFG', 'INT8_DEFAULT_CFG', 'NVFP4_DEFAULT_CFG'. " + "Set to None to disable quantization.", + ) + + # Calibration settings (full-inference calibration, matching PTQ workflow) + calibration_prompts_file: str | Path | None = Field( + default=None, + description="Path to a text file with one calibration prompt per line. " + "If None, uses the HuggingFace dataset 'Gustavosta/Stable-Diffusion-Prompts' ", + ) + + calibration_size: int = Field( + default=128, + description="Total number of calibration prompts to use. Each prompt runs a full " + "denoising inference through the DiT, covering all noise levels. ", + ge=0, + ) + + calibration_n_steps: int = Field( + default=30, + description="Number of denoising steps per calibration prompt. Each step calls the " + "transformer at a different noise level.", + ge=1, + ) + + calibration_guidance_scale: float = Field( + default=4.0, + description="CFG guidance scale during calibration. Default 4.0.", + ge=1.0, + ) + + restore_quantized_checkpoint: str | Path | None = Field( + default=None, + description="Path to restore a previously quantized model from mto.save().", + ) + + save_quantized_checkpoint: str | Path | None = Field( + default=None, + description="Path to save the final quantized model checkpoint.", + ) + + # Checkpoint resume settings + resume_from_checkpoint: str | Path | None = Field( + default=None, + description="Path to a training state checkpoint directory (from save_training_state) to resume " + "training from. Restores model weights, optimizer, LR scheduler, RNG states, and step counter. " + "Set to 'latest' to auto-find the latest checkpoint in output_dir/checkpoints/.", + ) + + must_save_by: float | None = Field( + default=None, + description="Minutes after which training must save a checkpoint and exit. " + "Use this when running under a Slurm time limit — set to a value slightly less " + "than the time limit (e.g., time_limit=30min → must_save_by=25) to ensure " + "a checkpoint is saved before the job is killed. Timer starts at train() entry. " + "Set to None to disable.", + gt=0, + ) + + # Debug/Test options + use_mock_data: bool = Field( + default=False, + description="Use mock data instead of real preprocessed data for testing.", + ) + + mock_data_samples: int = Field( + default=100, + description="Number of mock samples to generate when use_mock_data is True.", + ge=1, + ) + + +class DistillationTrainerConfig(LtxTrainerConfig): + """Extended trainer config with distillation settings.""" + + distillation: DistillationConfig = Field( + default_factory=DistillationConfig, + description="Distillation-specific configuration.", + ) + + +class DistillationTrainer(LtxvTrainer): + """ + Distillation trainer that extends LtxvTrainer with: + - Teacher model loading and inference + - ModelOpt quantization for student + - Combined task + distillation loss + """ + + def __init__(self, config: DistillationTrainerConfig) -> None: + # Store distillation config before parent init (needed by overrides called during super().__init__) + self._distillation_config = config.distillation + # Will be populated by _load_text_encoder_and_cache_embeddings() during super().__init__ + self._cached_calibration_embeddings: list | None = None + + # Create base trainer config (without distillation section) + trainer_config = LtxTrainerConfig( + **{k: v for k, v in config.model_dump().items() if k != "distillation"} + ) + + # Initialize parent (loads student model, sets up accelerator) + # Note: _prepare_models_for_training() is overridden to NOT call + # accelerator.prepare() on the student — we defer that to _init_optimizer() + # so model+optimizer can be prepared together (required by FSDP2). + super().__init__(trainer_config) + + # Load teacher model (after parent init so we have accelerator) + # Teacher is loaded, frozen, and prepared with a dummy optimizer. + self._load_teacher_model() + + logger.info( + f"Distillation training initialized with alpha={self._distillation_config.distillation_alpha:.2f}" + ) + + def _prepare_models_for_training(self) -> None: + """ + Override parent to defer accelerator.prepare() for the student model. + + The parent calls accelerator.prepare(transformer) here, but FSDP2 requires + model and optimizer to be prepared together. So we do everything the parent + does EXCEPT the accelerator.prepare() call — that happens in _init_optimizer() + where we can call prepare(model, optimizer, scheduler) together. + """ + from accelerate.utils import DistributedType + + # For FSDP + LoRA: Cast entire model to FP32 for uniform dtype + if ( + self._accelerator.distributed_type == DistributedType.FSDP + and self._config.model.training_mode == "lora" + ): + logger.debug("FSDP: casting transformer to FP32 for uniform dtype") + self._transformer = self._transformer.to(dtype=torch.float32) + + # Enable gradient checkpointing if requested + transformer = ( + self._transformer.get_base_model() + if hasattr(self._transformer, "get_base_model") + else self._transformer + ) + transformer.set_gradient_checkpointing( + self._config.optimization.enable_gradient_checkpointing + ) + + # Keep frozen models on CPU for memory efficiency + self._vae_decoder = self._vae_decoder.to("cpu") + if self._vae_encoder is not None: + self._vae_encoder = self._vae_encoder.to("cpu") + + # NOTE: We intentionally do NOT call accelerator.prepare(self._transformer) here. + # It will be called in _init_optimizer() together with the optimizer, which is + # required for FSDP2 compatibility. This also works fine with FSDP1. + + # Log GPU memory usage + vram_usage_gb = torch.cuda.memory_allocated() / 1024**3 + logger.debug(f"GPU memory usage after models preparation: {vram_usage_gb:.2f} GB") + + def _load_text_encoder_and_cache_embeddings(self): + """ + Override parent to also cache calibration prompt embeddings before Gemma is unloaded. + + The parent method loads the full Gemma text encoder, caches validation prompt embeddings, + then UNLOADS the heavy Gemma model (sets model/tokenizer/feature_extractor_linear to None) + to free VRAM. Only the lightweight embedding connectors remain. + + We hook in here to also cache calibration prompt embeddings while the full text encoder + is still available. These cached embeddings are later used by _run_inference_calibration() + via the ValidationSampler's CachedPromptEmbeddings mechanism. + """ + from ltx_trainer.model_loader import load_text_encoder + from ltx_trainer.validation_sampler import CachedPromptEmbeddings + + # Call parent to load text encoder, cache validation embeddings, and unload Gemma. + # But we need to intercept BEFORE the unload. We re-implement the parent logic + # with our addition in the middle. + + logger.debug("Loading text encoder...") + self._text_encoder = load_text_encoder( + checkpoint_path=self._config.model.model_path, + gemma_model_path=self._config.model.text_encoder_path, + device="cuda", + dtype=torch.bfloat16, + load_in_8bit=self._config.acceleration.load_text_encoder_in_8bit, + ) + + # Cache validation embeddings (same as parent) + cached_validation = None + if self._config.validation.prompts: + logger.info( + f"Pre-computing embeddings for {len(self._config.validation.prompts)} validation prompts..." + ) + cached_validation = [] + with torch.inference_mode(): + for prompt in self._config.validation.prompts: + v_ctx_pos, a_ctx_pos, _ = self._text_encoder(prompt) + v_ctx_neg, a_ctx_neg, _ = self._text_encoder( + self._config.validation.negative_prompt + ) + cached_validation.append( + CachedPromptEmbeddings( + video_context_positive=v_ctx_pos.cpu(), + audio_context_positive=a_ctx_pos.cpu(), + video_context_negative=v_ctx_neg.cpu() + if v_ctx_neg is not None + else None, + audio_context_negative=a_ctx_neg.cpu() + if a_ctx_neg is not None + else None, + ) + ) + + # Cache calibration prompt embeddings while the heavy text encoder is still loaded. + # Only needed if we'll actually run fresh calibration (Path C). Skip if a + # resumable checkpoint, user-specified checkpoint, or step 0 checkpoint exists. + calib_cfg = self._distillation_config + if ( + calib_cfg.quant_cfg is not None + and calib_cfg.calibration_size > 0 + and self._needs_fresh_calibration() + ): + prompts = self._load_calibration_prompts() + negative_prompt = getattr( + self._config.validation, + "negative_prompt", + "worst quality, inconsistent motion, blurry, jittery, distorted", + ) + logger.info( + f"Pre-computing embeddings for {len(prompts)} calibration prompts " + f"(guidance_scale={calib_cfg.calibration_guidance_scale})..." + ) + self._cached_calibration_embeddings = [] + use_cfg = calib_cfg.calibration_guidance_scale != 1.0 + with torch.inference_mode(): + for prompt in prompts: + v_ctx_pos, a_ctx_pos, _ = self._text_encoder(prompt) + v_ctx_neg, a_ctx_neg = None, None + if use_cfg: + v_ctx_neg, a_ctx_neg, _ = self._text_encoder(negative_prompt) + self._cached_calibration_embeddings.append( + CachedPromptEmbeddings( + video_context_positive=v_ctx_pos.cpu(), + audio_context_positive=a_ctx_pos.cpu(), + video_context_negative=v_ctx_neg.cpu() + if v_ctx_neg is not None + else None, + audio_context_negative=a_ctx_neg.cpu() + if a_ctx_neg is not None + else None, + ) + ) + logger.info(f"Cached {len(self._cached_calibration_embeddings)} calibration embeddings") + + # Unload heavy components to free VRAM, keeping only the embedding connectors + self._text_encoder.model = None + self._text_encoder.tokenizer = None + self._text_encoder.feature_extractor_linear = None + gc.collect() + torch.cuda.empty_cache() + logger.debug("Validation/calibration prompt embeddings cached. Gemma model unloaded") + + return cached_validation + + def _load_models(self) -> None: + """ + Load the LTX-2 model components with ModelOpt quantization for student. + + This overrides the parent method to: + 1. Load models as usual (without ltx-trainer's quantization) + 2. Apply ModelOpt fake quantization to the student transformer + """ + # Call parent to load all models normally + super()._load_models() + + # Apply ModelOpt quantization to student if configured + if self._distillation_config.quant_cfg is not None: + self._apply_modelopt_quantization() + gc.collect() + torch.cuda.empty_cache() + logger.info(f"Quantized model: {self._transformer}") + + def _needs_fresh_calibration(self) -> bool: + """Check whether fresh quantization calibration will be needed. + + Returns False if an existing checkpoint can be restored instead + (Path A, B, or B2 in _apply_modelopt_quantization), meaning we can + skip the expensive calibration embedding caching. + """ + cfg = self._distillation_config + + # Path A: resume checkpoint with modelopt_state.pt + if cfg.resume_from_checkpoint is not None: + checkpoint_dir = self._find_resume_checkpoint(cfg.resume_from_checkpoint) + if checkpoint_dir is not None: + if (checkpoint_dir / "modelopt_state.pt").exists(): + return False + + # Path B: user-specified quantized checkpoint + if cfg.restore_quantized_checkpoint is not None: + return False + + # Path B2: auto-detected step 0 checkpoint + step0_path = self._get_checkpoints_dir() / "step_000000_quantized" / "backbone.pt" + return not step0_path.exists() + + def _apply_modelopt_quantization(self) -> None: + """ + Apply ModelOpt fake quantization to the student transformer. + + Four paths are supported (checked in order): + + Path A - Resume from training checkpoint: + If resume_from_checkpoint is set, restore only the quantization module + architecture (fake quantizer modules) from the saved modelopt_state.pt. + The actual trained weights (including quantizer scales) will be loaded + later by accelerator.load_state() in _load_training_state(). + + Path B - Restore from a user-specified quantized checkpoint: + If restore_quantized_checkpoint is set, restore both architecture and + weights from a complete mto.save() checkpoint. + + Path B2 - Auto-detect step 0 quantized checkpoint: + If a previous run already completed calibration and saved the step 0 + checkpoint (step_000000_quantized/backbone.pt), restore from it + automatically. This avoids re-running the expensive calibration. + + Path C - Fresh quantization with full-inference calibration: + Apply mtq.quantize() with a forward_loop that runs full denoising + inference (like the PTQ workflow), covering all noise levels. + After calibration, saves the result as step 0 checkpoint for future runs. + """ + quant_cfg_name = self._distillation_config.quant_cfg + if not quant_cfg_name: + logger.info("No quant_cfg specified, skipping quantization") + return + + # Path A: Resume from training checkpoint — restore architecture only. + # The trained weights (including quantizer scales) are loaded later by + # accelerator.load_state() in _load_training_state(). + resume_path = self._distillation_config.resume_from_checkpoint + if resume_path is not None: + checkpoint_dir = self._find_resume_checkpoint(resume_path) + if checkpoint_dir is not None: + modelopt_state_path = checkpoint_dir / "modelopt_state.pt" + if modelopt_state_path.exists(): + logger.info( + f"Resuming: restoring quantization architecture from " + f"{modelopt_state_path} (weights loaded later by accelerator)" + ) + # Security NOTE: weights_only=False is used on ModelOpt-generated state, + # not on untrusted user input. + state = torch.load(modelopt_state_path, weights_only=False, map_location="cpu") + self._transformer = mto.restore_from_modelopt_state(self._transformer, state) + logger.info("Quantization architecture restored for resume") + return + else: + logger.warning( + f"modelopt_state.pt not found in {checkpoint_dir}, " + "falling through to fresh quantization" + ) + + # Path B: Restore from a standalone quantized checkpoint (architecture + weights). + if self._distillation_config.restore_quantized_checkpoint is not None: + restore_path = str(self._distillation_config.restore_quantized_checkpoint) + logger.info(f"Restoring quantized model from {restore_path}") + mto.restore(self._transformer, restore_path) + return + + # Path B2: Auto-detect step 0 quantized checkpoint from a previous run. + # If calibration was already completed and saved, reuse it instead of + # re-running the expensive calibration process. + step0_path = self._get_checkpoints_dir() / "step_000000_quantized" / "backbone.pt" + if step0_path.exists(): + logger.info( + f"Found existing step 0 quantized checkpoint at {step0_path}, " + "restoring instead of re-running calibration" + ) + try: + mto.restore(self._transformer, str(step0_path)) + return + except Exception as e: + logger.warning( + f"Failed to restore step 0 checkpoint (file may be corrupted): {e}. " + "Falling through to fresh quantization." + ) + + # Path C: Fresh quantization with full-inference calibration. + logger.info(f"Applying ModelOpt quantization ({quant_cfg_name}) to student transformer...") + + quant_config = get_quant_config(quant_cfg_name) + + def forward_loop(model): + """Run full-inference calibration covering all noise levels.""" + self._run_inference_calibration(model) + + mtq.quantize(self._transformer, quant_config, forward_loop=forward_loop) + + # Free cached calibration embeddings — no longer needed after quantization + self._cached_calibration_embeddings = None + + logger.info(f"ModelOpt quantization ({quant_cfg_name}) applied successfully") + + # Save the freshly quantized+calibrated model as "step 0" checkpoint. + # This avoids re-running calibration if training is interrupted before the + # first regular checkpoint. On resume, Path B2 auto-detects and loads this. + # Only model + quantizer scales are saved (no optimizer/scheduler state at step 0). + # We use atomic save (write to tmp, then rename) to prevent corrupt checkpoints. + step0_dir = self._get_checkpoints_dir() / "step_000000_quantized" + step0_path = step0_dir / "backbone.pt" + # Only global rank 0 saves (all ranks have identical models pre-FSDP); + # others wait at the barrier. Atomic save (tmp + rename) prevents corruption. + if is_global_rank0(): + step0_dir.mkdir(parents=True, exist_ok=True) + step0_tmp_path = step0_dir / "backbone.pt.tmp" + logger.info(f"Saving quantized model (step 0) to {step0_path}") + mto.save(self._transformer, str(step0_tmp_path)) + step0_tmp_path.rename(step0_path) + logger.info("Step 0 quantized checkpoint saved") + if dist.is_initialized(): + dist.barrier() + + def _create_mock_dataset(self) -> MockDataset: + """Create a mock dataset for testing without real data.""" + # Get video dimensions from validation config or use defaults + video_dims = getattr(self._config.validation, "video_dims", [512, 320, 33]) + width, height, num_frames = video_dims + + logger.info( + f"Creating mock dataset with {self._distillation_config.mock_data_samples} samples " + f"(video: {width}x{height}x{num_frames})" + ) + + return MockDataset( + width=width, + height=height, + num_frames=num_frames, + dataset_length=self._distillation_config.mock_data_samples, + ) + + def _load_calibration_prompts(self) -> list[str]: + """ + Load calibration prompts for full-inference quantization calibration. + + Follows the same pattern as the PTQ workflow (examples/diffusers/quantization/): + - If calibration_prompts_file is set: reads a text file with one prompt per line + - Otherwise: loads from the HuggingFace dataset 'Gustavosta/Stable-Diffusion-Prompts' + + Returns: + List of calibration prompts, truncated to calibration_size. + """ + calib_size = self._distillation_config.calibration_size + prompts_file = self._distillation_config.calibration_prompts_file + + if prompts_file is not None: + prompts_path = Path(prompts_file) + if not prompts_path.exists(): + raise FileNotFoundError(f"Calibration prompts file not found: {prompts_path}") + logger.info(f"Loading calibration prompts from {prompts_path}") + with open(prompts_path) as f: + prompts = [line.strip() for line in f if line.strip()] + else: + logger.info( + "Loading calibration prompts from HuggingFace dataset " + "'Gustavosta/Stable-Diffusion-Prompts'..." + ) + from datasets import load_dataset + + dataset = load_dataset("Gustavosta/Stable-Diffusion-Prompts") + prompts = list(dataset["train"]["Prompt"]) + + # Truncate to requested size + prompts = prompts[:calib_size] + logger.info(f"Loaded {len(prompts)} calibration prompts") + return prompts + + def _run_inference_calibration(self, model: torch.nn.Module) -> None: + """ + Run full-inference calibration through the DiT, covering all noise levels. + + This replaces the old training-style calibration with full denoising inference, + matching the PTQ workflow. For each calibration prompt, a complete denoising loop + is run (e.g. 30 steps), so the transformer sees activations at every noise level. + + With CFG guidance_scale > 1.0 (default 4.0), each denoising step calls the + transformer twice (positive + negative prompt), matching real inference patterns. + + Note: Text embeddings were pre-computed and cached in + _load_text_encoder_and_cache_embeddings() BEFORE the Gemma model was unloaded. + We pass these cached embeddings to the ValidationSampler via GenerationConfig. + + Args: + model: The transformer model being calibrated (same reference as self._transformer, + with statistics collection enabled by mtq.quantize). + """ + from ltx_trainer.validation_sampler import GenerationConfig, ValidationSampler + + calib_cfg = self._distillation_config + if calib_cfg.calibration_size == 0: + logger.info("Skipping calibration (calibration_size=0)") + return + + if not self._cached_calibration_embeddings: + raise RuntimeError( + "No cached calibration embeddings available! " + "Probably the saved checkpoint has no modelopt_state.pt or corrupted." + ) + + # Get video dimensions from validation config + video_dims = getattr(self._config.validation, "video_dims", [512, 320, 33]) + width, height, num_frames = video_dims + negative_prompt = getattr( + self._config.validation, + "negative_prompt", + "worst quality, inconsistent motion, blurry, jittery, distorted", + ) + num_prompts = len(self._cached_calibration_embeddings) + + logger.info( + f"Running full-inference calibration: {num_prompts} prompts, " + f"{calib_cfg.calibration_n_steps} steps/prompt, " + f"guidance_scale={calib_cfg.calibration_guidance_scale}, " + f"video={width}x{height}x{num_frames}" + ) + + # Create a ValidationSampler with the model being calibrated. + # The exact model reference matters: mtq.quantize() sets up statistics + # collection on this instance, so all forward passes must go through it. + # text_encoder=None because we use pre-cached embeddings (Gemma is unloaded). + sampler = ValidationSampler( + transformer=model, + vae_decoder=self._vae_decoder, + vae_encoder=self._vae_encoder, + text_encoder=None, # Gemma unloaded; using cached embeddings + audio_decoder=None, # Skip audio for calibration + vocoder=None, + ) + + device = "cuda" + model.eval() + + with torch.no_grad(): + for i, cached_emb in enumerate(self._cached_calibration_embeddings): + gen_config = GenerationConfig( + prompt="", # Not used when cached_embeddings is provided + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + frame_rate=getattr(self._config.validation, "frame_rate", 25.0), + num_inference_steps=calib_cfg.calibration_n_steps, + guidance_scale=calib_cfg.calibration_guidance_scale, + seed=42 + i, # Vary seed per prompt for diverse activations + generate_audio=False, + tiled_decoding=False, # Skip tiling overhead + cached_embeddings=cached_emb, # Pre-computed text embeddings + ) + + try: + sampler.generate(config=gen_config, device=device) + except Exception as e: + logger.warning(f"Calibration prompt {i} failed: {e}") + continue + + if (i + 1) % 10 == 0 or (i + 1) == len(self._cached_calibration_embeddings): + logger.info(f"Calibration progress: {i + 1}/{num_prompts} prompts") + + model.train() + logger.info("Full-inference calibration complete") + + def _init_optimizer(self) -> None: + """ + Override parent to prepare student model + optimizer + scheduler together. + + FSDP2 requires model and optimizer to be passed to accelerator.prepare() + in a single call. This override: + 1. Creates the optimizer (pointing at self._transformer parameters) + 2. Creates the LR scheduler + 3. Calls accelerator.prepare(model, optimizer, scheduler) together + + This is compatible with both FSDP1 and FSDP2. + """ + from torch.optim import AdamW + + opt_cfg = self._config.optimization + + lr = opt_cfg.learning_rate + if opt_cfg.optimizer_type == "adamw": + optimizer = AdamW(self._trainable_params, lr=lr) + elif opt_cfg.optimizer_type == "adamw8bit": + from bitsandbytes.optim import AdamW8bit + + optimizer = AdamW8bit(self._trainable_params, lr=lr) + else: + raise ValueError(f"Unknown optimizer type: {opt_cfg.optimizer_type}") + + lr_scheduler = self._create_scheduler(optimizer) + + # Prepare student model + optimizer + scheduler together (FSDP2 requirement) + logger.info("Preparing student model + optimizer + scheduler with accelerator...") + if lr_scheduler is not None: + self._transformer, self._optimizer, self._lr_scheduler = self._accelerator.prepare( + self._transformer, optimizer, lr_scheduler + ) + else: + self._transformer, self._optimizer = self._accelerator.prepare( + self._transformer, optimizer + ) + self._lr_scheduler = None + + # Log memory after preparation + if torch.cuda.is_available(): + mem_gb = torch.cuda.memory_allocated() / 1024**3 + logger.info(f"GPU memory after model+optimizer preparation: {mem_gb:.2f} GB") + + def _init_dataloader(self) -> None: + """Override to support mock data for training.""" + if self._distillation_config.use_mock_data: + from torch.utils.data import DataLoader + + self._dataset = self._create_mock_dataset() + self._dataloader = DataLoader( + self._dataset, + batch_size=self._config.optimization.batch_size, + shuffle=True, + num_workers=self._config.data.num_dataloader_workers, + pin_memory=True, + drop_last=True, + ) + # Wrap with accelerator + self._dataloader = self._accelerator.prepare(self._dataloader) + else: + # Use parent implementation for real data + super()._init_dataloader() + + def _load_teacher_model(self) -> None: + """ + Load the teacher transformer model for distillation. + + The teacher is loaded, frozen, and prepared with the accelerator using a + dummy SGD optimizer (lr=0, never stepped). The dummy optimizer is needed + because FSDP2 requires model+optimizer together in prepare(). For FSDP1, + this also works fine (prepare just wraps the model). + """ + from torch.optim import SGD + + teacher_path = self._distillation_config.teacher_model_path + if teacher_path is None: + teacher_path = self._config.model.model_path + + # Map dtype string to torch dtype + dtype_map = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + } + teacher_dtype = dtype_map[self._distillation_config.teacher_dtype] + + logger.info( + f"Loading teacher model from {teacher_path} with dtype={self._distillation_config.teacher_dtype}" + ) + + # Load teacher transformer to CPU first + self._teacher_transformer = load_transformer( + checkpoint_path=str(teacher_path), + device="cpu", + dtype=teacher_dtype, + ) + + # Teacher is inference-only, freeze it + self._teacher_transformer.requires_grad_(False) + self._teacher_transformer.eval() + + # Prepare teacher with accelerator using a dummy optimizer. + # FSDP2 requires model+optimizer together in prepare(). We use a minimal + # SGD with lr=0 that will never be stepped — just to satisfy the API. + logger.info( + f"Preparing teacher model with accelerator (distributed_type={self._accelerator.distributed_type})" + ) + teacher_params = list(self._teacher_transformer.parameters()) + dummy_optimizer = SGD(teacher_params, lr=0.0) + + self._teacher_transformer, wrapped_dummy_optimizer = self._accelerator.prepare( + self._teacher_transformer, dummy_optimizer + ) + + # Remove the teacher model and dummy optimizer from accelerator's internal + # tracking lists. This prevents save_state()/load_state() from saving/loading + # the teacher (which is frozen and loaded fresh from the original checkpoint + # on each run). The FSDP wrapping is already done at this point, so the + # teacher doesn't need to stay registered. + # Note: _models and _optimizers must stay 1:1 aligned for FSDP optimizer + # save/load (load_fsdp_optimizer uses _models[i] to pair with _optimizers[i]). + # We use the wrapped objects returned by prepare() since _optimizers stores + # AcceleratedOptimizer wrappers, not raw optimizers. + self._accelerator._models.remove(self._teacher_transformer) + self._accelerator._optimizers.remove(wrapped_dummy_optimizer) + + # Re-freeze teacher after prepare (FSDP wrapping may reset requires_grad) + self._teacher_transformer.requires_grad_(False) + self._teacher_transformer.eval() + + # Log memory after teacher loading + if torch.cuda.is_available(): + mem_gb = torch.cuda.memory_allocated() / 1024**3 + logger.info(f"GPU memory after teacher preparation: {mem_gb:.2f} GB") + + logger.info( + "Teacher model loaded and prepared (unregistered from accelerator state tracking)" + ) + + def _training_step(self, batch: dict[str, dict[str, Tensor]]) -> Tensor: + """ + Perform a single distillation training step. + + Computes combined loss: + L_total = alpha * L_task + (1 - alpha) * L_distill + + where: + - L_task: MSE between student prediction and flow matching target + - L_distill: MSE between student prediction and teacher prediction + """ + alpha = self._distillation_config.distillation_alpha + + # Apply embedding connectors to transform pre-computed text embeddings + conditions = batch["conditions"] + video_embeds, audio_embeds, attention_mask = self._text_encoder._run_connectors( + conditions["prompt_embeds"], conditions["prompt_attention_mask"] + ) + conditions["video_prompt_embeds"] = video_embeds + conditions["audio_prompt_embeds"] = audio_embeds + conditions["prompt_attention_mask"] = attention_mask + + # Use strategy to prepare training inputs + model_inputs = self._training_strategy.prepare_training_inputs( + batch, self._timestep_sampler + ) + + # Run student forward pass + student_video_pred, student_audio_pred = self._transformer( + video=model_inputs.video, + audio=model_inputs.audio, + perturbations=None, + ) + + # Compute task loss only if alpha > 0 + if alpha > 0: + task_loss = self._training_strategy.compute_loss( + student_video_pred, student_audio_pred, model_inputs + ) + else: + task_loss = torch.tensor(0.0, device=student_video_pred.device) + + # Compute distillation loss only if alpha < 1 + if alpha < 1.0: + # Run teacher forward pass (no gradients) + with torch.no_grad(): + teacher_video_pred, _teacher_audio_pred = self._teacher_transformer( + video=model_inputs.video, + audio=model_inputs.audio, + perturbations=None, + ) + + # Compute distillation loss + distill_loss = self._compute_distillation_loss( + student_video_pred, + teacher_video_pred, + loss_mask=model_inputs.video_loss_mask, + ) + else: + distill_loss = torch.tensor(0.0, device=student_video_pred.device) + + # Combine losses + total_loss = alpha * task_loss + (1.0 - alpha) * distill_loss + + # Log individual losses using parent's _log_metrics pattern (no explicit step) + # This avoids step conflicts with wandb's auto-incrementing step counter + if hasattr(self, "_accelerator") and self._accelerator.is_main_process: + self._log_metrics( + { + "loss/task": task_loss.item(), + "loss/distillation": distill_loss.item(), + "loss/total": total_loss.item(), + } + ) + + return total_loss + + def _compute_distillation_loss( + self, + student_pred: Tensor, + teacher_pred: Tensor, + loss_mask: Tensor | None = None, + ) -> Tensor: + """Compute distillation loss between student and teacher predictions.""" + loss_type = self._distillation_config.distillation_loss_type + + if loss_type == "mse": + loss = torch.nn.functional.mse_loss(student_pred, teacher_pred, reduction="none") + elif loss_type == "cosine": + student_flat = student_pred.flatten(start_dim=2) + teacher_flat = teacher_pred.flatten(start_dim=2) + cos_sim = torch.nn.functional.cosine_similarity(student_flat, teacher_flat, dim=-1) + loss = 1.0 - cos_sim.mean() + else: + raise ValueError(f"Unknown distillation loss type: {loss_type}") + + # Apply loss mask if provided + # loss_mask is [B, seq_len], need to unsqueeze to [B, seq_len, 1] for broadcasting + # with loss shape [B, seq_len, C] + if loss_mask is not None: + # Unsqueeze and convert to float for multiplication + loss_mask = loss_mask.unsqueeze(-1).float() + # Apply mask and normalize (same as original trainer) + loss = loss.mul(loss_mask).div(loss_mask.mean()) + loss = loss.mean() + else: + loss = loss.mean() + + return loss + + def save_quantized_model(self, path: str | Path | None = None) -> None: + """Save the quantized model using ModelOpt (global rank 0 only).""" + if not is_global_rank0(): + return + if path is None: + path = self._distillation_config.save_quantized_checkpoint + if path is None: + path = Path(self._config.output_dir) / "quantized_model" + + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + logger.info(f"Saving quantized model to {path}") + mto.save(self._transformer, str(path)) + logger.info("Quantized model saved successfully") + + # ── Overrides to fix multi-node shared-FS writes ────────────────────── + # The parent trainer guards file writes with IS_MAIN_PROCESS (LOCAL_RANK==0), + # which is True on every node. We override to use is_global_rank0() so that + # only a single process writes on a shared filesystem. + + def _save_checkpoint(self) -> Path | None: + """Save model weights (override: use global rank 0 for file writes).""" + from accelerate.utils import DistributedType + from safetensors.torch import save_file + + is_lora = self._config.model.training_mode == "lora" + is_fsdp = self._accelerator.distributed_type == DistributedType.FSDP + + save_dir = Path(self._config.output_dir) / "checkpoints" + prefix = "lora" if is_lora else "model" + filename = f"{prefix}_weights_step_{self._global_step:05d}.safetensors" + saved_weights_path = save_dir / filename + + # Collective operation — all ranks must participate + self._accelerator.wait_for_everyone() + full_state_dict = self._accelerator.get_state_dict(self._transformer) + + if not is_global_rank0(): + return None + + save_dir.mkdir(exist_ok=True, parents=True) + save_dtype = ( + torch.bfloat16 if self._config.checkpoints.precision == "bfloat16" else torch.float32 + ) + + if is_lora: + from peft import get_peft_model_state_dict + + unwrapped = self._accelerator.unwrap_model(self._transformer, keep_torch_compile=False) + state_dict = get_peft_model_state_dict( + unwrapped, state_dict=full_state_dict if is_fsdp else None + ) + state_dict = {k.replace("base_model.model.", "", 1): v for k, v in state_dict.items()} + state_dict = {f"diffusion_model.{k}": v for k, v in state_dict.items()} + state_dict = { + k: v.to(save_dtype) if isinstance(v, Tensor) else v for k, v in state_dict.items() + } + metadata = self._build_checkpoint_metadata() + save_file(state_dict, saved_weights_path, metadata=metadata) + else: + full_state_dict = { + k: v.to(save_dtype) if isinstance(v, Tensor) else v + for k, v in full_state_dict.items() + } + self._accelerator.save(full_state_dict, saved_weights_path) + + rel_path = saved_weights_path.relative_to(self._config.output_dir) + logger.info(f"Model weights for step {self._global_step} saved in {rel_path}") + + self._checkpoint_paths.append(saved_weights_path) + return saved_weights_path + + def _save_config(self) -> None: + """Save training config (override: use global rank 0 for file writes).""" + if not is_global_rank0(): + return + import yaml + + config_path = Path(self._config.output_dir) / "training_config.yaml" + with open(config_path, "w") as f: + yaml.dump(self._config.model_dump(), f, default_flow_style=False, indent=2) + logger.info( + f"Training configuration saved to: {config_path.relative_to(self._config.output_dir)}" + ) + + def _init_wandb(self) -> None: + """Initialize W&B (override: use global rank 0 to avoid duplicate runs).""" + if not self._config.wandb.enabled or not is_global_rank0(): + self._wandb_run = None + return + # Delegate to parent's implementation on global rank 0 + super()._init_wandb() + + def _get_checkpoints_dir(self) -> Path: + """Return the directory used for full training state checkpoints.""" + return Path(self._config.output_dir) / "checkpoints" + + def _save_training_state(self) -> Path | None: + """ + Save the full training state using accelerator.save_state(). + + This saves everything needed to resume training exactly: + - Student model weights (FSDP-sharded) + - Optimizer state + - LR scheduler state + - RNG states (Python, NumPy, PyTorch CPU/CUDA per device) + - Gradient scaler state (if using mixed precision) + - ModelOpt state (quantization architecture for restore on resume) + - Custom metadata (global_step, distillation config) + + Atomic save strategy: + 1. Save everything into step_XXXXXX_tmp/ + 2. After all writes complete, rename to step_XXXXXX/ + Directory rename is atomic on the same filesystem, so either + the final directory exists (complete) or it doesn't. If the + process is killed mid-save, only the _tmp directory remains, + which is cleaned up on the next run. + + Note: The teacher model is NOT saved here — it was unregistered from + the accelerator's tracking lists after prepare() (see _load_teacher_model). + On resume, the teacher is loaded fresh from the original pretrained checkpoint. + + Returns: + Path to the saved state directory, or None on non-main processes. + """ + final_dir = self._get_checkpoints_dir() / f"step_{self._global_step:06d}" + tmp_dir = self._get_checkpoints_dir() / f"step_{self._global_step:06d}_tmp" + + logger.info(f"Saving full training state at step {self._global_step}...") + + # Ensure the checkpoints directory exists before save_state. + if is_global_rank0(): + tmp_dir.mkdir(parents=True, exist_ok=True) + self._accelerator.wait_for_everyone() + + # Save into the _tmp directory first (all ranks participate for FSDP). + self._accelerator.save_state(str(tmp_dir)) + + # Additional saves only on global rank 0 to avoid file write races. + if is_global_rank0(): + # Save modelopt state for quantization architecture restoration on resume. + if self._distillation_config.quant_cfg is not None: + try: + modelopt_state_dict = mto.modelopt_state(self._transformer) + torch.save(modelopt_state_dict, tmp_dir / "modelopt_state.pt") + logger.debug("Saved modelopt_state.pt for resume") + except Exception as e: + logger.warning(f"Failed to save modelopt_state: {e}") + + # Save custom metadata. + metadata = { + "global_step": self._global_step, + "distillation_alpha": self._distillation_config.distillation_alpha, + "quant_cfg": self._distillation_config.quant_cfg, + } + metadata_path = tmp_dir / "distillation_metadata.json" + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + # Barrier: ensure all ranks finished writing before rename + self._accelerator.wait_for_everyone() + + # Atomic rename _tmp → final (only global rank 0) + if is_global_rank0(): + if tmp_dir.exists(): + tmp_dir.rename(final_dir) + logger.info(f"Training state saved to {final_dir} (step={self._global_step})") + else: + logger.error(f"Save directory {tmp_dir} not found after save_state — skipping") + + # Cleanup old / incomplete checkpoints + self._accelerator.wait_for_everyone() + self._cleanup_checkpoints() + + self._accelerator.wait_for_everyone() + return final_dir if is_global_rank0() else None + + def _cleanup_checkpoints(self) -> None: + """Remove old checkpoints, keeping only the last N. + + Also removes any *_tmp directories left behind by interrupted saves. + """ + if not is_global_rank0(): + return + + import shutil + + ckpt_dir = self._get_checkpoints_dir() + if not ckpt_dir.exists(): + return + + # Remove leftover _tmp directories from interrupted saves + for tmp_dir in ckpt_dir.glob("step_*_tmp"): + shutil.rmtree(tmp_dir, ignore_errors=True) + logger.info(f"Removed incomplete checkpoint: {tmp_dir.name}") + + # Keep only last N complete training checkpoints. + # Exclude _tmp (incomplete) and _quantized (calibration-only, not training state). + keep_n = self._config.checkpoints.keep_last_n + if keep_n <= 0: + return + + step_dirs = sorted(ckpt_dir.glob("step_[0-9]*"), key=lambda p: p.name) + step_dirs = [ + d + for d in step_dirs + if not d.name.endswith("_tmp") and not d.name.endswith("_quantized") + ] + if len(step_dirs) <= keep_n: + return + + dirs_to_remove = step_dirs[:-keep_n] + for old_dir in dirs_to_remove: + shutil.rmtree(old_dir, ignore_errors=True) + logger.info(f"Removed old checkpoint: {old_dir.name}") + + def _find_resume_checkpoint(self, path_or_keyword: str | Path) -> Path | None: + """ + Find the checkpoint directory to resume from. + + Only considers fully saved checkpoints (step_XXXXXX, not step_*_tmp). + Incomplete _tmp checkpoints are ignored and cleaned up. + + Args: + path_or_keyword: Either "latest" to auto-find, or an explicit path. + + Returns: + Path to the checkpoint directory, or None if not found. + """ + if str(path_or_keyword).lower() == "latest": + ckpt_dir = self._get_checkpoints_dir() + if not ckpt_dir.exists(): + logger.warning(f"No checkpoints directory found at {ckpt_dir}") + return None + + # Only match step_XXXXXX (6 digits), excluding _tmp (incomplete saves) + # and _quantized (step 0 calibration-only checkpoint, no training state). + step_dirs = sorted(ckpt_dir.glob("step_[0-9]*"), key=lambda p: p.name) + step_dirs = [ + d + for d in step_dirs + if not d.name.endswith("_tmp") and not d.name.endswith("_quantized") + ] + if not step_dirs: + logger.warning(f"No complete checkpoints found in {ckpt_dir}") + return None + + latest = step_dirs[-1] + logger.info(f"Auto-detected latest checkpoint: {latest}") + return latest + else: + path = Path(path_or_keyword) + if not path.exists(): + raise FileNotFoundError(f"Resume checkpoint not found: {path}") + return path + + def _load_training_state(self, checkpoint_dir: Path) -> int: + """ + Load full training state from a checkpoint directory. + + Note: The quantization architecture (fake quantizer modules) must already be + restored BEFORE this method is called. This happens in _apply_modelopt_quantization() + (Path A) which uses mto.restore_from_modelopt_state() to set up the module structure. + This method then loads the trained weights (including quantizer scales) into that + structure via accelerator.load_state(). + + This restores (all via accelerator.load_state()): + - Model weights (student, FSDP-sharded, including quantizer scales) + - Optimizer state + - LR scheduler state + - Dataloader iteration position (auto-skips consumed batches) + - RNG states (Python, NumPy, PyTorch CPU/CUDA per device) + - Gradient scaler (mixed precision) + - global_step (from custom metadata file) + + Args: + checkpoint_dir: Path to the training state checkpoint directory. + + Returns: + The global_step to resume from. + """ + logger.info(f"Resuming training state from {checkpoint_dir}...") + + # accelerator.load_state() is a collective op — all ranks must call it. + # It restores all objects registered via accelerator.prepare() in order: + # 1. Student model weights (self._transformer) — including quantizer scales + # 2. Optimizer state (self._optimizer) + # 3. LR scheduler state (self._lr_scheduler) + # 4. Dataloader iteration position (via skip_first_batches internally) + # 5. RNG states (Python, NumPy, PyTorch CPU/CUDA per device) + # 6. Gradient scaler (mixed precision) + # Note: Teacher model was unregistered from accelerator (see _load_teacher_model), + # so it is NOT loaded here — it is loaded fresh from pretrained on each run. + self._accelerator.load_state(str(checkpoint_dir)) + logger.info( + "Restored: student model (with quantizer scales), optimizer, LR scheduler, " + "dataloader position, RNG states, and gradient scaler via accelerator.load_state()" + ) + + # Load custom metadata to get global_step + metadata_path = checkpoint_dir / "distillation_metadata.json" + if metadata_path.exists(): + with open(metadata_path) as f: + metadata = json.load(f) + resumed_step = metadata.get("global_step", 0) + logger.info(f"Restored global_step={resumed_step} from metadata") + else: + # Fallback: try to parse step from directory name + try: + resumed_step = int(checkpoint_dir.name.split("_")[-1]) + logger.warning( + f"Metadata file not found, parsed step from dir name: {resumed_step}" + ) + except (ValueError, IndexError): + resumed_step = 0 + logger.warning("Could not determine step from checkpoint, resuming from step 0") + + return resumed_step + + def train( + self, + disable_progress_bars: bool = False, + step_callback=None, + ) -> tuple[Path | None, dict]: + """ + Override parent train() to add full checkpoint resume support. + + When `distillation.resume_from_checkpoint` is set, this: + 1. Initializes optimizer/dataloader/scheduler as normal + 2. Loads full training state (model, optimizer, scheduler, RNG) + 3. Skips already-completed steps + 4. Saves full training state at checkpoint intervals + """ + from accelerate.utils import DistributedType, set_seed + from ltx_trainer.gpu_utils import get_gpu_memory_gb + from ltx_trainer.hf_hub_utils import push_to_hub + from ltx_trainer.progress import TrainingProgress + from ltx_trainer.trainer import TrainingStats + + MEMORY_CHECK_INTERVAL = 200 # noqa: N806 + + device = self._accelerator.device + cfg = self._config + start_mem = get_gpu_memory_gb(device) + + train_start_time = time.time() + + # Use the same seed for all processes and ensure deterministic operations + set_seed(cfg.seed) + logger.debug(f"Process {self._accelerator.process_index} using seed: {cfg.seed}") + + self._init_optimizer() + self._init_dataloader() + self._init_timestep_sampler() + + # Synchronize all processes after initialization + self._accelerator.wait_for_everyone() + + Path(cfg.output_dir).mkdir(parents=True, exist_ok=True) + + # Save the training configuration as YAML + self._save_config() + + # ===================================================================== + # Resume from checkpoint if configured + # ===================================================================== + resume_step = 0 + resume_path = self._distillation_config.resume_from_checkpoint + if resume_path is not None: + checkpoint_dir = self._find_resume_checkpoint(resume_path) + if checkpoint_dir is not None: + resume_step = self._load_training_state(checkpoint_dir) + logger.info(f"Resuming training from step {resume_step}") + else: + logger.warning("No checkpoint found to resume from, starting from scratch") + + # Create the dataloader iterator AFTER load_state() so it picks up the + # resumed dataloader state. accelerator.load_state() automatically replaces + # self._dataloader with a version that skips already-consumed batches + # (via skip_first_batches), so iter() here starts at the correct position. + data_iter = iter(self._dataloader) + + # Timer for Slurm time-limit-aware checkpointing + must_save_by_minutes = self._distillation_config.must_save_by + if must_save_by_minutes is not None: + save_deadline = train_start_time + must_save_by_minutes * 60 + logger.info( + f"Time-limit save enabled: will save and exit after " + f"{must_save_by_minutes:.1f} minutes" + ) + else: + save_deadline = None + + logger.info("Starting training...") + config_msg = ( + f"Config: steps={cfg.optimization.steps}, " + f"grad_accum={cfg.optimization.gradient_accumulation_steps}, " + f"checkpoints.interval={cfg.checkpoints.interval}, " + f"checkpoints.keep_last_n={cfg.checkpoints.keep_last_n}, " + f"output_dir={cfg.output_dir}, " + f"must_save_by={must_save_by_minutes}" + ) + logger.info(config_msg) + # Also print to stdout (logger goes to stderr via RichHandler, + # which lands in .err files in Slurm) + if IS_MAIN_PROCESS: + print(f"[distillation_trainer] {config_msg}", flush=True) + + # Create progress tracking + progress_enabled = IS_MAIN_PROCESS and not disable_progress_bars + progress = TrainingProgress( + enabled=progress_enabled, + total_steps=cfg.optimization.steps, + ) + + if IS_MAIN_PROCESS and disable_progress_bars: + logger.warning( + "Progress bars disabled. Intermediate status messages will be logged instead." + ) + + self._transformer.train() + self._global_step = resume_step + + peak_mem_during_training = start_mem + + sampled_videos_paths = None + + # Calculate how many raw steps to skip and how many to run + total_raw_steps = cfg.optimization.steps * cfg.optimization.gradient_accumulation_steps + skip_raw_steps = resume_step * cfg.optimization.gradient_accumulation_steps + + with progress: + # Initial validation before training starts (skip if resuming) + if ( + resume_step == 0 + and cfg.validation.interval + and not cfg.validation.skip_initial_validation + ): + sampled_videos_paths = self._sample_videos(progress) + if ( + IS_MAIN_PROCESS + and sampled_videos_paths + and self._config.wandb.log_validation_videos + ): + self._log_validation_samples(sampled_videos_paths, cfg.validation.prompts) + + self._accelerator.wait_for_everyone() + + # Accumulators for averaging metrics across gradient accumulation steps + grad_accum_steps = cfg.optimization.gradient_accumulation_steps + accum_loss = 0.0 + accum_step_time = 0.0 + + for step in range(skip_raw_steps, total_raw_steps): + # Get next batch, reset the dataloader if needed + try: + batch = next(data_iter) + except StopIteration: + data_iter = iter(self._dataloader) + batch = next(data_iter) + + step_start_time = time.time() + with self._accelerator.accumulate(self._transformer): + is_optimization_step = (step + 1) % grad_accum_steps == 0 + if is_optimization_step: + self._global_step += 1 + + loss = self._training_step(batch) + self._accelerator.backward(loss) + + # Accumulate metrics for this micro-batch + accum_loss += loss.item() + accum_step_time += time.time() - step_start_time + + if self._accelerator.sync_gradients and cfg.optimization.max_grad_norm > 0: + self._accelerator.clip_grad_norm_( + self._trainable_params, + cfg.optimization.max_grad_norm, + ) + + self._optimizer.step() + self._optimizer.zero_grad() + + if self._lr_scheduler is not None: + self._lr_scheduler.step() + + # Run validation if needed + if ( + cfg.validation.interval + and self._global_step > 0 + and self._global_step % cfg.validation.interval == 0 + and is_optimization_step + ): + if self._accelerator.distributed_type == DistributedType.FSDP: + sampled_videos_paths = self._sample_videos(progress) + if ( + IS_MAIN_PROCESS + and sampled_videos_paths + and self._config.wandb.log_validation_videos + ): + self._log_validation_samples( + sampled_videos_paths, cfg.validation.prompts + ) + elif IS_MAIN_PROCESS: + sampled_videos_paths = self._sample_videos(progress) + if sampled_videos_paths and self._config.wandb.log_validation_videos: + self._log_validation_samples( + sampled_videos_paths, cfg.validation.prompts + ) + + # Save training state for resuming (model, optimizer, scheduler, + # dataloader position, RNG states — all handled by accelerator) + saved_this_step = False + ckpt_interval = cfg.checkpoints.interval + if ( + ckpt_interval + and self._global_step > 0 + and self._global_step % ckpt_interval == 0 + and is_optimization_step + ): + logger.info( + f"Saving checkpoint at step {self._global_step} " + f"(interval={ckpt_interval})..." + ) + self._save_training_state() + saved_this_step = True + + # Time-limit save: if we're approaching the Slurm time limit, + # save a checkpoint and exit gracefully. + if ( + save_deadline is not None + and is_optimization_step + and time.time() >= save_deadline + ): + elapsed_min = (time.time() - train_start_time) / 60 + logger.info( + f"Time limit reached ({elapsed_min:.1f} min >= " + f"{must_save_by_minutes:.1f} min). " + f"Saving checkpoint at step {self._global_step} and exiting..." + ) + if not saved_this_step: + self._save_training_state() + # Break out of the training loop; post-loop code + # will collect stats and return. + break + + self._accelerator.wait_for_everyone() + + # Call step callback if provided + if step_callback and is_optimization_step: + step_callback( + self._global_step, cfg.optimization.steps, sampled_videos_paths + ) + + self._accelerator.wait_for_everyone() + + # On optimization steps: compute averaged metrics, log, then reset + if is_optimization_step: + avg_loss = accum_loss / grad_accum_steps + total_step_time = accum_step_time + + current_lr = self._optimizer.param_groups[0]["lr"] + + progress.update_training( + loss=avg_loss, + lr=current_lr, + step_time=total_step_time, + advance=True, + ) + + # Log averaged metrics to W&B + if IS_MAIN_PROCESS: + self._log_metrics( + { + "train/loss": avg_loss, + "train/learning_rate": current_lr, + "train/step_time": total_step_time, + "train/global_step": self._global_step, + } + ) + + # Periodic step logging to console/Slurm logs + if IS_MAIN_PROCESS and self._global_step % 10 == 0: + elapsed = time.time() - train_start_time + progress_pct = self._global_step / cfg.optimization.steps + if progress_pct > 0: + eta = (elapsed / progress_pct) - elapsed + eta_str = f"{eta // 3600:.0f}h {(eta % 3600) // 60:.0f}m" + else: + eta_str = "calculating..." + logger.info( + f"Step {self._global_step}/{cfg.optimization.steps} | " + f"Loss: {avg_loss:.4f} | LR: {current_lr:.2e} | " + f"Time/Step: {total_step_time:.2f}s | ETA: {eta_str}", + ) + + # Reset accumulators + accum_loss = 0.0 + accum_step_time = 0.0 + + # Sample GPU memory periodically + if step % MEMORY_CHECK_INTERVAL == 0: + current_mem = get_gpu_memory_gb(device) + peak_mem_during_training = max(peak_mem_during_training, current_mem) + + # Collect final stats + train_end_time = time.time() + end_mem = get_gpu_memory_gb(device) + peak_mem = max(start_mem, end_mem, peak_mem_during_training) + + total_time_seconds = train_end_time - train_start_time + actual_steps = self._global_step - resume_step + steps_per_second = actual_steps / total_time_seconds if total_time_seconds > 0 else 0 + samples_per_second = ( + steps_per_second * self._accelerator.num_processes * cfg.optimization.batch_size + ) + + stats = TrainingStats( + total_time_seconds=total_time_seconds, + steps_per_second=steps_per_second, + samples_per_second=samples_per_second, + peak_gpu_memory_gb=peak_mem, + num_processes=self._accelerator.num_processes, + global_batch_size=cfg.optimization.batch_size * self._accelerator.num_processes, + ) + + # Save final training state (for potential resume) + self._save_training_state() + + # Save inference-ready model weights (standalone safetensors file) + saved_path = self._save_checkpoint() + + if is_global_rank0(): + self._log_training_stats(stats) + + if cfg.hub.push_to_hub: + push_to_hub(saved_path, sampled_videos_paths, self._config) + + if self._wandb_run is not None: + self._log_metrics( + { + "stats/total_time_minutes": stats.total_time_seconds / 60, + "stats/steps_per_second": stats.steps_per_second, + "stats/samples_per_second": stats.samples_per_second, + "stats/peak_gpu_memory_gb": stats.peak_gpu_memory_gb, + } + ) + self._wandb_run.finish() + + self._accelerator.wait_for_everyone() + self._accelerator.end_training() + + return saved_path, stats + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="LTX-2 Distillation Training with ModelOpt Quantization", + # Allow OmegaConf-style overrides to pass through + allow_abbrev=False, + ) + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to the YAML configuration file", + ) + + # Parse known args to allow for OmegaConf overrides + args, overrides = parser.parse_known_args() + return args, overrides + + +def main(): + """Main entry point for distillation training.""" + # CRITICAL: Set CUDA device BEFORE any model loading. + # + # The LTX trainer loads the text encoder in __init__ BEFORE _setup_accelerator(), + # using device="cuda" which defaults to GPU 0. We must set the device early + # so that "cuda" maps to the correct GPU for each process. + # + # Note: We do NOT call init_process_group() here - let accelerate handle that. + # We only set the CUDA device based on LOCAL_RANK. + + # Read distributed environment variables (set by accelerate launch / torchrun) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + rank = int(os.environ.get("RANK", 0)) + master_addr = os.environ.get("MASTER_ADDR", "localhost") + master_port = os.environ.get("MASTER_PORT", "29500") + + # Debug: Print all relevant environment variables + print( + f"[DEBUG] PID={os.getpid()} RANK={rank} LOCAL_RANK={local_rank} " + f"WORLD_SIZE={world_size} MASTER_ADDR={master_addr} MASTER_PORT={master_port}" + ) + print(f"[DEBUG] torch.cuda.device_count()={torch.cuda.device_count()}") + + # Set CUDA device based on LOCAL_RANK - this ensures device="cuda" uses correct GPU + if torch.cuda.is_available() and local_rank < torch.cuda.device_count(): + torch.cuda.set_device(local_rank) + print( + f"[DEBUG] Set CUDA device to {local_rank}, current device: {torch.cuda.current_device()}" + ) + else: + print(f"[WARNING] LOCAL_RANK={local_rank} but device_count={torch.cuda.device_count()}") + + logger.info(f"Process RANK={rank}, LOCAL_RANK={local_rank}, WORLD_SIZE={world_size}") + + args, cli_overrides = parse_args() + + # Load base config from YAML using OmegaConf + base_config = OmegaConf.load(args.config) + + # Parse CLI overrides using OmegaConf + # Supports formats like: + # distillation.distillation_alpha=0.6 + # ++distillation.quant_cfg=FP8_DEFAULT_CFG + # model.training_mode=lora + if cli_overrides: + # Clean up override strings (remove leading ++, +, etc.) + cleaned_overrides = [] + for override in cli_overrides: + # Strip leading + or ++ (Hydra-style) + clean = override.lstrip("+") + if "=" in clean: + cleaned_overrides.append(clean) + elif IS_MAIN_PROCESS: + logger.warning(f"Ignoring malformed override: {override}") + + if cleaned_overrides: + cli_config = OmegaConf.from_dotlist(cleaned_overrides) + # Merge CLI overrides into base config (CLI takes precedence) + config = OmegaConf.merge(base_config, cli_config) + if IS_MAIN_PROCESS: + logger.info(f"Applied {len(cleaned_overrides)} config overrides:") + for override in cleaned_overrides: + logger.info(f" {override}") + else: + config = base_config + else: + config = base_config + + # Convert OmegaConf to plain dict for Pydantic + config_dict = OmegaConf.to_container(config, resolve=True) + + # Create typed config object + config = DistillationTrainerConfig(**config_dict) + + # Create trainer and run + trainer = DistillationTrainer(config) + + # Train + saved_path, stats = trainer.train() + + # Save quantized model if configured + if config.distillation.quant_cfg is not None: + trainer.save_quantized_model() + + if IS_MAIN_PROCESS: + logger.info(f"Training complete. Model saved to: {saved_path}") + logger.info(f"Training stats: {stats}") + + +if __name__ == "__main__": + main() diff --git a/examples/diffusers/distillation/requirements.txt b/examples/diffusers/distillation/requirements.txt new file mode 100644 index 000000000..964edf625 --- /dev/null +++ b/examples/diffusers/distillation/requirements.txt @@ -0,0 +1,4 @@ +ltx-core @ git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-core +ltx-pipelines @ git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-pipelines +ltx-trainer @ git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-trainer +omegaconf