Skip to content

NKI 2 nl.zeros Does Not Support the buffer= Parameter #1282

@nandeeka

Description

@nandeeka

Describe the bug

The nl.zeros function appears to no longer support the buffer= parameter. According to the documentation (https://awsdocs-neuron.readthedocs-hosted.com/en/latest/nki/api/generated/nki.language.zeros.html#nki.language.zeros), this parameter should be available. I see this error on a DL AMI launched today.

Model Name

None (custom NKI kernel)

Describe the workload type

None (custom NKI kernel)

Instance Type

trn1.2xlarge

Release version

aws-neuronx-collectives/now 2.29.41.0-681fef5f5 amd64 [installed,local]
aws-neuronx-dkms/now 2.25.4.0 all [installed,local]
aws-neuronx-oci-hook/now 2.13.52.0 amd64 [installed,local]
aws-neuronx-runtime-lib/now 2.29.40.0-f954cd7a5 amd64 [installed,local]
aws-neuronx-tools/now 2.27.33.0-5d9c0b901 amd64 [installed,local]

Reproduction Steps

Here is an example of a kernel experiencing this issue:

import torch
import torch.nn as nn
import torch.nn.functional as F
import nki as nki
import nki.isa as nisa
import nki.language as nl
import os

os.environ["NEURON_FRAMEWORK_DEBUG"] = "1"
os.environ["XLA_IR_DEBUG"] = "1"
os.environ["XLA_HLO_DEBUG"] = "1"

@nki.jit
def unary(AH):
    AS = nl.ndarray((128, 512), dtype=nl.float16, buffer=nl.sbuf)
    nisa.dma_copy(AS[:, :], AH[:, :])
    ZS = nl.zeros((128, 512), dtype=nl.float16, buffer=nl.sbuf)
    nisa.activation(ZS[:, :], nl.square, AS[:, :])
    ZH = nl.ndarray((128, 512), dtype=nl.float16, buffer=nl.shared_hbm)
    nisa.dma_copy(ZH[:, :], ZS[:, :])
    return ZH

class NKIKernel(nn.Module):
    def __init__(self):
        super(NKIKernel, self).__init__()

    def forward(self, A):
        return unary(A)

def main():
    from torch_xla.core import xla_model as xm

    torch.manual_seed(0)
    device = xm.xla_device()

    model = NKIKernel().to(device)

    M, N = 512, 128

    A = torch.randn(N, M, dtype=torch.float16).to(device)

    output = model(A)

    print(output)

    xm.mark_step()


if __name__ == "__main__":
    main()

Regression Issue

  • Select this option if this issue appears to be a regression.

Possible Solution

No response

Logs/Context/Additional Information

Here is the output for the above script:

/home/ubuntu/nki-kernels/out/../src/compiled/unary.py:34: DeprecationWarning: Use torch_xla.device instead
  device = xm.xla_device()
2026-Feb-23 01:51:43.0736 5042:5066 [0] int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t**):213 CCOM WARN NET/OFI Failed to initialize sendrecv protocol
2026-Feb-23 01:51:43.0746 5042:5066 [0] int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t**):354 CCOM WARN NET/OFI aws-ofi-nccl initialization failed
2026-Feb-23 01:51:43.0755 5042:5066 [0] ncclResult_t nccl_net_ofi_init_no_atexit_fini_v6(ncclDebugLogger_t):183 CCOM WARN NET/OFI Initializing plugin failed
2026-Feb-23 01:51:43.0765 5042:5066 [0] net_plugin.cc:97 CCOM WARN OFI plugin initNet() failed is EFA enabled?
/opt/aws_neuronx_venv_pytorch_2_9/lib/python3.12/site-packages/nki/compiler/backends/neuron/TraceKernel.py:94: UserWarning: NEURON_PLATFORM_TARGET_OVERRIDE is deprecated. Please pass platform_target to the kernel.
  warnings.warn(
The Python AST is located at: python_ast_tmp.name='/tmp/unaryfa_f5thz_python_ast.klir'
The KLR format is located at: final_klir_filepath='/tmp/unaryuezechwh.klir'
=========== errors from kernel tracing  =========== 

/home/ubuntu/nki-kernels/out/../src/compiled/unary.py:17:
    ZS = nl.zeros((128, 512), dtype=nl.float16, buffer=nl.sbuf)
         ^-- unexpected keyword argument 'buffer' in builtin function 'builtin_lang_zeros'
Traceback (most recent call last):
  File "/home/ubuntu/nki-kernels/out/../src/compiled/unary.py", line 50, in <module>
    main()
  File "/home/ubuntu/nki-kernels/out/../src/compiled/unary.py", line 42, in main
    output = model(A)
             ^^^^^^^^
  File "/opt/aws_neuronx_venv_pytorch_2_9/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/aws_neuronx_venv_pytorch_2_9/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/nki-kernels/out/../src/compiled/unary.py", line 28, in forward
    return unary(A)
           ^^^^^^^^
  File "/opt/aws_neuronx_venv_pytorch_2_9/lib/python3.12/site-packages/nki/compile.py", line 100, in __call__
    return kernel(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/aws_neuronx_venv_pytorch_2_9/lib/python3.12/site-packages/nki/_torch_xla.py", line 102, in __call__
    config = self.dump_config_with_boundargs(boundargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/aws_neuronx_venv_pytorch_2_9/lib/python3.12/site-packages/nki/compiler/backends/neuron/FrameworkKernel.py", line 264, in dump_config_with_boundargs
    _, klir_binary, metadata = self.specialize_and_call(boundargs)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/aws_neuronx_venv_pytorch_2_9/lib/python3.12/site-packages/nki/compiler/backends/neuron/TraceKernel.py", line 281, in specialize_and_call
    raise Exception("Error(s) during tracing:\n" + "\n".join(metadata["errors"]))
Exception: Error(s) during tracing:

/home/ubuntu/nki-kernels/out/../src/compiled/unary.py:17:
    ZS = nl.zeros((128, 512), dtype=nl.float16, buffer=nl.sbuf)
         ^-- unexpected keyword argument 'buffer' in builtin function 'builtin_lang_zeros'

Metadata

Metadata

Assignees

No one assigned

    Labels

    NKIbugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions