Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla

if (is_fp8_dtype(ret.Atype)) {
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK(ret.lda % 16 == 0,
"Leading dimension requirement on A for FP8 GEMM. Caller must pad.");
// NVTE_CHECK(ret.lda % 16 == 0,
// "Leading dimension requirement on A for FP8 GEMM. Caller must pad.");
Comment on lines +157 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that the new grouped GEMM properly handles FP8 inputs. The leading dimension alignment requirement validation (lda % 16 == 0) has been commented out, which could cause correctness issues if unaligned inputs are passed.

}
} else if (nvfp4) {
// NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe.
Expand Down Expand Up @@ -245,8 +245,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla

if (is_fp8_dtype(ret.Atype)) {
// Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage
NVTE_CHECK(ret.ldb % 16 == 0,
"Leading dimension requirement on B for FP8 GEMM. Caller must pad.");
// NVTE_CHECK(ret.ldb % 16 == 0,
// "Leading dimension requirement on B for FP8 GEMM. Caller must pad.");
}
} else if (nvfp4) {
if (is_B_transposed) {
Expand Down
5 changes: 2 additions & 3 deletions transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,8 @@ __global__ void setup_grouped_gemm_kernel(
a_cols[idx] = static_cast<int>(a_first);
b_rows[idx] = static_cast<int>(b_last);
b_cols[idx] = static_cast<int>(b_first);
// For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N).
d_rows[idx] = static_cast<int>(d_first);
d_cols[idx] = static_cast<int>(d_last);
d_rows[idx] = static_cast<int>(d_last);
d_cols[idx] = static_cast<int>(d_first);

// Fill alpha/beta pointers (per-matrix)
alpha_ptrs[idx] = alpha_ptr + idx;
Expand Down
117 changes: 84 additions & 33 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,27 +583,27 @@ def lowering(
)

lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed)
lhs_contracting_size = (
reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:])
if lhs_transposed
else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary])
)
assert_cublas_requirements(
scaling_mode,
lhs_contracting_size,
"LHS",
)
rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed)
rhs_contracting_size = (
reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary])
if rhs_transposed
else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:])
)
assert_cublas_requirements(
scaling_mode,
rhs_contracting_size,
"RHS",
)
# lhs_contracting_size = (
# reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:])
# if lhs_transposed
# else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary])
# )
# assert_cublas_requirements(
# scaling_mode,
# lhs_contracting_size,
# f"LHS {lhs_aval.shape} with contracting dims {lhs_cdims}",
# )
# rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed)
# rhs_contracting_size = (
# reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary])
# if rhs_transposed
# else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:])
# )
# assert_cublas_requirements(
# scaling_mode,
# rhs_contracting_size,
# f"RHS {rhs_aval.shape} with contracting dims {rhs_cdims}",
# )

args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta)
kwargs = {
Expand Down Expand Up @@ -936,7 +936,15 @@ def _parse_operand_output_specs(

# Non-contracting dims of RHS always needs to be gathered along the FSDP axis
rhs_non_cspecs = tuple(
None if spec is not None and spec == gsr.fsdp_resource else spec
(
None
if spec is not None
and (
spec == gsr.fsdp_resource
or (isinstance(spec, tuple) and gsr.fsdp_resource in spec)
)
else spec
)
for spec in rhs_non_cspecs
)

Expand Down Expand Up @@ -1420,7 +1428,7 @@ class GroupedGemmPrimitive(BasePrimitive):

name = "te_grouped_gemm_ffi"
multiple_results = True
impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
impl_static_args = (10, 11, 12, 13, 14, 15, 16, 17, 18, 19)
inner_primitive = None
outer_primitive = None

Expand All @@ -1432,7 +1440,10 @@ def abstract(
rhs_scale_inv_aval,
bias_aval,
group_sizes_aval,
group_offset_aval,
group_offset_lhs_aval,
group_offset_out_aval,
alpha,
beta,
*,
M,
N,
Expand Down Expand Up @@ -1470,7 +1481,7 @@ def abstract(
Returns:
A jnp.ndarray containing the result of the grouped GEMM operation
"""
del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval
del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_out_aval
del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes
# TODO(Phuong): move some shape checks from Cpp to here
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
Expand All @@ -1492,11 +1503,16 @@ def abstract(
# We also pad scale_inv swizzle buffers size for 256 bytes alignment.
workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding
workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding

workspace_size += (
1024 * 1024
Comment on lines +1506 to +1508
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hardcoded 1MB workspace buffer allocation is a HACK as noted. The workspace size calculation should properly account for setup vs cublas workspace needs separately, or this could lead to buffer overruns or inefficient memory usage.

Reference gemm.cpp:669-673 where this workspace is split into setup (1MB) and cublas portions.

) # HACK: properly make a workspace_setup buffer in addition to the workspace_cublas buffer
workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)

out_shape = (M, N)
if is_grouped_dense_wgrad:
out_shape = (group_sizes_aval.size, M, N)
num_tensors = group_sizes_aval.size
out_shape = (num_tensors, M, N)
out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
return (out_aval, workspace_aval)

Expand Down Expand Up @@ -1543,7 +1559,10 @@ def impl(
rhs_scale_inv,
bias,
group_sizes,
group_offset,
group_offset_lhs,
group_offset_out,
alpha,
beta,
M,
N,
K,
Expand All @@ -1563,7 +1582,10 @@ def impl(
rhs_scale_inv,
bias,
group_sizes,
group_offset,
group_offset_lhs,
group_offset_out,
alpha,
beta,
M=M,
N=N,
K=K,
Expand Down Expand Up @@ -1929,8 +1951,11 @@ def grouped_gemm(
lhs: [M, K] or [K, N]
rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K]
"""
# TODO(Phuong): implement the group_offset
group_offset = group_offset or jnp.zeros((1,), jnp.int32)

assert group_offset is None, "group_offset is not yet implemented"
assert (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion error message contains an f-string but doesn't actually format anything since there's no f prefix on the outer string.

Suggested change
assert (
assert jax.config.jax_enable_x64, "Grouped GEMM currently requires jax_enable_x64 to be True for correct behavior"

jax.config.jax_enable_x64
), "Grouped GEMM currently requires jax_enable_x64 to be True for correct behavior"

# TODO(Phuong): implement the precision
del precision
Expand Down Expand Up @@ -2066,20 +2091,46 @@ def grouped_gemm(
else:
assert group_sizes.size == rhs_shape[0]

assert group_offset.size == 1

has_bias = bias is not None
assert not has_bias or bias.shape == (group_sizes.size, N)
bias = jnp.empty((), jnp.float32) if bias is None else bias

# TODO(jberchtold): move the int64 and offset computation to C++ side in a kernel to avoid needing JAX to support int64
Comment on lines 2097 to +2098
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Computing offsets in Python with JAX int64 is a workaround. Move this computation to C++ to avoid requiring jax_enable_x64 and reduce overhead.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

group_sizes = group_sizes.astype(jnp.int64)
# Compute group_offset as cumulative sum of group_sizes, starting with 0
group_offset = jnp.concatenate(
[jnp.array([0], dtype=jnp.int64), jnp.cumsum(group_sizes, dtype=jnp.int64)[:-1]]
)
if is_grouped_dense_wgrad:
group_offset_lhs = (
group_offset * M
) # Offset is by number of elements total, not number of rows
# HACK: this _out is really the rhs in this case
group_offset_out = (
group_offset * N
) # Offset is by number of elements total, not number of rows
else:
group_offset_lhs = (
group_offset * K_lhs
) # Offset is by number of elements total, not number of rows
group_offset_out = (
group_offset * N
) # Offset is by number of elements total, not number of rows

num_gemms = group_sizes.shape[0] # Due to interlaced zeros to support int64
alpha = jnp.ones((num_gemms,), jnp.float32)
beta = jnp.zeros((num_gemms,), jnp.float32)
(out,) = GroupedGemmPrimitive.outer_primitive.bind(
lhs_data,
lhs_scale_inv,
rhs_data,
rhs_scale_inv,
bias,
group_sizes,
group_offset,
group_offset_lhs,
group_offset_out,
alpha,
beta,
M=M,
N=N,
K=K_lhs,
Expand Down
10 changes: 6 additions & 4 deletions transformer_engine/jax/cpp_extensions/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from .base import BasePrimitive, register_primitive
from .misc import (
get_padded_spec,
check_valid_batch_dims,
te_dtype_to_jax_dtype,
jax_dtype_to_te_dtype,
multidim_transpose,
Expand Down Expand Up @@ -97,7 +96,9 @@ def abstract(
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_shape = x_aval.shape
assert scale_aval is None or scale_aval.dtype == jnp.float32
assert (
scale_aval is None or scale_aval.dtype == jnp.float32
), f"scale must be float32 but received {scale_aval}"
if stochastic_rounding:
assert ScalingMode(
scaling_mode
Expand Down Expand Up @@ -1213,7 +1214,7 @@ def grouped_quantize(
assert n_groups == len(
quantizer.quantizers
), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}"
scale = jnp.empty((n_groups,), jnp.float32)
scale = jnp.ones((n_groups,), jnp.float32)

if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
for i, quantizer_i in enumerate(quantizer.quantizers):
Expand Down Expand Up @@ -1249,7 +1250,8 @@ def grouped_quantize(
) = GroupedQuantizePrimitive.outer_primitive.bind(
x,
scale,
group_sizes,
# TODO(jberchtold): Remove this int32 cast once GMM does not require JAX int64 dtype
group_sizes.astype(jnp.int32),
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
q_layout=q_layout,
Expand Down
Loading
Loading