From b61922ad50f4e76556323ef80b676013b6872390 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 2 Apr 2026 21:29:31 +0200 Subject: [PATCH 1/9] fix missing file in modelbuilder --- onnx_diagnostic/helpers/model_builder_helper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnx_diagnostic/helpers/model_builder_helper.py b/onnx_diagnostic/helpers/model_builder_helper.py index 6a383e9b..02c06172 100644 --- a/onnx_diagnostic/helpers/model_builder_helper.py +++ b/onnx_diagnostic/helpers/model_builder_helper.py @@ -48,6 +48,7 @@ def download_model_builder_to_cache( "phi.py", "qwen.py", "smollm.py", + "whipser.py", ]: u = f"{'/'.join(url.split('/')[:-1])}/builders/{subfile}" response = requests.get(u) From ca60329dfe27c4190508b06bf2710bc70b062887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 2 Apr 2026 21:37:42 +0200 Subject: [PATCH 2/9] remobe mamba --- _unittests/ut_helpers/test_cache_helper.py | 1 - onnx_diagnostic/helpers/cache_helper.py | 52 ---------------------- onnx_diagnostic/helpers/torch_helper.py | 4 -- onnx_diagnostic/tasks/text_generation.py | 2 +- 4 files changed, 1 insertion(+), 58 deletions(-) diff --git a/_unittests/ut_helpers/test_cache_helper.py b/_unittests/ut_helpers/test_cache_helper.py index 16b444c5..a1fbc4c8 100644 --- a/_unittests/ut_helpers/test_cache_helper.py +++ b/_unittests/ut_helpers/test_cache_helper.py @@ -9,7 +9,6 @@ make_dynamic_cache, make_encoder_decoder_cache, make_hybrid_cache, - make_mamba_cache, make_sliding_window_cache, make_static_cache, ) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index a672e411..7f90f8be 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -537,58 +537,6 @@ def make_encoder_decoder_cache( make_encoder_decoder_cache = None # type: ignore[assignment] -def make_mamba_cache( - key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], - cls_layers: Optional[Union[str, List[type]]] = None, - cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None, -) -> "MambaCache": # noqa: F821 - """Creates a ``MambaCache``. `cls_layers`, `cls_kwargs` are unused.""" - # import is moved here because this part is slow. - try: - from transformers.models.mamba.modeling_mamba import MambaCache - except ImportError: - from transformers.cache_utils import MambaCache - dtype = key_value_pairs[0][0].dtype - - class _config: - def __init__(self): - self.intermediate_size = key_value_pairs[0][0].shape[1] - self.conv_kernel = key_value_pairs[0][0].shape[-1] - self.state_size = key_value_pairs[0][1].shape[-1] - self.num_hidden_layers = len(key_value_pairs) - self.dtype = dtype - - def get_text_config(self, *args, **kwargs): - return self - - cache = MambaCache( - _config(), - max_batch_size=key_value_pairs[0][0].shape[0], - device=key_value_pairs[0][0].device, - dtype=dtype, - ) - for i in range(len(key_value_pairs)): - assert cache.conv_states[i].dtype == dtype, ( - f"Type mismatch for cache.conv_states[{i}].dtype=" - f"{cache.conv_states[i].dtype} != {dtype}" - ) - assert cache.ssm_states[i].dtype == dtype, ( - f"Type mismatch for cache.ssm_states[{i}].dtype=" - f"{cache.ssm_states[i].dtype} != {dtype}" - ) - assert cache.conv_states[i].shape == key_value_pairs[i][0].shape, ( - f"Shape mismatch, expected {cache.conv_states[i].shape}, " - f"got {key_value_pairs[i][0].shape}" - ) - cache.conv_states[i][:, :, :] = key_value_pairs[i][0] - assert cache.ssm_states[i].shape == key_value_pairs[i][1].shape, ( - f"Shape mismatch, expected {cache.ssm_states[i].shape}, " - f"got {key_value_pairs[i][1].shape}" - ) - cache.ssm_states[i][:, :, :] = key_value_pairs[i][1] - return finalize_cache(cache) - - if hasattr(transformers.cache_utils, "SlidingWindowCache"): def make_sliding_window_cache( diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index 1d8d7f40..28bcf4b5 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -898,10 +898,6 @@ def torch_deepcopy(value: Any) -> Any: torch_deepcopy(value.self_attention_cache), torch_deepcopy(value.cross_attention_cache), ) - if value.__class__.__name__ == "MambaCache": - from .cache_helper import make_mamba_cache - - return make_mamba_cache(list(zip(value.conv_states, value.ssm_states))) if value.__class__ in torch.utils._pytree.SUPPORTED_NODES: args, spec = torch.utils._pytree.tree_flatten(value) diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 25b4d29c..cd7b2e80 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -1,6 +1,6 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union import torch -from ..helpers.cache_helper import make_dynamic_cache, make_mamba_cache, make_static_cache +from ..helpers.cache_helper import make_dynamic_cache, make_static_cache from ..helpers.config_helper import ( update_config, check_hasattr, From 69f87fe2582ddaf34d4545666f6c996b054fc8df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 2 Apr 2026 21:42:11 +0200 Subject: [PATCH 3/9] remove mamba --- _unittests/ut_helpers/test_cache_helper.py | 17 --- _unittests/ut_helpers/test_torch_helper.py | 21 +-- _unittests/ut_tasks/test_tasks.py | 18 +-- _unittests/ut_tasks/try_tasks.py | 74 ----------- .../test_onnx_export_errors.py | 122 +----------------- onnx_diagnostic/export/shape_helper.py | 8 -- onnx_diagnostic/tasks/text_generation.py | 83 +----------- .../onnx_export_serialization.py | 31 ----- .../serialization/transformers_impl.py | 66 ---------- .../hghub/hub_data_cached_configs.py | 39 ------ 10 files changed, 4 insertions(+), 475 deletions(-) diff --git a/_unittests/ut_helpers/test_cache_helper.py b/_unittests/ut_helpers/test_cache_helper.py index a1fbc4c8..31b7043a 100644 --- a/_unittests/ut_helpers/test_cache_helper.py +++ b/_unittests/ut_helpers/test_cache_helper.py @@ -149,23 +149,6 @@ def test_unflatten_flatten_encoder_decoder_cache(self): self.string_type(c2, with_shape=True), ) - @requires_transformers("4.51") # the structure changes - def test_make_mamba_cache(self): - cache = make_mamba_cache( - [ - (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), - (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), - (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), - ] - ) - text = self.string_type(cache, with_shape=True) - self.assertEqual( - "MambaCache(conv_states=#3[T1s4x4x4,T1s4x4x4,T1s4x4x4], " - "ssm_states=#3[T1s4x4x4,T1s4x4x4,T1s4x4x4])", - text, - ) - self.assertEqual(0, max_diff(cache, cache)["abs"]) - @unittest.skipIf( not make_sliding_window_cache, "SlidingWindowCache removed in transformers>=5" ) diff --git a/_unittests/ut_helpers/test_torch_helper.py b/_unittests/ut_helpers/test_torch_helper.py index 6f38c1db..da2efd9a 100644 --- a/_unittests/ut_helpers/test_torch_helper.py +++ b/_unittests/ut_helpers/test_torch_helper.py @@ -4,7 +4,7 @@ import onnx import torch import transformers -from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_torch +from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout from onnx_diagnostic.helpers import max_diff, string_type from onnx_diagnostic.helpers.torch_helper import ( dummy_llm, @@ -22,7 +22,6 @@ from onnx_diagnostic.helpers.cache_helper import ( make_dynamic_cache, make_encoder_decoder_cache, - make_mamba_cache, make_sliding_window_cache, CacheKeyValue, ) @@ -313,24 +312,6 @@ def test_torch_deepcopy_cache_dce(self): self.assertEqual(hash1, hash2) self.assertGreater(torch_tensor_size(cc), 1) - @requires_torch("4.50") - def test_torch_deepcopy_mamba_cache(self): - cache = make_mamba_cache( - [ - (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), - (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), - (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), - ] - ) - at = torch_deepcopy(cache) - self.assertEqual(type(cache), type(at)) - self.assertEqual(max_diff(cache, at)["abs"], 0) - hash1 = string_type(at, with_shape=True, with_min_max=True) - cache.conv_states[0] += 1000 - hash2 = string_type(at, with_shape=True, with_min_max=True) - self.assertEqual(hash1, hash2) - self.assertGreater(torch_tensor_size(cache), 1) - def test_torch_deepcopy_base_model_outputs(self): bo = transformers.modeling_outputs.BaseModelOutput( last_hidden_state=torch.rand((4, 4, 4)) diff --git a/_unittests/ut_tasks/test_tasks.py b/_unittests/ut_tasks/test_tasks.py index b44eed47..e8728fad 100644 --- a/_unittests/ut_tasks/test_tasks.py +++ b/_unittests/ut_tasks/test_tasks.py @@ -1,7 +1,7 @@ import os import unittest import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers +from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout from onnx_diagnostic.helpers.torch_helper import to_any, torch_deepcopy from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs from onnx_diagnostic.torch_export_patches import torch_export_patches @@ -257,22 +257,6 @@ def test_sentence_similary(self): model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False ) - @hide_stdout() - def test_falcon_mamba_dev(self): - mid = "tiiuae/falcon-mamba-tiny-dev" - data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True) - self.assertEqual(data["task"], "text-generation") - model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] - model(**inputs) - model(**data["inputs2"]) - self.assertIn((data["size"], data["n_weights"]), [(274958336, 68739584)]) - if not has_transformers("5.3.99"): - raise unittest.SkipTest("The model has control flow.") - with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1): - torch.export.export( - model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False - ) - if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_tasks/try_tasks.py b/_unittests/ut_tasks/try_tasks.py index e47aeed0..8ad9e3b6 100644 --- a/_unittests/ut_tasks/try_tasks.py +++ b/_unittests/ut_tasks/try_tasks.py @@ -683,80 +683,6 @@ def mean_pooling(model_output, attention_mask): print("Sentence embeddings:") print(sentence_embeddings) - @never_test() - def test_falcon_mamba_dev(self): - # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k falcon_mamba_dev - # https://huggingface.co/tiiuae/falcon-mamba-tiny-dev - - from transformers import AutoTokenizer - import transformers - import torch - - model = "tiiuae/falcon-mamba-tiny-dev" - - tokenizer = AutoTokenizer.from_pretrained(model) - pipeline = transformers.pipeline( - "text-generation", - model=model, - tokenizer=tokenizer, - dtype=torch.bfloat16, - trust_remote_code=True, - device_map="auto", - ) - print() - with steal_forward(pipeline.model): - sequences = pipeline( - "Girafatron is obsessed with giraffes, " - "the most glorious animal on the face of this Earth. " - "Giraftron believes all other animals are irrelevant " - "when compared to the glorious majesty of the giraffe." - "\nDaniel: Hello, Girafatron!\nGirafatron:", - max_length=200, - do_sample=True, - top_k=10, - num_return_sequences=1, - eos_token_id=tokenizer.eos_token_id, - ) - for seq in sequences: - print(f"Result: {seq['generated_text']}") - - @never_test() - def test_falcon_mamba_7b(self): - # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k falcon_mamba_7b - # https://huggingface.co/tiiuae/falcon-mamba-7b - - from transformers import AutoTokenizer - import transformers - import torch - - model = "tiiuae/falcon-mamba-7b" - - tokenizer = AutoTokenizer.from_pretrained(model) - pipeline = transformers.pipeline( - "text-generation", - model=model, - tokenizer=tokenizer, - dtype=torch.bfloat16, - trust_remote_code=True, - device_map="auto", - ) - print() - with steal_forward(pipeline.model): - sequences = pipeline( - "Girafatron is obsessed with giraffes, " - "the most glorious animal on the face of this Earth. " - "Giraftron believes all other animals are irrelevant " - "when compared to the glorious majesty of the giraffe." - "\nDaniel: Hello, Girafatron!\nGirafatron:", - max_length=200, - do_sample=True, - top_k=10, - num_return_sequences=1, - eos_token_id=tokenizer.eos_token_id, - ) - for seq in sequences: - print(f"Result: {seq['generated_text']}") - @never_test() def test_object_detection(self): # clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k object_ diff --git a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py index 48863955..6f70133c 100644 --- a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py +++ b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py @@ -1,128 +1,8 @@ import unittest -from onnx_diagnostic.ext_test_case import ( - ExtTestCase, - requires_torch, - requires_transformers, - skipif_ci_windows, - ignore_warnings, - hide_stdout, -) -from onnx_diagnostic.helpers import string_type -from onnx_diagnostic.torch_export_patches.onnx_export_errors import ( - torch_export_patches, -) +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings class TestOnnxExportErrors(ExtTestCase): - @requires_transformers("4.49.999") - @skipif_ci_windows("not working on Windows") - @ignore_warnings(UserWarning) - @hide_stdout() - def test_pytree_flatten_mamba_cache(self): - import torch - import torch.utils._pytree as py_pytree - - try: - from transformers.models.mamba.modeling_mamba import MambaCache - except ImportError: - from transformers.cache_utils import MambaCache - - class _config: - def __init__(self): - self.intermediate_size = 8 - self.state_size = 16 - self.conv_kernel = 32 - self.num_hidden_layers = 64 - self.dtype = torch.float16 - - cache = MambaCache(_config(), max_batch_size=1, device="cpu") - - with torch_export_patches(verbose=1): - values, spec = py_pytree.tree_flatten(cache) - cache2 = py_pytree.tree_unflatten(values, spec) - self.assertEqual(cache.max_batch_size, cache2.max_batch_size) - self.assertEqual(cache.intermediate_size, cache2.intermediate_size) - self.assertEqual(cache.ssm_state_size, cache2.ssm_state_size) - self.assertEqual(cache.conv_kernel_size, cache2.conv_kernel_size) - self.assertEqualArrayAny(cache.conv_states, cache2.conv_states) - self.assertEqualArrayAny(cache.ssm_states, cache2.ssm_states) - - @requires_transformers("4.50") - @requires_torch("2.7") - @skipif_ci_windows("not working on Windows") - @ignore_warnings(UserWarning) - @hide_stdout() - def test_exportable_mamba_cache(self): - import torch - from transformers.models.mamba.modeling_mamba import MambaCache - - class _config: - def __init__(self): - self.intermediate_size = 8 - self.state_size = 16 - self.conv_kernel = 32 - self.num_hidden_layers = 64 - self.dtype = torch.float16 - - class Model(torch.nn.Module): - def forward(self, x: torch.Tensor, cache: MambaCache): - x1 = cache.ssm_states[0] + x - x2 = cache.conv_states[0][:, :, ::2] + x1 - return x2 - - cache = MambaCache(_config(), max_batch_size=1, device="cpu") - # MambaCache was updated in 4.50 - self.assertEqual( - "MambaCache(conv_states=#64[T10r3,...], ssm_states=#64[T10r3,...])", - string_type(cache), - ) - x = torch.ones(2, 8, 16).to(torch.float16) - model = Model() - model(x, cache) - - with torch_export_patches(verbose=1, patch_transformers=True): - cache = MambaCache(_config(), max_batch_size=1, device="cpu") - torch.export.export(Model(), (x, cache)) - - @requires_transformers("4.49.999") - @skipif_ci_windows("not working on Windows") - @ignore_warnings(UserWarning) - def test_exportable_mamba_cache_dynamic(self): - import torch - from transformers.models.mamba.modeling_mamba import MambaCache - - class _config: - def __init__(self): - self.intermediate_size = 8 - self.state_size = 16 - self.conv_kernel = 32 - self.num_hidden_layers = 2 - self.dtype = torch.float16 - - class Model(torch.nn.Module): - def forward(self, x: torch.Tensor, cache: MambaCache): - x1 = cache.ssm_states[0] + x - x2 = cache.conv_states[0][:, :, ::2] + x1 - return x2 - - cache = MambaCache(_config(), max_batch_size=1, device="cpu") - self.assertEqual( - string_type(cache), - "MambaCache(conv_states=#2[T10r3,T10r3], ssm_states=#2[T10r3,T10r3])", - ) - x = torch.ones(2, 8, 16).to(torch.float16) - model = Model() - model(x, cache) - DYN = torch.export.Dim.DYNAMIC - - with torch_export_patches(patch_transformers=True): - cache = MambaCache(_config(), max_batch_size=2, device="cpu") - torch.export.export( - Model(), - (x, cache), - dynamic_shapes=({0: DYN}, [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]]), - ) - @ignore_warnings(UserWarning) def test_exportable_dynamic_shapes_constraints(self): import torch diff --git a/onnx_diagnostic/export/shape_helper.py b/onnx_diagnostic/export/shape_helper.py index 5bc45de1..39b51adb 100644 --- a/onnx_diagnostic/export/shape_helper.py +++ b/onnx_diagnostic/export/shape_helper.py @@ -46,7 +46,6 @@ def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any: from onnx_diagnostic.helpers.cache_helper import ( make_dynamic_cache, make_encoder_decoder_cache, - make_mamba_cache, make_static_cache, ) from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs @@ -84,13 +83,6 @@ def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any: ], max_cache_len=15, ), - make_mamba_cache( - [ - (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), - (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), - (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), - ] - ), ] with torch_export_patches(patch_transformers=True): diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 68713a59..d335f5e9 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -54,74 +54,6 @@ def reduce_model_config(config: Any) -> Dict[str, Any]: return kwargs -def _get_input_mamba( - model: torch.nn.Module, - config: Optional[Any], - dummy_max_token_id: int, - num_hidden_layers: int, - batch_size: int = 2, - sequence_length: int = 30, - sequence_length2: int = 3, - dynamic_rope: bool = False, - num_key_value_heads: Optional[int] = None, - head_dim: Optional[int] = None, - cls_cache: Optional[Union[type, str]] = None, - **kwargs, # unused -): - try: - from transformers.models.mamba.modeling_mamba import MambaCache - except ImportError: - from transformers.cache_utils import MambaCache - - assert cls_cache in ( - "MambaCache", - MambaCache, - ), f"Unexpected value for cls_cache={cls_cache} and config={config}" - - batch = "batch" - seq_length_multiple = 8 - sequence_length = ( - (sequence_length + seq_length_multiple) // seq_length_multiple * seq_length_multiple - ) - # sequence_inc = seq_length_multiple - sequence_length2 = seq_length_multiple - - shapes = { - "input_ids": {0: batch, 1: "sequence_length"}, - "attention_mask": { - 0: batch, - 1: "cache+seq", # cache_length + seq_length - }, - "cache_position": { - 0: batch, - 1: "cache+seq", # cache_length + seq_length - }, - "cache_params": [{0: batch} for _ in range(num_hidden_layers * 2)], - } - inputs = dict( - input_ids=torch.randint( - 0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2) - ).to(torch.int64), - attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to( - torch.int64 - ), - cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64), - # .expand((batch_size, -1)) - cache_params=make_mamba_cache( - [ - ( - torch.randn( - batch_size, kwargs["intermediate_size"], kwargs["conv_kernel"] - ), - torch.randn(batch_size, kwargs["intermediate_size"], kwargs["state_size"]), - ) - for i in range(num_hidden_layers) - ] - ), - ) - return dict(inputs=inputs, dynamic_shapes=shapes) - - def get_inputs( model: torch.nn.Module, config: Optional[Any], @@ -158,20 +90,7 @@ def get_inputs( cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) if config is not None and hasattr(config, "use_mambapy"): - res = _get_input_mamba( - model=model, - config=config, - dummy_max_token_id=dummy_max_token_id, - num_hidden_layers=num_hidden_layers, - batch_size=batch_size, - sequence_length=sequence_length, - sequence_length2=sequence_length2, - dynamic_rope=dynamic_rope, - num_key_value_heads=num_key_value_heads, - head_dim=head_dim, - cls_cache=cls_cache, - **kwargs, # unused - ) + raise NotImplementedError(f"Config {config} is nuot supported.") else: if head_dim is None: assert config, "head_dim is None, the value cannot be set without a configuration" diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index d70b1cc3..c8990456 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -16,20 +16,6 @@ PATCH_OF_PATCHES: Set[Any] = set() -def get_mamba_cache_cls() -> type: - try: - from transformers.models.mamba.modeling_mamba import MambaCache - - return MambaCache - except ImportError: - try: - from transformers.cache_utils import MambaCache - - return MambaCache - except ImportError: - return None - - def get_hybrid_cache_cls() -> type: try: from transformers.cache_utils import HybridCache @@ -226,23 +212,6 @@ def serialization_functions( verbose=verbose, ), } - MambaCache = get_mamba_cache_cls() - if MambaCache: - from .serialization.transformers_impl import ( - flatten_mamba_cache, - unflatten_mamba_cache, - flatten_with_keys_mamba_cache, - ) - - transformers_classes[MambaCache] = ( - lambda verbose=verbose: register_class_serialization( - MambaCache, - flatten_mamba_cache, - unflatten_mamba_cache, - flatten_with_keys_mamba_cache, - verbose=verbose, - ) - ) HybridCache = get_hybrid_cache_cls() if HybridCache: from .serialization.transformers_impl import ( diff --git a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py index cb716720..9a537c0d 100644 --- a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +++ b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py @@ -15,10 +15,6 @@ except ImportError: HybridCache = None -try: - from transformers.models.mamba.modeling_mamba import MambaCache -except ImportError: - from transformers.cache_utils import MambaCache from transformers.modeling_outputs import BaseModelOutput from ...helpers.cache_helper import make_dynamic_cache, make_static_cache, CacheKeyValue from . import make_serialization_function_for_dataclass @@ -302,68 +298,6 @@ def unflatten_encoder_decoder_cache( ) -############ -# MambaCache -############ - - -def flatten_mamba_cache( - mamba_cache: MambaCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" - assert isinstance(mamba_cache.conv_states, list) and isinstance( - mamba_cache.ssm_states, list - ), ( - f"Unexpected types for conv_states and ssm_states {type(mamba_cache.conv_states)}, " - f"{type(mamba_cache.ssm_states)}" - ) - flat = [ - ("conv_states", mamba_cache.conv_states), - ("ssm_states", mamba_cache.ssm_states), - ] - return [f[1] for f in flat], [f[0] for f in flat] - - -def unflatten_mamba_cache( - values: List[Any], context: torch.utils._pytree.Context, output_type=None -) -> MambaCache: - """Restores a :class:`transformers.cache_utils.MambaCache` from python objects.""" - conv_states, ssm_states = values - - class _config: - def __init__(self): - if isinstance(conv_states, list): - self.intermediate_size = conv_states[0].shape[1] - self.state_size = ssm_states[0].shape[2] - self.conv_kernel = conv_states[0].shape[2] - self.num_hidden_layers = len(conv_states) - else: - self.intermediate_size = conv_states.shape[2] - self.state_size = ssm_states.shape[3] - self.conv_kernel = conv_states.shape[3] - self.num_hidden_layers = conv_states.shape[0] - - cache = MambaCache( - _config(), - max_batch_size=1, - dtype=values[-1][0].dtype, - device="cpu" if values[-1][0].get_device() < 0 else "cuda", - ) - values = dict(zip(context, values)) - for k, v in values.items(): - setattr(cache, k, v) - return cache - - -def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[ - List[Tuple[torch.utils._pytree.KeyEntry, Any]], - torch.utils._pytree.Context, -]: - """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" - values, context = flatten_mamba_cache(cache) - return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context - - ############# # dataclasses ############# diff --git a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py index daee5225..3999135d 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py @@ -3774,45 +3774,6 @@ def _ccached_sentence_transformers_all_MiniLM_L6_v1(): ) -def _ccached_tiiuae_falcon_mamba_tiny_dev(): - "tiiuae/falcon-mamba-tiny-dev" - return transformers.FalconMambaConfig( - **{ - "architectures": ["FalconMambaForCausalLM"], - "bos_token_id": 0, - "conv_kernel": 4, - "eos_token_id": 11, - "expand": 16, - "hidden_act": "silu", - "hidden_size": 4096, - "initializer_range": 0.1, - "intermediate_size": 8192, - "layer_norm_epsilon": 1e-05, - "mixer_rms_eps": 1e-06, - "model_type": "falcon_mamba", - "num_hidden_layers": 64, - "pad_token_id": 11, - "rescale_prenorm_residual": false, - "residual_in_fp32": true, - "state_size": 16, - "tie_word_embeddings": false, - "time_step_floor": 0.0001, - "time_step_init_scheme": "random", - "time_step_max": 0.1, - "time_step_min": 0.001, - "time_step_rank": 256, - "time_step_scale": 1.0, - "torch_dtype": "bfloat16", - "transformers_version": "4.52.0.dev0", - "use_bias": false, - "use_cache": true, - "use_conv_bias": true, - "use_mambapy": false, - "vocab_size": 65024, - } - ) - - def _ccached_facebook_bart_base(): "facebook/bart-base" return transformers.BartConfig( From 2cffb06da5e3be408614aadc6b7e6e6eec86ac0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 2 Apr 2026 21:43:48 +0200 Subject: [PATCH 4/9] changes --- CHANGELOGS.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 154b4fd1..ed4c1242 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -6,6 +6,7 @@ Change Logs * :pr:`422`: add remove_inputs to InputObserver * :pr:`421`: fix a few patches for MoE +* :pr:`426`: remove MambaCache 0.9.2 +++++ From 25215c6bc48b3ad78720f6353c83ac16c30faae9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 2 Apr 2026 22:23:08 +0200 Subject: [PATCH 5/9] spell --- onnx_diagnostic/helpers/model_builder_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/helpers/model_builder_helper.py b/onnx_diagnostic/helpers/model_builder_helper.py index 02c06172..ce5b29b2 100644 --- a/onnx_diagnostic/helpers/model_builder_helper.py +++ b/onnx_diagnostic/helpers/model_builder_helper.py @@ -48,7 +48,7 @@ def download_model_builder_to_cache( "phi.py", "qwen.py", "smollm.py", - "whipser.py", + "whisper.py", ]: u = f"{'/'.join(url.split('/')[:-1])}/builders/{subfile}" response = requests.get(u) From e93c73cf24ee7fd653f1c7d3a10ed9dd94596522 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 2 Apr 2026 22:24:36 +0200 Subject: [PATCH 6/9] fix --- onnx_diagnostic/torch_onnx/runtime_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_onnx/runtime_info.py b/onnx_diagnostic/torch_onnx/runtime_info.py index 2aa30f65..b4af9a9d 100644 --- a/onnx_diagnostic/torch_onnx/runtime_info.py +++ b/onnx_diagnostic/torch_onnx/runtime_info.py @@ -130,7 +130,7 @@ def set_value(self, value: Union[torch.Tensor, TensorLike]): ) else: self.dtype = value.dtype - self.shape = None if is_sequence else tuple(map(int, value.shape)) + self.shape = None if is_sequence else tuple(map(int, value.shape)) # type: ignore def clean_value(self): """Sets value to None.""" From 37c516406991a0206f2342be6e71d3c47de35e97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 2 Apr 2026 23:17:26 +0200 Subject: [PATCH 7/9] bug --- _unittests/ut_torch_models/test_tiny_llms.py | 11 +- .../_patch_transformers_masking_utils.py | 119 ++++++++++++------ 2 files changed, 88 insertions(+), 42 deletions(-) diff --git a/_unittests/ut_torch_models/test_tiny_llms.py b/_unittests/ut_torch_models/test_tiny_llms.py index ac37a7b7..9203a8bc 100644 --- a/_unittests/ut_torch_models/test_tiny_llms.py +++ b/_unittests/ut_torch_models/test_tiny_llms.py @@ -51,17 +51,18 @@ def test_tiny_llm_run_static(self): @requires_torch("2.8") def test_tiny_llm_export_static(self): data = get_tiny_llm(use_static_cache=True) - model, inputs = data["model"], data["inputs"] + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + if "cache_position" in inputs: + del inputs["cache_position"] + del ds["cache_position"] expected = model(**copy.deepcopy(inputs)) - self.assertEqual( - {"attention_mask", "past_key_values", "input_ids", "cache_position"}, set(inputs) - ) + self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs)) with torch_export_patches(patch_transformers=True, stop_if_static=0): ep = torch.export.export( model, (), kwargs=copy.deepcopy(inputs), - dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]), + dynamic_shapes=use_dyn_not_str(ds), ) got = ep.module()(**inputs) self.assertEqualArrayAny(expected, got) diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py index 46e2b33f..744a931a 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py @@ -146,61 +146,106 @@ def patched_eager_mask( return mask def patched_sdpa_mask_recent_torch( - batch_size: int, - cache_position: torch.Tensor, - kv_length: int, + batch_size: int = 0, + q_length: int = 0, + kv_length: int = 0, + q_offset: int = 0, kv_offset: int = 0, mask_function: Callable = causal_mask_function, attention_mask: Optional[torch.Tensor] = None, local_size: Optional[int] = None, allow_is_causal_skip: bool = True, allow_is_bidirectional_skip: bool = False, + use_vmap: bool = False, + device: torch.device | str = "cpu", **kwargs, ) -> Optional[torch.Tensor]: """manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``.""" - q_length = cache_position.shape[0] - padding_mask = prepare_padding_mask( - attention_mask, kv_length, kv_offset, **_prepare_padding_mask_kwargs - ) - if allow_is_causal_skip and _ignore_causal_mask_sdpa( - padding_mask, q_length, kv_length, kv_offset, local_size - ): - return None - if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa: - # transformers<=5.0: 1 parameter, 3 for transformers>5.0 - n_parameters = len(inspect.signature(_ignore_bidirectional_mask_sdpa).parameters) - if _ignore_bidirectional_mask_sdpa( - *[padding_mask, kv_length, kv_offset][:n_parameters] + if isinstance(q_length, torch.Tensor): + # `cache_position` is deprecated as an arg, + # and will be removed in Transformers v5.6. Please use `q_length` and " + # `q_offset` instead, similarly to `kv_length` and `kv_offset`" + q_length, q_offset = q_length.shape[0], q_length[0].to(device) + device = q_length.device + + padding_mask = prepare_padding_mask( + attention_mask, kv_length, kv_offset, **_prepare_padding_mask_kwargs + ) + if allow_is_causal_skip and _ignore_causal_mask_sdpa( + padding_mask, q_length, kv_length, kv_offset, local_size ): return None + if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa: + # transformers<=5.0: 1 parameter, 3 for transformers>5.0 + n_parameters = len( + inspect.signature(_ignore_bidirectional_mask_sdpa).parameters + ) + if _ignore_bidirectional_mask_sdpa( + *[padding_mask, kv_length, kv_offset][:n_parameters] + ): + return None - if mask_function is bidirectional_mask_function: - if padding_mask is not None: - # used for slicing without data-dependent slicing - mask_indices = ( - torch.arange(kv_length, device=cache_position.device) + kv_offset + if mask_function is bidirectional_mask_function: + if padding_mask is not None: + # used for slicing without data-dependent slicing + mask_indices = torch.arange(kv_length, device=device) + kv_offset + return padding_mask[:, None, None, mask_indices].expand( + -1, -1, q_length, -1 + ) + return torch.ones( + batch_size, + 1, + q_length, + kv_length, + dtype=torch.bool, + device=device, ) - return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1) - return torch.ones( - batch_size, - 1, - q_length, - kv_length, - dtype=torch.bool, - device=cache_position.device, + + kv_arange = torch.arange(kv_length, device=device) + kv_arange += kv_offset + if padding_mask is not None: + mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) + batch_arange = torch.arange(batch_size, device=device) + head_arange = torch.arange(1, device=device) + # PATCHED: this line calls the patched version of vmap_for_bhqkv + causal_mask = patched__vmap_for_bhqkv(mask_function)( + batch_arange, head_arange, q_length, kv_arange ) + return causal_mask + + padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset) + + # Under specific conditions, we can avoid materializing the mask + # 1. Causal masks can rely on the `is_causal` argument + # 2. Bidirectional do not need any further processing (no bias) + if allow_is_causal_skip and _ignore_causal_mask_sdpa( + padding_mask, q_length, kv_length, kv_offset, local_size + ): + return None + if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa( + padding_mask, kv_length, local_size + ): + return None - kv_arange = torch.arange(kv_length, device=cache_position.device) - kv_arange += kv_offset + # Potentially add the padding 2D mask if padding_mask is not None: mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) - batch_arange = torch.arange(batch_size, device=cache_position.device) - head_arange = torch.arange(1, device=cache_position.device) - # PATCHED: this line calls the patched version of vmap_for_bhqkv - causal_mask = patched__vmap_for_bhqkv(mask_function)( - batch_arange, head_arange, cache_position, kv_arange + + batch_arange = torch.arange(batch_size, device=device) + head_arange = torch.arange(1, device=device) + q_arange = torch.arange(q_length, device=device) + q_offset + kv_arange = torch.arange(kv_length, device=device) + kv_offset + + # Actual mask creation + # Option 1: Fast non-vmap mask creation (default) + # Apply mask function element-wise through broadcasting + attention_mask = mask_function( + *_non_vmap_expansion_sdpa(batch_arange, head_arange, q_arange, kv_arange) ) - return causal_mask + # Expand the mask to match batch size and + # query length if they weren't used in the mask function + attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length) + return attention_mask def patched_sdpa_mask( batch_size: int, From 5d39ae1b83bca35c4f6607abbec60735b71c9cd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 3 Apr 2026 14:06:50 +0200 Subject: [PATCH 8/9] fix --- _unittests/ut_export/test_api.py | 2 +- .../test_patch_transformers.py | 4 +- .../_patch_transformers_masking_utils.py | 134 ++++++++++++------ 3 files changed, 92 insertions(+), 48 deletions(-) diff --git a/_unittests/ut_export/test_api.py b/_unittests/ut_export/test_api.py index b1395cf3..55ae4739 100644 --- a/_unittests/ut_export/test_api.py +++ b/_unittests/ut_export/test_api.py @@ -46,7 +46,7 @@ def forward(self, x, y): @hide_stdout() @ignore_warnings(FutureWarning) - @requires_transformers("4.50") + @requires_transformers("4.57") def test_tiny_llm_to_onnx(self): import onnxruntime diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 66bbc6fa..a361c062 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -74,7 +74,7 @@ def test_sdpa_mask_patched(self): patched_sdpa_mask = patch_transformers.patched_sdpa_mask kwargs = { "batch_size": 1, - "cache_position": torch.tensor([3], dtype=torch.int64), + "q_length": torch.tensor([3], dtype=torch.int64), "kv_length": 4, "kv_offset": 0, "mask_function": transformers.masking_utils.causal_mask_function, @@ -89,7 +89,7 @@ def test_sdpa_mask_patched(self): kwargs = { "batch_size": 1, - "cache_position": torch.tensor([3], dtype=torch.int64), + "q_length": torch.tensor([3], dtype=torch.int64), "kv_length": 4, "kv_offset": 0, "mask_function": transformers.masking_utils.causal_mask_function, diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py index 744a931a..4c5a7ccd 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py @@ -248,9 +248,10 @@ def patched_sdpa_mask_recent_torch( return attention_mask def patched_sdpa_mask( - batch_size: int, - cache_position: torch.Tensor, - kv_length: int, + batch_size: int = 0, + q_length: int = 0, + kv_length: int = 0, + q_offset: int = 0, kv_offset: int = 0, mask_function: Callable = causal_mask_function, attention_mask: torch.Tensor | None = None, @@ -262,7 +263,79 @@ def patched_sdpa_mask( **kwargs, ) -> torch.Tensor | None: """manual patch for function ``transformers.masking_utils.sdpa_mask``.""" - q_length = cache_position.shape[0] + if isinstance(q_length, torch.Tensor): + # `cache_position` is deprecated as an arg, + # and will be removed in Transformers v5.6. Please use `q_length` and " + # `q_offset` instead, similarly to `kv_length` and `kv_offset`" + cache_position = q_length + device = q_length.device + q_length = q_length.shape[0] + + # Potentially pad the 2D mask + padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset) + + # Under specific conditions, we can avoid materializing the mask + # 1. Causal masks can rely on the `is_causal` argument + # 2. Bidirectional do not need any further processing (no bias) + if allow_is_causal_skip and _ignore_causal_mask_sdpa( + padding_mask, q_length, kv_length, kv_offset, local_size + ): + return None + if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa( + padding_mask, kv_length, local_size + ): + return None + + # Potentially add the padding 2D mask + if padding_mask is not None: + mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) + + batch_arange = torch.arange(batch_size, device=device) + head_arange = torch.arange(1, device=device) + # Similar to `kv_arange = torch.arange(start=kv_offset, + # end=kv_offset + kv_length, device=cache_position.device)` + # but without data-dependent slicing (i.e. torch.compile friendly) + kv_arange = torch.arange(kv_length, device=device) + kv_offset + + # Actual mask creation + # Option 1: Fast non-vmap mask creation (default) + # PATCHED + use_vmap = False + if not use_vmap: + # Apply mask function element-wise through broadcasting + attention_mask = mask_function( + *_non_vmap_expansion_sdpa( + batch_arange, head_arange, cache_position, kv_arange + ) + ) + # Expand the mask to match batch size + # and query length if they weren't used in the mask function + attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length) + + # Option 2: Vmap mask creation (torch>=2.6 and custom patterns) + # elif _is_torch_greater_or_equal_than_2_6: + # This creates the 4D mask easily. + # Note that we need this context manager as vmap + # cannot handle slicing a tensor from + # scalar tensor (it internally calls `.item()` which vmap does not allow, + # but this context works around it + # We don't need to add an offset to the mask_function either, + # as we vmap directly the correct indices for k and kv indices + # with TransformGetItemToIndex(): + # attention_mask = _vmap_expansion_sdpa(mask_function)( + # batch_arange, head_arange, cache_position, kv_arange + # ) + + # Option 3: Error out since it indicates that the user did something custom, + # which they shouldn't have (torch<2.6) + else: + raise ValueError( + "The vmap functionality for mask creation " + "is only supported from torch>=2.6. " + "Please update your torch version or use " + "`use_vmap=False` with index-based masks." + ) + return attention_mask # Potentially pad the 2D mask padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset) @@ -283,46 +356,17 @@ def patched_sdpa_mask( if padding_mask is not None: mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) - batch_arange = torch.arange(batch_size, device=cache_position.device) - head_arange = torch.arange(1, device=cache_position.device) - # Similar to `kv_arange = torch.arange(start=kv_offset, - # end=kv_offset + kv_length, device=cache_position.device)` - # but without data-dependent slicing (i.e. torch.compile friendly) - kv_arange = torch.arange(kv_length, device=cache_position.device) + kv_offset + batch_arange = torch.arange(batch_size, device=device) + head_arange = torch.arange(1, device=device) + q_arange = torch.arange(q_length, device=device) + q_offset + kv_arange = torch.arange(kv_length, device=device) + kv_offset + + # Apply mask function element-wise through broadcasting + attention_mask = mask_function( + *_non_vmap_expansion_sdpa(batch_arange, head_arange, q_arange, kv_arange) + ) + # Expand the mask to match batch size and query + # length if they weren't used in the mask function + attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length) - # Actual mask creation - # Option 1: Fast non-vmap mask creation (default) - # PATCHED - use_vmap = False - if not use_vmap: - # Apply mask function element-wise through broadcasting - attention_mask = mask_function( - *_non_vmap_expansion_sdpa(batch_arange, head_arange, cache_position, kv_arange) - ) - # Expand the mask to match batch size - # and query length if they weren't used in the mask function - attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length) - - # Option 2: Vmap mask creation (torch>=2.6 and custom patterns) - # elif _is_torch_greater_or_equal_than_2_6: - # This creates the 4D mask easily. - # Note that we need this context manager as vmap cannot handle slicing a tensor from - # scalar tensor (it internally calls `.item()` which vmap does not allow, - # but this context works around it - # We don't need to add an offset to the mask_function either, - # as we vmap directly the correct indices for k and kv indices - # with TransformGetItemToIndex(): - # attention_mask = _vmap_expansion_sdpa(mask_function)( - # batch_arange, head_arange, cache_position, kv_arange - # ) - - # Option 3: Error out since it indicates that the user did something custom, - # which they shouldn't have (torch<2.6) - else: - raise ValueError( - "The vmap functionality for mask creation " - "is only supported from torch>=2.6. " - "Please update your torch version or use " - "`use_vmap=False` with index-based masks." - ) return attention_mask From 9f913f1cc3bc4e098704a0b2d9709286a61fc8d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 3 Apr 2026 14:35:11 +0200 Subject: [PATCH 9/9] fix --- .../test_patch_transformers.py | 77 ------------------- .../_patch_transformers_masking_utils.py | 2 +- 2 files changed, 1 insertion(+), 78 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index a361c062..44c4d327 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -102,83 +102,6 @@ def test_sdpa_mask_patched(self): got = patched_sdpa_mask(**kwargs) self.assertEqualArray(expected, got) - @requires_transformers("4.99") - def test_sdpa_mask_recent_torch_is_running(self): - def _copy_vmap_for_bhqkv(mask_function, bh_indices=True): - dimensions = [(None, None, None, 0), (None, None, 0, None)] - if bh_indices: - dimensions.extend([(None, 0, None, None), (0, None, None, None)]) - for dims in dimensions: - mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0) - return mask_function - - def copy_of_sdpa_mask_recent_torch( - batch_size, - cache_position, - kv_length, - kv_offset=0, - mask_function=transformers.masking_utils.causal_mask_function, - attention_mask=None, - local_size=None, - allow_is_causal_skip=True, - **kwargs, - ): - q_length = cache_position.shape[0] - padding_mask = transformers.masking_utils.prepare_padding_mask( - attention_mask, kv_length, kv_offset - ) - if allow_is_causal_skip and transformers.masking_utils._ignore_causal_mask_sdpa( - padding_mask, q_length, kv_length, kv_offset, local_size - ): - return None - kv_arange = torch.arange(kv_length, device=cache_position.device) - kv_arange += kv_offset - if padding_mask is not None: - mask_function = transformers.masking_utils.and_masks( - mask_function, - transformers.masking_utils.padding_mask_function(padding_mask), - ) - - batch_arange = torch.arange(batch_size, device=cache_position.device) - head_arange = torch.arange(1, device=cache_position.device) - with transformers.masking_utils.TransformGetItemToIndex(): - causal_mask = _copy_vmap_for_bhqkv(mask_function)( - batch_arange, head_arange, cache_position, kv_arange - ) - return causal_mask - - sdpa_mask_recent_torch = copy_of_sdpa_mask_recent_torch - patched_sdpa_mask_recent_torch = patch_transformers.patched_sdpa_mask_recent_torch - kwargs = { - "batch_size": 1, - "cache_position": torch.tensor([3], dtype=torch.int64), - "kv_length": 4, - "kv_offset": 0, - "mask_function": transformers.masking_utils.causal_mask_function, - "attention_mask": torch.tensor([[True, True, True, True]]), - "local_size": None, - "allow_is_causal_skip": True, - "allow_is_bidirectional_skip": False, - } - expected = sdpa_mask_recent_torch(**kwargs) - got = patched_sdpa_mask_recent_torch(**kwargs) - self.assertEqual(expected, got) - - kwargs = { - "batch_size": 1, - "cache_position": torch.tensor([3], dtype=torch.int64), - "kv_length": 4, - "kv_offset": 0, - "mask_function": transformers.masking_utils.causal_mask_function, - "attention_mask": torch.tensor([[True, True, True, True]]), - "local_size": None, - "allow_is_causal_skip": False, - "allow_is_bidirectional_skip": False, - } - expected = sdpa_mask_recent_torch(**kwargs) - got = patched_sdpa_mask_recent_torch(**kwargs) - self.assertEqualArray(expected, got) - def test_sdpa_attention_forward_not_causal(self): sdpa_attention_forward = sdpa_attention.sdpa_attention_forward patched_sdpa_attention_forward = patch_transformers.patched_sdpa_attention_forward diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py index 4c5a7ccd..c0d94a2e 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py @@ -165,8 +165,8 @@ def patched_sdpa_mask_recent_torch( # `cache_position` is deprecated as an arg, # and will be removed in Transformers v5.6. Please use `q_length` and " # `q_offset` instead, similarly to `kv_length` and `kv_offset`" - q_length, q_offset = q_length.shape[0], q_length[0].to(device) device = q_length.device + q_length, q_offset = q_length.shape[0], q_length[0].to(device) padding_mask = prepare_padding_mask( attention_mask, kv_length, kv_offset, **_prepare_padding_mask_kwargs