diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 7f15eec892..7a4fa268af 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -170,6 +170,7 @@ def assert_dequantized_grouped_scaled_tensor( ("quick_gelu", "linear"), ("squared_relu",), ("squared_relu", "linear"), + ("clamped_silu", "clamped_linear"), ] ACTIVATION_TYPES = { @@ -182,17 +183,21 @@ def assert_dequantized_grouped_scaled_tensor( class TestActivation: - def ref_act(self, x, activation_type): - return _jax_act_lu(x, activation_type).data + def ref_act(self, x, activation_type, act_params): + return _jax_act_lu(x, activation_type, act_params=act_params).data - def value_n_grad_ref_func(self, x, activation_type): + def value_n_grad_ref_func(self, x, activation_type, act_params): jitted_reference = jit( - value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,)) + value_and_grad( + lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,) + ) ) return jitted_reference(x) - def primitive_func(self, inputs, activation_type, quantizer): - out = activation(inputs, activation_type=activation_type, quantizer=quantizer) + def primitive_func(self, inputs, activation_type, quantizer, act_params): + out = activation( + inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params + ) return jnp.mean(out) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @@ -209,12 +214,20 @@ def test_act_grad(self, shape, activation_type): x = jnp.repeat(x, len(activation_type), axis=-2) value_n_grad_primitive_func = jit( - value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) + value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3) ) - - prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None) - ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) - + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params) + ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype) @@ -234,7 +247,8 @@ def test_act_grad_with_tensor_scaling_fp8( self.activation_type = activation_type value_n_grad_primitive_func = jit( - value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) + value_and_grad(self.primitive_func, (0,)), + static_argnums=(1, 3), ) quantizer = QuantizerFactory.create( @@ -242,9 +256,21 @@ def test_act_grad_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) - prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer) - ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + prim_out, (prim_grad,) = value_n_grad_primitive_func( + x, activation_type, quantizer, act_params + ) + ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type) @@ -273,10 +299,18 @@ def test_act_forward_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=q_layout, ) - - te_output = tex.act_lu(x, activation_type, te_quantizer) - jax_output = _jax_act_lu(x, activation_type, jax_quantizer) - + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + te_output = tex.act_lu(x, activation_type, te_quantizer, act_params) + jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params) assert_bitwise_scaled_tensors(te_output, jax_output) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @@ -296,10 +330,18 @@ def test_act_forward_with_block_scaling_fp8( quantizer = QuantizerFactory.create( scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) - - output = tex.act_lu(x, activation_type, quantizer) - ref_out = self.ref_act(x, activation_type) - + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + output = tex.act_lu(x, activation_type, quantizer, act_params) + ref_out = self.ref_act(x, activation_type, act_params) assert_dequantized_scaled_tensor(output, ref_out) @@ -734,6 +776,7 @@ def test_quantize_dbias( def _test_quantize_dact_dbias( self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout ): + key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1) @@ -785,7 +828,7 @@ def _test_quantize_dact_dbias( (in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()) # Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently. or ( - activation_type == ("squared_relu",) + activation_type in {("squared_relu",), ("clamped_silu", "clamped_linear")} and in_dtype == jnp.bfloat16 and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING ) diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index e50d71040d..4e48088586 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -39,6 +39,7 @@ enum class NVTE_Activation_Type { QGEGLU, SRELU, SREGLU, + CLAMPED_SWIGLU }; /*! \brief Computes the GeLU activation of the input. diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index ca37a28319..93086bd827 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -924,7 +924,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) template -void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, +void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, cudaStream_t stream) { checkCuDriverContext(stream); @@ -1006,7 +1006,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu template -void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, +void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, cudaStream_t stream) { checkCuDriverContext(stream); @@ -1138,7 +1138,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); - NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::COLWISE: @@ -1155,7 +1154,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise, p); NVTE_CHECK_CUDA(cudaGetLastError()); break; @@ -1180,7 +1178,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out } template -void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) { +void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) { CheckInputTensor(input, "gated_act_input"); CheckOutputTensor(*output, "gated_act_output"); NVTE_CHECK(input.flat_last_dim() % 2 == 0, @@ -1213,7 +1211,7 @@ void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t st template -void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP &p, +void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) { CheckInputTensor(grad, "dgated_act_grad"); CheckInputTensor(input, "dgated_act_input"); @@ -1252,7 +1250,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamO template -void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, +void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, cudaStream_t stream) { constexpr bool allow_empty = false; CheckInputTensor(gated_input, "gated_input"); @@ -1318,7 +1316,7 @@ namespace detail { template void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, - ParamOP &p, cudaStream_t stream) { + ParamOP p, cudaStream_t stream) { using namespace gated_kernels; Tensor grad_empty_tensor; const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index 12b35ec43c..daa3679c48 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -11,7 +11,6 @@ import jax import jax.numpy as jnp - from . import cpp_extensions as tex from .quantize.tensor import NoScaleTensor @@ -22,6 +21,7 @@ def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[tex.activation.ActivationParams] = None, ) -> jnp.ndarray: """Apply activation functions to input tensor with optional quantization. @@ -32,17 +32,19 @@ def activation( x: Input tensor to apply activations to activation_type: Sequence of activation functions quantizer: Optional quantizer for quantizing the output + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Activated output tensor """ assert x.shape[-1] % len(activation_type) == 0 - output = _activation(x, activation_type, quantizer) + output = _activation(x, activation_type, quantizer, act_params) return output -@partial(jax.custom_vjp, nondiff_argnums=(1,)) -def _activation(x, activation_type, quantizer): +@partial(jax.custom_vjp, nondiff_argnums=(1, 3)) +def _activation(x, activation_type, quantizer, act_params): """Internal implementation of activation with custom VJP. This function implements the core activation logic with support for @@ -52,36 +54,42 @@ def _activation(x, activation_type, quantizer): x: Input tensor activation_type: Sequence of activation functions quantizer: Optional quantizer + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Activated tensor """ - _output, _ = _activation_fwd_rule(x, activation_type, quantizer) + _output, _ = _activation_fwd_rule(x, activation_type, quantizer, act_params) return _output -def _activation_fwd_rule(x, activation_type, quantizer): +def _activation_fwd_rule(x, activation_type, quantizer, act_params): """Forward pass rule for activation function. Args: x: Input tensor activation_type: Sequence of activation functions quantizer: Optional quantizer + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Tuple of (output, context) for backward pass """ - fwd_output = tex.act_lu(x, activation_type, quantizer) + fwd_output = tex.act_lu(x, activation_type, quantizer, act_params) # This is a no-op for higher-precision tensors fwd_output = fwd_output.dequantize() return fwd_output, (x, quantizer) -def _activation_bwd_rule(activation_type, ctx, g): +def _activation_bwd_rule(activation_type, act_params, ctx, g): """Backward pass rule for activation function. Args: activation_type: Sequence of activation functions + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. ctx: Context from forward pass g: Gradient from upstream @@ -90,7 +98,7 @@ def _activation_bwd_rule(activation_type, ctx, g): """ (x, _) = ctx assert x.dtype == g.dtype - dx = tex.dact_lu(g, x, activation_type) + dx = tex.dact_lu(g, x, activation_type, act_params=act_params) # No quantization is used in this VJP backward, so the output should # always be a NoScaleTensor assert isinstance(dx, NoScaleTensor) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index a8c14a6087..925c1d01ae 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -5,6 +5,7 @@ from typing import Sequence, Union, Callable, Optional, Tuple import operator from functools import reduce, partial +from dataclasses import dataclass import jax import jax.numpy as jnp @@ -12,9 +13,9 @@ from jax.experimental.custom_partitioning import SdyShardingRule from jax.sharding import PartitionSpec +import numpy as np import transformer_engine_jax from transformer_engine_jax import NVTE_Activation_Type - from .base import BasePrimitive, register_primitive from .misc import ( jax_dtype_to_te_dtype, @@ -51,17 +52,87 @@ ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU, ("squared_relu",): NVTE_Activation_Type.SRELU, ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU, + ("clamped_silu", "clamped_linear"): NVTE_Activation_Type.CLAMPED_SWIGLU, } -def _convert_to_activation_function(fn_or_string): +@dataclass(frozen=True) +class ClampedSwigluParams: + """Parameters for the Clamped SwiGLU activation function + used in GPT OSS.""" + + limit: float = 7.0 + alpha: float = 1.702 + + def __hash__(self): + """Custom hash function to ensure dataclass is hashable for jax jit to work. + + Returns: + int: Hash value of the dataclass instance. + """ + return hash((self.limit, self.alpha)) + + def to_ffi_lowering_dict(self): + """Convert the activation parameters to a dictionary format for FFI lowering. + + Returns: + dict: A dictionary representation of the activation parameters consumable by + XLA FFI bindings for activation functions. + """ + return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} + + +@dataclass(frozen=True) +class ActivationParams: + """Parameters for various activation functions. + Currently only Clamped SwiGLU activation has parameters. + """ + + clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams() + + @staticmethod + def create(activation_type, **kwargs): + """Factory method to create ActivationParams based on activation_type.""" + CLAMPED_ACTIVATION_TYPES = { + ("clamped_silu", "clamped_linear"), + "clamped_silu", + "clamped_linear", + } + if activation_type in CLAMPED_ACTIVATION_TYPES: + return ActivationParams(ClampedSwigluParams(**kwargs)) + return ActivationParams() # Default params for activations without parameters + + def __hash__(self): + """Custom hash function to ensure dataclass is hashable for jax jit to work""" + return hash((self.clamped_swiglu,)) + + def to_ffi_lowering_dict(self): + """Convert the activation parameters to a dictionary format for FFI lowering. + Returns: + dict: A dictionary representation of the activation parameters consumable by + XLA FFI bindings for activation functions. + """ + return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()} + + +def _convert_to_activation_function(fn_or_string, act_params: ActivationParams): """Convert a string to an activation function.""" if fn_or_string == "linear": return lambda x: x + if fn_or_string == "clamped_linear": + # This function is used for ClampedSwiGLU + # used in GPT OSS where the gates are not only clamped + # but also shifted by +1 + limit = act_params.clamped_swiglu.limit + return lambda x: jnp.clip(x, min=-limit, max=limit) + 1 if fn_or_string == "quick_gelu": return lambda x: jax.nn.sigmoid(1.702 * x) * x if fn_or_string == "squared_relu": return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)]) + if fn_or_string == "clamped_silu": + limit = act_params.clamped_swiglu.limit + alpha = act_params.clamped_swiglu.alpha + return lambda x: jax.nn.sigmoid(alpha * jnp.minimum(x, limit)) * jnp.minimum(x, limit) if isinstance(fn_or_string, str): return getattr(jax.nn, fn_or_string) if callable(fn_or_string): @@ -84,7 +155,8 @@ class ActLuPrimitive(BasePrimitive): 6, 7, 8, - ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer + 9, + ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer, act_params inner_primitive = None outer_primitive = None @@ -100,11 +172,12 @@ def abstract( is_2x, scale_dtype, is_outer, + act_params, ): """ te_act_lu_p abstract """ - del act_enum + del act_enum, act_params dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 @@ -150,6 +223,7 @@ def lowering( is_2x, scale_dtype, is_outer, + act_params, ): """ te_gated_act_lu_p lowering rules @@ -158,9 +232,14 @@ def lowering( x_aval, scale_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 - out = ffi.ffi_lowering(ActLuPrimitive.name)( - ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x + ctx, + x, + scale, + act_enum=act_enum, + scaling_mode=scaling_mode.value, + is_2x=is_2x, + act_params=act_params.to_ffi_lowering_dict(), ) return out @@ -175,6 +254,7 @@ def impl( is_2x, scale_dtype, is_outer, + act_params, ): """ to describe implementation @@ -193,6 +273,7 @@ def impl( is_2x=is_2x, scale_dtype=scale_dtype, is_outer=False, + act_params=act_params, ) ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -221,6 +302,7 @@ def batcher( is_2x, scale_dtype, is_outer, + act_params, ): """ to describe batch rules for vmap @@ -242,6 +324,7 @@ def batcher( scaling_mode=scaling_mode, is_2x=is_2x, scale_dtype=scale_dtype, + act_params=act_params, ), out_bdims, ) @@ -255,6 +338,7 @@ def infer_sharding_from_operands( is_2x, scale_dtype, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -266,6 +350,7 @@ def infer_sharding_from_operands( scale_dtype, act_len, is_outer, + act_params, ) # Unused. x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) @@ -318,6 +403,7 @@ def partition( is_2x, scale_dtype, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -378,6 +464,7 @@ def sharded_impl(x, scale): is_2x=is_2x, scale_dtype=scale_dtype, is_outer=True, + act_params=act_params, ) ) @@ -405,11 +492,12 @@ def shardy_sharding_rule( is_2x, scale_dtype, is_outer, + act_params, mesh, value_types, result_types, ): - del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types + del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types, act_params prefix = "ActLu_" input_shape = value_types[0].shape output_shape = input_shape[:-2] + input_shape[-1:] @@ -455,8 +543,8 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): name = "te_dact_dbias_quantize_ffi" multiple_results = True - # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer - impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10) + # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer, act_params + impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11) inner_primitive = None outer_primitive = None @@ -474,11 +562,12 @@ def abstract( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p abstract """ - del act_enum + del act_enum, act_params dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype) assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype == dz_dtype @@ -575,6 +664,7 @@ def lowering( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p lowering rules @@ -593,6 +683,7 @@ def lowering( is_2x=is_2x, is_dbias=is_dbias, act_enum=int(act_enum), + act_params=act_params.to_ffi_lowering_dict(), ) @staticmethod @@ -608,6 +699,7 @@ def impl( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p impl @@ -627,6 +719,7 @@ def impl( act_enum=act_enum, act_len=act_len, is_outer=False, + act_params=act_params, ) ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -655,6 +748,7 @@ def batcher( act_enum, act_len, is_outer, + act_params, ): """ to describe batch rules for vmap @@ -685,6 +779,7 @@ def batcher( is_dbias=is_dbias, act_enum=act_enum, act_len=act_len, + act_params=act_params, ), out_bdims, ) @@ -699,11 +794,12 @@ def infer_sharding_from_operands( act_enum, act_len, is_outer, + act_params, mesh, arg_infos, result_infos, ): - del out_dtype, result_infos, act_enum + del out_dtype, result_infos, act_enum, act_params del scale_dtype, act_len, is_outer x_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[2]) @@ -774,6 +870,7 @@ def partition( act_enum, act_len, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -854,6 +951,7 @@ def sharded_impl(dz, x, scale): act_enum=act_enum, act_len=act_len, is_outer=True, + act_params=act_params, ) ) if is_dbias: @@ -880,11 +978,13 @@ def shardy_sharding_rule( act_enum, act_len, is_outer, + act_params, mesh, value_types, result_types, ): - del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types + + del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types, act_params prefix = "DActLuDBias_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2 @@ -923,20 +1023,22 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" -def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]: +def _jax_act_lu( + inputs, activation_type, quantizer=None, act_params: Optional[ActivationParams] = None +) -> Union[NoScaleTensor, ScaledTensor]: """ JAX native activation implementation """ + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert inputs.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {inputs.shape} and act_len {act_len}" ) - x = jnp.split(inputs, act_len, axis=-2) acts = [] for idx, act_fn in enumerate(activation_type): - x_i = _convert_to_activation_function(act_fn)(x[idx]) + x_i = _convert_to_activation_function(act_fn, act_params)(x[idx]) acts.append(x_i) x = reduce(operator.mul, acts) x = jnp.squeeze(x, axis=-2) @@ -951,10 +1053,12 @@ def _jax_quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]], is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ): """ JAX implementation of dact_lu and dbias with optional quantization """ + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -962,7 +1066,8 @@ def _jax_quantize_dact_dbias( ) _, vjp_func = jax.vjp( - partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) + partial(_jax_act_lu, activation_type=activation_type, act_params=act_params), + x.astype(jnp.float32), ) # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards. dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None) @@ -985,6 +1090,7 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -1008,24 +1114,22 @@ def act_lu( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {x.shape} and act_len {act_len}" ) - + act_params = act_params if act_params is not None else ActivationParams() if not ActLuPrimitive.enabled(): - return _jax_act_lu(x, activation_type, quantizer) + return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support colwise-only quantization yet if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: - return _jax_act_lu(x, activation_type, quantizer) - + return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support 2x quantization for DelayedScaling yet war_output = try_apply_delayed_scaling_2x_war( - f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer + f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer, act_params=act_params ) if war_output is not None: return war_output scale = jnp.empty((1,), jnp.float32) output_shape = (*x.shape[:-2], x.shape[-1]) - if quantizer is None: out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind( x, @@ -1037,6 +1141,7 @@ def act_lu( is_2x=False, scale_dtype=jnp.float32, is_outer=True, + act_params=act_params, ) out = out.reshape(output_shape) out = NoScaleTensor( @@ -1051,6 +1156,7 @@ def act_lu( x=x, activation_type=activation_type, quantizer=None, + act_params=act_params, ) out, _ = _quantize_dbias_impl( out, @@ -1060,7 +1166,6 @@ def act_lu( amax_scope=amax_scope, ) return out - if isinstance(quantizer, DelayedScaleQuantizer): scale = quantizer.scale @@ -1080,6 +1185,7 @@ def act_lu( is_2x=quantizer.is_2x2x(), scale_dtype=quantizer.get_scale_dtype(), is_outer=True, + act_params=act_params, ) quantizer.update(updated_amax) @@ -1102,6 +1208,7 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1118,7 +1225,7 @@ def quantize_dact_dbias( - The gradient of the activation with respect to the input. - The gradient of the activation with respect to the bias. """ - + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -1131,8 +1238,7 @@ def quantize_dact_dbias( if not PrimitiveClass.enabled() or ( quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE ): - return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) - + return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params) if quantizer is None: output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( dz, @@ -1148,6 +1254,7 @@ def quantize_dact_dbias( act_enum=act_type_id, act_len=act_len, is_outer=True, + act_params=act_params, ) output = output.astype(x.dtype) dbias = None @@ -1163,7 +1270,11 @@ def quantize_dact_dbias( # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): out = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None + dz.astype(jnp.float32), + x.astype(jnp.float32), + activation_type, + quantizer=None, + act_params=act_params, ) return _quantize_dbias_impl( out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 @@ -1180,6 +1291,7 @@ def quantize_dact_dbias( is_dbias=is_dbias, quantizer=quantizer, flatten_axis=-2, + act_params=act_params, ) if war_output is not None: return war_output @@ -1191,6 +1303,7 @@ def quantize_dact_dbias( x=x, activation_type=activation_type, quantizer=None, + act_params=act_params, ) out, dbias = _quantize_dbias_impl( out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 @@ -1203,7 +1316,10 @@ def quantize_dact_dbias( # TE/common dact_dbias_quantize does not support gated act yet if is_dbias and is_gated: dgated = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type + dz.astype(jnp.float32), + x.astype(jnp.float32), + activation_type=activation_type, + act_params=act_params, ) out, dbias = _quantize_dbias_impl( dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 @@ -1229,6 +1345,7 @@ def quantize_dact_dbias( act_enum=act_type_id, act_len=act_len, is_outer=True, + act_params=act_params, ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise @@ -1257,6 +1374,7 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1270,11 +1388,13 @@ def dact_lu( Returns: The gradient of the activation with respect to the input. """ + act_params = act_params if act_params is not None else ActivationParams() output, _ = quantize_dact_dbias( dz=dz, x=x, activation_type=activation_type, is_dbias=False, quantizer=quantizer, + act_params=act_params, ) return output diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 2ab95002fa..bbfc62120a 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -36,6 +36,15 @@ namespace transformer_engine { namespace jax { +struct ClampedSwigluConfig { + float limit; + float alpha; +}; + +struct ActivationConfig { + ClampedSwigluConfig clamped_swiglu; +}; + inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } // Activation @@ -137,6 +146,14 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); } // namespace jax } // namespace transformer_engine +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig, + ::xla::ffi::StructMember("limit"), + ::xla::ffi::StructMember("alpha")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::ActivationConfig, + ::xla::ffi::StructMember("clamped_swiglu")); + // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op); diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index b2b3db52c8..0ecf791505 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -18,7 +18,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, - bool is_2x_int) { + bool is_2x_int, ActivationConfig act_params) { + // parameters for clamped swiglu used in GPT OSS + auto swiglu_limit = act_params.clamped_swiglu.limit; + auto swiglu_alpha = act_params.clamped_swiglu.alpha; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -125,6 +128,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal case NVTE_Activation_Type::SREGLU: nvte_sreglu(input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::CLAMPED_SWIGLU: + nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha, + stream); + break; default: NVTE_ERROR("Unsupported ActivationEnum"); break; @@ -145,17 +152,19 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Ret() // amax .Attr("act_enum") .Attr("scaling_mode") - .Attr("is_2x"), + .Attr("is_2x") + .Attr("act_params"), FFI_CudaGraph_Traits); Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, int64_t act_enum, - JAXX_Scaling_Mode scaling_mode, bool is_2x_int) { + JAXX_Scaling_Mode scaling_mode, bool is_2x_int, + ActivationConfig act_params) { return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, - act_enum, scaling_mode, is_2x_int); + act_enum, scaling_mode, is_2x_int, act_params); } XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, @@ -170,7 +179,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, .Ret() // amax .Attr("act_enum") .Attr("scaling_mode") - .Attr("is_2x")); + .Attr("is_2x") + .Attr("act_params")); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, @@ -240,7 +250,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, - int64_t act_enum, bool is_2x, bool is_dbias) { + int64_t act_enum, bool is_2x, bool is_dbias, + ActivationConfig act_params) { + // parameters for clamped swiglu used in GPT OSS + auto swiglu_limit = act_params.clamped_swiglu.limit; + auto swiglu_alpha = act_params.clamped_swiglu.alpha; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -407,6 +421,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, case NVTE_Activation_Type::SREGLU: nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::CLAMPED_SWIGLU: + nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + swiglu_limit, swiglu_alpha, stream); + break; default: NVTE_ERROR("Unsupported ActivationEnum"); break; @@ -432,21 +450,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Attr("scaling_mode") .Attr("act_enum") .Attr("is_2x") - .Attr("is_dbias"), + .Attr("is_dbias") + .Attr("act_params"), FFI_CudaGraph_Traits); -Error_Type DActLuDBiasQuantizeInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, - Buffer_Type act_input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type colwise_output_buf, - Result_Type scale_inv_buf, - Result_Type colwise_scale_inv_buf, Result_Type amax_buf, - Result_Type dbias_buf, Result_Type workspace_buf, - JAXX_Scaling_Mode scaling_mode, int64_t act_enum, - bool is_2x, bool is_dbias) { +Error_Type DActLuDBiasQuantizeInitializeFFI( + cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf, + Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, + Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x, + bool is_dbias, ActivationConfig act_params) { return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, act_input_buf, scale_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf, - workspace_buf, scaling_mode, act_enum, is_2x, is_dbias); + workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params); } XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, @@ -466,7 +483,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, .Attr("scaling_mode") .Attr("act_enum") .Attr("is_2x") - .Attr("is_dbias")); + .Attr("is_dbias") + .Attr("act_params")); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 36dd8205bf..23d46b3384 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -143,6 +143,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("QGEGLU", NVTE_Activation_Type::QGEGLU) .value("SRELU", NVTE_Activation_Type::SRELU) .value("SREGLU", NVTE_Activation_Type::SREGLU) + .value("CLAMPED_SWIGLU", NVTE_Activation_Type::CLAMPED_SWIGLU) .export_values(); pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index c548c54efa..f02876d8f4 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -898,6 +898,10 @@ class LayerNormMLP(TransformerEngineBase): activations: Sequence[Union[str, Callable]], default = ('relu',) The sequence of activation functions to apply after the first dense layer transformation. Each activation has its own transformation layer. + activation_params: dict, default = None + The parameters needed(if any) by the activation functions specified in :attr:`activations`. + At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS + need additional parameters. intermediate_dropout_rng_name: str, default = 'dropout' The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. intermediate_dropout_rate: float, default = 0.1 @@ -956,6 +960,7 @@ class LayerNormMLP(TransformerEngineBase): bias_axes_2: Tuple[str, ...] = ("embed",) return_layernorm_output: bool = True activations: Sequence[Union[str, Callable]] = ("relu",) + activation_params: dict = None intermediate_dropout_rng_name: str = "dropout" intermediate_dropout_rate: float = 0.1 intermediate_hidden_dropout_dims: Sequence[int] = () @@ -1023,6 +1028,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: ("relu", "linear"), ("quick_gelu", "linear"), ("squared_relu", "linear"), + ("clamped_silu", "clamped_linear"), ] act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)] normalized_acts = [] @@ -1031,7 +1037,9 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: return False normalized_acts.append(act.lower()) normalized_acts = tuple( - reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts + reversed(normalized_acts) + if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear") + else normalized_acts ) is_act_implemented = normalized_acts in (gated_act_pool + act_pool) @@ -1150,6 +1158,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn1_ckpt_name=self.ffn1_ckpt_name, ffn2_ckpt_name=self.ffn2_ckpt_name, activation_type=normalized_acts, + activation_params=self.activation_params, quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), ) out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple) @@ -1287,4 +1296,4 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): out = checkpoint_name(out, self.ffn2_ckpt_name) assert out.dtype == input_dtype - return out, ln_output # Output, layner_norm_output + return out, ln_output # Output, layer_norm_output diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index ad66684f2b..868bcfa057 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -1632,6 +1632,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods mlp_activations: Sequence[str], default = ('relu', ) The sequence of activation functions to apply after the first linear transformation. Each activation has its own transformation layer. + mlp_activation_params: dict = None + This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment + ClampedSwiglu is the only activation that requires parameters. use_bias: bool, default = False Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. If set to False, the layer will not learn additive biases. @@ -1752,6 +1755,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods mha_kernel_init: Initializer = None mlp_kernel_init: Initializer = None mlp_activations: Sequence[str] = ("relu",) + mlp_activation_params: dict = None use_bias: bool = False bias_init: Initializer = nn.initializers.zeros apply_residual_connection_post_layernorm: bool = False @@ -2046,6 +2050,7 @@ def hidden_dropout(x, deterministic): return_layernorm_output=self.apply_residual_connection_post_layernorm, intermediate_dim=self.mlp_hidden_size, activations=self.mlp_activations, + activation_params=self.mlp_activation_params, intermediate_dropout_rng_name=self.dropout_rng_name, intermediate_dropout_rate=self.intermediate_dropout, intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index cf77f8e0a0..77daa4672c 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -50,6 +50,7 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), + activation_params: dict = None, collective_op_sets: Tuple[tex.CollectiveOpSet] = ( tex.noop_collective_op_set, tex.noop_collective_op_set, @@ -138,13 +139,14 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, collective_op_sets, quantizer_sets, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -165,6 +167,7 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], + activation_params: dict, collective_op_sets: Tuple[tex.CollectiveOpSet], quantizer_sets, ): @@ -220,6 +223,7 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, collective_op_sets, quantizer_sets, ) @@ -246,6 +250,7 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, collective_op_sets, quantizer_sets, ): @@ -335,6 +340,11 @@ def _layernorm_mlp_fwd_rule( dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, + act_params=( + tex.activation.ActivationParams.create(activation_type, **activation_params) + if activation_params + else None + ), ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) @@ -402,6 +412,7 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, collective_op_sets, ctx, grad, @@ -497,6 +508,11 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, + act_params=( + tex.activation.ActivationParams.create(activation_type, **activation_params) + if activation_params + else None + ), ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim