[inference_fusion] convert conv3d patch embed to linear#45041
[inference_fusion] convert conv3d patch embed to linear#45041ArthurZucker merged 28 commits intohuggingface:mainfrom
Conversation
vasqu
left a comment
There was a problem hiding this comment.
Just some quick comments from my side. Let me know when the PR is in a ready enough state, I would run our CI for the models to check for the integration tests
(they will fail if commits are pushed in the meantime so we have to sync this properly 😄)
|
@vasqu thank you for the review! I'll address them tomorrow and ping you when done. just marking this pr as draft for now. |
|
@vasqu Hi, all related CI checks passed, I think we can trigger integration tests. |
|
run-slow: ernie4_5_vl_moe, glm4v, glm4v_moe, glm_image, glm_ocr, qwen2_5_omni, qwen2_5_vl, qwen2_vl, qwen3_5, qwen3_5_moe, qwen3_omni_moe, qwen3_vl, qwen3_vl_moe |
|
Yup, let's check it out, run slow also compares against main whether failures are new or not |
|
This comment contains models: ["models/ernie4_5_vl_moe", "models/glm4v", "models/glm4v_moe", "models/glm_image", "models/glm_ocr", "models/qwen2_5_omni", "models/qwen2_5_vl", "models/qwen2_vl", "models/qwen3_5", "models/qwen3_5_moe", "models/qwen3_omni_moe", "models/qwen3_vl", "models/qwen3_vl_moe"] |
CI ResultsCommit Info
Model CI Report❌ 6 new failed tests from this PR 😭
|
vasqu
left a comment
There was a problem hiding this comment.
Ok, so based on the generations starting to differ, I feel more confident in making this an optional feature in the spirit of #44942 - while the diffs are similar, it is hard to guarantee same performance across all affected models. We had similar experiences with fused qkv and split q, k, v - it depends on the trained model whether it's more easily affected or not.
This is why I would like to make this optional if possible because that will allow the best of both worlds
- The original version which is the behavior one would expect
- The fused version for speed gains but slightly different outputs
Wdyt @JJJYmmm? Really appreciate the initiative
This reverts commit 445a400.
|
@vasqu Hi! I reworked this into a more general fusion dispatcher, while keeping this pr scoped to patch embedding fusion only. the modeling files stays unchanged, and the 3dconv->linear patch fusion is now opt-in through model = Qwen2VLForConditionalGeneration.from_pretrained(
...,
fusion_config={
"patch_embeddings": True,
# compatible with #44942, we can add flags here like
# "qkv_fusion": True,
# "mlp_fusion": True,
},
)I also tried to make this compatible with #44942: the entry point is now a generic |
vasqu
left a comment
There was a problem hiding this comment.
- Conversions look super clean to me
- We have to exchange the fusion mapping to utilize the monkey patching we already have (on me didnt notice it existed 😢)
- Maybe some docs on how to apply this at the end would be nice then
Overall, it hopefully shouldn't change your core logic but use some existing stuff / extending it
# Conflicts: # tests/models/qwen3_5/test_modeling_qwen3_5.py # tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py
|
[For maintainers] Suggested jobs to run (before merge) run-slow: qwen2_vl |
|
@vasqu really appreciate the detailed review :) I have moved them into classes and add the tests accordingly. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Is the fusion config serialized? We could serialize it + test with a model on the hub that has it set to True for example.
There was a problem hiding this comment.
Yup, don't have much to add anymore! my hats off to you
I think Arthur's comment on serialization is the only point left:
- Do we save/load the fusion config anywhere?
- Maybe under the base PreTrainedConfig?
Currently, if I see it correctly we would always need to pass the fusion config but if we were to save it, we could avoid it and make it a bit more convenient for users
| if (tuple(converter.source_patterns), tuple(converter.target_patterns), type(converter)) | ||
| not in existing_converter_keys |
There was a problem hiding this comment.
Yea I think this is fair for now. I think in the future it would make sense to have compatible conversions for example when we apply chunking to qkv fused weights
- If operations are only chunking (qkv -> q, k, v)
- Remove that operation instead
I don't think it makes sense in the current context here but would be a nice addition in the future
|
@ArthurZucker @vasqu sorry for missing this in the previous round! it is supported now: if I added a test covering |
vasqu
left a comment
There was a problem hiding this comment.
Perfect, one nit would be maybe to really have a remote hub config but this imo also enough already - wdyt @ArthurZucker?
|
Having a look tomorrow!!!! |
ArthurZucker
left a comment
There was a problem hiding this comment.
TY! Let's improve tests please
| - uses `is_fusable(...)` as the final structural check | ||
| - builds the patch mapping used by monkey patching | ||
|
|
||
| Results are cached per `(fusion_name, cls)` to avoid repeated meta-initialization. |
There was a problem hiding this comment.
not 100% sure this is good, if you only want to fuse the first 4 mlps you can't with this! what we need to cache is the matched pattern I think. this way if you say layer.[0-4].mlp after 0, it will skip and if you had another fusion for the MLP class, they will be skipped!
There was a problem hiding this comment.
yeah that is a fair point, but I feel the limitation here is more fundamental than the cache itself: the current monkey-patching path register_patch_mapping(...) is class-level, so it can express “fuse this module class”, but not really “fuse only layers 0-4”.
There was a problem hiding this comment.
You're right! mmm we could make it instance level potentially? if not perfect to make it clear its not instance level
|
|
||
| def __init__(self, config): | ||
| super().__init__(config) | ||
| # Instantiate through the fake module so `apply_patches()` sees the replacement. |
There was a problem hiding this comment.
can you elaborate? you mean that me need DummyPatchEmbedding to be importable from transformers?
There was a problem hiding this comment.
we need DummyPatchEmbedding to be reachable from a transformers.* module.
the reason is that apply_patches() scans sys.modules and only rewrites class attributes exposed from modules whose name starts with transformers.
transformers/src/transformers/monkey_patching.py
Lines 271 to 272 in 47d7765
so if the dummy class were only referenced as a local symbol inside the test file, the registered patch would never actually be applied during model construction.
that is why this test creates DUMMY_TRANSFORMERS_MODULE_NAME and resolves DummyPatchEmbedding through that fake transformers.* module
| for class_name, patchable_class in DUMMY_PATCHABLE_CLASSES.items(): | ||
| setattr(DUMMY_TRANSFORMERS_MODULE, class_name, patchable_class) |
There was a problem hiding this comment.
there is only 1 item
There was a problem hiding this comment.
I kept it this way just for easier extension if we add other fusions e.g. #44942. 🫡
| - uses `is_fusable(...)` as the final structural check | ||
| - builds the patch mapping used by monkey patching | ||
|
|
||
| Results are cached per `(fusion_name, cls)` to avoid repeated meta-initialization. |
There was a problem hiding this comment.
You're right! mmm we could make it instance level potentially? if not perfect to make it clear its not instance level
|
Ty @JJJYmmm 🤗 |
…45041) * ok * fix consistency * pass qwen35 reverse mapping * update new failed test according to captured info * Revert "update new failed test according to captured info" This reverts commit 445a400. * make it optional * make fusion_mapping more general * make conv3d conversion more general * make fusion_mapping more general * better name for conversion * add fusion_mapping doc and clean tests * fix reverse mapping test follow gemma3n * chore: retrigger ci * tests: move qwen3.5 reverse mapping fix to separate branch * code clean! * ruff format and clean test to make it simple * richer doc * get converters from config rather than each module * add explict module_name check for fusion! * better isolated test and code clean * support serialized fusion_config * ruff format * config can handle unknown attributes * move fused cls out of spec by mixin * detailed comments * ruff
What does this PR do?
This PR addresses the performance issues observed with nn.Conv3d across different PyTorch/cuDNN, such as vllm-project/vllm#27418, https://mp.weixin.qq.com/s/hKRIpB561EdrMY8cbg1hEw.
We replace the original conv3d in PatchEmbed with equivalent linear forward. Add unites test and it passed.
Here is also a speed test script.
result for mps(m5 + torch2.10):
result for GPU h100 @wulipc:
The results show that the linear forward pass achieves lower latency to Conv3d, so we can just replace it. We temporarily retain
nn.Conv3dfor downstream compatibility, with plans to migrate tonn.Linearin future checkpoints.cc @wulipc @ShuaiBai623