Skip to content

OMNIML-2663] Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes#852

Open
ajrasane wants to merge 6 commits intomainfrom
ajrasane/onnx_qdq
Open

OMNIML-2663] Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes#852
ajrasane wants to merge 6 commits intomainfrom
ajrasane/onnx_qdq

Conversation

@ajrasane
Copy link
Contributor

@ajrasane ajrasane commented Feb 4, 2026

What does this PR do?

Type of change:
New feature

Overview:

  • Updated FP8 quant exporter to replace modelopt custom QDQ nodes with native ONNX QDQ nodes
  • Updated get_onnx_bytes_and_metadata to make convert_float_to_float16() default instead of autocast
  • Created util functions to fix graph structure after conversion

Testing

python torch_quant_to_onnx.py --quantize_mode=fp8 \
	--onnx_save_path=<model_path> \
	--calibration_data_size 64 \
	--batch_size 128

python evaluate.py --onnx_path=<model_path> \
	--model_name=vit_base_patch16_224 \
	--results_path=./results.txt \
	--batch_size 128

Results:
Before replacement:

The top1 accuracy of the model is 85.06%
The top5 accuracy of the model is 97.558%
Inference latency of the model is 5.27963 ms

After replacement:

The top1 accuracy of the model is 85.054%
The top5 accuracy of the model is 97.542%
Inference latency of the model is 5.74771 ms

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: No
  • Replaced modelopt QDQ nodes with native ONNX qdq nodes
  • Did you write any new necessary tests?: No
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?: No

Summary by CodeRabbit

Release Notes

  • New Features

    • Added automatic graph optimization utilities for ONNX models
  • Improvements

    • Enhanced FP8 export functionality with improved ONNX standard compliance
    • Streamlined weight quantization and FP16 conversion pipeline
    • Optimized export workflow for more efficient model deployment

@ajrasane ajrasane requested review from a team as code owners February 4, 2026 01:08
@ajrasane ajrasane requested a review from cjluo-nv February 4, 2026 01:08
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 4, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR introduces FP8 export enhancements and adds utility functions for ONNX graph manipulation. It updates the post-processing pipeline to convert TRT-specific FP8 quantization nodes to native ONNX equivalents, and refactors weight quantization and FP16 conversion logic to use new utility functions for duplicate cast removal and targeted cast conversions.

Changes

Cohort / File(s) Summary
FP8 Export Post-Processing
modelopt/onnx/export/fp8_exporter.py
Enhanced post_process method from no-op to full post-processing routine converting TRT_FP8QuantizeLinear and TRT_FP8DequantizeLinear to native ONNX equivalents, with graph cleanup and topological sort. Added logging throughout conversion pipeline. Updated docstrings in compress_weights to clarify FP8 QDQ processing behavior.
ONNX Graph Utilities
modelopt/onnx/utils.py
Added remove_duplicate_casts function to identify and eliminate consecutive Cast nodes targeting identical types, rewiring consumers and re-sorting graph. Added change_casts_to_fp16 function to scan and convert FP32 casts to FP16 when feeding target operation types.
Torch ONNX Integration
modelopt/torch/_deploy/utils/torch_onnx.py
Updated imports to use new utility functions. Simplified FP16 handling path with dedicated block using convert_float_to_float16 and change_casts_to_fp16. Refactored to always apply quantize_weights and consistently remove duplicate casts in streamlined flow.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title directly and concisely summarizes the main change: replacing ModelOpt FP8 QDQ nodes with native ONNX QDQ nodes, which is clearly reflected in the file changes across fp8_exporter.py, utils.py, and torch_onnx.py.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ajrasane/onnx_qdq

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 104-140: The post_process function's docstring mentions updating
GELU nodes to tanh approximation and inserting Cast nodes after Sqrt, but the
implementation in post_process only converts
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear to
QuantizeLinear/DequantizeLinear; either remove or revise those docstring lines
to reflect current behavior, or implement the missing steps: locate GELU nodes
in graph.nodes and replace/modify them to the tanh-approx variant, and insert
Cast nodes immediately after Sqrt nodes' outputs; reference post_process,
TRT_FP8QuantizeLinear, TRT_FP8DequantizeLinear, GELU, and Sqrt when making the
change.
- Around line 119-126: The FP8 zero-point tensor zp_tensor is missing explicit
shape metadata; update the creation of zp_tensor (used to build zero_point and
appended to node.inputs) to set its dims explicitly (e.g., call
zp_tensor.dims.extend([1]) for a 1-element tensor) so it matches other tensors
created in this module (see the FP8 weights tensor creation) and ensures ONNX
runtimes receive shape info.

In `@modelopt/onnx/utils.py`:
- Around line 1314-1349: In change_casts_to_fp16, only modify Cast nodes that
actually cast from FP32: for each Cast node (node.op_type == "Cast") look up the
source tensor name node.input[0] in graph.initializer, graph.input,
graph.value_info or graph.output to get its element_type and only change the
node.attribute "to" from onnx.TensorProto.FLOAT to onnx.TensorProto.FLOAT16 if
the source dtype is FLOAT; also avoid changing Casts that are FP16->FP32 and add
a debug log entry when you modify a Cast (include node.name or node.output[0]
and original->new dtypes) to aid debugging.
🧹 Nitpick comments (1)
modelopt/onnx/utils.py (1)

1218-1261: Consider edge case where first Cast has multiple consumers.

The function checks len(node.outputs[0].outputs) != 1 (line 1231) to ensure the first Cast's output goes to exactly one node. However, this may be overly restrictive. If the first Cast feeds into a duplicate second Cast AND other nodes, you could still remove the duplicate Cast while preserving the connection to other consumers. The current logic skips this optimization opportunity.

