Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
2a27823
Test working as I think it should work
vthumbe1503 Aug 26, 2025
d4c06c5
initial draft of changes to get GPT oss based swiglu integrated, gate…
vthumbe1503 Sep 5, 2025
1f596af
redundant implementation for the pytorch to te hook up, refactoring t…
vthumbe1503 Sep 6, 2025
42f85c3
all gated kernels modified, pytest working for oss swiglu
vthumbe1503 Sep 8, 2025
c9d3311
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2025
5d06c2a
fix the merge conflict
vthumbe1503 Sep 8, 2025
025ce6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2025
d964b24
accidentally had removed some activations, minor bug in the templated…
vthumbe1503 Sep 8, 2025
de9ef2f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2025
8e17473
parent de9ef2fe450daae0d4ea1b647a37219f72814f66
vthumbe1503 Sep 8, 2025
1f2c65b
accidentally removed the copyright
vthumbe1503 Sep 8, 2025
75c4b13
fix linting issue
vthumbe1503 Sep 8, 2025
288e926
minor issue in comments
vthumbe1503 Sep 8, 2025
448eceb
Commit is for another PR
vthumbe1503 Sep 10, 2025
23b5822
revert changes since this belongs to another PR
vthumbe1503 Sep 10, 2025
a1a5794
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2025
0d6a3ea
Revert change back since belongs to another PR
vthumbe1503 Sep 10, 2025
33c3364
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2025
a724c2d
Changes belong to another PR
vthumbe1503 Sep 10, 2025
34d9815
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2025
3475264
Revert changes here
vthumbe1503 Sep 10, 2025
5e687d1
address review comments
vthumbe1503 Sep 15, 2025
8535dfb
cleanup
vthumbe1503 Sep 15, 2025
fa0e9a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 15, 2025
aee3fb9
fix linting error
vthumbe1503 Sep 15, 2025
87ae3d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 15, 2025
3858eab
Address review comments, fix mxfp8 kernel bug: was not passing clampe…
vthumbe1503 Sep 18, 2025
de3080e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2025
7bf0bc4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2025
fe93c01
Use limit=0.75 in clamped SwiGLU test
timmoon10 Sep 19, 2025
5d3b169
Address review comments
vthumbe1503 Sep 19, 2025
0c17c7e
JAX integration changes
vthumbe1503 Sep 24, 2025
90e070c
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Sep 24, 2025
66c7086
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2025
af19dbf
revert line break
vthumbe1503 Sep 24, 2025
4f29915
revert line break
vthumbe1503 Sep 24, 2025
24828f3
missed adding oss swiglu to nvte enum in common
vthumbe1503 Sep 24, 2025
19410b6
fix jax linting errors
vthumbe1503 Sep 24, 2025
5480d29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2025
7a917ea
fix jax linting errors
vthumbe1503 Sep 24, 2025
53dd179
revert multi_gpu_encoder change
vthumbe1503 Sep 24, 2025
d048807
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Sep 25, 2025
3bfae54
fix flax integration bug
vthumbe1503 Sep 25, 2025
9c60c47
Merge branch 'gpt-oss-jax' of github.com:vthumbe1503/TransformerEngin…
vthumbe1503 Sep 25, 2025
38382dc
fix linting error
vthumbe1503 Sep 25, 2025
c7ef078
bug fixed in other branch and not here
vthumbe1503 Sep 26, 2025
c39ab8d
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Sep 26, 2025
8446cc4
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Sep 29, 2025
2a2e6de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2025
b2f4fcb
bug in dbias computation
vthumbe1503 Sep 29, 2025
b7df6b6
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Oct 1, 2025
4f41c1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2025
115e528
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Oct 1, 2025
d2072b1
address review comments
vthumbe1503 Oct 1, 2025
6d9df80
Merge branch 'gpt-oss-jax' of github.com:vthumbe1503/TransformerEngin…
vthumbe1503 Oct 1, 2025
978fcde
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2025
13a3e3c
minor bug because of merge conflict
vthumbe1503 Oct 1, 2025
df0f449
Merge branch 'gpt-oss-jax' of github.com:vthumbe1503/TransformerEngin…
vthumbe1503 Oct 1, 2025
d59526b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 1, 2025
4783514
accept copilot suggestion
vthumbe1503 Oct 1, 2025
cda0c82
Merge branch 'gpt-oss-jax' of github.com:vthumbe1503/TransformerEngin…
vthumbe1503 Oct 1, 2025
14f8971
fix test and remove a redundant test addition
vthumbe1503 Oct 1, 2025
b9d7da7
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Oct 1, 2025
6b0c73c
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Oct 2, 2025
d462da1
Merge branch 'main' into gpt-oss-jax
vthumbe1503 Oct 2, 2025
5a55b0d
address review comments
vthumbe1503 Oct 3, 2025
bf3e04b
Merge branch 'gpt-oss-jax' of github.com:vthumbe1503/TransformerEngin…
vthumbe1503 Oct 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 66 additions & 23 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def assert_dequantized_grouped_scaled_tensor(
("quick_gelu", "linear"),
("squared_relu",),
("squared_relu", "linear"),
("clamped_silu", "clamped_linear"),
]

