Skip to content

[inference_fusion] convert conv3d patch embed to linear#45041

Merged
ArthurZucker merged 28 commits intohuggingface:mainfrom
JJJYmmm:main
Apr 13, 2026
Merged

[inference_fusion] convert conv3d patch embed to linear#45041
ArthurZucker merged 28 commits intohuggingface:mainfrom
JJJYmmm:main

Conversation

@JJJYmmm
Copy link
Copy Markdown
Contributor

@JJJYmmm JJJYmmm commented Mar 27, 2026

What does this PR do?

This PR addresses the performance issues observed with nn.Conv3d across different PyTorch/cuDNN, such as vllm-project/vllm#27418, https://mp.weixin.qq.com/s/hKRIpB561EdrMY8cbg1hEw.

We replace the original conv3d in PatchEmbed with equivalent linear forward. Add unites test and it passed.

Here is also a speed test script.

# benchmark_patch_embed_qwen2_vl.py
import argparse
import statistics
import time
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from torch import nn


@dataclass
class PatchEmbedBenchmarkConfig:
    patch_size: int
    temporal_patch_size: int
    in_channels: int
    embed_dim: int


def conv3d_forward(proj: nn.Conv3d, config: PatchEmbedBenchmarkConfig, hidden_states: torch.Tensor) -> torch.Tensor:
    target_dtype = proj.weight.dtype
    hidden_states = hidden_states.view(
        -1, config.in_channels, config.temporal_patch_size, config.patch_size, config.patch_size
    )
    hidden_states = proj(hidden_states.to(dtype=target_dtype)).view(-1, config.embed_dim)
    return hidden_states


def sync_device(device: torch.device) -> None:
    if device.type == "cuda":
        torch.cuda.synchronize(device)
    elif device.type == "mps" and hasattr(torch, "mps") and torch.mps.is_available():
        torch.mps.synchronize()


def benchmark(
    fn, proj: nn.Conv3d, config: PatchEmbedBenchmarkConfig, hidden_states: torch.Tensor, warmup_steps: int, benchmark_steps: int
) -> list[float]:
    for _ in range(warmup_steps):
        fn(proj, config, hidden_states)
    sync_device(hidden_states.device)

    timings_ms = []
    for _ in range(benchmark_steps):
        start = time.perf_counter()
        fn(proj, config, hidden_states)
        sync_device(hidden_states.device)
        timings_ms.append((time.perf_counter() - start) * 1000)
    return timings_ms


def linear_forward(proj: nn.Conv3d, config: PatchEmbedBenchmarkConfig, hidden_states: torch.Tensor) -> torch.Tensor:
    target_dtype = proj.weight.dtype
    hidden_states = hidden_states.view(
        -1, config.in_channels, config.temporal_patch_size, config.patch_size, config.patch_size
    )
    proj_weight = proj.weight.view(config.embed_dim, -1)
    hidden_states = hidden_states.to(dtype=target_dtype).reshape(hidden_states.shape[0], -1)
    hidden_states = F.linear(hidden_states, proj_weight, proj.bias)
    return hidden_states


def resolve_device(device: str) -> torch.device:
    if device == "auto":
        if torch.cuda.is_available():
            return torch.device("cuda")
        if torch.backends.mps.is_available():
            return torch.device("mps")
        return torch.device("cpu")
    return torch.device(device)


def resolve_dtype(dtype: str) -> torch.dtype:
    if dtype == "float16":
        return torch.float16
    if dtype == "bfloat16":
        return torch.bfloat16
    return torch.float32


