diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index e5d75e1501..7a6f8bddc2 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -624,12 +624,58 @@ def convert_to_2d(offsets, batch, max_seqlen): ) return output, softmax_aux, rng_state + # Flattened arg indices: 0=q, 1=k, 2=v, 3=bias, 4=softmax_offset, 5=seed, + # 6,7=seqlens, 8,9=seq_offsets, 10,11=segment_ids, 12,13=segment_pos. + _SEGMENT_IDS_BATCH_DIMS_IDX = (10, 11) + _SEGMENT_POS_BATCH_DIMS_IDX = (12, 13) + @staticmethod def batcher(batched_args, batch_dims, *, config): + # batch_dims: tuple of length len(batched_args); each element is the axis index + # that is the batch axis (0, 1, ...) or None if that arg has no batch dim. + # check_valid_batch_dims: only 0 or None allowed (single leading batch or no batch). check_valid_batch_dims(batch_dims) assert FusedAttnFwdPrimitive.outer_primitive is not None q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims + # Ensure segment_pos are batched like segment_ids so impl sees matching shapes. + # JAX may give segment_ids batch_dim=0 (i.e. batched) and segment_pos batch_dim=None (i.e. not batched) when + # segment_pos were generated inside a vmapped function (e.g. single or nested vmap). + batched_args_list = list(batched_args) + seg_id_bdim = batch_dims[FusedAttnFwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX[0]] + seg_pos_bdim = batch_dims[FusedAttnFwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX[0]] + # Expand the segment_pos batch dim to match the segment_ids batch dim, if no batch dim exists for segment_pos + if seg_id_bdim is not None and seg_pos_bdim is None and batched_args_list[10].size > 0: + # Pair (segment_ids idx, segment_pos idx): (10, 12) for q, (11, 13) for kv. + for seg_id_idx, seg_pos_idx in zip( + FusedAttnFwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX, + FusedAttnFwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX, + ): + segment_ids = batched_args_list[seg_id_idx] + segment_pos = batched_args_list[seg_pos_idx] + assert segment_ids.ndim > segment_pos.ndim, ( + "segment_ids must have more dims than segment_pos when adding batch dims; " + f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}" + ) + assert segment_ids.shape[segment_pos.ndim :] == segment_pos.shape, ( + "segment_pos must have same trailing shape as segment_ids when adding batch" + f" dims; got segment_ids.shape={segment_ids.shape}," + f" segment_pos.shape={segment_pos.shape}" + ) + # Expand the segment_pos by as many batch dims as the segment_ids has + leading_bdim = segment_ids.ndim - segment_pos.ndim + target_shape = segment_ids.shape[:leading_bdim] + segment_pos.shape + expanded = segment_pos + for _ in range(leading_bdim): + expanded = lax.expand_dims(expanded, (0,)) + batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape) + # Update the batch_dims to use 0 instead of None for segment_pos batch dims + batch_dims = tuple( + 0 if i in FusedAttnFwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX else b + for i, b in enumerate(batch_dims) + ) + batched_args = tuple(batched_args_list) + out_bdims = q_bdim, q_bdim, seed_bdim return ( FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config), @@ -1079,12 +1125,55 @@ def convert_to_2d(offsets, batch, max_seqlen): ) return dq, dk, dv, dbias, dsoftmax_offset + # Flattened arg indices: 0=q, 1=k, 2=v, 3=bias, 4=softmax_offset, 5=softmax_aux, + # 6=rng_state, 7=output, 8=doutput, 9,10=seqlens, 11,12=seq_offsets, + # 13,14=segment_ids, 15,16=segment_pos. + _SEGMENT_IDS_BATCH_DIMS_IDX = (13, 14) + _SEGMENT_POS_BATCH_DIMS_IDX = (15, 16) + @staticmethod def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) assert FusedAttnBwdPrimitive.outer_primitive is not None q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims + # Ensure segment_pos are batched like segment_ids so impl sees matching shapes. + # JAX may give segment_ids batch_dim=0 (i.e. batched) and segment_pos batch_dim=None (i.e. not batched) when + # segment_pos were generated inside a vmapped function (e.g. single or nested vmap). + batched_args_list = list(batched_args) + seg_id_bdim = batch_dims[FusedAttnBwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX[0]] + seg_pos_bdim = batch_dims[FusedAttnBwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX[0]] + # Expand the segment_pos batch dim to match the segment_ids batch dim, if no batch dim exists for segment_pos + if seg_id_bdim is not None and seg_pos_bdim is None and batched_args_list[13].size > 0: + for seg_id_idx, seg_pos_idx in zip( + FusedAttnBwdPrimitive._SEGMENT_IDS_BATCH_DIMS_IDX, + FusedAttnBwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX, + ): + segment_ids = batched_args_list[seg_id_idx] + segment_pos = batched_args_list[seg_pos_idx] + assert segment_ids.ndim > segment_pos.ndim, ( + "segment_ids must have more dims than segment_pos when adding batch dims; " + f"got segment_ids.ndim={segment_ids.ndim}, segment_pos.ndim={segment_pos.ndim}" + ) + assert segment_ids.shape[segment_pos.ndim :] == segment_pos.shape, ( + "segment_pos must have same trailing shape as segment_ids when adding batch" + f" dims; got segment_ids.shape={segment_ids.shape}," + f" segment_pos.shape={segment_pos.shape}" + ) + leading_bdim = segment_ids.ndim - segment_pos.ndim + target_shape = segment_ids.shape[:leading_bdim] + segment_pos.shape + # Expand the segment_pos batch dim to match the segment_ids batch dim, if no batch dim exists for segment_pos + expanded = segment_pos + for _ in range(leading_bdim): + expanded = lax.expand_dims(expanded, (0,)) + batched_args_list[seg_pos_idx] = jnp.broadcast_to(expanded, target_shape) + # Update the batch_dims to use 0 instead of None for segment_pos batch dims + batch_dims = tuple( + 0 if i in FusedAttnBwdPrimitive._SEGMENT_POS_BATCH_DIMS_IDX else b + for i, b in enumerate(batch_dims) + ) + batched_args = tuple(batched_args_list) + out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim return ( FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config),