ACTIVATION_TYPES = {
Expand All @@ -182,17 +183,21 @@ def assert_dequantized_grouped_scaled_tensor(


class TestActivation:
def ref_act(self, x, activation_type):
return _jax_act_lu(x, activation_type).data
def ref_act(self, x, activation_type, act_params):
return _jax_act_lu(x, activation_type, act_params=act_params).data

def value_n_grad_ref_func(self, x, activation_type):
def value_n_grad_ref_func(self, x, activation_type, act_params):
jitted_reference = jit(
value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,))
value_and_grad(
lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,)
)
)
return jitted_reference(x)

def primitive_func(self, inputs, activation_type, quantizer):
out = activation(inputs, activation_type=activation_type, quantizer=quantizer)
def primitive_func(self, inputs, activation_type, quantizer, act_params):
out = activation(
inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params
)
return jnp.mean(out)

@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
Expand All @@ -209,12 +214,20 @@ def test_act_grad(self, shape, activation_type):
x = jnp.repeat(x, len(activation_type), axis=-2)

value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3)
)

prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)

act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)

Expand All @@ -234,17 +247,30 @@ def test_act_grad_with_tensor_scaling_fp8(
self.activation_type = activation_type

value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1,)
value_and_grad(self.primitive_func, (0,)),
static_argnums=(1, 3),
)

quantizer = QuantizerFactory.create(
scaling_mode=scaling_mode,
q_dtype=output_type,
q_layout=QuantizeLayout.ROWWISE,
)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)

prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(
x, activation_type, quantizer, act_params
)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)

assert_allclose(prim_out, ref_out, dtype=output_type)
assert_allclose(prim_grad, ref_grad, dtype=output_type)
Expand Down Expand Up @@ -273,10 +299,18 @@ def test_act_forward_with_tensor_scaling_fp8(
q_dtype=output_type,
q_layout=q_layout,
)

te_output = tex.act_lu(x, activation_type, te_quantizer)
jax_output = _jax_act_lu(x, activation_type, jax_quantizer)

act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
te_output = tex.act_lu(x, activation_type, te_quantizer, act_params)
jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params)
assert_bitwise_scaled_tensors(te_output, jax_output)

@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
Expand All @@ -296,10 +330,18 @@ def test_act_forward_with_block_scaling_fp8(
quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
)

output = tex.act_lu(x, activation_type, quantizer)
ref_out = self.ref_act(x, activation_type)

act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
output = tex.act_lu(x, activation_type, quantizer, act_params)
ref_out = self.ref_act(x, activation_type, act_params)
assert_dequantized_scaled_tensor(output, ref_out)