def main():
    parser = argparse.ArgumentParser(description="Benchmark Qwen2VL PatchEmbed linear path against Conv3d.")
    parser.add_argument("--device", default="auto", help="auto, cpu, cuda, or mps")
    parser.add_argument("--dtype", default="bfloat16", choices=["float32", "float16", "bfloat16"])
    parser.add_argument("--num-patches", type=int, default=4096, help="Number of flattened patches to benchmark")
    parser.add_argument("--embed-dim", type=int, default=1152)
    parser.add_argument("--in-channels", type=int, default=3)
    parser.add_argument("--patch-size", type=int, default=14)
    parser.add_argument("--temporal-patch-size", type=int, default=2)
    parser.add_argument("--warmup-steps", type=int, default=20)
    parser.add_argument("--benchmark-steps", type=int, default=100)
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    device = resolve_device(args.device)
    dtype = resolve_dtype(args.dtype)
    if device.type == "cpu" and dtype in {torch.float16, torch.bfloat16}:
        raise ValueError("CPU benchmark only supports float32 reliably for this script.")

    config = PatchEmbedBenchmarkConfig(
        patch_size=args.patch_size,
        temporal_patch_size=args.temporal_patch_size,
        in_channels=args.in_channels,
        embed_dim=args.embed_dim,
    )
    kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
    proj = nn.Conv3d(
        config.in_channels, config.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
    ).to(device=device, dtype=dtype)
    proj.eval()

    patch_volume = config.in_channels * config.temporal_patch_size * config.patch_size * config.patch_size
    hidden_states = torch.randn(args.num_patches, patch_volume, device=device, dtype=dtype)

    with torch.inference_mode():
        linear_output = linear_forward(proj, config, hidden_states)
        conv3d_output = conv3d_forward(proj, config, hidden_states)
        max_abs_diff = (linear_output - conv3d_output).abs().max().item()
        torch.testing.assert_close(linear_output, conv3d_output)

        linear_timings = benchmark(
            linear_forward,
            proj,
            config,
            hidden_states,
            warmup_steps=args.warmup_steps,
            benchmark_steps=args.benchmark_steps,
        )
        conv3d_timings = benchmark(
            conv3d_forward,
            proj,
            config,
            hidden_states,
            warmup_steps=args.warmup_steps,
            benchmark_steps=args.benchmark_steps,
        )

    linear_mean = statistics.mean(linear_timings)
    linear_median = statistics.median(linear_timings)
    conv3d_mean = statistics.mean(conv3d_timings)
    conv3d_median = statistics.median(conv3d_timings)

    print(f"device={device} dtype={dtype} num_patches={args.num_patches}")
    print(f"max_abs_diff={max_abs_diff:.8f}")
    print(f"linear_mean_ms={linear_mean:.4f}")
    print(f"linear_median_ms={linear_median:.4f}")
    print(f"conv3d_mean_ms={conv3d_mean:.4f}")
    print(f"conv3d_median_ms={conv3d_median:.4f}")
    print(f"mean_speedup_vs_conv3d={conv3d_mean / linear_mean:.4f}x")
    print(f"median_speedup_vs_conv3d={conv3d_median / linear_median:.4f}x")


if __name__ == "__main__":
    main()

result for mps(m5 + torch2.10):

python benchmark_patch_embed_qwen2_vl.py --num-patches 4096 --warmup-steps 20 --benchmark-steps 100
device=mps dtype=torch.bfloat16 num_patches=4096
max_abs_diff=0.00781250
linear_mean_ms=0.9433
linear_median_ms=0.9374
conv3d_mean_ms=703.3915
conv3d_median_ms=704.5094
mean_speedup_vs_conv3d=745.6549x
median_speedup_vs_conv3d=751.5601x

result for GPU h100 @wulipc:

Python benchmark_patch_embed_qwen2_vl.py --num-patches 4096 -warmup-steps 20 --benchmark-steps 100
device=cuda dtype=torch.bfloat16 num_patches=4096
max_abs_diff=0.00781250
linear_mean_ms=0.0392
linear_median_ms=0.0372conv3d_mean_ms=0.5528
conv3d_median_ms=0.5490
mean_speedup_vs_conv3d=14.1141x
median_speedup_vs conv3d=14.7596x

