Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 109 additions & 46 deletions src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import inspect
import math
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -28,6 +29,38 @@ def is_qwen3_omni(model):
return 'qwen3_omni' in mt


def _call_with_supported_kwargs(fn, *args, **kwargs):
signature = inspect.signature(fn)
if not any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters}
return fn(*args, **kwargs)
Comment thread
meichangsu1 marked this conversation as resolved.


def _call_create_causal_mask(fn, config, input_embeds, attention_mask, cache_position_or_past_key_values, *args,
**kwargs):
if 'cache_position' in inspect.signature(fn).parameters:
return _call_with_supported_kwargs(
fn,
config,
input_embeds,
attention_mask,
cache_position_or_past_key_values,
*args,
**kwargs,
)
if cache_position_or_past_key_values is None and 'past_key_values' in kwargs:
return _call_with_supported_kwargs(fn, config, input_embeds, attention_mask, *args, **kwargs)
return _call_with_supported_kwargs(
fn,
config,
input_embeds,
attention_mask,
cache_position_or_past_key_values,
*args,
**kwargs,
)


# main content copied from ms-swift
class SequenceParallel:

Expand Down Expand Up @@ -77,59 +110,89 @@ def _prepare_flash_attn(self, base_model: torch.nn.Module):
try:
from transformers import masking_utils

_origin_flash_attention_mask = masking_utils.flash_attention_mask

# Patch attention masks for SP: avoid masking when full sequence is reconstructed.
def flash_attention_mask(batch_size,
cache_position,
kv_length,
kv_offset=0,
mask_function=masking_utils.causal_mask_function,
attention_mask=None,
**kwargs):
if self.world_size == 1:
return _origin_flash_attention_mask(batch_size, cache_position, kv_length, kv_offset, mask_function,
attention_mask, **kwargs)
if attention_mask is not None:
if attention_mask.all():
attention_mask = None

return attention_mask

masking_utils.flash_attention_mask = flash_attention_mask
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['flash_attention_2'] = flash_attention_mask

def sdpa_mask(batch_size, cache_position, kv_length, *args, **kwargs):
if self.world_size == 1:
return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size,
cache_position,
kv_length, *args,
**kwargs)
device = cache_position.device
cache_position = self.real_position_ids[0]
cache_position = self.pad(cache_position, padding_value=-1, position_ids=self.real_position_ids, dim=0)
cache_position = torch.arange(0, cache_position.shape[0], device=device)
kv_length = cache_position.shape[0]
return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size,
cache_position,
kv_length, *args,
**kwargs)

masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[
'sdpa_origin'] = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa']
origin_sdpa = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa']
origin_uses_cache_position = 'cache_position' in inspect.signature(origin_sdpa).parameters

def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs):
q_length = q_length if q_length is not None else kwargs.pop('cache_position', None)
device = q_length.device if torch.is_tensor(q_length) else kwargs.pop('device', None)
if device is None:
device = self.real_position_ids.device

cache_position = None
if self.world_size > 1:
padded_position_ids = self.pad(
self.real_position_ids[0],
padding_value=-1,
position_ids=self.real_position_ids,
dim=0,
)
global_length = padded_position_ids.shape[0]
if origin_uses_cache_position:
cache_position = torch.arange(0, global_length, device=device)
kv_length = global_length
else:
# Newer Transformers passes q_length/q_offset instead of cache_position. In SP training,
# create_causal_mask may still see the local shard length, so restore the global query length
# only for the no-cache prefill path; cache/sliding paths keep their upstream offsets.
q_offset = kwargs.get('q_offset', 0)
kv_offset = kwargs.get('kv_offset', 0)
no_cache_offsets = ((not torch.is_tensor(q_offset) and q_offset == 0)
and (not torch.is_tensor(kv_offset) and kv_offset == 0))
if no_cache_offsets:
q_length = global_length
attention_mask = kwargs.get('attention_mask')
if attention_mask is not None and torch.is_tensor(attention_mask):
kv_length = attention_mask.shape[-1]
else:
kv_length = global_length

if origin_uses_cache_position:
if cache_position is None:
cache_position = q_length if torch.is_tensor(q_length) else torch.arange(
q_length, device=device)
return origin_sdpa(batch_size, cache_position, kv_length, *args, **kwargs)

return origin_sdpa(batch_size, q_length, kv_length, *args, device=device, **kwargs)
Comment thread
meichangsu1 marked this conversation as resolved.

masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'] = origin_sdpa
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask

def create_causal_mask(config, input_embeds, attention_mask, cache_position, *args, **kwargs):
def create_causal_mask(config,
input_embeds,
attention_mask,
cache_position_or_past_key_values=None,
*args,
**kwargs):
if self.world_size == 1:
return masking_utils.origin_create_causal_mask(config, input_embeds, attention_mask, cache_position,
*args, **kwargs)
return _call_create_causal_mask(
masking_utils.origin_create_causal_mask,
config,
input_embeds,
attention_mask,
cache_position_or_past_key_values,
*args,
**kwargs,
)
input_embeds = torch.ones(
(input_embeds.shape[0], input_embeds.shape[1] * self.sp_world_size, input_embeds.shape[2]),
dtype=input_embeds.dtype,
device=input_embeds.device)
cache_position = torch.arange(0, input_embeds.shape[1], device=input_embeds.device)
return masking_utils.origin_create_causal_mask(config, input_embeds, attention_mask, cache_position,
*args, **kwargs)
if 'cache_position' in inspect.signature(masking_utils.origin_create_causal_mask).parameters:
cache_position_or_past_key_values = torch.arange(
0,
input_embeds.shape[1],
device=input_embeds.device,
)
Comment on lines +181 to +186
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the optimization suggested for sdpa_mask, the signature of masking_utils.origin_create_causal_mask should be inspected once outside the create_causal_mask function to avoid overhead in the forward pass.

return _call_create_causal_mask(
masking_utils.origin_create_causal_mask,
config,
input_embeds,
attention_mask,
cache_position_or_past_key_values,
*args,
**kwargs,
)

masking_utils.origin_create_causal_mask = masking_utils.create_causal_mask
masking_utils.create_causal_mask = create_causal_mask
Expand Down
41 changes: 36 additions & 5 deletions src/twinkle/patch/gdn_padding_free.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import inspect
import torch
import transformers
from functools import lru_cache
from packaging.version import Version
from transformers.utils.import_utils import is_flash_linear_attention_available
from typing import Optional

Expand Down Expand Up @@ -33,10 +37,30 @@ def _get_flash_linear_attention_kernels():
return causal_conv1d, chunk_gated_delta_rule


@lru_cache(maxsize=None)
def _supported_kwarg_names(fn):
signature = inspect.signature(fn)
if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
return None
return frozenset(signature.parameters)


def _call_with_supported_kwargs(fn, *args, **kwargs):
supported_kwargs = _supported_kwarg_names(fn)
if supported_kwargs is not None:
kwargs = {key: value for key, value in kwargs.items() if key in supported_kwargs}
return fn(*args, **kwargs)


def _needs_chunk_gated_delta_rule_cu_seqlens_patch() -> bool:
return Version(transformers.__version__) < Version('5.9.0')


def _patch_gdn_kernels_for_cu_seqlens(
mod: torch.nn.Module,
*,
cu_seqlens: torch.Tensor,
patch_chunk_rule: bool,
origin_forward,
forward_args,
forward_kwargs,
Expand All @@ -62,12 +86,14 @@ def chunk_gated_delta_rule_wrapper(query, key, value, **kwargs):
return chunk_gated_delta_rule(query, key, value, **kwargs)

mod.causal_conv1d_fn = causal_conv1d_wrapper
mod.chunk_gated_delta_rule = chunk_gated_delta_rule_wrapper
if patch_chunk_rule:
mod.chunk_gated_delta_rule = chunk_gated_delta_rule_wrapper
try:
return origin_forward(mod, *forward_args, **forward_kwargs)
return _call_with_supported_kwargs(origin_forward, mod, *forward_args, **forward_kwargs)
finally:
mod.causal_conv1d_fn = old_conv_fn
mod.chunk_gated_delta_rule = old_chunk_rule
if patch_chunk_rule:
mod.chunk_gated_delta_rule = old_chunk_rule


class GatedDeltaNetPaddingFreePatch(Patch):
Expand Down Expand Up @@ -99,7 +125,8 @@ def decoder_forward(
**extra_kwargs,
):
if getattr(layer, 'layer_type', None) != 'linear_attention':
return origin_decoder_forward(
return _call_with_supported_kwargs(
origin_decoder_forward,
layer,
hidden_states=hidden_states,
position_embeddings=position_embeddings,
Expand Down Expand Up @@ -136,6 +163,7 @@ def decoder_forward(

if not getattr(Qwen3_5GatedDeltaNet, '_twinkle_padding_free_gdn_patched', False):
origin_forward = Qwen3_5GatedDeltaNet.forward
patch_chunk_rule = _needs_chunk_gated_delta_rule_cu_seqlens_patch()

def forward(
mod,
Expand All @@ -147,7 +175,8 @@ def forward(
**extra_kwargs,
):
if cu_seq_lens_q is None:
return origin_forward(
return _call_with_supported_kwargs(
origin_forward,
mod,
hidden_states,
cache_params=cache_params,
Expand All @@ -158,12 +187,14 @@ def forward(
return _patch_gdn_kernels_for_cu_seqlens(
mod,
cu_seqlens=cu_seq_lens_q,
patch_chunk_rule=patch_chunk_rule,
origin_forward=origin_forward,
forward_args=(hidden_states, ),
forward_kwargs={
'cache_params': cache_params,
'cache_position': cache_position,
'attention_mask': attention_mask,
'cu_seq_lens_q': cu_seq_lens_q,
**extra_kwargs,
},
)
Expand Down
Loading