Expand Down Expand Up @@ -734,6 +776,7 @@ def test_quantize_dbias(
def _test_quantize_dact_dbias(
self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout
):

key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1)
Expand Down Expand Up @@ -785,7 +828,7 @@ def _test_quantize_dact_dbias(
(in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling())
# Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently.
or (
activation_type == ("squared_relu",)
activation_type in {("squared_relu",), ("clamped_silu", "clamped_linear")}
and in_dtype == jnp.bfloat16
and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ enum class NVTE_Activation_Type {
QGEGLU,
SRELU,
SREGLU,
CLAMPED_SWIGLU
};

/*! \brief Computes the GeLU activation of the input.
Expand Down
14 changes: 6 additions & 8 deletions transformer_engine/common/util/cast_gated_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)

template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p,
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) {
checkCuDriverContext(stream);

Expand Down Expand Up @@ -1006,7 +1006,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu

template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p,
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) {
checkCuDriverContext(stream);

Expand Down Expand Up @@ -1138,7 +1138,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p);

NVTE_CHECK_CUDA(cudaGetLastError());
break;
case ScalingType::COLWISE:
Expand All @@ -1155,7 +1154,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,

scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
Expand All @@ -1180,7 +1178,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
}

template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) {
void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(input.flat_last_dim() % 2 == 0,
Expand Down Expand Up @@ -1213,7 +1211,7 @@ void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t st

template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP &p,
void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p,
cudaStream_t stream) {
CheckInputTensor(grad, "dgated_act_grad");
CheckInputTensor(input, "dgated_act_input");
Expand Down Expand Up @@ -1252,7 +1250,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamO

template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p,
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) {
constexpr bool allow_empty = false;
CheckInputTensor(gated_input, "gated_input");
Expand Down Expand Up @@ -1318,7 +1316,7 @@ namespace detail {
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output,
ParamOP &p, cudaStream_t stream) {
ParamOP p, cudaStream_t stream) {
using namespace gated_kernels;
Tensor grad_empty_tensor;
const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor;
Expand Down
26 changes: 17 additions & 9 deletions transformer_engine/jax/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import jax
import jax.numpy as jnp

from . import cpp_extensions as tex

from .quantize.tensor import NoScaleTensor
Expand All @@ -22,6 +21,7 @@ def activation(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
act_params: Optional[tex.activation.ActivationParams] = None,
) -> jnp.ndarray:
"""Apply activation functions to input tensor with optional quantization.

Expand All @@ -32,17 +32,19 @@ def activation(
x: Input tensor to apply activations to
activation_type: Sequence of activation functions
quantizer: Optional quantizer for quantizing the output
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.

Returns:
Activated output tensor
"""
assert x.shape[-1] % len(activation_type) == 0
output = _activation(x, activation_type, quantizer)
output = _activation(x, activation_type, quantizer, act_params)
return output


@partial(jax.custom_vjp, nondiff_argnums=(1,))
def _activation(x, activation_type, quantizer):
@partial(jax.custom_vjp, nondiff_argnums=(1, 3))
def _activation(x, activation_type, quantizer, act_params):
"""Internal implementation of activation with custom VJP.

This function implements the core activation logic with support for
Expand All @@ -52,36 +54,42 @@ def _activation(x, activation_type, quantizer):
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.

Returns:
Activated tensor
"""
_output, _ = _activation_fwd_rule(x, activation_type, quantizer)
_output, _ = _activation_fwd_rule(x, activation_type, quantizer, act_params)
return _output


def _activation_fwd_rule(x, activation_type, quantizer):
def _activation_fwd_rule(x, activation_type, quantizer, act_params):
"""Forward pass rule for activation function.

Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.

Returns:
Tuple of (output, context) for backward pass
"""
fwd_output = tex.act_lu(x, activation_type, quantizer)
fwd_output = tex.act_lu(x, activation_type, quantizer, act_params)
# This is a no-op for higher-precision tensors
fwd_output = fwd_output.dequantize()
return fwd_output, (x, quantizer)


def _activation_bwd_rule(activation_type, ctx, g):
def _activation_bwd_rule(activation_type, act_params, ctx, g):
"""Backward pass rule for activation function.

Args:
activation_type: Sequence of activation functions
act_params: Optional activation parameters. Currently used
just for ClampedSwiGLU.
ctx: Context from forward pass
g: Gradient from upstream

Expand All @@ -90,7 +98,7 @@ def _activation_bwd_rule(activation_type, ctx, g):
"""
(x, _) = ctx
assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type)
dx = tex.dact_lu(g, x, activation_type, act_params=act_params)
# No quantization is used in this VJP backward, so the output should
# always be a NoScaleTensor
assert isinstance(dx, NoScaleTensor)
Expand Down
Loading
Loading