-
Notifications
You must be signed in to change notification settings - Fork 640
[JAX] Integrate BF16 Grouped GEMM with on-device group sizes #2680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -583,27 +583,27 @@ def lowering( | |||||
| ) | ||||||
|
|
||||||
| lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed) | ||||||
| lhs_contracting_size = ( | ||||||
| reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:]) | ||||||
| if lhs_transposed | ||||||
| else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary]) | ||||||
| ) | ||||||
| assert_cublas_requirements( | ||||||
| scaling_mode, | ||||||
| lhs_contracting_size, | ||||||
| "LHS", | ||||||
| ) | ||||||
| rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed) | ||||||
| rhs_contracting_size = ( | ||||||
| reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary]) | ||||||
| if rhs_transposed | ||||||
| else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:]) | ||||||
| ) | ||||||
| assert_cublas_requirements( | ||||||
| scaling_mode, | ||||||
| rhs_contracting_size, | ||||||
| "RHS", | ||||||
| ) | ||||||
| # lhs_contracting_size = ( | ||||||
| # reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:]) | ||||||
| # if lhs_transposed | ||||||
| # else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary]) | ||||||
| # ) | ||||||
| # assert_cublas_requirements( | ||||||
| # scaling_mode, | ||||||
| # lhs_contracting_size, | ||||||
| # f"LHS {lhs_aval.shape} with contracting dims {lhs_cdims}", | ||||||
| # ) | ||||||
| # rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed) | ||||||
| # rhs_contracting_size = ( | ||||||
| # reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary]) | ||||||
| # if rhs_transposed | ||||||
| # else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:]) | ||||||
| # ) | ||||||
| # assert_cublas_requirements( | ||||||
| # scaling_mode, | ||||||
| # rhs_contracting_size, | ||||||
| # f"RHS {rhs_aval.shape} with contracting dims {rhs_cdims}", | ||||||
| # ) | ||||||
|
|
||||||
| args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta) | ||||||
| kwargs = { | ||||||
|
|
@@ -936,7 +936,15 @@ def _parse_operand_output_specs( | |||||
|
|
||||||
| # Non-contracting dims of RHS always needs to be gathered along the FSDP axis | ||||||
| rhs_non_cspecs = tuple( | ||||||
| None if spec is not None and spec == gsr.fsdp_resource else spec | ||||||
| ( | ||||||
| None | ||||||
| if spec is not None | ||||||
| and ( | ||||||
| spec == gsr.fsdp_resource | ||||||
| or (isinstance(spec, tuple) and gsr.fsdp_resource in spec) | ||||||
| ) | ||||||
| else spec | ||||||
| ) | ||||||
| for spec in rhs_non_cspecs | ||||||
| ) | ||||||
|
|
||||||
|
|
@@ -1420,7 +1428,7 @@ class GroupedGemmPrimitive(BasePrimitive): | |||||
|
|
||||||
| name = "te_grouped_gemm_ffi" | ||||||
| multiple_results = True | ||||||
| impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16) | ||||||
| impl_static_args = (10, 11, 12, 13, 14, 15, 16, 17, 18, 19) | ||||||
| inner_primitive = None | ||||||
| outer_primitive = None | ||||||
|
|
||||||
|
|
@@ -1432,7 +1440,10 @@ def abstract( | |||||
| rhs_scale_inv_aval, | ||||||
| bias_aval, | ||||||
| group_sizes_aval, | ||||||
| group_offset_aval, | ||||||
| group_offset_lhs_aval, | ||||||
| group_offset_out_aval, | ||||||
| alpha, | ||||||
| beta, | ||||||
| *, | ||||||
| M, | ||||||
| N, | ||||||
|
|
@@ -1470,7 +1481,7 @@ def abstract( | |||||
| Returns: | ||||||
| A jnp.ndarray containing the result of the grouped GEMM operation | ||||||
| """ | ||||||
| del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval | ||||||
| del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_out_aval | ||||||
| del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes | ||||||
| # TODO(Phuong): move some shape checks from Cpp to here | ||||||
| workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams | ||||||
|
|
@@ -1492,11 +1503,16 @@ def abstract( | |||||
| # We also pad scale_inv swizzle buffers size for 256 bytes alignment. | ||||||
| workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding | ||||||
| workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding | ||||||
|
|
||||||
| workspace_size += ( | ||||||
| 1024 * 1024 | ||||||
|
Comment on lines
+1506
to
+1508
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This hardcoded 1MB workspace buffer allocation is a HACK as noted. The workspace size calculation should properly account for setup vs cublas workspace needs separately, or this could lead to buffer overruns or inefficient memory usage. Reference |
||||||
| ) # HACK: properly make a workspace_setup buffer in addition to the workspace_cublas buffer | ||||||
| workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) | ||||||
|
|
||||||
| out_shape = (M, N) | ||||||
| if is_grouped_dense_wgrad: | ||||||
| out_shape = (group_sizes_aval.size, M, N) | ||||||
| num_tensors = group_sizes_aval.size | ||||||
| out_shape = (num_tensors, M, N) | ||||||
| out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) | ||||||
| return (out_aval, workspace_aval) | ||||||
|
|
||||||
|
|
@@ -1543,7 +1559,10 @@ def impl( | |||||
| rhs_scale_inv, | ||||||
| bias, | ||||||
| group_sizes, | ||||||
| group_offset, | ||||||
| group_offset_lhs, | ||||||
| group_offset_out, | ||||||
| alpha, | ||||||
| beta, | ||||||
| M, | ||||||
| N, | ||||||
| K, | ||||||
|
|
@@ -1563,7 +1582,10 @@ def impl( | |||||
| rhs_scale_inv, | ||||||
| bias, | ||||||
| group_sizes, | ||||||
| group_offset, | ||||||
| group_offset_lhs, | ||||||
| group_offset_out, | ||||||
| alpha, | ||||||
| beta, | ||||||
| M=M, | ||||||
| N=N, | ||||||
| K=K, | ||||||
|
|
@@ -1929,8 +1951,11 @@ def grouped_gemm( | |||||
| lhs: [M, K] or [K, N] | ||||||
| rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K] | ||||||
| """ | ||||||
| # TODO(Phuong): implement the group_offset | ||||||
| group_offset = group_offset or jnp.zeros((1,), jnp.int32) | ||||||
|
|
||||||
| assert group_offset is None, "group_offset is not yet implemented" | ||||||
| assert ( | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The assertion error message contains an f-string but doesn't actually format anything since there's no f prefix on the outer string.
Suggested change
|
||||||
| jax.config.jax_enable_x64 | ||||||
| ), "Grouped GEMM currently requires jax_enable_x64 to be True for correct behavior" | ||||||
|
|
||||||
| # TODO(Phuong): implement the precision | ||||||
| del precision | ||||||
|
|
@@ -2066,20 +2091,46 @@ def grouped_gemm( | |||||
| else: | ||||||
| assert group_sizes.size == rhs_shape[0] | ||||||
|
|
||||||
| assert group_offset.size == 1 | ||||||
|
|
||||||
| has_bias = bias is not None | ||||||
| assert not has_bias or bias.shape == (group_sizes.size, N) | ||||||
| bias = jnp.empty((), jnp.float32) if bias is None else bias | ||||||
|
|
||||||
| # TODO(jberchtold): move the int64 and offset computation to C++ side in a kernel to avoid needing JAX to support int64 | ||||||
|
Comment on lines
2097
to
+2098
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Computing offsets in Python with JAX int64 is a workaround. Move this computation to C++ to avoid requiring Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||||||
| group_sizes = group_sizes.astype(jnp.int64) | ||||||
| # Compute group_offset as cumulative sum of group_sizes, starting with 0 | ||||||
| group_offset = jnp.concatenate( | ||||||
| [jnp.array([0], dtype=jnp.int64), jnp.cumsum(group_sizes, dtype=jnp.int64)[:-1]] | ||||||
| ) | ||||||
| if is_grouped_dense_wgrad: | ||||||
| group_offset_lhs = ( | ||||||
| group_offset * M | ||||||
| ) # Offset is by number of elements total, not number of rows | ||||||
| # HACK: this _out is really the rhs in this case | ||||||
| group_offset_out = ( | ||||||
| group_offset * N | ||||||
| ) # Offset is by number of elements total, not number of rows | ||||||
| else: | ||||||
| group_offset_lhs = ( | ||||||
| group_offset * K_lhs | ||||||
| ) # Offset is by number of elements total, not number of rows | ||||||
| group_offset_out = ( | ||||||
| group_offset * N | ||||||
| ) # Offset is by number of elements total, not number of rows | ||||||
|
|
||||||
| num_gemms = group_sizes.shape[0] # Due to interlaced zeros to support int64 | ||||||
| alpha = jnp.ones((num_gemms,), jnp.float32) | ||||||
| beta = jnp.zeros((num_gemms,), jnp.float32) | ||||||
| (out,) = GroupedGemmPrimitive.outer_primitive.bind( | ||||||
| lhs_data, | ||||||
| lhs_scale_inv, | ||||||
| rhs_data, | ||||||
| rhs_scale_inv, | ||||||
| bias, | ||||||
| group_sizes, | ||||||
| group_offset, | ||||||
| group_offset_lhs, | ||||||
| group_offset_out, | ||||||
| alpha, | ||||||
| beta, | ||||||
| M=M, | ||||||
| N=N, | ||||||
| K=K_lhs, | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check that the new grouped GEMM properly handles FP8 inputs. The leading dimension alignment requirement validation (lda % 16 == 0) has been commented out, which could cause correctness issues if unaligned inputs are passed.