Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,12 +624,58 @@ def convert_to_2d(offsets, batch, max_seqlen):
)
return output, softmax_aux, rng_state

# Flattened arg indices: 0=q, 1=k, 2=v, 3=bias, 4=softmax_offset, 5=seed,
# 6,7=seqlens, 8,9=seq_offsets, 10,11=segment_ids, 12,13=segment_pos.
_SEGMENT_IDS_BATCH_DIMS_IDX = (10, 11)
_SEGMENT_POS_BATCH_DIMS_IDX = (12, 13)

@staticmethod
def batcher(batched_args, batch_dims, *, config):
# batch_dims: tuple of length len(batched_args); each element is the axis index
# that is the batch axis (0, 1, ...) or None if that arg has no batch dim.
# check_valid_batch_dims: only 0 or None allowed (single leading batch or no batch).
check_valid_batch_dims(batch_dims)
assert FusedAttnFwdPrimitive.outer_primitive is not None
q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims

# Ensure segment_pos are batched like segment_ids so impl sees matching shapes.
# JAX may give segment_ids batch_dim=0 (i.e. batched) and segment_pos batch_dim=None (i.e. not batched) when
# segment_pos were generated inside a vmapped function (e.g. single or nested vmap).
batched_args_list = list(batched_args)
seg_id_bdim = batch_dims[FusedAttnFwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX[0]]
seg_pos_bdim = batch_dims[FusedAttnFwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX[0]]
# Expand the segment_pos batch dim to match the segment_ids batch dim, if no batch dim exists for segment_pos
if seg_id_bdim is not None and seg_pos_bdim is None and batched_args_list[10].size > 0:
# Pair (segment_ids idx, segment_pos idx): (10, 12) for q, (11, 13) for kv.
for seg_id_idx, seg_pos_idx in zip(
FusedAttnFwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX,
FusedAttnFwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX,
):
segment_ids = batched_args_list[seg_id_idx]
segment_pos = batched_args_list[seg_pos_idx]
assert segment_ids.ndim > segment_pos.ndim, (
"segment_ids must have more dims than segment_pos when adding batch dims; "
f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}"
)
assert segment_ids.shape[segment_pos.ndim :] == segment_pos.shape, (
"segment_pos must have same trailing shape as segment_ids when adding batch"
f" dims; got segment_ids.shape={segment_ids.shape},"
f" segment_pos.shape={segment_pos.shape}"
)
# Expand the segment_pos by as many batch dims as the segment_ids has
leading_bdim = segment_ids.ndim - segment_pos.ndim
target_shape = segment_ids.shape[:leading_bdim] + segment_pos.shape
expanded = segment_pos
for _ in range(leading_bdim):
expanded = lax.expand_dims(expanded, (0,))
batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape)
# Update the batch_dims to use 0 instead of None for segment_pos batch dims
batch_dims = tuple(
0 if i in FusedAttnFwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX else b
for i, b in enumerate(batch_dims)
)
batched_args = tuple(batched_args_list)

out_bdims = q_bdim, q_bdim, seed_bdim
return (
FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config),
Expand Down Expand Up @@ -1079,12 +1125,55 @@ def convert_to_2d(offsets, batch, max_seqlen):
)
return dq, dk, dv, dbias, dsoftmax_offset

# Flattened arg indices: 0=q, 1=k, 2=v, 3=bias, 4=softmax_offset, 5=softmax_aux,
# 6=rng_state, 7=output, 8=doutput, 9,10=seqlens, 11,12=seq_offsets,
# 13,14=segment_ids, 15,16=segment_pos.
_SEGMENT_IDS_BATCH_DIMS_IDX = (13, 14)
_SEGMENT_POS_BATCH_DIMS_IDX = (15, 16)

@staticmethod
def batcher(batched_args, batch_dims, *, config):
check_valid_batch_dims(batch_dims)
assert FusedAttnBwdPrimitive.outer_primitive is not None
q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims

# Ensure segment_pos are batched like segment_ids so impl sees matching shapes.
# JAX may give segment_ids batch_dim=0 (i.e. batched) and segment_pos batch_dim=None (i.e. not batched) when
# segment_pos were generated inside a vmapped function (e.g. single or nested vmap).
batched_args_list = list(batched_args)
seg_id_bdim = batch_dims[FusedAttnBwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX[0]]
seg_pos_bdim = batch_dims[FusedAttnBwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX[0]]
# Expand the segment_pos batch dim to match the segment_ids batch dim, if no batch dim exists for segment_pos
if seg_id_bdim is not None and seg_pos_bdim is None and batched_args_list[13].size > 0:
for seg_id_idx, seg_pos_idx in zip(
FusedAttnBwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX,
FusedAttnBwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX,
):
segment_ids = batched_args_list[seg_id_idx]
segment_pos = batched_args_list[seg_pos_idx]
assert segment_ids.ndim > segment_pos.ndim, (
"segment_ids must have more dims than segment_pos when adding batch dims; "
f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}"
)
assert segment_ids.shape[segment_pos.ndim :] == segment_pos.shape, (
"segment_pos must have same trailing shape as segment_ids when adding batch"
f" dims; got segment_ids.shape={segment_ids.shape},"
f" segment_pos.shape={segment_pos.shape}"
)
leading_bdim = segment_ids.ndim - segment_pos.ndim
target_shape = segment_ids.shape[:leading_bdim] + segment_pos.shape
# Expand the segment_pos batch dim to match the segment_ids batch dim, if no batch dim exists for segment_pos
expanded = segment_pos
for _ in range(leading_bdim):
expanded = lax.expand_dims(expanded, (0,))
batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape)
# Update the batch_dims to use 0 instead of None for segment_pos batch dims
batch_dims = tuple(
0 if i in FusedAttnBwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX else b
for i, b in enumerate(batch_dims)
)
batched_args = tuple(batched_args_list)

out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim
return (
FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config),
Expand Down
Loading