Skip to content

NCC_IXLV002: Core barrier name mismatch when LNC2 NKI kernel with sendrecv is called 2+ times in a traced graph #1298

@binaryks

Description

@binaryks

Describe the bug

[Bug] NCC_IXLV002: Core barrier name mismatch when LNC2 NKI kernel with sendrecv is called 2+ times in a traced graph

Environment

  • Instance: trn2.3xlarge (LNC2, logical-neuroncore-config: 2)
  • OS: Ubuntu 24.04.4 LTS (x86_64)
  • neuronx-cc: 2.23.6484.0+3b612583
  • torch-neuronx: 2.9.0.2.12.22436+0f1dac25

Description

When an NKI kernel that uses nisa.sendrecv (LNC2 inter-core communication) is launched with grid [2] (i.e., kernel[2](...)) two or more times in a single traced graph, neuronx-cc crashes during compilation with error NCC_IXLV002 (core barrier name mismatch in lnc_verifier).

We first encountered this in a real multi-layer transformer model (Qwen3-MoE with fused QKV kernels using sendrecv for cross-core reduction). The reproducer below is a simplified version that isolates the issue — the out_dim=128 is just one specific configuration that triggers the bug, not the only one. In our production model the graph structure is more complex and also hits this error.

The same graph compiles successfully if the kernel is called only once.

Error Message

[INTERNAL_ERROR] [NCC_IXLV002] Expected core barriers to have name
LocalSendRecv-0_476_sr_cb_begin, but core barrier 6 in subgraph 0
on core 1 has name LocalSendRecv-0_176_sr_cb_begin

Model Name

N/A

Describe the workload type

tensor all reduce

Instance Type

trn2.3xlarge

Release version

  • Instance: trn2.3xlarge (LNC2, logical-neuroncore-config: 2)
  • OS: Ubuntu 24.04.4 LTS (x86_64)

apt list --installed | grep -i -e neuron

aws-neuronx-collectives/unknown,now 2.30.59.0-f5cdefb39 amd64 [installed]
aws-neuronx-dkms/unknown,now 2.26.5.0 all [installed,upgradable to: 2.26.10.0]
aws-neuronx-oci-hook/unknown,now 2.14.102.0 amd64 [installed]
aws-neuronx-runtime-lib/unknown,now 2.30.51.0-faafe26f0 amd64 [installed]
aws-neuronx-tools/unknown,now 2.28.23.0-f1c114a9d amd64 [installed]

pip list | grep -i -e neuron -e torch -e transformers -e jax

libneuronxla                  2.2.15515.0+50c26cbd
neuronx-cc                    2.23.6484.0+3b612583
neuronx-distributed           0.17.26814+4b18de63
neuronx-distributed-inference 0.8.16251+f3ca5575
torch                         2.9.0
torch-neuronx                 2.9.0.2.12.22436+0f1dac25
torch-xla                     2.9.0
torchvision                   0.24.0
transformers                  4.57.6

Reproduction Steps

Minimal Reproducer

Self-contained script — no external dependencies beyond torch, nki, torch_neuronx:

"""
Minimal reproducer for NCC_IXLV002 compiler bug.

A simple LNC2 NKI kernel that:
  - Each core loads its own row from a [2, H] input (core 0 -> row 0, core 1 -> row 1)
  - sendrecv to exchange data between cores
  - tensor_tensor add to reduce
  - Store result

When this kernel[2] is called 2+ times in a graph followed by a matmul
with output dim <= 128, neuronx-cc crashes with NCC_IXLV002 (core barrier
name mismatch in lnc_verifier).

Environment:
  - neuronx-cc 2.23.6484.0
  - trn2 (LNC2 default)

"""python

import torch
import torch.nn as nn
import nki
import nki.isa as nisa
import nki.language as nl
import torch_neuronx

HIDDEN = 2048


