Skip to content

Refactor quantization parameters and methods for text encoder and DiT models#1128

Draft
KyleShao1016 wants to merge 2 commits intohao-ai-lab:mainfrom
KyleShao1016:ltx2_fp8_support
Draft

Refactor quantization parameters and methods for text encoder and DiT models#1128
KyleShao1016 wants to merge 2 commits intohao-ai-lab:mainfrom
KyleShao1016:ltx2_fp8_support

Conversation

@KyleShao1016
Copy link
Copy Markdown
Contributor

Motivation

FastVideo's existing FP8 path (absmax_fp8.py) works but is monolithic — quantization config,
weight handling, and kernel dispatch are all tangled together. This makes it hard to extend
(e.g., adding static activation quantization, new kernel backends, or per-layer granularity).

This PR refactors the FP8 quantization layer to follow vLLM's Config → Method → Kernel
architecture, which cleanly separates concerns and makes future extensions straightforward.

What this PR does

New architecture (vLLM-aligned)

  • Fp8Config — declares quantization settings (activation type, weight dtype, ignored layers)
  • Fp8LinearMethod / Fp8OnlineLinearMethod — handle weight creation, loading, and forward dispatch
  • FP8ScaledMMLinearKernel — abstract kernel interface with a PyTorch (torch._scaled_mm) backend
  • QuantFP8 — activation quantization (dynamic per-tensor scaling)
  • QuantKey / GroupShape / ScaleDesc — granularity descriptors for future extensibility

Refactoring of existing code

  • Extracted DiT-specific bridge functions (scan_fp8_modules, prepare_model_for_fp8) from
    absmax_fp8.py into dit_fp8_bridge.py — these inject FP8 into plain nn.Linear modules
    that don't use LinearBase
  • Extracted shared utilities (supports_fp8_compute, quantize_input_dynamic, etc.) into fp8_utils.py
  • absmax_fp8.py re-exports bridge functions for backward compatibility

Naming standardization

  • Renamed scale parameters from scale_weight/scale_input to weight_scale/input_scale
    to match checkpoint format and vLLM convention
  • Removed renaming regex rules from ltx2.py that were papering over the mismatch

CLI & loader integration

  • Renamed --override-dit-quant--dit-quantization, --override-text-encoder-quant
    --text-encoder-quantization (old flags kept as backward-compatible aliases)
  • component_loader.py now uses the quantization registry to instantiate any registered method

New files

File Purpose
fp8.py Fp8Config, Fp8LinearMethod, Fp8OnlineLinearMethod
fp8_utils.py Shared FP8 utility functions
input_quant_fp8.py QuantFP8 activation quantization
dit_fp8_bridge.py FP8 injection for plain nn.Linear (DiT models)
kernels/scaled_mm/ Kernel abstraction + PyTorch backend
utils/quant_utils.py QuantKey, GroupShape, ScaleDesc descriptors
tests/ops/quantization/test_fp8.py Tests for all new components

Test plan

  • test_fp8.py — kernel correctness, config/method integration, online quantization vs bf16 reference
  • test_absmax_fp8.py — existing tests pass with renamed scale parameters
  • Pre-commit hooks (yapf, ruff, mypy, codespell) all pass

Next Steps

  • Wiring Fp8LinearMethod into LinearBase.forward for end-to-end inference
  • Static activation quantization support
  • Per-channel / block-wise granularity

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @KyleShao1016, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly overhauls the FP8 quantization framework by adopting a modular, vLLM-inspired architecture. The primary goal is to enhance the maintainability and extensibility of quantization methods, making it easier to introduce new features like static activation quantization or different kernel backends. The changes involve a clear separation of configuration, method implementation, and kernel dispatch, alongside standardizing naming conventions and improving integration with model loading processes, particularly for DiT models.

