Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
4db1e24
Add embedding_scale and final_logit_softcap for Gemma 4 support
jlamypoirier Apr 25, 2026
cd0a54c
Add QK norm and post-mixer/MLP normalization for Gemma 4 support
jlamypoirier Apr 27, 2026
894a44e
Expand test_attention with MQA/MHA/rotary/norm coverage
jlamypoirier Apr 28, 2026
05601ce
Add ProportionalRotary and rewrite test_rotary as parametrized suite
jlamypoirier Apr 28, 2026
37d1e2c
Add value_norm, shared_key_value, and FixedRMSNorm (Gemma 4 attention)
jlamypoirier Apr 28, 2026
9408fcb
Fix in-place rotary corruption of query_norm backward context
jlamypoirier Apr 28, 2026
89722b3
Add HybridMoEMLP and independent pre-norm controls for decoder block
jlamypoirier Apr 29, 2026
f760d98
Add Gemma 4 HuggingFace checkpoint converter
jlamypoirier Apr 30, 2026
d939f9d
Improve Gemma4 converter: shared_key_value, MoE weight shapes, roundt…
jlamypoirier Apr 30, 2026
43fdf2b
Fix test_rotary and test_mlp failures
jlamypoirier Apr 30, 2026
b91620d
Add output_scale to DecoderBlock for Gemma 4 layer_scalar
jlamypoirier May 1, 2026
7d32174
Fold output_scale into bias_dropout_add; fix subprocess interpreter
jlamypoirier May 1, 2026
092d2c4
Move pre/post norm to MLPBaseConfig; add MoE router preprocessing
jlamypoirier May 1, 2026
ec6d55b
Add learnable per-expert scale to MoE router
jlamypoirier May 1, 2026
b49d835
Exercise HybridMoE in gemma4 test fixture
jlamypoirier May 4, 2026
35ea13c
Guard converters against unsupported features
jlamypoirier May 4, 2026
0177a9e
Address PR review feedback
jlamypoirier May 5, 2026
2d697ce
Address second-round PR review
jlamypoirier May 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 35 additions & 21 deletions fast_llm/functional/triton/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def triton_normalization_forward_kernel(
n_cols,
eps,
has_bias: tl_constexpr,
has_weight: tl_constexpr,
zero_centered: tl_constexpr,
block_size: tl_constexpr,
):
Expand All @@ -40,11 +41,13 @@ def triton_normalization_forward_kernel(
tl.store(inv_var_ptr + row, inv_var)

# Weight
weight = tl.load(weight_ptr + cols, mask=mask)
if zero_centered:
weight += 1

output = input_ * inv_var * weight
if has_weight:
weight = tl.load(weight_ptr + cols, mask=mask)
if zero_centered:
weight += 1
output = input_ * inv_var * weight
else:
output = input_ * inv_var

# Bias
if has_bias:
Expand All @@ -69,6 +72,7 @@ def triton_normalization_backward_kernel_1(
n_rows,
eps,
has_bias: tl_constexpr,
has_weight: tl_constexpr,
parameter_grad: tl_constexpr,
zero_centered: tl_constexpr,
block_size: tl_constexpr,
Expand All @@ -87,10 +91,6 @@ def triton_normalization_backward_kernel_1(
# Load data
output = tl.load(output_ptr + offsets, mask=mask, other=0).to(tl.float32)
grad_output = tl.load(grad_output_ptr + offsets, mask=mask, other=0).to(tl.float32)
weight = tl.load(weight_ptr + cols, mask=col_mask).to(tl.float32)
if zero_centered:
weight += 1

inv_var = tl.load(inv_var_ptr + rows, mask=row_mask)

# Bias
Expand All @@ -99,9 +99,18 @@ def triton_normalization_backward_kernel_1(
output = output - bias

# Input grad
weight_regularised = tl.where(weight >= 0, tl.maximum(weight, eps), tl.minimum(weight, -eps))
input_normalized = tl.where(mask, output / weight_regularised, 0.0)
weight_grad_output = tl.where(mask, weight * grad_output * inv_var, 0.0)
if has_weight:
weight = tl.load(weight_ptr + cols, mask=col_mask).to(tl.float32)
if zero_centered:
weight += 1
weight_regularised = tl.where(weight >= 0, tl.maximum(weight, eps), tl.minimum(weight, -eps))
input_normalized = tl.where(mask, output / weight_regularised, 0.0)
weight_grad_output = tl.where(mask, weight * grad_output * inv_var, 0.0)
else:
# weight == 1 everywhere: forward output = input * inv_var, so input_normalized = output
input_normalized = tl.where(mask, output, 0.0)
weight_grad_output = tl.where(mask, grad_output * inv_var, 0.0)

grad_input = weight_grad_output - input_normalized * (
tl.sum(input_normalized * weight_grad_output, axis=1)[:, None] / n_cols
)
Expand Down Expand Up @@ -170,7 +179,7 @@ def triton_normalization_backward_kernel_2(

def triton_normalization_forward(
input_: torch.Tensor,
weight: torch.Tensor,
weight: torch.Tensor | None,
bias: torch.Tensor | None,
eps: float,
training: bool,
Expand All @@ -179,14 +188,15 @@ def triton_normalization_forward(
# Note: Converting input automatically to training dtype to match Apex behaviour,
# needed for full precision residual.
# TODO: Review this?
assert weight.shape == input_.shape[-1:]
if bias is not None:
assert weight.shape == bias.shape
if weight is not None:
assert weight.shape == input_.shape[-1:]
if bias is not None:
assert weight.shape == bias.shape
assert input_.is_contiguous()
n_rows = input_.shape[:-1].numel()
n_cols = weight.numel()
n_cols = input_.shape[-1]

output = torch.empty_like(input_, dtype=weight.dtype)
output = torch.empty_like(input_, dtype=weight.dtype if weight is not None else input_.dtype)
inv_var = torch.empty(n_rows, dtype=torch.float32, device=input_.device)

block_size = triton.next_power_of_2(n_cols)
Expand All @@ -202,6 +212,7 @@ def triton_normalization_forward(
n_cols,
eps,
bias is not None,
weight is not None,
zero_centered,
block_size,
num_warps=num_warps,
Expand All @@ -217,16 +228,18 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin
# We delete the context to prevent a memory leak
context.clear()
has_bias = bias is not None
has_weight = weight is not None

parameter_grad = weight.requires_grad
assert parameter_grad == hasattr(weight, "grad_buffer")
parameter_grad = weight.requires_grad if has_weight else False
if has_weight:
assert parameter_grad == hasattr(weight, "grad_buffer")
if has_bias:
assert parameter_grad == bias.requires_grad

grad_output = grad_output.contiguous()

n_rows = grad_output.shape[:-1].numel()
n_cols = weight.numel()
n_cols = grad_output.shape[-1]
# TODO: Improve heuristics
# The ones from triton tutorial (32, 128) are terrible.
# These seem to match torch compile heuristics and were near-optimal for A100 tests with [8192, 4096], bf16.
Expand Down Expand Up @@ -274,6 +287,7 @@ def triton_normalization_backward(grad_output: torch.Tensor, context: list[typin
n_rows,
eps,
has_bias,
has_weight,
parameter_grad,
zero_centered,
block_size,
Expand Down
23 changes: 19 additions & 4 deletions fast_llm/functional/triton/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
@triton_jit()
def triton_rotary_kernel(
input_ptr,
output_ptr,
frequencies_ptr,
stride_0,
stride_1,
Expand All @@ -30,6 +31,8 @@ def triton_rotary_kernel(
input_offsets = stride_0 * (pid_0 // seq_len) + stride_1 * position_id + stride_2 * head_offsets + offsets[None, :]
input_re_ptr = input_ptr + input_offsets
input_im_ptr = input_re_ptr + rotary_dim
output_re_ptr = output_ptr + input_offsets
output_im_ptr = output_re_ptr + rotary_dim

if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0:
input_re = tl.load(input_re_ptr).to(tl.float32)
Expand All @@ -54,31 +57,42 @@ def triton_rotary_kernel(
out_im = input_im * frequencies_re + input_re * frequencies_im

if rotary_block_size % rotary_dim == 0 and num_heads % head_block_size == 0:
tl.store(input_re_ptr, out_re)
tl.store(input_im_ptr, out_im)
tl.store(output_re_ptr, out_re)
tl.store(output_im_ptr, out_im)
else:
tl.store(input_re_ptr, out_re, mask=mask) # noqa
tl.store(input_im_ptr, out_im, mask=mask)
tl.store(output_re_ptr, out_re, mask=mask) # noqa
tl.store(output_im_ptr, out_im, mask=mask)


def triton_rotary_(
input_: torch.Tensor,
frequencies: torch.Tensor,
is_key_value: bool = False,
backward: bool = False,
inplace: bool = True,
) -> torch.Tensor:
# TODO: Improve assumptions.
# TODO: Make a transposed version to avoid contiguous call in key backward.
# TODO: Improve block size heuristics.
out = input_
write = input_
if input_.stride(-1) != 1:
# TODO: Make a transposed version to avoid contiguous call in key backward.
input_ = input_.contiguous()
write = input_
if not inplace:
out = torch.empty_like(input_)
write = out
if is_key_value:
# The kernel only writes the key chunk; copy the value chunk so `out` is fully defined.
out.chunk(2, dim=-2)[1].copy_(input_.chunk(2, dim=-2)[1])
if input_.ndim == 3:
input_ = input_.unsqueeze(0)
write = write.unsqueeze(0)
frequencies = frequencies.unsqueeze(0)
if is_key_value:
input_ = input_.chunk(2, dim=-2)[0]
write = write.chunk(2, dim=-2)[0]
batch_size, seq_len, num_heads, head_size = input_.shape
rotary_dim = div(head_size, 2)
rotary_block_size = triton.next_power_of_2(rotary_dim)
Expand All @@ -89,6 +103,7 @@ def triton_rotary_(
# Folded the large y dim into the x dim as gridDim.x is 32 bit while gridDim.y and gridDim.z are 16 bit registers
triton_rotary_kernel[(batch_size * seq_len, triton.cdiv(num_heads, head_block_size))](
input_,
write,
frequencies,
input_.stride(0),
input_.stride(1),
Expand Down
109 changes: 97 additions & 12 deletions fast_llm/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,17 @@ def __init__(

head_size_dim = TensorDim("head_size", self._config.head_size)
query_dim = CompositeTensorDim("query", (head_group_dim, group_heads_dim, head_size_dim))
key_value_dim = ConcatenatedTensorDim(
"key_value",
(
CompositeTensorDim("key", (head_group_dim, head_size_dim)),
CompositeTensorDim("value", (head_group_dim, head_size_dim)),
),
key_dim = CompositeTensorDim("key", (head_group_dim, head_size_dim))
key_value_dim = (
key_dim
if self._config.shared_key_value
else ConcatenatedTensorDim(
"key_value",
(
key_dim,
CompositeTensorDim("value", (head_group_dim, head_size_dim)),
),
)
)
self._dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, head_size_dim))

Expand All @@ -136,7 +141,7 @@ def __init__(
lr_scale=self._lr_scale,
peft=None if self._config.key_layer.apply_peft is None else self._peft,
)
if self._peft is not None and self._config.key_layer.apply_peft is None:
if self._peft is not None and self._config.key_layer.apply_peft is None and not self._config.shared_key_value:
# Default: Apply to value only.
# TODO: Avoid this hack.
self.key_value = self._peft.apply_linear(
Expand All @@ -148,6 +153,23 @@ def __init__(
# Rotary embeddings.
self._rotary = self._config.rotary.get_layer(head_size_dim)

# QKV norms (applied after projection, before RoPE).
self.query_norm = (
self._config.query_norm.get_layer(head_size_dim, lr_scale=self._lr_scale, peft=None)
if self._config.query_norm is not None
else None
)
self.key_norm = (
self._config.key_norm.get_layer(head_size_dim, lr_scale=self._lr_scale, peft=None)
if self._config.key_norm is not None
else None
)
self.value_norm = (
self._config.value_norm.get_layer(head_size_dim, lr_scale=self._lr_scale, peft=None)
if self._config.value_norm is not None
else None
)

# Output.
self.dense = self._config.dense_layer.get_layer(
self._dense_dim,
Expand Down Expand Up @@ -236,6 +258,28 @@ def _attn_flash(
softmax_scale=self._softmax_scale,
)

def _apply_norm_with_grad_capture(
self, norm: typing.Callable, x: torch.Tensor
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
# Run `norm` and, in training, save (leaf, output) so backward can replay only the norm
# sub-graph between rotary's manual fwd/bwd. Eval shares the unified contiguous() call —
# RMSNormalization uses .view() internally and rejects non-contiguous inputs.
x = x.contiguous()
if self.training:
with torch.enable_grad():
leaf = x.detach().requires_grad_()
normed = norm(leaf)
return normed.detach(), (leaf, normed)
return norm(x), None

@staticmethod
def _backward_norm_capture(context: tuple[torch.Tensor, torch.Tensor] | None, grad: torch.Tensor) -> torch.Tensor:
if context is None:
return grad
leaf, normed = context
normed.backward(grad.contiguous())
return leaf.grad

def _query_key_value_forward(
self, input_: torch.Tensor, kwargs: dict[str, typing.Any]
) -> tuple[torch.Tensor, torch.Tensor, dict[str, typing.Any]]:
Expand All @@ -252,10 +296,29 @@ def _query_key_value_forward(
# TODO: This is probably unnecessary.
handle.wait()

query_unflat = query.unflatten(1, (self._local_heads, self._config.head_size))
if self._config.shared_key_value:
kv_unflat = key_value.unflatten(1, (self._local_head_groups, self._config.head_size))
kv_unflat = torch.cat([kv_unflat, kv_unflat], dim=1)
else:
kv_unflat = key_value.unflatten(1, (2 * self._local_head_groups, self._config.head_size))

query_norm_context = None
if self._config.query_norm is not None:
query_unflat, query_norm_context = self._apply_norm_with_grad_capture(self.query_norm, query_unflat)

key_norm_context = None
value_norm_context = None
if self._config.key_norm is not None or self._config.value_norm is not None:
key_unflat, value_unflat = kv_unflat.chunk(2, dim=1)
if self._config.key_norm is not None:
key_unflat, key_norm_context = self._apply_norm_with_grad_capture(self.key_norm, key_unflat)
if self._config.value_norm is not None:
value_unflat, value_norm_context = self._apply_norm_with_grad_capture(self.value_norm, value_unflat)
kv_unflat = torch.cat([key_unflat, value_unflat], dim=1)

query, key_value, rotary_context = self._rotary.forward_only(
query.unflatten(1, (self._local_heads, self._config.head_size)),
key_value.unflatten(1, (2 * self._local_head_groups, self._config.head_size)),
kwargs,
query_unflat, kv_unflat, kwargs, inplace_query=query_norm_context is None
)

if self._sequence_data_parallel_dim.group:
Expand All @@ -266,7 +329,14 @@ def _query_key_value_forward(
if handle:
handle.wait()

context = {"query": query_context, "key_value": key_value_context, "rotary": rotary_context}
context = {
"query": query_context,
"key_value": key_value_context,
"rotary": rotary_context,
"query_norm": query_norm_context,
"key_norm": key_norm_context,
"value_norm": value_norm_context,
}
return query, key_value, context

def _query_key_value_backward(
Expand All @@ -283,14 +353,29 @@ def _query_key_value_backward(
rotary_context = context.pop("rotary")
query_grad, _ = self._rotary.backward(query_grad, None, rotary_context)

query_grad = self._backward_norm_capture(context.pop("query_norm"), query_grad)

# TODO: Overlap with both.
input_grad = self.query.backward(query_grad.flatten(1), context.pop("query"))

if handle:
handle.wait()

_, key_value_grad = self._rotary.backward(None, key_value_grad, rotary_context)
key_value_grad = key_value_grad.flatten(1)

key_norm_context = context.pop("key_norm")
value_norm_context = context.pop("value_norm")
if key_norm_context is not None or value_norm_context is not None:
key_grad, value_grad = key_value_grad.chunk(2, dim=1)
key_grad = self._backward_norm_capture(key_norm_context, key_grad)
value_grad = self._backward_norm_capture(value_norm_context, value_grad)
key_value_grad = torch.cat([key_grad, value_grad], dim=1)

if self._config.shared_key_value:
key_grad, value_grad = key_value_grad.chunk(2, dim=1)
key_value_grad = (key_grad + value_grad).flatten(1)
else:
key_value_grad = key_value_grad.flatten(1)

if self._config.head_groups == 1 and (group := self._parallel_dim.group):
if self._sequence_parallel:
Expand Down
Loading
Loading