@nki.jit
def lnc2_reduce(x):
    """
    Minimal LNC2 kernel: each core loads its shard, sendrecv + add.
    Input:  x[2, H]   (2 rows, one per core)
    Output: out[1, H]  (sum of both rows)
    """
    H = x.shape[1]

    # Which core am I? (0 or 1)
    core_id = nl.program_id(axis=0)
    other_core = 1 - core_id

    # Each core loads its own row
    my_row = nl.ndarray((1, H), dtype=x.dtype, buffer=nl.sbuf)
    nisa.dma_copy(dst=my_row, src=x[core_id:core_id + 1, :])

    # Receive the other core's row via sendrecv
    recv_buf = nl.ndarray((1, H), dtype=x.dtype, buffer=nl.sbuf)
    nisa.sendrecv(
        src=my_row, dst=recv_buf,
        send_to_rank=other_core, recv_from_rank=other_core,
        pipe_id=0,
    )

    # Reduce: my_row + recv_buf
    nisa.tensor_tensor(my_row, my_row, recv_buf, op=nl.add)

    # Store result
    out = nl.ndarray((1, H), dtype=x.dtype, buffer=nl.shared_hbm)
    nisa.dma_copy(dst=out, src=my_row)
    return out


class ReduceLayer(nn.Module):
    """Wraps the LNC2 kernel + a projection matmul."""

    def __init__(self):
        super().__init__()
        self.proj = nn.Parameter(torch.randn(HIDDEN, HIDDEN, dtype=torch.bfloat16))

    def forward(self, x):
        # x: [1, 1, HIDDEN]
        # Reshape for kernel: [2, HIDDEN] (LNC2 needs 2 rows)
        x_2d = x.view(1, HIDDEN).expand(2, HIDDEN).contiguous()
        # LNC2 kernel call
        reduced = lnc2_reduce[2](x_2d)
        # Project back: [1, HIDDEN] @ [HIDDEN, HIDDEN] -> [1, HIDDEN]
        result = torch.matmul(reduced, self.proj.data)
        return result.unsqueeze(0)  # -> [1, 1, HIDDEN]


class Model(nn.Module):
    def __init__(self, n_layers, out_dim):
        super().__init__()
        self.layers = nn.ModuleList([ReduceLayer() for _ in range(n_layers)])
        self.head = nn.Linear(HIDDEN, out_dim, bias=False).to(torch.bfloat16)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.head(x)


def trace_test(n_layers, out_dim):
    x = torch.randn(1, 1, HIDDEN, dtype=torch.bfloat16)
    model = Model(n_layers, out_dim)
    model.eval()
    try:
        torch_neuronx.trace(
            model, x,
            compiler_args=["--target=trn2", "--model-type", "transformer", "-O1"],
        )
        return "PASS"
    except Exception as e:
        err = str(e)
        if "NCC_IXLV002" in err or "neuronx-cc failed with 70" in err:
            return "NCC_IXLV002"
        return f"other: {type(e).__name__}: {err[:200]}"


def main():
    print(f"{'Config':<45} Result")
    print("-" * 60)

    configs = [
        (1, 128, "1 layer, out=128"),
        (1, 256, "1 layer, out=256"),
        (2, 256, "2 layers, out=256  (PASS expected)"),
        (2, 128, "2 layers, out=128  (FAIL expected)"),
        (2, 2048, "2 layers, out=2048 (PASS expected)"),
        (3, 128, "3 layers, out=128  (FAIL expected)"),
    ]

    for n_layers, out_dim, label in configs:
        r = trace_test(n_layers, out_dim)
        print(f"  {label:<43} {r}")

    print("-" * 60)


if __name__ == "__main__":
    main()

Test Results

Config Expected Actual
1 layer, out=128 PASS ✅ PASS
1 layer, out=256 PASS ✅ PASS
2 layers, out=256 PASS ✅ PASS
2 layers, out=128 FAIL NCC_IXLV002
2 layers, out=2048 PASS ✅ PASS
3 layers, out=128 FAIL NCC_IXLV002

Trigger Conditions (all must be true)

  1. NKI kernel uses nisa.sendrecv (LNC2 inter-core communication)
  2. Kernel is launched with grid [2] (LNC2 mode, one program per core)
  3. Kernel is called ≥ 2 times in the same traced graph

Regression Issue

  • Select this option if this issue appears to be a regression.

Possible Solution

No response

Logs/Context/Additional Information

No response

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions