From eae429b9ea2aa04a5c54f5114e9763b2413517c6 Mon Sep 17 00:00:00 2001 From: qq_30035749 Date: Thu, 21 May 2026 17:05:49 +0800 Subject: [PATCH 1/7] feat: add support for forward methods with incompatible kwargs Add `_call_with_supported_kwargs` utility to filter out unsupported keyword arguments when calling forward methods, preventing errors from incompatible function signatures. This fixes issues where `origin_forward` methods may not accept all passed kwargs. --- src/twinkle/patch/gdn_padding_free.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/twinkle/patch/gdn_padding_free.py b/src/twinkle/patch/gdn_padding_free.py index d34b7ec7..ccd4c9b2 100644 --- a/src/twinkle/patch/gdn_padding_free.py +++ b/src/twinkle/patch/gdn_padding_free.py @@ -1,6 +1,8 @@ +import inspect +from typing import Optional + import torch from transformers.utils.import_utils import is_flash_linear_attention_available -from typing import Optional from twinkle.patch import Patch @@ -33,6 +35,13 @@ def _get_flash_linear_attention_kernels(): return causal_conv1d, chunk_gated_delta_rule +def _call_with_supported_kwargs(fn, *args, **kwargs): + signature = inspect.signature(fn) + if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): + kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters} + return fn(*args, **kwargs) + + def _patch_gdn_kernels_for_cu_seqlens( mod: torch.nn.Module, *, @@ -64,7 +73,7 @@ def chunk_gated_delta_rule_wrapper(query, key, value, **kwargs): mod.causal_conv1d_fn = causal_conv1d_wrapper mod.chunk_gated_delta_rule = chunk_gated_delta_rule_wrapper try: - return origin_forward(mod, *forward_args, **forward_kwargs) + return _call_with_supported_kwargs(origin_forward, mod, *forward_args, **forward_kwargs) finally: mod.causal_conv1d_fn = old_conv_fn mod.chunk_gated_delta_rule = old_chunk_rule @@ -147,7 +156,8 @@ def forward( **extra_kwargs, ): if cu_seq_lens_q is None: - return origin_forward( + return _call_with_supported_kwargs( + origin_forward, mod, hidden_states, cache_params=cache_params, From aa9bf73b54e0f1f1af16c2b06e34fa0e5ce1bf6e Mon Sep 17 00:00:00 2001 From: qq_30035749 Date: Fri, 22 May 2026 00:24:29 +0800 Subject: [PATCH 2/7] fix: support native padding-free in GatedDeltaNet and improve kwargs handling - Add `_call_with_supported_kwargs` and `_call_create_causal_mask` helpers to filter unsupported kwargs - Rename `cache_position` parameter to `q_length` in flash_attention_mask and sdpa_mask for clarity - Fix device detection in sdpa_mask when `q_length` is not a tensor - Ensure compatibility with models that don't accept `cache_position` in causal mask functions --- .../strategy/sequence_parallel/__init__.py | 128 ++++++++++++------ src/twinkle/patch/gdn_padding_free.py | 15 +- 2 files changed, 99 insertions(+), 44 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py index 51a28015..bbe06873 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import inspect import math import torch import torch.distributed as dist @@ -28,6 +29,38 @@ def is_qwen3_omni(model): return 'qwen3_omni' in mt +def _call_with_supported_kwargs(fn, *args, **kwargs): + signature = inspect.signature(fn) + if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): + kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters} + return fn(*args, **kwargs) + + +def _call_create_causal_mask(fn, config, input_embeds, attention_mask, cache_position_or_past_key_values, *args, + **kwargs): + if 'cache_position' in inspect.signature(fn).parameters: + return _call_with_supported_kwargs( + fn, + config, + input_embeds, + attention_mask, + cache_position_or_past_key_values, + *args, + **kwargs, + ) + if cache_position_or_past_key_values is None and 'past_key_values' in kwargs: + return _call_with_supported_kwargs(fn, config, input_embeds, attention_mask, *args, **kwargs) + return _call_with_supported_kwargs( + fn, + config, + input_embeds, + attention_mask, + cache_position_or_past_key_values, + *args, + **kwargs, + ) + + # main content copied from ms-swift class SequenceParallel: @@ -77,59 +110,72 @@ def _prepare_flash_attn(self, base_model: torch.nn.Module): try: from transformers import masking_utils - _origin_flash_attention_mask = masking_utils.flash_attention_mask - - # Patch attention masks for SP: avoid masking when full sequence is reconstructed. - def flash_attention_mask(batch_size, - cache_position, - kv_length, - kv_offset=0, - mask_function=masking_utils.causal_mask_function, - attention_mask=None, - **kwargs): - if self.world_size == 1: - return _origin_flash_attention_mask(batch_size, cache_position, kv_length, kv_offset, mask_function, - attention_mask, **kwargs) - if attention_mask is not None: - if attention_mask.all(): - attention_mask = None - - return attention_mask + def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs): + origin_sdpa = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'] + origin_uses_cache_position = 'cache_position' in inspect.signature(origin_sdpa).parameters + q_length = q_length if q_length is not None else kwargs.pop('cache_position', None) + device = q_length.device if torch.is_tensor(q_length) else kwargs.get('device') + if device is None: + device = self.real_position_ids.device + + cache_position = None + if self.world_size > 1 and origin_uses_cache_position: + padded_position_ids = self.pad( + self.real_position_ids[0], + padding_value=-1, + position_ids=self.real_position_ids, + dim=0, + ) + cache_position = torch.arange(0, padded_position_ids.shape[0], device=device) + kv_length = cache_position.shape[0] - masking_utils.flash_attention_mask = flash_attention_mask - masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['flash_attention_2'] = flash_attention_mask + if origin_uses_cache_position: + if cache_position is None: + cache_position = q_length if torch.is_tensor(q_length) else torch.arange( + q_length, device=device) + return origin_sdpa(batch_size, cache_position, kv_length, *args, **kwargs) - def sdpa_mask(batch_size, cache_position, kv_length, *args, **kwargs): - if self.world_size == 1: - return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size, - cache_position, - kv_length, *args, - **kwargs) - device = cache_position.device - cache_position = self.real_position_ids[0] - cache_position = self.pad(cache_position, padding_value=-1, position_ids=self.real_position_ids, dim=0) - cache_position = torch.arange(0, cache_position.shape[0], device=device) - kv_length = cache_position.shape[0] - return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size, - cache_position, - kv_length, *args, - **kwargs) + return origin_sdpa(batch_size, q_length, kv_length, *args, device=device, **kwargs) masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[ 'sdpa_origin'] = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask - def create_causal_mask(config, input_embeds, attention_mask, cache_position, *args, **kwargs): + def create_causal_mask(config, + input_embeds, + attention_mask, + cache_position_or_past_key_values=None, + *args, + **kwargs): if self.world_size == 1: - return masking_utils.origin_create_causal_mask(config, input_embeds, attention_mask, cache_position, - *args, **kwargs) + return _call_create_causal_mask( + masking_utils.origin_create_causal_mask, + config, + input_embeds, + attention_mask, + cache_position_or_past_key_values, + *args, + **kwargs, + ) input_embeds = torch.ones( (input_embeds.shape[0], input_embeds.shape[1] * self.sp_world_size, input_embeds.shape[2]), dtype=input_embeds.dtype, device=input_embeds.device) - cache_position = torch.arange(0, input_embeds.shape[1], device=input_embeds.device) - return masking_utils.origin_create_causal_mask(config, input_embeds, attention_mask, cache_position, - *args, **kwargs) + if 'cache_position' in inspect.signature(masking_utils.origin_create_causal_mask).parameters: + cache_position_or_past_key_values = torch.arange( + 0, + input_embeds.shape[1], + device=input_embeds.device, + ) + return _call_create_causal_mask( + masking_utils.origin_create_causal_mask, + config, + input_embeds, + attention_mask, + cache_position_or_past_key_values, + *args, + **kwargs, + ) masking_utils.origin_create_causal_mask = masking_utils.create_causal_mask masking_utils.create_causal_mask = create_causal_mask diff --git a/src/twinkle/patch/gdn_padding_free.py b/src/twinkle/patch/gdn_padding_free.py index ccd4c9b2..dde8bdf7 100644 --- a/src/twinkle/patch/gdn_padding_free.py +++ b/src/twinkle/patch/gdn_padding_free.py @@ -1,8 +1,7 @@ import inspect -from typing import Optional - import torch from transformers.utils.import_utils import is_flash_linear_attention_available +from typing import Optional from twinkle.patch import Patch @@ -42,6 +41,13 @@ def _call_with_supported_kwargs(fn, *args, **kwargs): return fn(*args, **kwargs) +def _supports_native_padding_free(Qwen3_5GatedDeltaNet) -> bool: + try: + return 'cu_seq_lens_q' in inspect.getsource(Qwen3_5GatedDeltaNet.forward) + except (OSError, TypeError): + return False + + def _patch_gdn_kernels_for_cu_seqlens( mod: torch.nn.Module, *, @@ -93,6 +99,8 @@ def __call__(self, module, *args, **kwargs): if getattr(Qwen3_5GatedDeltaNet, '_twinkle_sp_linear_patched', False): return module._twinkle_gdn_padding_free_patched = True + if _supports_native_padding_free(Qwen3_5GatedDeltaNet): + return if not getattr(Qwen3_5DecoderLayer, '_twinkle_padding_free_cu_seqlens_patched', False): origin_decoder_forward = Qwen3_5DecoderLayer.forward @@ -108,7 +116,8 @@ def decoder_forward( **extra_kwargs, ): if getattr(layer, 'layer_type', None) != 'linear_attention': - return origin_decoder_forward( + return _call_with_supported_kwargs( + origin_decoder_forward, layer, hidden_states=hidden_states, position_embeddings=position_embeddings, From 584ad663d7912ea127e3b74277aa998c80702d8a Mon Sep 17 00:00:00 2001 From: qq_30035749 Date: Fri, 22 May 2026 18:56:38 +0800 Subject: [PATCH 3/7] fix(sequence_parallel): restore global query length for no-cache prefill path In sequence parallel training, when newer Transformers versions pass q_length/q_offset instead of cache_position, the causal mask creation may still see the local shard length. This change restores the global query length for the no-cache prefill path while keeping cache/sliding paths with their upstream offsets. Also refactor GDN padding-free detection to use transformers version check instead of source inspection, supporting transformers >= 5.9.0. --- .../strategy/sequence_parallel/__init__.py | 25 ++++++++++++++++--- src/twinkle/patch/gdn_padding_free.py | 11 ++++---- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py index bbe06873..08f5b990 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py @@ -114,20 +114,37 @@ def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs): origin_sdpa = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'] origin_uses_cache_position = 'cache_position' in inspect.signature(origin_sdpa).parameters q_length = q_length if q_length is not None else kwargs.pop('cache_position', None) - device = q_length.device if torch.is_tensor(q_length) else kwargs.get('device') + device = q_length.device if torch.is_tensor(q_length) else kwargs.pop('device', None) if device is None: device = self.real_position_ids.device cache_position = None - if self.world_size > 1 and origin_uses_cache_position: + if self.world_size > 1: padded_position_ids = self.pad( self.real_position_ids[0], padding_value=-1, position_ids=self.real_position_ids, dim=0, ) - cache_position = torch.arange(0, padded_position_ids.shape[0], device=device) - kv_length = cache_position.shape[0] + global_length = padded_position_ids.shape[0] + if origin_uses_cache_position: + cache_position = torch.arange(0, global_length, device=device) + kv_length = global_length + else: + # Newer Transformers passes q_length/q_offset instead of cache_position. In SP training, + # create_causal_mask may still see the local shard length, so restore the global query length + # only for the no-cache prefill path; cache/sliding paths keep their upstream offsets. + q_offset = kwargs.get('q_offset', 0) + kv_offset = kwargs.get('kv_offset', 0) + no_cache_offsets = ((not torch.is_tensor(q_offset) and q_offset == 0) + and (not torch.is_tensor(kv_offset) and kv_offset == 0)) + if no_cache_offsets: + q_length = global_length + attention_mask = kwargs.get('attention_mask') + if attention_mask is not None and torch.is_tensor(attention_mask): + kv_length = attention_mask.shape[-1] + else: + kv_length = global_length if origin_uses_cache_position: if cache_position is None: diff --git a/src/twinkle/patch/gdn_padding_free.py b/src/twinkle/patch/gdn_padding_free.py index dde8bdf7..1026a175 100644 --- a/src/twinkle/patch/gdn_padding_free.py +++ b/src/twinkle/patch/gdn_padding_free.py @@ -1,5 +1,7 @@ import inspect import torch +from packaging.version import Version +from transformers import __version__ as transformers_version from transformers.utils.import_utils import is_flash_linear_attention_available from typing import Optional @@ -41,11 +43,8 @@ def _call_with_supported_kwargs(fn, *args, **kwargs): return fn(*args, **kwargs) -def _supports_native_padding_free(Qwen3_5GatedDeltaNet) -> bool: - try: - return 'cu_seq_lens_q' in inspect.getsource(Qwen3_5GatedDeltaNet.forward) - except (OSError, TypeError): - return False +def _supports_native_padding_free() -> bool: + return Version(Version(transformers_version).base_version) >= Version('5.9.0') def _patch_gdn_kernels_for_cu_seqlens( @@ -99,7 +98,7 @@ def __call__(self, module, *args, **kwargs): if getattr(Qwen3_5GatedDeltaNet, '_twinkle_sp_linear_patched', False): return module._twinkle_gdn_padding_free_patched = True - if _supports_native_padding_free(Qwen3_5GatedDeltaNet): + if _supports_native_padding_free(): return if not getattr(Qwen3_5DecoderLayer, '_twinkle_padding_free_cu_seqlens_patched', False): From 93eea384397250a08fc8fae444fabf62b642b53e Mon Sep 17 00:00:00 2001 From: qq_30035749 Date: Sun, 24 May 2026 15:32:53 +0800 Subject: [PATCH 4/7] fix: remove native padding-free version check for GDN patch The version check for transformers >= 5.9.0 was removed because it is no longer needed. The GDN padding-free patch should always be applied regardless of the transformers version, as the native support check is handled elsewhere or the patch is required for all versions. --- src/twinkle/patch/gdn_padding_free.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/twinkle/patch/gdn_padding_free.py b/src/twinkle/patch/gdn_padding_free.py index 1026a175..ab3deffb 100644 --- a/src/twinkle/patch/gdn_padding_free.py +++ b/src/twinkle/patch/gdn_padding_free.py @@ -1,7 +1,5 @@ import inspect import torch -from packaging.version import Version -from transformers import __version__ as transformers_version from transformers.utils.import_utils import is_flash_linear_attention_available from typing import Optional @@ -43,10 +41,6 @@ def _call_with_supported_kwargs(fn, *args, **kwargs): return fn(*args, **kwargs) -def _supports_native_padding_free() -> bool: - return Version(Version(transformers_version).base_version) >= Version('5.9.0') - - def _patch_gdn_kernels_for_cu_seqlens( mod: torch.nn.Module, *, @@ -98,8 +92,6 @@ def __call__(self, module, *args, **kwargs): if getattr(Qwen3_5GatedDeltaNet, '_twinkle_sp_linear_patched', False): return module._twinkle_gdn_padding_free_patched = True - if _supports_native_padding_free(): - return if not getattr(Qwen3_5DecoderLayer, '_twinkle_padding_free_cu_seqlens_patched', False): origin_decoder_forward = Qwen3_5DecoderLayer.forward From a0443e5b78fad0799e37075bcaf94c3cef704724 Mon Sep 17 00:00:00 2001 From: qq_30035749 Date: Mon, 25 May 2026 11:04:13 +0800 Subject: [PATCH 5/7] fix: conditionally patch chunk_gated_delta_rule for transformers < 5.9.0 Add version check to only apply the chunk_gated_delta_rule cu_seqlens patch when using transformers versions below 5.9.0, preventing compatibility issues with newer releases. --- src/twinkle/patch/gdn_padding_free.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/twinkle/patch/gdn_padding_free.py b/src/twinkle/patch/gdn_padding_free.py index ab3deffb..9198daf3 100644 --- a/src/twinkle/patch/gdn_padding_free.py +++ b/src/twinkle/patch/gdn_padding_free.py @@ -1,5 +1,7 @@ import inspect import torch +import transformers +from packaging.version import Version from transformers.utils.import_utils import is_flash_linear_attention_available from typing import Optional @@ -41,10 +43,15 @@ def _call_with_supported_kwargs(fn, *args, **kwargs): return fn(*args, **kwargs) +def _needs_chunk_gated_delta_rule_cu_seqlens_patch() -> bool: + return Version(transformers.__version__) < Version('5.9.0') + + def _patch_gdn_kernels_for_cu_seqlens( mod: torch.nn.Module, *, cu_seqlens: torch.Tensor, + patch_chunk_rule: bool, origin_forward, forward_args, forward_kwargs, @@ -70,12 +77,14 @@ def chunk_gated_delta_rule_wrapper(query, key, value, **kwargs): return chunk_gated_delta_rule(query, key, value, **kwargs) mod.causal_conv1d_fn = causal_conv1d_wrapper - mod.chunk_gated_delta_rule = chunk_gated_delta_rule_wrapper + if patch_chunk_rule: + mod.chunk_gated_delta_rule = chunk_gated_delta_rule_wrapper try: return _call_with_supported_kwargs(origin_forward, mod, *forward_args, **forward_kwargs) finally: mod.causal_conv1d_fn = old_conv_fn - mod.chunk_gated_delta_rule = old_chunk_rule + if patch_chunk_rule: + mod.chunk_gated_delta_rule = old_chunk_rule class GatedDeltaNetPaddingFreePatch(Patch): @@ -145,6 +154,7 @@ def decoder_forward( if not getattr(Qwen3_5GatedDeltaNet, '_twinkle_padding_free_gdn_patched', False): origin_forward = Qwen3_5GatedDeltaNet.forward + patch_chunk_rule = _needs_chunk_gated_delta_rule_cu_seqlens_patch() def forward( mod, @@ -168,12 +178,14 @@ def forward( return _patch_gdn_kernels_for_cu_seqlens( mod, cu_seqlens=cu_seq_lens_q, + patch_chunk_rule=patch_chunk_rule, origin_forward=origin_forward, forward_args=(hidden_states, ), forward_kwargs={ 'cache_params': cache_params, 'cache_position': cache_position, 'attention_mask': attention_mask, + 'cu_seq_lens_q': cu_seq_lens_q, **extra_kwargs, }, ) From 1822cddf41139aac92430dfc80bf31fc6b2d8ca3 Mon Sep 17 00:00:00 2001 From: qq_30035749 Date: Mon, 25 May 2026 16:27:38 +0800 Subject: [PATCH 6/7] fix gemini code review --- .../strategy/sequence_parallel/__init__.py | 8 ++++---- src/twinkle/patch/gdn_padding_free.py | 15 ++++++++++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py index 08f5b990..21b3836a 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py @@ -110,9 +110,10 @@ def _prepare_flash_attn(self, base_model: torch.nn.Module): try: from transformers import masking_utils + origin_sdpa = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] + origin_uses_cache_position = 'cache_position' in inspect.signature(origin_sdpa).parameters + def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs): - origin_sdpa = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'] - origin_uses_cache_position = 'cache_position' in inspect.signature(origin_sdpa).parameters q_length = q_length if q_length is not None else kwargs.pop('cache_position', None) device = q_length.device if torch.is_tensor(q_length) else kwargs.pop('device', None) if device is None: @@ -154,8 +155,7 @@ def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs): return origin_sdpa(batch_size, q_length, kv_length, *args, device=device, **kwargs) - masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[ - 'sdpa_origin'] = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] + masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'] = origin_sdpa masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask def create_causal_mask(config, diff --git a/src/twinkle/patch/gdn_padding_free.py b/src/twinkle/patch/gdn_padding_free.py index 9198daf3..38ef83be 100644 --- a/src/twinkle/patch/gdn_padding_free.py +++ b/src/twinkle/patch/gdn_padding_free.py @@ -1,6 +1,7 @@ import inspect import torch import transformers +from functools import lru_cache from packaging.version import Version from transformers.utils.import_utils import is_flash_linear_attention_available from typing import Optional @@ -36,10 +37,18 @@ def _get_flash_linear_attention_kernels(): return causal_conv1d, chunk_gated_delta_rule -def _call_with_supported_kwargs(fn, *args, **kwargs): +@lru_cache(maxsize=None) +def _supported_kwarg_names(fn): signature = inspect.signature(fn) - if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): - kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters} + if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): + return None + return frozenset(signature.parameters) + + +def _call_with_supported_kwargs(fn, *args, **kwargs): + supported_kwargs = _supported_kwarg_names(fn) + if supported_kwargs is not None: + kwargs = {key: value for key, value in kwargs.items() if key in supported_kwargs} return fn(*args, **kwargs) From 4d248a302302caa55d1df88f11946383c8660c2e Mon Sep 17 00:00:00 2001 From: qq_30035749 Date: Wed, 27 May 2026 14:55:43 +0800 Subject: [PATCH 7/7] refactor code --- .../strategy/sequence_parallel/__init__.py | 21 ++++++---------- src/twinkle/patch/gdn_padding_free.py | 24 ++++--------------- src/twinkle/utils/utils.py | 20 ++++++++++++++++ 3 files changed, 31 insertions(+), 34 deletions(-) diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py index 21b3836a..bb8419cb 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import inspect import math import torch import torch.distributed as dist @@ -13,6 +12,7 @@ from twinkle.patch import apply_patch from twinkle.utils import DeviceMesh from twinkle.utils.transformers_utils import get_llm_model +from twinkle.utils.utils import call_with_supported_kwargs, has_signature_parameter from .linear_attention_sp import Qwen3_5GatedDeltaNetUlyssesPatch from .utils import (DistributedAttention, GatherLoss, _derive_sequence_parallel_sizes, _get_seq_groups_from_device_mesh, _get_ulysses_size, _SeqAllToAll, get_config_attr, get_cu_seqlens_from_position_ids, is_hccl_backend, @@ -29,17 +29,10 @@ def is_qwen3_omni(model): return 'qwen3_omni' in mt -def _call_with_supported_kwargs(fn, *args, **kwargs): - signature = inspect.signature(fn) - if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): - kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters} - return fn(*args, **kwargs) - - def _call_create_causal_mask(fn, config, input_embeds, attention_mask, cache_position_or_past_key_values, *args, **kwargs): - if 'cache_position' in inspect.signature(fn).parameters: - return _call_with_supported_kwargs( + if has_signature_parameter(fn, 'cache_position'): + return call_with_supported_kwargs( fn, config, input_embeds, @@ -49,8 +42,8 @@ def _call_create_causal_mask(fn, config, input_embeds, attention_mask, cache_pos **kwargs, ) if cache_position_or_past_key_values is None and 'past_key_values' in kwargs: - return _call_with_supported_kwargs(fn, config, input_embeds, attention_mask, *args, **kwargs) - return _call_with_supported_kwargs( + return call_with_supported_kwargs(fn, config, input_embeds, attention_mask, *args, **kwargs) + return call_with_supported_kwargs( fn, config, input_embeds, @@ -111,7 +104,7 @@ def _prepare_flash_attn(self, base_model: torch.nn.Module): from transformers import masking_utils origin_sdpa = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] - origin_uses_cache_position = 'cache_position' in inspect.signature(origin_sdpa).parameters + origin_uses_cache_position = has_signature_parameter(origin_sdpa, 'cache_position') def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs): q_length = q_length if q_length is not None else kwargs.pop('cache_position', None) @@ -178,7 +171,7 @@ def create_causal_mask(config, (input_embeds.shape[0], input_embeds.shape[1] * self.sp_world_size, input_embeds.shape[2]), dtype=input_embeds.dtype, device=input_embeds.device) - if 'cache_position' in inspect.signature(masking_utils.origin_create_causal_mask).parameters: + if has_signature_parameter(masking_utils.origin_create_causal_mask, 'cache_position'): cache_position_or_past_key_values = torch.arange( 0, input_embeds.shape[1], diff --git a/src/twinkle/patch/gdn_padding_free.py b/src/twinkle/patch/gdn_padding_free.py index 38ef83be..759a222f 100644 --- a/src/twinkle/patch/gdn_padding_free.py +++ b/src/twinkle/patch/gdn_padding_free.py @@ -1,12 +1,11 @@ -import inspect import torch import transformers -from functools import lru_cache from packaging.version import Version from transformers.utils.import_utils import is_flash_linear_attention_available from typing import Optional from twinkle.patch import Patch +from twinkle.utils.utils import call_with_supported_kwargs def _is_qwen35_model(hf_config) -> bool: @@ -37,21 +36,6 @@ def _get_flash_linear_attention_kernels(): return causal_conv1d, chunk_gated_delta_rule -@lru_cache(maxsize=None) -def _supported_kwarg_names(fn): - signature = inspect.signature(fn) - if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): - return None - return frozenset(signature.parameters) - - -def _call_with_supported_kwargs(fn, *args, **kwargs): - supported_kwargs = _supported_kwarg_names(fn) - if supported_kwargs is not None: - kwargs = {key: value for key, value in kwargs.items() if key in supported_kwargs} - return fn(*args, **kwargs) - - def _needs_chunk_gated_delta_rule_cu_seqlens_patch() -> bool: return Version(transformers.__version__) < Version('5.9.0') @@ -89,7 +73,7 @@ def chunk_gated_delta_rule_wrapper(query, key, value, **kwargs): if patch_chunk_rule: mod.chunk_gated_delta_rule = chunk_gated_delta_rule_wrapper try: - return _call_with_supported_kwargs(origin_forward, mod, *forward_args, **forward_kwargs) + return call_with_supported_kwargs(origin_forward, mod, *forward_args, **forward_kwargs) finally: mod.causal_conv1d_fn = old_conv_fn if patch_chunk_rule: @@ -125,7 +109,7 @@ def decoder_forward( **extra_kwargs, ): if getattr(layer, 'layer_type', None) != 'linear_attention': - return _call_with_supported_kwargs( + return call_with_supported_kwargs( origin_decoder_forward, layer, hidden_states=hidden_states, @@ -175,7 +159,7 @@ def forward( **extra_kwargs, ): if cu_seq_lens_q is None: - return _call_with_supported_kwargs( + return call_with_supported_kwargs( origin_forward, mod, hidden_states, diff --git a/src/twinkle/utils/utils.py b/src/twinkle/utils/utils.py index ed87d974..40894a68 100644 --- a/src/twinkle/utils/utils.py +++ b/src/twinkle/utils/utils.py @@ -1,8 +1,10 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import fnmatch import glob +import inspect import os import shutil +from functools import lru_cache def deep_getattr(obj, attr: str, default=None): @@ -17,6 +19,24 @@ def deep_getattr(obj, attr: str, default=None): return obj +@lru_cache(maxsize=None) +def signature_info(fn): + signature = inspect.signature(fn) + accepts_kwargs = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()) + return accepts_kwargs, frozenset(signature.parameters) + + +def has_signature_parameter(fn, name: str) -> bool: + return name in signature_info(fn)[1] + + +def call_with_supported_kwargs(fn, *args, **kwargs): + accepts_kwargs, parameters = signature_info(fn) + if not accepts_kwargs: + kwargs = {key: value for key, value in kwargs.items() if key in parameters} + return fn(*args, **kwargs) + + def copy_files_by_pattern(source_dir, dest_dir, patterns, exclude_patterns=None): if not os.path.exists(dest_dir): os.makedirs(dest_dir)