OMNIML-2663] Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes#852
OMNIML-2663] Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes#852
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR introduces FP8 export enhancements and adds utility functions for ONNX graph manipulation. It updates the post-processing pipeline to convert TRT-specific FP8 quantization nodes to native ONNX equivalents, and refactors weight quantization and FP16 conversion logic to use new utility functions for duplicate cast removal and targeted cast conversions. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 104-140: The post_process function's docstring mentions updating
GELU nodes to tanh approximation and inserting Cast nodes after Sqrt, but the
implementation in post_process only converts
TRT_FP8QuantizeLinear/TRT_FP8DequantizeLinear to
QuantizeLinear/DequantizeLinear; either remove or revise those docstring lines
to reflect current behavior, or implement the missing steps: locate GELU nodes
in graph.nodes and replace/modify them to the tanh-approx variant, and insert
Cast nodes immediately after Sqrt nodes' outputs; reference post_process,
TRT_FP8QuantizeLinear, TRT_FP8DequantizeLinear, GELU, and Sqrt when making the
change.
- Around line 119-126: The FP8 zero-point tensor zp_tensor is missing explicit
shape metadata; update the creation of zp_tensor (used to build zero_point and
appended to node.inputs) to set its dims explicitly (e.g., call
zp_tensor.dims.extend([1]) for a 1-element tensor) so it matches other tensors
created in this module (see the FP8 weights tensor creation) and ensures ONNX
runtimes receive shape info.
In `@modelopt/onnx/utils.py`:
- Around line 1314-1349: In change_casts_to_fp16, only modify Cast nodes that
actually cast from FP32: for each Cast node (node.op_type == "Cast") look up the
source tensor name node.input[0] in graph.initializer, graph.input,
graph.value_info or graph.output to get its element_type and only change the
node.attribute "to" from onnx.TensorProto.FLOAT to onnx.TensorProto.FLOAT16 if
the source dtype is FLOAT; also avoid changing Casts that are FP16->FP32 and add
a debug log entry when you modify a Cast (include node.name or node.output[0]
and original->new dtypes) to aid debugging.
🧹 Nitpick comments (1)
modelopt/onnx/utils.py (1)
1218-1261: Consider edge case where first Cast has multiple consumers.The function checks
len(node.outputs[0].outputs) != 1(line 1231) to ensure the first Cast's output goes to exactly one node. However, this may be overly restrictive. If the first Cast feeds into a duplicate second Cast AND other nodes, you could still remove the duplicate Cast while preserving the connection to other consumers. The current logic skips this optimization opportunity.This is a minor optimization opportunity and the current implementation is safe.
| def change_casts_to_fp16(model: onnx.ModelProto, target_op_types: list[str]) -> onnx.ModelProto: | ||
| """Change Cast nodes that cast to FP32 and feed into specified nodes to cast to FP16 instead. | ||
|
|
||
| Args: | ||
| model: The ONNX model to modify. | ||
| target_op_types: List of op types to check for. Cast nodes feeding into these will be | ||
| changed from FP32 to FP16. | ||
|
|
||
| Returns: | ||
| The modified ONNX model with Cast nodes updated. | ||
| """ | ||
| # Build a map of tensor name -> consumer nodes | ||
| tensor_to_consumers: dict[str, list[onnx.NodeProto]] = {} | ||
| for node in model.graph.node: | ||
| for inp in node.input: | ||
| if inp: | ||
| tensor_to_consumers.setdefault(inp, []).append(node) | ||
|
|
||
| # Find Cast nodes that feed into target ops and change FP32 -> FP16 | ||
| for node in model.graph.node: | ||
| if node.op_type != "Cast": | ||
| continue | ||
|
|
||
| # Check if this Cast outputs to a target op type | ||
| cast_output = node.output[0] | ||
| consumers = tensor_to_consumers.get(cast_output, []) | ||
| feeds_target = any(c.op_type in target_op_types for c in consumers) | ||
|
|
||
| if not feeds_target: | ||
| continue | ||
|
|
||
| # Check if Cast is to FP32, and change to FP16 | ||
| for attr in node.attribute: | ||
| if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT: | ||
| attr.i = onnx.TensorProto.FLOAT16 | ||
| break |
There was a problem hiding this comment.
Function modifies Cast nodes regardless of whether they actually cast from FP32.
The function changes the to attribute from FP32 to FP16, but only checks if the target type is FP32. It doesn't verify the source type. If a Cast node is already casting from FP16 to FP32 (to match precision requirements), changing it to FP16 could break the graph semantics.
Consider also logging when a Cast node is modified for debugging purposes.
📝 Suggested improvement with logging
# Check if Cast is to FP32, and change to FP16
for attr in node.attribute:
if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
attr.i = onnx.TensorProto.FLOAT16
+ logger.debug(f"Changed Cast node {node.name} from FP32 to FP16")
break🤖 Prompt for AI Agents
In `@modelopt/onnx/utils.py` around lines 1314 - 1349, In change_casts_to_fp16,
only modify Cast nodes that actually cast from FP32: for each Cast node
(node.op_type == "Cast") look up the source tensor name node.input[0] in
graph.initializer, graph.input, graph.value_info or graph.output to get its
element_type and only change the node.attribute "to" from onnx.TensorProto.FLOAT
to onnx.TensorProto.FLOAT16 if the source dtype is FLOAT; also avoid changing
Casts that are FP16->FP32 and add a debug log entry when you modify a Cast
(include node.name or node.output[0] and original->new dtypes) to aid debugging.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #852 +/- ##
=======================================
Coverage 73.73% 73.73%
=======================================
Files 199 199
Lines 21165 21211 +46
=======================================
+ Hits 15606 15640 +34
- Misses 5559 5571 +12 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@ajrasane can you please add the before and after accuracy results in the PR description? I.e: with FP8 custom Q/DQ nodes vs with FP8 native Q/DQ nodes. Thanks! |
|
Let's also add this change to the Changelog file. |
| op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], | ||
| ) | ||
| # Change FP32 cast nodes feeding into Concat/Add to FP16 | ||
| onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"]) |
There was a problem hiding this comment.
Can you please elaborate the goal/need of this function? Thanks!
There was a problem hiding this comment.
This is because after the convert_float_to_float16() function, one of the inputs for these nodes is FP16, while the other is FP32. Hence we run into a compilation issue with TensorRT. To fix this, I manually update them here for these operators.
There was a problem hiding this comment.
Got it, thanks for the explanation. Can you please update the docstring to give a bit more details? Thanks!
There was a problem hiding this comment.
@ajrasane would you consider using autocast's convert_to_f16 and avoid this patch?
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
9b30f17 to
e2abd9d
Compare
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Accuracy looks good, any idea why perf is slower after this PR? Also, can you please specify which model these numbers are for? Thanks. |
| op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], | ||
| ) | ||
| # Change FP32 cast nodes feeding into Concat/Add to FP16 | ||
| onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"]) |
There was a problem hiding this comment.
@ajrasane would you consider using autocast's convert_to_f16 and avoid this patch?
| logger.debug(f"Failed to fold Constant->Cast {node.name}: {e}") | ||
|
|
||
| if removed_count > 0: | ||
| graph.cleanup().toposort() |
There was a problem hiding this comment.
I recall some issues with toposort.
If you see any failures do to it, we can probably omit, _bypass_cast maintains node sorting.
AutoCast's unit testing covers this part well, and indeed, I see there's quite a few failures with this refactor. |
What does this PR do?
Type of change:
New feature
Overview:
Testing
Results:
Before replacement:
After replacement:
Before your PR is "Ready for review"
Summary by CodeRabbit
Release Notes
New Features
Improvements