python benchmark_patch_embed_qwen2_vl.py.py --num-patches 40960 --warmup-steps 20 -benchmark-steps 100 device=cuda dtype=torch.bfloat16 num_patches=40960
max_abs_diff=0.01562500
linear_mean_ms=0.1884
linear median ms=0.1850conv3d_mean_ms=4.8897
conv3d_median_ms=4.8880
mean_speedup_vs_conv3d=25.9581x
median_speedup_vs_conv3d=26.4224x

The results show that the linear forward pass achieves lower latency to Conv3d, so we can just replace it. We temporarily retain nn.Conv3d for downstream compatibility, with plans to migrate to nn.Linear in future checkpoints.

cc @wulipc @ShuaiBai623

Comment thread src/transformers/models/glm4v/modeling_glm4v.py Outdated
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some quick comments from my side. Let me know when the PR is in a ready enough state, I would run our CI for the models to check for the integration tests

(they will fail if commits are pushed in the meantime so we have to sync this properly 😄)

Comment thread src/transformers/conversion_mapping.py Outdated
Comment thread src/transformers/core_model_loading.py Outdated
Comment thread src/transformers/core_model_loading.py Outdated
@JJJYmmm
Copy link
Copy Markdown
Contributor Author

JJJYmmm commented Mar 27, 2026

@vasqu thank you for the review! I'll address them tomorrow and ping you when done. just marking this pr as draft for now.

@JJJYmmm JJJYmmm marked this pull request as draft March 27, 2026 17:50
@JJJYmmm
Copy link
Copy Markdown
Contributor Author

JJJYmmm commented Mar 28, 2026

@vasqu Hi, all related CI checks passed, I think we can trigger integration tests.
I also ran qwen3_5, qwen2_vl and glm4v intergration tests locally, qwen3_5 and qwen2_vl passed all and glm4v failed some cases that also fail on the main branch. So maybe we can run first and correct them if needed.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 28, 2026

run-slow: ernie4_5_vl_moe, glm4v, glm4v_moe, glm_image, glm_ocr, qwen2_5_omni, qwen2_5_vl, qwen2_vl, qwen3_5, qwen3_5_moe, qwen3_omni_moe, qwen3_vl, qwen3_vl_moe

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Mar 28, 2026

Yup, let's check it out, run slow also compares against main whether failures are new or not

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/ernie4_5_vl_moe", "models/glm4v", "models/glm4v_moe", "models/glm_image", "models/glm_ocr", "models/qwen2_5_omni", "models/qwen2_5_vl", "models/qwen2_vl", "models/qwen3_5", "models/qwen3_5_moe", "models/qwen3_omni_moe", "models/qwen3_vl", "models/qwen3_vl_moe"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 4242527a workflow commit (merge commit)
PR 461382ca branch commit (from PR)
main 9a9997fd base commit (on main)

Model CI Report

6 new failed tests from this PR 😭

  • glm4v:
    tests/models/glm4v/test_modeling_glm4v.py::Glm4vIntegrationTest::test_small_model_integration_test_batch_wo_image (✅ ⟹ ❌)

  • glm_ocr:
    tests/models/glm_ocr/test_modeling_glm_ocr.py::GlmOcrIntegrationTest::test_small_model_integration_test_expand (❌ ⟹ ❌)

  • qwen2_5_vl:
    tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py::Qwen2_5_VLIntegrationTest::test_small_model_integration_test_expand (✅ ⟹ ❌)

  • qwen3_vl_moe:
    tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py::Qwen3VLMoeIntegrationTest::test_small_model_integration_test (✅ ⟹ ❌)
    tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py::Qwen3VLMoeIntegrationTest::test_small_model_integration_test_batch (❌ ⟹ ❌)
    tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py::Qwen3VLMoeIntegrationTest::test_small_model_integration_test_batch_different_resolutions (❌ ⟹ ❌)

@JJJYmmm JJJYmmm marked this pull request as ready for review March 28, 2026 17:24
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so based on the generations starting to differ, I feel more confident in making this an optional feature in the spirit of #44942 - while the diffs are similar, it is hard to guarantee same performance across all affected models. We had similar experiences with fused qkv and split q, k, v - it depends on the trained model whether it's more easily affected or not.