This is a minor optimization opportunity and the current implementation is safe.

Comment on lines +1314 to +1349
def change_casts_to_fp16(model: onnx.ModelProto, target_op_types: list[str]) -> onnx.ModelProto:
"""Change Cast nodes that cast to FP32 and feed into specified nodes to cast to FP16 instead.

Args:
model: The ONNX model to modify.
target_op_types: List of op types to check for. Cast nodes feeding into these will be
changed from FP32 to FP16.

Returns:
The modified ONNX model with Cast nodes updated.
"""
# Build a map of tensor name -> consumer nodes
tensor_to_consumers: dict[str, list[onnx.NodeProto]] = {}
for node in model.graph.node:
for inp in node.input:
if inp:
tensor_to_consumers.setdefault(inp, []).append(node)

# Find Cast nodes that feed into target ops and change FP32 -> FP16
for node in model.graph.node:
if node.op_type != "Cast":
continue

# Check if this Cast outputs to a target op type
cast_output = node.output[0]
consumers = tensor_to_consumers.get(cast_output, [])
feeds_target = any(c.op_type in target_op_types for c in consumers)

if not feeds_target:
continue

# Check if Cast is to FP32, and change to FP16
for attr in node.attribute:
if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
attr.i = onnx.TensorProto.FLOAT16
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Function modifies Cast nodes regardless of whether they actually cast from FP32.

The function changes the to attribute from FP32 to FP16, but only checks if the target type is FP32. It doesn't verify the source type. If a Cast node is already casting from FP16 to FP32 (to match precision requirements), changing it to FP16 could break the graph semantics.

Consider also logging when a Cast node is modified for debugging purposes.

📝 Suggested improvement with logging
         # Check if Cast is to FP32, and change to FP16
         for attr in node.attribute:
             if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
                 attr.i = onnx.TensorProto.FLOAT16
+                logger.debug(f"Changed Cast node {node.name} from FP32 to FP16")
                 break
🤖 Prompt for AI Agents
In `@modelopt/onnx/utils.py` around lines 1314 - 1349, In change_casts_to_fp16,
only modify Cast nodes that actually cast from FP32: for each Cast node
(node.op_type == "Cast") look up the source tensor name node.input[0] in
graph.initializer, graph.input, graph.value_info or graph.output to get its
element_type and only change the node.attribute "to" from onnx.TensorProto.FLOAT
to onnx.TensorProto.FLOAT16 if the source dtype is FLOAT; also avoid changing
Casts that are FP16->FP32 and add a debug log entry when you modify a Cast
(include node.name or node.output[0] and original->new dtypes) to aid debugging.

@codecov
Copy link

codecov bot commented Feb 4, 2026

Codecov Report

❌ Patch coverage is 39.68254% with 38 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.73%. Comparing base (95511a0) to head (e2abd9d).

Files with missing lines Patch % Lines
modelopt/onnx/utils.py 30.61% 34 Missing ⚠️
modelopt/torch/_deploy/utils/torch_onnx.py 71.42% 4 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main     #852   +/-   ##
=======================================
  Coverage   73.73%   73.73%           
=======================================
  Files         199      199           
  Lines       21165    21211   +46     
=======================================
+ Hits        15606    15640   +34     
- Misses       5559     5571   +12     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gcunhase
Copy link
Contributor

gcunhase commented Feb 9, 2026

@ajrasane can you please add the before and after accuracy results in the PR description? I.e: with FP8 custom Q/DQ nodes vs with FP8 native Q/DQ nodes. Thanks!

@gcunhase
Copy link
Contributor

gcunhase commented Feb 9, 2026

Let's also add this change to the Changelog file.

op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"],
)
# Change FP32 cast nodes feeding into Concat/Add to FP16
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please elaborate the goal/need of this function? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because after the convert_float_to_float16() function, one of the inputs for these nodes is FP16, while the other is FP32. Hence we run into a compilation issue with TensorRT. To fix this, I manually update them here for these operators.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks for the explanation. Can you please update the docstring to give a bit more details? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ajrasane would you consider using autocast's convert_to_f16 and avoid this patch?

Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
@ajrasane ajrasane requested a review from a team as a code owner February 13, 2026 14:09
@ajrasane ajrasane requested a review from galagam February 13, 2026 14:09
@gcunhase
Copy link
Contributor

gcunhase commented Feb 13, 2026

5.74771 ms

Accuracy looks good, any idea why perf is slower after this PR?

Also, can you please specify which model these numbers are for?

Thanks.

Copy link
Contributor

@gcunhase gcunhase left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @galagam are you okay with making the redundant casts function a utils function? Thanks!

Copy link
Contributor

@galagam galagam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"],
)
# Change FP32 cast nodes feeding into Concat/Add to FP16
onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ajrasane would you consider using autocast's convert_to_f16 and avoid this patch?

logger.debug(f"Failed to fold Constant->Cast {node.name}: {e}")

if removed_count > 0:
graph.cleanup().toposort()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall some issues with toposort.
If you see any failures do to it, we can probably omit, _bypass_cast maintains node sorting.

@galagam
Copy link
Contributor

galagam commented Feb 15, 2026

LGTM, @galagam are you okay with making the redundant casts function a utils function? Thanks!

AutoCast's unit testing covers this part well, and indeed, I see there's quite a few failures with this refactor.
Approved the general concept, but need to make sure we don't cause regressions/behavior changes for AutoCast.
Thanks.
@gcunhase @ajrasane

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants