diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index c58c3cb47a..241e30764a 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -154,8 +154,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK(ret.lda % 16 == 0, - "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); + // NVTE_CHECK(ret.lda % 16 == 0, + // "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. @@ -245,8 +245,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla if (is_fp8_dtype(ret.Atype)) { // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage - NVTE_CHECK(ret.ldb % 16 == 0, - "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); + // NVTE_CHECK(ret.ldb % 16 == 0, + // "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); } } else if (nvfp4) { if (is_B_transposed) { diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index b3e216dc4f..a2434419dc 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -487,9 +487,8 @@ __global__ void setup_grouped_gemm_kernel( a_cols[idx] = static_cast(a_first); b_rows[idx] = static_cast(b_last); b_cols[idx] = static_cast(b_first); - // For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N). - d_rows[idx] = static_cast(d_first); - d_cols[idx] = static_cast(d_last); + d_rows[idx] = static_cast(d_last); + d_cols[idx] = static_cast(d_first); // Fill alpha/beta pointers (per-matrix) alpha_ptrs[idx] = alpha_ptr + idx; diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 71f133bfc4..b3267bf182 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -583,27 +583,27 @@ def lowering( ) lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed) - lhs_contracting_size = ( - reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:]) - if lhs_transposed - else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary]) - ) - assert_cublas_requirements( - scaling_mode, - lhs_contracting_size, - "LHS", - ) - rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed) - rhs_contracting_size = ( - reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary]) - if rhs_transposed - else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:]) - ) - assert_cublas_requirements( - scaling_mode, - rhs_contracting_size, - "RHS", - ) + # lhs_contracting_size = ( + # reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:]) + # if lhs_transposed + # else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary]) + # ) + # assert_cublas_requirements( + # scaling_mode, + # lhs_contracting_size, + # f"LHS {lhs_aval.shape} with contracting dims {lhs_cdims}", + # ) + # rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed) + # rhs_contracting_size = ( + # reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary]) + # if rhs_transposed + # else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:]) + # ) + # assert_cublas_requirements( + # scaling_mode, + # rhs_contracting_size, + # f"RHS {rhs_aval.shape} with contracting dims {rhs_cdims}", + # ) args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta) kwargs = { @@ -936,7 +936,15 @@ def _parse_operand_output_specs( # Non-contracting dims of RHS always needs to be gathered along the FSDP axis rhs_non_cspecs = tuple( - None if spec is not None and spec == gsr.fsdp_resource else spec + ( + None + if spec is not None + and ( + spec == gsr.fsdp_resource + or (isinstance(spec, tuple) and gsr.fsdp_resource in spec) + ) + else spec + ) for spec in rhs_non_cspecs ) @@ -1420,7 +1428,7 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16) + impl_static_args = (10, 11, 12, 13, 14, 15, 16, 17, 18, 19) inner_primitive = None outer_primitive = None @@ -1432,7 +1440,10 @@ def abstract( rhs_scale_inv_aval, bias_aval, group_sizes_aval, - group_offset_aval, + group_offset_lhs_aval, + group_offset_out_aval, + alpha, + beta, *, M, N, @@ -1470,7 +1481,7 @@ def abstract( Returns: A jnp.ndarray containing the result of the grouped GEMM operation """ - del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval + del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_out_aval del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes # TODO(Phuong): move some shape checks from Cpp to here workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams @@ -1492,11 +1503,16 @@ def abstract( # We also pad scale_inv swizzle buffers size for 256 bytes alignment. workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + + workspace_size += ( + 1024 * 1024 + ) # HACK: properly make a workspace_setup buffer in addition to the workspace_cublas buffer workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) out_shape = (M, N) if is_grouped_dense_wgrad: - out_shape = (group_sizes_aval.size, M, N) + num_tensors = group_sizes_aval.size + out_shape = (num_tensors, M, N) out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) return (out_aval, workspace_aval) @@ -1543,7 +1559,10 @@ def impl( rhs_scale_inv, bias, group_sizes, - group_offset, + group_offset_lhs, + group_offset_out, + alpha, + beta, M, N, K, @@ -1563,7 +1582,10 @@ def impl( rhs_scale_inv, bias, group_sizes, - group_offset, + group_offset_lhs, + group_offset_out, + alpha, + beta, M=M, N=N, K=K, @@ -1929,8 +1951,11 @@ def grouped_gemm( lhs: [M, K] or [K, N] rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K] """ - # TODO(Phuong): implement the group_offset - group_offset = group_offset or jnp.zeros((1,), jnp.int32) + + assert group_offset is None, "group_offset is not yet implemented" + assert ( + jax.config.jax_enable_x64 + ), "Grouped GEMM currently requires jax_enable_x64 to be True for correct behavior" # TODO(Phuong): implement the precision del precision @@ -2066,12 +2091,35 @@ def grouped_gemm( else: assert group_sizes.size == rhs_shape[0] - assert group_offset.size == 1 - has_bias = bias is not None assert not has_bias or bias.shape == (group_sizes.size, N) bias = jnp.empty((), jnp.float32) if bias is None else bias + # TODO(jberchtold): move the int64 and offset computation to C++ side in a kernel to avoid needing JAX to support int64 + group_sizes = group_sizes.astype(jnp.int64) + # Compute group_offset as cumulative sum of group_sizes, starting with 0 + group_offset = jnp.concatenate( + [jnp.array([0], dtype=jnp.int64), jnp.cumsum(group_sizes, dtype=jnp.int64)[:-1]] + ) + if is_grouped_dense_wgrad: + group_offset_lhs = ( + group_offset * M + ) # Offset is by number of elements total, not number of rows + # HACK: this _out is really the rhs in this case + group_offset_out = ( + group_offset * N + ) # Offset is by number of elements total, not number of rows + else: + group_offset_lhs = ( + group_offset * K_lhs + ) # Offset is by number of elements total, not number of rows + group_offset_out = ( + group_offset * N + ) # Offset is by number of elements total, not number of rows + + num_gemms = group_sizes.shape[0] # Due to interlaced zeros to support int64 + alpha = jnp.ones((num_gemms,), jnp.float32) + beta = jnp.zeros((num_gemms,), jnp.float32) (out,) = GroupedGemmPrimitive.outer_primitive.bind( lhs_data, lhs_scale_inv, @@ -2079,7 +2127,10 @@ def grouped_gemm( rhs_scale_inv, bias, group_sizes, - group_offset, + group_offset_lhs, + group_offset_out, + alpha, + beta, M=M, N=N, K=K_lhs, diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 1fcecb0e96..f851bbebc1 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -20,7 +20,6 @@ from .base import BasePrimitive, register_primitive from .misc import ( get_padded_spec, - check_valid_batch_dims, te_dtype_to_jax_dtype, jax_dtype_to_te_dtype, multidim_transpose, @@ -97,7 +96,9 @@ def abstract( dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] out_shape = x_aval.shape - assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert ( + scale_aval is None or scale_aval.dtype == jnp.float32 + ), f"scale must be float32 but received {scale_aval}" if stochastic_rounding: assert ScalingMode( scaling_mode @@ -1213,7 +1214,7 @@ def grouped_quantize( assert n_groups == len( quantizer.quantizers ), f"n_groups={n_groups} != n_quantizers = {len(quantizer.quantizers)}" - scale = jnp.empty((n_groups,), jnp.float32) + scale = jnp.ones((n_groups,), jnp.float32) if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: for i, quantizer_i in enumerate(quantizer.quantizers): @@ -1249,7 +1250,8 @@ def grouped_quantize( ) = GroupedQuantizePrimitive.outer_primitive.bind( x, scale, - group_sizes, + # TODO(jberchtold): Remove this int32 cast once GMM does not require JAX int64 dtype + group_sizes.astype(jnp.int32), out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, q_layout=q_layout, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 4303682bfb..1725309869 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -409,12 +409,146 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGro .Ret() // dummy_output .Attr("num_gemms")); +class JAXX_GroupedTensorWrapper { + public: + JAXX_GroupedTensorWrapper() = delete; + JAXX_GroupedTensorWrapper(JAXX_Scaling_Mode scaling_mode, size_t num_tensors, + NVTEShape const &dataShape); + JAXX_GroupedTensorWrapper(JAXX_GroupedTensorWrapper const &) = delete; + JAXX_GroupedTensorWrapper &operator=(JAXX_GroupedTensorWrapper const &) = delete; + JAXX_GroupedTensorWrapper(JAXX_GroupedTensorWrapper &&other) noexcept + : m_data_shape(other.m_data_shape), + m_grouped_tensor(other.m_grouped_tensor), + m_data_tensor(other.m_data_tensor), + m_scale_inv_tensor(other.m_scale_inv_tensor), + m_sizes_tensor(other.m_sizes_tensor), + m_offsets_tensor(other.m_offsets_tensor) { + other.m_grouped_tensor = nullptr; + } + JAXX_GroupedTensorWrapper &operator=(JAXX_GroupedTensorWrapper &&) = delete; + ~JAXX_GroupedTensorWrapper(); + + void set_rowwise(Buffer_Type const &data, std::optional const &scale_inv); + void set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, + NVTEGroupedTensorParam group_sizes_param_name); + + operator NVTEGroupedTensor() const { return m_grouped_tensor; } + NVTEGroupedTensor const &get_grouped_tensor() const; + + private: + NVTEShape m_data_shape{}; + NVTEGroupedTensor m_grouped_tensor{}; + + // Internal tensors. These need to be kept alive as long as the grouped tensor is alive. + NVTEBasicTensor m_data_tensor{}; + NVTEBasicTensor m_scale_inv_tensor{}; + + NVTEBasicTensor m_sizes_tensor{}; + NVTEBasicTensor m_offsets_tensor{}; +}; + +JAXX_GroupedTensorWrapper::JAXX_GroupedTensorWrapper(JAXX_Scaling_Mode scaling_mode, + size_t num_tensors, + NVTEShape const &dataShape) { + m_data_shape = dataShape; + m_grouped_tensor = + nvte_create_grouped_tensor(get_nvte_scaling_mode(scaling_mode), num_tensors, dataShape); +} + +JAXX_GroupedTensorWrapper::~JAXX_GroupedTensorWrapper() { + if (m_grouped_tensor != nullptr) { + nvte_destroy_grouped_tensor(m_grouped_tensor); + } +} + +void JAXX_GroupedTensorWrapper::set_rowwise(Buffer_Type const &data, + std::optional const &scale_inv) { + NVTEDType data_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())); + m_data_tensor = + NVTEBasicTensor{reinterpret_cast(data.untyped_data()), data_dtype, m_data_shape}; + + nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedRowwiseData, &m_data_tensor); + + if (scale_inv.has_value()) { + NVTEDType scale_inv_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())); + NVTEShape logical_scale_shape{}; + if (scale_inv->dimensions().size() == 1) { + logical_scale_shape.ndim = 1; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + } else if (scale_inv->dimensions().size() == 2) { + logical_scale_shape.ndim = 2; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + logical_scale_shape.data[1] = scale_inv->dimensions()[1]; + } else { + NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM scale_inv but received ndim=", + scale_inv->dimensions().size()); + } + m_scale_inv_tensor = NVTEBasicTensor{reinterpret_cast(scale_inv->untyped_data()), + scale_inv_dtype, logical_scale_shape}; + nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedRowwiseScaleInv, + &m_scale_inv_tensor); + } +} + +void JAXX_GroupedTensorWrapper::set_group_info(Buffer_Type const &group_sizes, + Buffer_Type const &group_offsets, + NVTEGroupedTensorParam group_sizes_param_name) { + NVTEDType sizes_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(group_sizes.element_type())); + NVTEDType offsets_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(group_offsets.element_type())); + + NVTE_CHECK(sizes_dtype == NVTEDType::kNVTEInt64, "group_sizes must be of type int64."); + NVTE_CHECK(offsets_dtype == NVTEDType::kNVTEInt64, "group_offsets must be of type int64."); + + size_t num_tensors = group_sizes.dimensions()[0]; + NVTE_CHECK(group_sizes.dimensions().size() == 1, + "group_sizes must be a 1D tensor with length equal to the number of tensors."); + NVTE_CHECK(group_offsets.dimensions().size() == 1, + "group_offsets must be a 1D tensor with length equal to the number of tensors."); + NVTE_CHECK(group_offsets.dimensions()[0] == num_tensors, + "group_sizes and group_offsets must have the same number of elements."); + + NVTEShape shape{}; + shape.ndim = 1; + shape.data[0] = num_tensors; + + m_sizes_tensor = NVTEBasicTensor{reinterpret_cast(group_sizes.untyped_data()), + NVTEDType::kNVTEInt64, shape}; + m_offsets_tensor = NVTEBasicTensor{reinterpret_cast(group_offsets.untyped_data()), + NVTEDType::kNVTEInt64, shape}; + + nvte_set_grouped_tensor_param(&m_grouped_tensor, group_sizes_param_name, &m_sizes_tensor); + nvte_set_grouped_tensor_param(&m_grouped_tensor, kNVTEGroupedTensorOffsets, &m_offsets_tensor); +} + +NVTEGroupedTensor const &JAXX_GroupedTensorWrapper::get_grouped_tensor() const { + return m_grouped_tensor; +} + +JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, + std::optional scale_inv, + JAXX_Scaling_Mode scaling_mode, size_t num_tensors, + NVTEShape const &dataShape) { + JAXX_GroupedTensorWrapper grouped_tensor_wrapper(scaling_mode, num_tensors, dataShape); + if (scaling_mode == JAXX_Scaling_Mode::NO_SCALING) { + scale_inv = std::nullopt; + } + grouped_tensor_wrapper.set_rowwise(data, scale_inv); + + return std::move(grouped_tensor_wrapper); +} + Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, - Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, - bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, - bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { + Buffer_Type group_sizes, Buffer_Type group_offset_lhs, + Buffer_Type group_offset_out, Buffer_Type alpha, Buffer_Type beta, + Result_Type output, Result_Type workspace, size_t m, size_t n, size_t k, + bool lhs_is_trans, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, + bool has_bias, bool is_grouped_dense_wgrad, + bool use_async_d2h_group_sizes) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -446,6 +580,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type NVTE_CHECK(group_sizes.dimensions().size() == 1); size_t num_gemms = group_sizes.dimensions()[0]; + NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, + "Only non-quantized grouped GEMM is supported in current implementation."); + // It is weird that TE/Common GEMM only use colwise for MXFP8 const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || @@ -491,22 +628,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); size_t out_dtype_bytes = te_dtype_bytes(out_dtype); - if (is_tensor_scaling) { - size_t dpitch = tensor_scaling_sinv_aligment; - size_t spitch = lhs_sinv_dtype_bytes; - size_t width = lhs_sinv_dtype_bytes; - size_t height = lhs_sinv_size; - cudaMemcpy2DAsync(lhs_scatter_aligned_ptr, dpitch, lhs_sinv_ptr, spitch, width, height, - cudaMemcpyDeviceToDevice, stream); - spitch = rhs_sinv_dtype_bytes; - width = rhs_sinv_dtype_bytes; - height = rhs_sinv_size; - cudaMemcpy2DAsync(rhs_scatter_aligned_ptr, dpitch, rhs_sinv_ptr, spitch, width, height, - cudaMemcpyDeviceToDevice, stream); - lhs_sinv_ptr = lhs_scatter_aligned_ptr; - rhs_sinv_ptr = rhs_scatter_aligned_ptr; - } - NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); @@ -533,29 +654,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type " = ", expected_out_size, ", got ", actual_out_size); } - size_t dim_list_bytes = sizeof(int32_t) * num_gemms; - std::vector dim_list_host(num_gemms); - size_t host_num_gemms = 0; - if (use_async_d2h_group_sizes) { - host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); - NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, - " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); - } else { - auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); - } - size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - if (!is_grouped_dense_wgrad) { - NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, - ", got sum(group_sizes)=", sum_group_sizes); - } else { - NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, - ", got sum(group_sizes)=", sum_group_sizes); - } - auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); bool grad = false; bool accumulate = false; @@ -569,221 +667,86 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); } - // These lists are to keep the TensorWrapper objects alive - std::vector lhs_wrapper_list; - std::vector rhs_wrapper_list; - std::vector lhs_swizzle_wrapper_list; // For MXFP8 scale_inv swizzling - std::vector rhs_swizzle_wrapper_list; - std::vector bias_wrapper_list; - std::vector pre_gelu_wrapper_list; - std::vector out_wrapper_list; - std::vector workspace_wrapper_list; - - // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM - std::vector lhs_list; - std::vector rhs_list; - std::vector lhs_swizzle_list; - std::vector rhs_swizzle_list; - std::vector bias_list; - std::vector pre_gelu_list; - std::vector out_list; - std::vector workspace_list; - - size_t lhs_sinv_total_size = 0; - size_t rhs_sinv_total_size = 0; - - std::vector zero_out_dptr_list; - std::vector zero_out_size_list; - - for (size_t i = 0; i < num_gemms; i++) { - // Matrix data shapes - size_t m_i = dim_list_host[i]; - auto lhs_shape_i = std::vector{m_i, k}; - auto rhs_shape_i = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; - auto out_shape_i = std::vector{m_i, n}; - if (is_grouped_dense_wgrad) { - size_t k_i = dim_list_host[i]; - lhs_shape_i[0] = lhs_is_trans ? k_i : m; - lhs_shape_i[1] = lhs_is_trans ? m : k_i; - rhs_shape_i[0] = rhs_is_trans ? n : k_i; - rhs_shape_i[1] = rhs_is_trans ? k_i : n; - out_shape_i[0] = m; - out_shape_i[1] = n; - } - - size_t lhs_size = lhs_shape_i[0] * lhs_shape_i[1]; - size_t rhs_size = rhs_shape_i[0] * rhs_shape_i[1]; - size_t out_size = out_shape_i[0] * out_shape_i[1]; - bool is_empty_gemm = lhs_size == 0 || rhs_size == 0; - if (is_empty_gemm && out_size > 0) { - zero_out_dptr_list.push_back(out_ptr); - zero_out_size_list.push_back(out_size * out_dtype_bytes); - } - - // Set matrix data pointers - auto lhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto rhs_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto out_i = TensorWrapper(static_cast(out_ptr), out_shape_i, out_dtype); - void *lhs_vptr = static_cast(lhs_ptr); - void *rhs_vptr = static_cast(rhs_ptr); - if (rhs_use_colwise) // MatA to enter cuBLAS - rhs_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - else - rhs_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - if (lhs_use_colwise) // MatB to enter cuBLAS - lhs_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - else - lhs_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - - // Set scale_inv shapes and pointers - void *rhs_sinv_vptr = static_cast(rhs_sinv_ptr); - void *lhs_sinv_vptr = static_cast(lhs_sinv_ptr); - size_t lhs_sinv_size_i = 0; - size_t rhs_sinv_size_i = 0; - if (is_tensor_scaling) { - auto tensor_scaling_sinv_shape = std::vector{1}; - // If is_empty_gemm, scale_inv does not have the corresponding value, do not move the pointers - if (!is_empty_gemm) { - lhs_sinv_size_i = tensor_scaling_sinv_aligment / lhs_sinv_dtype_bytes; - rhs_sinv_size_i = tensor_scaling_sinv_aligment / rhs_sinv_dtype_bytes; - } - if (rhs_use_colwise) // MatA to enter cuBLAS - rhs_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); - else - rhs_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, tensor_scaling_sinv_shape); - if (lhs_use_colwise) // MatB to enter cuBLAS - lhs_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); - else - lhs_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, tensor_scaling_sinv_shape); - } else if (is_mxfp8_scaling) { - auto lhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - auto rhs_swizzle_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - void *swizzled_lhs_sinv_vptr = static_cast(swizzled_lhs_sinv_ptr); - void *swizzled_rhs_sinv_vptr = static_cast(swizzled_rhs_sinv_ptr); - - // {lhs, rhs}_swizzle_i point to unswizzled scale_inv data as input, while {lhs, rhs}_i - // point to swizzled scale_inv data (store on workspace, only used for GEMM). - // Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers - auto lhs_sinv_shape_i = - get_block_scale_shape(scaling_mode, lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); - auto rhs_sinv_shape_i = - get_block_scale_shape(scaling_mode, rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise); - lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1]; - rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1]; - if (lhs_use_colwise) { - lhs_swizzle_i.set_columnwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - lhs_swizzle_i.set_columnwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - lhs_i.set_columnwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - } else { - lhs_swizzle_i.set_rowwise_data(lhs_vptr, lhs_dtype, lhs_shape_i); - lhs_swizzle_i.set_rowwise_scale_inv(lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - lhs_i.set_rowwise_scale_inv(swizzled_lhs_sinv_vptr, lhs_sinv_dtype, lhs_sinv_shape_i); - } - lhs_i.set_with_gemm_swizzled_scales(true); - if (rhs_use_colwise) { - rhs_swizzle_i.set_columnwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - rhs_swizzle_i.set_columnwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - rhs_i.set_columnwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - } else { - rhs_swizzle_i.set_rowwise_data(rhs_vptr, rhs_dtype, rhs_shape_i); - rhs_swizzle_i.set_rowwise_scale_inv(rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - rhs_i.set_rowwise_scale_inv(swizzled_rhs_sinv_vptr, rhs_sinv_dtype, rhs_sinv_shape_i); - } - rhs_i.set_with_gemm_swizzled_scales(true); - - if (!is_empty_gemm) { - lhs_swizzle_wrapper_list.push_back(std::move(lhs_swizzle_i)); - rhs_swizzle_wrapper_list.push_back(std::move(rhs_swizzle_i)); - lhs_swizzle_list.push_back(lhs_swizzle_wrapper_list.back().data()); - rhs_swizzle_list.push_back(rhs_swizzle_wrapper_list.back().data()); - } - } else { - NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, - "Unsupported scaling mode: ", static_cast(scaling_mode)); - } - - auto bias_i = TensorWrapper(bias_ptr, bias_shape, bias_dtype); - auto pre_gelu_i = TensorWrapper(nullptr, std::vector{0}, out_dtype); - - // Update pointer for the next GEMM pair - lhs_ptr += lhs_size * lhs_dtype_bytes; - rhs_ptr += rhs_size * rhs_dtype_bytes; - out_ptr += out_size * out_dtype_bytes; - if (is_fp8_gemm) { - lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; - rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; - lhs_sinv_total_size += lhs_sinv_size_i; - rhs_sinv_total_size += rhs_sinv_size_i; - if (is_mxfp8_scaling) { - swizzled_lhs_sinv_ptr += lhs_sinv_size_i * lhs_sinv_dtype_bytes; - swizzled_rhs_sinv_ptr += rhs_sinv_size_i * rhs_sinv_dtype_bytes; - } - } - if (has_bias) bias_ptr += n * bias_dtype_bytes; - - // Move objects to the lists to keep them alive - if (is_empty_gemm) continue; - lhs_wrapper_list.push_back(std::move(lhs_i)); - rhs_wrapper_list.push_back(std::move(rhs_i)); - out_wrapper_list.push_back(std::move(out_i)); - bias_wrapper_list.push_back(std::move(bias_i)); - pre_gelu_wrapper_list.push_back(std::move(pre_gelu_i)); - - lhs_list.push_back(lhs_wrapper_list.back().data()); - rhs_list.push_back(rhs_wrapper_list.back().data()); - bias_list.push_back(bias_wrapper_list.back().data()); - pre_gelu_list.push_back(pre_gelu_wrapper_list.back().data()); - out_list.push_back(out_wrapper_list.back().data()); - } - - auto workspace_shape = std::vector{workspace_size}; - for (int i = 0; i < num_streams; i++) { - auto workspace_i = - TensorWrapper(static_cast(workspace_ptr), workspace_shape, DType::kByte); - workspace_wrapper_list.push_back(std::move(workspace_i)); - workspace_list.push_back(workspace_wrapper_list.back().data()); - workspace_ptr += workspace_size; - } - - if (is_fp8_gemm) { - if (is_tensor_scaling) { - lhs_sinv_size *= tensor_scaling_sinv_aligment; - rhs_sinv_size *= tensor_scaling_sinv_aligment; - } - NVTE_CHECK(lhs_sinv_total_size <= lhs_sinv_size, "Actual total lhs_sinv size ", - lhs_sinv_total_size, " exceeds estimated upper bound ", lhs_sinv_size); - NVTE_CHECK(rhs_sinv_total_size <= rhs_sinv_size, "Actual total rhs_sinv size ", - rhs_sinv_total_size, " exceeds estimated upper bound ", rhs_sinv_size); + constexpr size_t workspace_setup_size = 1024 * 1024; // HACK: dummy workspace for setup + TensorWrapper workspace_setup(workspace_ptr, std::vector{workspace_setup_size}, + DType::kByte); + TensorWrapper workspace_cublas(workspace_ptr + workspace_setup_size, + std::vector{workspace_size}, DType::kByte); + + TensorWrapper alpha_tensor(static_cast(alpha.untyped_data()), + std::vector{num_gemms}, + convert_ffi_datatype_to_te_dtype(alpha.element_type())); + TensorWrapper beta_tensor(static_cast(beta.untyped_data()), + std::vector{num_gemms}, + convert_ffi_datatype_to_te_dtype(beta.element_type())); + + if (is_grouped_dense_wgrad) { + NVTE_CHECK(lhs_is_trans && !rhs_is_trans, + "For grouped dense wgrad, only TN GEMM is supported in TE/JAX currently."); + + //// RHS + NVTEShape rhsShape{.data = {k, n}, .ndim = 2}; + auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); + rhs_tensor.set_group_info(group_sizes, group_offset_out, kNVTEGroupedFirstDims); + + //// LHS + NVTEShape lhsShape{.data = {k, m}, .ndim = 2}; + lhs_is_trans = true; + auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); + lhs_tensor.set_group_info(group_sizes, group_offset_lhs, kNVTEGroupedFirstDims); + + //// OUTPUT + NVTEShape outShape{.data = {num_gemms * m, n}, .ndim = 2}; + auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, + num_gemms, outShape); + + // Output needs to be zeroed in case any group sizes have size zero, meaning the expert weight isn't used in the fwd, meaning the corresponding output gradient should be zero. But using the grouped GEMM, the output buffer contains uninitialized data. + // TODO(jberchtold): make this memset smaller by only zeroing the expert weights that correspond to groups with size zero. + cudaMemsetAsync(output->untyped_data(), 0, output->size_bytes(), stream); + + nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, + alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), + workspace_cublas.data(), + nullptr, // config (use defaults) + stream); + + return ffi_with_cuda_error_check(); } - size_t num_non_empty_gemms = lhs_list.size(); + // Nominal case for FWD or DGRAD - if (is_mxfp8_scaling) { - for (int i = 0; i < num_non_empty_gemms; i++) { - // The i-th GEMM will use the (i % num_streams)-th stream to compute, - // use the same stream to swizzle the scaling factors to make sure that - // the swizzling is done before the GEMM computation starts. - int stream_id = i % num_streams; - cudaStream_t stream_i = nvte_get_compute_stream(stream_id); - nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i); - nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); - } + //// RHS + NVTEShape rhsShape{.data = {num_gemms * k, n}, .ndim = 2}; + if (rhs_is_trans) { + rhsShape.data[0] = num_gemms * n; + rhsShape.data[1] = k; } + auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM - size_t num_zero_outs = zero_out_dptr_list.size(); - for (int i = 0; i < num_zero_outs; i++) { - int stream_id = i % num_streams; - cudaStream_t stream_i = nvte_get_compute_stream(stream_id); - void *dptr = zero_out_dptr_list[i]; - size_t count = zero_out_size_list[i]; - NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i)); + //// LHS + NVTEShape lhsShape{.data = {m, k}, .ndim = 2}; + if (lhs_is_trans) { + std::swap(lhsShape.data[0], lhsShape.data[1]); } - - nvte_multi_tensor_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, - grad, workspace_list.data(), accumulate, use_split_accumulator, - num_math_sm, stream); + auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); + lhs_tensor.set_group_info(group_sizes, group_offset_lhs, + lhs_is_trans ? kNVTEGroupedLastDims : kNVTEGroupedFirstDims); + + //// OUTPUT + NVTEShape outShape{.data = {m, n}, .ndim = 2}; + auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, + num_gemms, outShape); + out_tensor.set_group_info(group_sizes, group_offset_out, kNVTEGroupedFirstDims); + + // This memset is required because the group sizes may not fill the full buffer since we overallocate for the worst case. However, in theory unused space on the grouped axis should not be utilizied downstream, but it seems like somehow it is utilized. + // TODO(jberchtold): try removing this + cudaMemsetAsync(output->untyped_data(), 0, output->size_bytes(), stream); + + nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, + alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), + workspace_cublas.data(), + nullptr, // config (use defaults) + stream); return ffi_with_cuda_error_check(); } @@ -797,7 +760,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Arg() // rhs_sinv .Arg() // bias .Arg() // group_sizes - .Arg() // group_offset + .Arg() // group_offset_lhs + .Arg() // group_offset_out + .Arg() // alpha + .Arg() // beta .Ret() // output .Ret() // workspace .Attr("M") @@ -808,7 +774,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("scaling_mode") .Attr("has_bias") .Attr("is_grouped_dense_wgrad") - .Attr("use_async_d2h_group_sizes")); + .Attr("use_async_d2h_group_sizes"), + FFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index dd7d2a47ba..98b043ef35 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -4,7 +4,11 @@ """Transformer Engine bindings for JAX""" from .module import DenseGeneral, LayerNorm from .module import LayerNormDenseGeneral, LayerNormMLP -from .module import wrap_function_in_te_state_module, make_dot_general_cls +from .module import ( + wrap_function_in_te_state_module, + make_dot_general_cls, + make_ragged_dot_cls, +) from .transformer import extend_logical_axis_rules from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import TransformerLayer, TransformerLayerType @@ -16,6 +20,7 @@ "LayerNormMLP", "wrap_function_in_te_state_module", "make_dot_general_cls", + "make_ragged_dot_cls", "extend_logical_axis_rules", "DotProductAttention", "MultiHeadAttention", diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 3d82d8f0b4..a661f30356 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -17,7 +17,7 @@ from jax.ad_checkpoint import checkpoint_name -from ..dense import dense +from ..dense import dense, grouped_dense from ..layernorm import canonicalize_norm_type from ..layernorm import layernorm @@ -377,6 +377,7 @@ def generate_quantizer_set( variable_collection: str = None, quantization_checkpoint_name: Optional[str] = None, fp8_recipe=None, + n_groups: int = None, ): """ Generate a set of FP8 meta for a GEMM. @@ -409,6 +410,7 @@ def generate_quantizer_set( fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set, checkpoint_name=quantization_checkpoint_name, + n_groups=n_groups, ) return quantizer_set @@ -1379,12 +1381,13 @@ def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] class TEWrapper(te.flax.module.TransformerEngineBase): """Wrapper Flax module for TransformerEngine quantization support.""" - def generate_quantizer_set(self, postfix: str = ""): + def generate_quantizer_set(self, postfix: str = "", n_groups: int = None): OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" return super().generate_quantizer_set( postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, fp8_recipe=quantization_recipe, + n_groups=n_groups, ) @nn.compact @@ -1438,3 +1441,24 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): ) return wrap_function_in_te_state_module(te_dot_general, quantization_recipe, "dot_general") + + +def make_ragged_dot_cls(quantization_recipe): + assert quantization_recipe is None, "Ragged dot grouped GEMM does not support quantization yet" + + def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): + num_groups = group_sizes.shape[0] + quantizer_set = generate_quantizer_set(n_groups=num_groups) + + out = grouped_dense( + x, + kernel, + group_sizes=group_sizes, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, + ) + return out + + return wrap_function_in_te_state_module( + te_grouped_dot_general, quantization_recipe, "ragged_dot" + )() diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index f5ca6aeaed..1923932692 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -68,7 +68,7 @@ def compute_scale_from_amax( sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale) assert sf.shape == (1,), f"Expected sf.shape == (1,), but got {sf.shape}" - return sf + return sf.astype(jnp.float32) @register_pytree_node_class