Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,25 @@ def _unflatten_heads(tensor, heads):
return tensor


def _reshape_data_for_flash(tensor, heads):
def _reshape_data_for_flash(tensor, heads, num_context_shards = 1):
"""
Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
blocks is divisible by the number of shards.
"""
if tensor.ndim != 4:
tensor = _unflatten_heads(tensor, heads)
return tensor

# Pad sequence dimension so it is evenly divisible by the context mesh axis,
# which shard_map requires.
if num_context_shards <= 1:
return tensor
rem = tensor.shape[2] % num_context_shards
if rem == 0:
return tensor
pad_width = [(0, 0)] * tensor.ndim
pad_width[2] = (0, num_context_shards - rem)
return jnp.pad(tensor, pad_width)


def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
Expand Down Expand Up @@ -255,9 +265,11 @@ def _tpu_flash_attention(
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
)
num_context_shards = mesh.shape["context"]
query = _reshape_data_for_flash(query, heads)
key = _reshape_data_for_flash(key, heads)
value = _reshape_data_for_flash(value, heads)
orig_q_seq_len = query.shape[1]
query = _reshape_data_for_flash(query, heads, num_context_shards)
key = _reshape_data_for_flash(key, heads, num_context_shards)
value = _reshape_data_for_flash(value, heads, num_context_shards)

q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)

Expand Down Expand Up @@ -401,6 +413,8 @@ def ring_scan_body(carry, _):
f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}"
)
x = wrap_flash_attention(query, key, value)
# Trim back to original sequence length after context-axis padding.
x = x[:, :, :orig_q_seq_len, :]
x = _reshape_heads_to_head_dim(x)

return x
Expand Down
Loading