Highlights

  • Refactored FP8 Quantization Architecture: The FP8 quantization layer has been refactored to align with vLLM's Config → Method → Kernel architecture, promoting cleaner separation of concerns and improved extensibility. This includes new Fp8Config, Fp8LinearMethod, Fp8OnlineLinearMethod, and FP8ScaledMMLinearKernel components.
  • DiT Model FP8 Integration: Dedicated bridge functions (scan_fp8_modules, prepare_model_for_fp8) were extracted into dit_fp8_bridge.py to inject FP8 quantization into plain nn.Linear modules used by DiT models, which do not use LinearBase.
  • Naming Standardization: Scale parameters have been renamed from scale_weight/scale_input to weight_scale/input_scale to match checkpoint formats and vLLM conventions, and corresponding renaming rules were removed from ltx2.py.
  • New FP8 Utility and Kernel Modules: Several new modules were introduced for FP8 quantization, including fp8_utils.py for shared utilities, input_quant_fp8.py for activation quantization, and a kernels/scaled_mm/ directory for abstract kernel interfaces and PyTorch backends using torch._scaled_mm.
  • CLI and Loader Updates: CLI arguments for quantization were renamed (--override-dit-quant to --dit-quantization, --override-text-encoder-quant to --text-encoder-quantization) with backward compatibility. The component_loader.py now uses a quantization registry for method instantiation, and model loading logic was updated to correctly handle FP8 quantized parameters.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • examples/inference/optimizations/text_encoder_quant_example.py
    • Updated CLI argument name for text encoder quantization.
  • fastvideo/fastvideo_args.py
    • Renamed override_text_encoder_quant to text_encoder_quantization.
    • Added dit_quantization argument for DiT model quantization.
    • Updated CLI argument parsing to reflect new names and ensure backward compatibility.
  • fastvideo/layers/quantization/init.py
    • Expanded QuantizationMethods literal type to include 'fp8'.
    • Imported and registered Fp8Config in the quantization method mapping.
  • fastvideo/layers/quantization/absmax_fp8.py
    • Added new imports for FP8 utility functions and DiT bridge functions.
    • Re-exported bridge and utility functions for backward compatibility.
    • Modified AbsMaxFP8Config.from_config to validate the quantization method.
    • Removed outdated comments regarding supported linear types.
    • Renamed scale_weight and scale_input parameters to weight_scale and input_scale.
    • Introduced process_weights_after_loading for weight transposition.
    • Refactored the apply method to dispatch to _apply_fp8 for GPU compute or _apply_dequant as a fallback.
    • Implemented _apply_fp8 utilizing torch._scaled_mm for efficient FP8 matrix multiplication.
    • Updated _apply_dequant to use the new scale parameter names and handle scale dimensions.
  • fastvideo/layers/quantization/dit_fp8_bridge.py
    • Added new file to provide functions for injecting FP8 quantization into nn.Linear modules for DiT models.
    • Implemented scan_fp8_modules to identify FP8 modules from safetensors headers.
    • Implemented prepare_model_for_fp8 to replace nn.Linear weights with FP8 parameters and register FP8 forward hooks.
  • fastvideo/layers/quantization/fp8.py
    • Added new file defining the core FP8 quantization configuration and linear methods.
    • Defined Fp8Config for managing FP8 quantization settings like activation scheme and ignored layers.
    • Implemented Fp8LinearMethod for handling pre-quantized FP8 checkpoints.
    • Implemented Fp8OnlineLinearMethod for on-the-fly quantization of BF16/FP16 weights to FP8.
  • fastvideo/layers/quantization/fp8_utils.py
    • Added new file containing shared utility functions for FP8 operations.
    • Included supports_fp8_compute to check GPU capability for FP8.
    • Provided quantize_input_dynamic and quantize_input_static for input quantization.
    • Added is_fp8_dtype to check if a dtype is FP8.
  • fastvideo/layers/quantization/input_quant_fp8.py
    • Added new file implementing QuantFP8 module for dynamic or static activation quantization.
  • fastvideo/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py
    • Added new file defining the abstract base class FP8ScaledMMLinearKernel for scaled matrix multiplication.
    • Introduced FP8ScaledMMLinearLayerConfig dataclass for kernel configuration.
  • fastvideo/layers/quantization/kernels/scaled_mm/init.py
    • Added new file for selecting and initializing the appropriate FP8 scaled matrix multiplication kernel.
  • fastvideo/layers/quantization/kernels/scaled_mm/pytorch.py
    • Added new file containing PyTorch-specific implementations of FP8ScaledMMLinearKernel.
    • Implemented PerTensorTorchFP8ScaledMMLinearKernel for per-tensor scaling.
    • Implemented ChannelWiseTorchFP8ScaledMMLinearKernel for channel-wise scaling.
  • fastvideo/layers/quantization/utils/init.py
    • Added new file to export quantization utility classes and constants.
  • fastvideo/layers/quantization/utils/quant_utils.py
    • Added new file defining GroupShape, ScaleDesc, and QuantKey for describing quantization granularity and properties.
    • Included is_layer_skipped utility function.
  • fastvideo/models/loader/component_loader.py
    • Updated model loading logic to use the new text_encoder_quantization argument.
    • Integrated dit_quantization configuration into DiT model loading.
    • Adjusted parameter dtype assertion to correctly handle mixed-precision FP8 models.
  • fastvideo/models/loader/fsdp_load.py
    • Added logic to prepare nn.Linear modules for FP8 quantization using the dit_fp8_bridge during FSDP model loading.
    • Modified parameter loading to respect the original dtype of quantized parameters (e.g., FP8 weights, float32 scales).
  • fastvideo/tests/ops/quantization/test_absmax_fp8.py
    • Updated imports to reflect the refactored absmax_fp8.py structure.
    • Renamed scale parameters in test cases to weight_scale and input_scale.
    • Added new unit tests for AbsMaxFP8Config and FP8 helper functions.
    • Introduced comprehensive tests for the FP8 compute path using torch._scaled_mm on compatible GPUs, covering various scaling scenarios, 3D input, bias, and approximate correctness against dequantized output.
  • fastvideo/tests/ops/quantization/test_fp8.py
    • Added new file with extensive unit tests for the newly introduced FP8 quantization components.
    • Includes tests for GroupShape, ScaleDesc, QuantKey, Fp8Utils, QuantFP8, kernel dispatch logic, Fp8Config, quantization registry, Fp8LinearMethodDequant, Fp8OnlineLinearMethod, and the FP8 compute path, including comparisons to BF16 references.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant and well-executed refactoring of the FP8 quantization logic, aligning it with vLLM's Config -> Method -> Kernel architecture. This greatly improves modularity, maintainability, and extensibility for future quantization work. The changes are comprehensive, including new abstractions for configuration, methods, and kernels, as well as thorough test coverage for the new components. The backward compatibility for existing flags and functions is also a nice touch. I have one suggestion to refactor a small piece of duplicated code to further improve maintainability.

… models

- Renamed `override_text_encoder_quant` to `text_encoder_quantization` for consistency.
- Introduced `dit_quantization` parameter for DiT model quantization.
- Updated argument parsing to reflect new parameter names.
- Enhanced quantization methods to support FP8 and added new utility functions for quantization.
- Implemented a bridge for injecting FP8 parameters into nn.Linear modules for DiT models.
- Added tests for new quantization methods and configurations.

This refactor aims to streamline the quantization process and improve compatibility with FP8 compute.
- Introduced new parameters for FP8 activation granularity and ignored layers in FastVideoArgs.
- Added functionality to detect FP8 weights from safetensors and configure models accordingly.
- Enhanced the quantization methods to support both offline and online FP8 processing.
- Updated the AbsMaxFP8Config to delegate quantization methods to the new Fp8LinearMethod.
- Added tests to validate the new FP8 configurations and ensure correct behavior.

This update aims to improve the flexibility and performance of quantization in models using FP8.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant