Skip to content

nisa.activation with nl.copy does not add bias #1296

@benlimpa

Description

@benlimpa

Describe the bug

When calling nisa.activation with the nl.copy operator, the bias term is not added to the input. This doesn't seem to occur with other operators like nl.relu that I've tested.

Model Name

N/A

Describe the workload type

N/A

Instance Type

trn1.2xlarge

Release version

NKI version: 0.2.0

neuronx-cc -V:

NeuronX Compiler version 2.23.6484.0+3b612583

Python version 3.12.3
HWM version 2.23.0.6484+3b612583
NumPy version 2.4.2

Running on AMI ami-06906839c7114cbff
Running in region usw2-az4

apt list --installed | grep -i -e neuron:

aws-neuronx-collectives/unknown,now 2.30.59.0-f5cdefb39 amd64 [installed]
aws-neuronx-dkms/unknown,now 2.26.5.0 all [installed,upgradable to: 2.26.10.0]
aws-neuronx-oci-hook/unknown,now 2.14.102.0 amd64 [installed]
aws-neuronx-runtime-lib/unknown,now 2.30.51.0-faafe26f0 amd64 [installed]
aws-neuronx-tools/unknown,now 2.28.23.0-f1c114a9d amd64 [installed]

pip list | grep -i -e neuron -e torch -e transformers -e jax:

libneuronxla                             2.2.15515.0+50c26cbd
neuronx-cc                               2.23.6484.0+3b612583
neuronx-distributed                      0.17.26814+4b18de63
torch                                    2.9.0
torch-neuronx                            2.9.0.2.12.22436+0f1dac25
torch-xla                                2.9.0
torchvision                              0.24.0

Reproduction Steps

Code to reproduce:

"""
Minimal reproducer: nisa.activation with op=nl.copy ignores bias.

Expected: zeros + ones_bias = ones
Actual:   output remains zeros when op=nl.copy; works correctly with op=nl.relu

Platform: trn1 / NKI
"""

import nki
import nki.isa as nisa
import nki.language as nl
import torch
import torch_xla

P, N = 4, 4


@nki.jit(platform_target="trn1")
def kernel_copy(data, bias_in):
    result = nl.ndarray((P, N), dtype=nl.float32, buffer=nl.shared_hbm)
    data_tile = nl.ndarray((P, N), dtype=nl.float32, buffer=nl.sbuf)
    nisa.dma_copy(dst=data_tile, src=data)
    bias_tile = nl.ndarray((P, 1), dtype=nl.float32, buffer=nl.sbuf)
    nisa.dma_copy(dst=bias_tile, src=bias_in)
    out_tile = nl.ndarray((P, N), dtype=nl.float32, buffer=nl.sbuf)

    nisa.activation(dst=out_tile, op=nl.copy, data=data_tile, bias=bias_tile)

    nisa.dma_copy(dst=result, src=out_tile)
    return result


@nki.jit(platform_target="trn1")
def kernel_relu(data, bias_in):
    result = nl.ndarray((P, N), dtype=nl.float32, buffer=nl.shared_hbm)
    data_tile = nl.ndarray((P, N), dtype=nl.float32, buffer=nl.sbuf)
    nisa.dma_copy(dst=data_tile, src=data)
    bias_tile = nl.ndarray((P, 1), dtype=nl.float32, buffer=nl.sbuf)
    nisa.dma_copy(dst=bias_tile, src=bias_in)
    out_tile = nl.ndarray((P, N), dtype=nl.float32, buffer=nl.sbuf)

    nisa.activation(dst=out_tile, op=nl.relu, data=data_tile, bias=bias_tile)

    nisa.dma_copy(dst=result, src=out_tile)
    return result


device = torch_xla.device()
data = torch.zeros((P, N), dtype=torch.float32).to(device)
bias = torch.ones((P, 1), dtype=torch.float32).to(device)

out_copy = kernel_copy(data, bias).to("cpu")
out_relu = kernel_relu(data, bias).to("cpu")

print("op=nl.copy output (should be all 1s, but is all 0s):\n", out_copy)
print("op=nl.relu output (should be all 1s):\n", out_relu)

expected = torch.ones(P, N)
assert torch.allclose(out_relu, expected)

# Error: op=nl.copy does not add the bias, so output remains zeros instead of becoming ones.
error = torch.zeros(P, N)
assert torch.allclose(out_copy, error)

Regression Issue

  • Select this option if this issue appears to be a regression.

Possible Solution

No response

Logs/Context/Additional Information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Trn1bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions