diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index cc1d9ea1..d246cee4 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -128,7 +128,7 @@ def _unflatten_heads(tensor, heads): return tensor -def _reshape_data_for_flash(tensor, heads): +def _reshape_data_for_flash(tensor, heads, num_context_shards = 1): """ Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of @@ -136,7 +136,17 @@ def _reshape_data_for_flash(tensor, heads): """ if tensor.ndim != 4: tensor = _unflatten_heads(tensor, heads) - return tensor + + # Pad sequence dimension so it is evenly divisible by the context mesh axis, + # which shard_map requires. + if num_context_shards <= 1: + return tensor + rem = tensor.shape[2] % num_context_shards + if rem == 0: + return tensor + pad_width = [(0, 0)] * tensor.ndim + pad_width[2] = (0, num_context_shards - rem) + return jnp.pad(tensor, pad_width) def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1): @@ -255,9 +265,11 @@ def _tpu_flash_attention( use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False, ) num_context_shards = mesh.shape["context"] - query = _reshape_data_for_flash(query, heads) - key = _reshape_data_for_flash(key, heads) - value = _reshape_data_for_flash(value, heads) + orig_q_seq_len = query.shape[1] + query = _reshape_data_for_flash(query, heads, num_context_shards) + key = _reshape_data_for_flash(key, heads, num_context_shards) + value = _reshape_data_for_flash(value, heads, num_context_shards) + q_axis_names = nn.logical_to_mesh_axes(axis_names_q) kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) @@ -401,6 +413,8 @@ def ring_scan_body(carry, _): f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}" ) x = wrap_flash_attention(query, key, value) + # Trim back to original sequence length after context-axis padding. + x = x[:, :, :orig_q_seq_len, :] x = _reshape_heads_to_head_dim(x) return x