This is why I would like to make this optional if possible because that will allow the best of both worlds

  • The original version which is the behavior one would expect
  • The fused version for speed gains but slightly different outputs

Wdyt @JJJYmmm? Really appreciate the initiative

@JJJYmmm
Copy link
Copy Markdown
Contributor Author

JJJYmmm commented Mar 30, 2026

@vasqu I agree! Since this pr also affects other models (e.g., glm-v), the impact on performance is still uncertain. I'll make it optional following #44942!

@JJJYmmm
Copy link
Copy Markdown
Contributor Author

JJJYmmm commented Mar 30, 2026

@vasqu Hi! I reworked this into a more general fusion dispatcher, while keeping this pr scoped to patch embedding fusion only. the modeling files stays unchanged, and the 3dconv->linear patch fusion is now opt-in through from_pretrained(...):

model = Qwen2VLForConditionalGeneration.from_pretrained(
    ...,
    fusion_config={
        "patch_embeddings": True,
        # compatible with #44942, we can add flags here like
        # "qkv_fusion": True,
        # "mlp_fusion": True, 
    },
)

I also tried to make this compatible with #44942: the entry point is now a generic fusion_config, and the fusion logic is dispatched through fusion_mapping.py. So if we want to support more fusion types later, we should only need to add new handlers in fusion_mapping.py, instead of introducing another dedicated loading flag.

@JJJYmmm JJJYmmm changed the title [3dconv][qwenvl] convert patch embed forward to linear [inference_fusion] convert conv3d patch embed to linear Mar 31, 2026
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Conversions look super clean to me
  2. We have to exchange the fusion mapping to utilize the monkey patching we already have (on me didnt notice it existed 😢)
  3. Maybe some docs on how to apply this at the end would be nice then

Overall, it hopefully shouldn't change your core logic but use some existing stuff / extending it

Comment thread src/transformers/core_model_loading.py
Comment thread src/transformers/fusion_mapping.py
Comment thread tests/models/qwen3_5/test_modeling_qwen3_5.py Outdated
Comment thread tests/models/qwen3_5_moe/test_modeling_qwen3_5_moe.py Outdated
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 1, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: qwen2_vl

@JJJYmmm
Copy link
Copy Markdown
Contributor Author

JJJYmmm commented Apr 1, 2026

@vasqu really appreciate the detailed review :) I have moved them into classes and add the tests accordingly.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the fusion config serialized? We could serialize it + test with a model on the hub that has it set to True for example.

Comment thread src/transformers/fusion_mapping.py
Comment thread src/transformers/fusion_mapping.py Outdated
Comment thread src/transformers/fusion_mapping.py Outdated
Comment thread src/transformers/fusion_mapping.py
Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, don't have much to add anymore! my hats off to you

I think Arthur's comment on serialization is the only point left:

  • Do we save/load the fusion config anywhere?
  • Maybe under the base PreTrainedConfig?

Currently, if I see it correctly we would always need to pass the fusion config but if we were to save it, we could avoid it and make it a bit more convenient for users

Comment thread src/transformers/fusion_mapping.py Outdated
Comment thread src/transformers/fusion_mapping.py Outdated
Comment on lines +165 to +166
if (tuple(converter.source_patterns), tuple(converter.target_patterns), type(converter))
not in existing_converter_keys
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I think this is fair for now. I think in the future it would make sense to have compatible conversions for example when we apply chunking to qkv fused weights

  • If operations are only chunking (qkv -> q, k, v)
  • Remove that operation instead

I don't think it makes sense in the current context here but would be a nice addition in the future

Comment thread tests/utils/test_fusion_mapping.py Outdated
@JJJYmmm
Copy link
Copy Markdown
Contributor Author

JJJYmmm commented Apr 2, 2026

@ArthurZucker @vasqu sorry for missing this in the previous round! it is supported now: if fusion_config is saved in config.json, from_pretrained() will pick it up automatically on the next load.

I added a test covering from_pretrained(fusion_config={...}) -> save_pretrained() -> from_pretrained(), so the serialized config is exercised as well. I also verified it on my local ckpt with modified config.json. 🫡

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, one nit would be maybe to really have a remote hub config but this imo also enough already - wdyt @ArthurZucker?

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Having a look tomorrow!!!!

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TY! Let's improve tests please

Comment thread src/transformers/fusion_mapping.py Outdated
- uses `is_fusable(...)` as the final structural check
- builds the patch mapping used by monkey patching

Results are cached per `(fusion_name, cls)` to avoid repeated meta-initialization.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not 100% sure this is good, if you only want to fuse the first 4 mlps you can't with this! what we need to cache is the matched pattern I think. this way if you say layer.[0-4].mlp after 0, it will skip and if you had another fusion for the MLP class, they will be skipped!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that is a fair point, but I feel the limitation here is more fundamental than the cache itself: the current monkey-patching path register_patch_mapping(...) is class-level, so it can express “fuse this module class”, but not really “fuse only layers 0-4”.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right! mmm we could make it instance level potentially? if not perfect to make it clear its not instance level

Comment thread src/transformers/fusion_mapping.py Outdated
Comment thread tests/utils/test_fusion_mapping.py Outdated

def __init__(self, config):
super().__init__(config)
# Instantiate through the fake module so `apply_patches()` sees the replacement.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you elaborate? you mean that me need DummyPatchEmbedding to be importable from transformers?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need DummyPatchEmbedding to be reachable from a transformers.* module.

the reason is that apply_patches() scans sys.modules and only rewrites class attributes exposed from modules whose name starts with transformers.

if not module.__name__.startswith("transformers"):
continue

so if the dummy class were only referenced as a local symbol inside the test file, the registered patch would never actually be applied during model construction.

that is why this test creates DUMMY_TRANSFORMERS_MODULE_NAME and resolves DummyPatchEmbedding through that fake transformers.* module

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg

Comment on lines +80 to +81
for class_name, patchable_class in DUMMY_PATCHABLE_CLASSES.items():
setattr(DUMMY_TRANSFORMERS_MODULE, class_name, patchable_class)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is only 1 item

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept it this way just for easier extension if we add other fusions e.g. #44942. 🫡

Comment thread tests/utils/test_fusion_mapping.py
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

- uses `is_fusable(...)` as the final structural check
- builds the patch mapping used by monkey patching

Results are cached per `(fusion_name, cls)` to avoid repeated meta-initialization.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right! mmm we could make it instance level potentially? if not perfect to make it clear its not instance level

@ArthurZucker ArthurZucker enabled auto-merge April 13, 2026 15:01
@ArthurZucker
Copy link
Copy Markdown
Collaborator

Ty @JJJYmmm 🤗

@ArthurZucker ArthurZucker added this pull request to the merge queue Apr 13, 2026
Merged via the queue into huggingface:main with commit 3a947e2 Apr 13, 2026
28 checks passed
sirzechs66 pushed a commit to sirzechs66/transformers that referenced this pull request Apr 18, 2026
…45041)

* ok

* fix consistency

* pass qwen35 reverse mapping

* update new failed test according to captured info

* Revert "update new failed test according to captured info"

This reverts commit 445a400.

* make it optional

* make fusion_mapping more general

* make conv3d conversion more general

* make fusion_mapping more general

* better name for conversion

* add fusion_mapping doc and clean tests

* fix reverse mapping test follow gemma3n

* chore: retrigger ci

* tests: move qwen3.5 reverse mapping fix to separate branch

* code clean!

* ruff format and clean test to make it simple

* richer doc

* get converters from config rather than each module

* add explict module_name check for fusion!

* better isolated test and code clean

* support serialized fusion_config

* ruff format

* config can handle unknown attributes

* move fused cls out of spec by mixin

* detailed comments

* ruff
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants