diff --git a/Tutorials/Images/DeeploySystem.png b/Tutorials/Images/DeeploySystem.png new file mode 100644 index 0000000000..2ad39a3f07 Binary files /dev/null and b/Tutorials/Images/DeeploySystem.png differ diff --git a/Tutorials/Images/EthLogoNeg.png b/Tutorials/Images/EthLogoNeg.png new file mode 100644 index 0000000000..0d3c5be5de Binary files /dev/null and b/Tutorials/Images/EthLogoNeg.png differ diff --git a/Tutorials/Images/Siracusa.png b/Tutorials/Images/Siracusa.png new file mode 100644 index 0000000000..88a5794681 Binary files /dev/null and b/Tutorials/Images/Siracusa.png differ diff --git a/Tutorials/Images/Victor_Jung_EDGEAIForumDeeploy_S5.png b/Tutorials/Images/Victor_Jung_EDGEAIForumDeeploy_S5.png new file mode 100644 index 0000000000..3e78baa68e Binary files /dev/null and b/Tutorials/Images/Victor_Jung_EDGEAIForumDeeploy_S5.png differ diff --git a/Tutorials/PartIII_skeletons/iLeakyReLU/README.md b/Tutorials/PartIII_skeletons/iLeakyReLU/README.md new file mode 100644 index 0000000000..a1c085f957 --- /dev/null +++ b/Tutorials/PartIII_skeletons/iLeakyReLU/README.md @@ -0,0 +1,25 @@ +# SoCDAML Part III - Student skeletons for `iLeakyReLU` + +These files are your starting points for the Part III lab. Each one +contains the surrounding boilerplate; the conceptually interesting +parts are marked with `TODO(student)` comments and short hints. + +| File | What's in it | What to do | +|------|--------------|------------| +| `generate.py` | Complete ONNX + golden-value generator | Run it (Step 1) | +| `iLeakyReLU.h` | Complete kernel header | Copy to `TargetLibraries/PULPOpen/inc/kernel/` (Step 3) | +| `iLeakyReLU.c` | Multi-core chunking provided; inner loop TODO | Fill the TODO, copy to `TargetLibraries/PULPOpen/src/` (Step 3) | +| `iLeakyReLU_simd.c` | SIMD chunking + load/max/store provided; one TODO line | Fill in Step 6b after the scalar works | +| `iLeakyReLUParser.py` | `parseNode` and `parseNodeCtxt` are TODO | Fill in, paste class into `Deeploy/Targets/Generic/Parsers.py` (Step 2) | +| `iLeakyReLUTemplate.py` | Mako template body is TODO | Fill in, copy to `Deeploy/Targets/PULPOpen/Templates/` (Step 4) | +| `iLeakyReLUTileConstraint.py` | Inherits `UnaryTileConstraint`; performance constraint TODO | Fill in (Step 5 + Step 6a), copy to `Deeploy/Targets/PULPOpen/TileConstraints/` | + +The **Binding** (in `Bindings.py`), **Mapper + PULPMapping entry** (in `Platform.py`), the **`TilingReadyNodeBindings` registration** (in `Tiler.py`), and the **aggregator include** in `DeeployPULPMath.h` are *not* shipped as paste-in snippets — you'll write them yourself with the markdown's guidance and `
` solutions. + +The companion `Deeploy/Tutorials/SoCDAML.md` (Part III) walks through +the six steps in order and includes collapsed solutions to peek at +when you're stuck. + +If you really need the answer key, look in +`Deeploy/Tutorials/PartIII_solution/iLeakyReLU/` -- but try the lab +first; you'll learn far more. diff --git a/Tutorials/PartIII_skeletons/iLeakyReLU/generate.py b/Tutorials/PartIII_skeletons/iLeakyReLU/generate.py new file mode 100644 index 0000000000..3c5db5b871 --- /dev/null +++ b/Tutorials/PartIII_skeletons/iLeakyReLU/generate.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# ---------------------------------------------------------------------- +# File: generate.py (SoCDAML Part III - Step 1, provided complete) +# +# Builds the single-node ONNX graph + golden tensors that DeeployTest's +# harness will use to validate your iLeakyReLU implementation. +# +# Run from this directory: +# python generate.py +# +# Outputs: +# network.onnx, inputs.npz, outputs.npz +# +# Quantization-friendly LeakyReLU formula used here: +# out[i] = x if x >= 0 +# (mul*x) >> shift otherwise +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import onnx +from onnx import TensorProto, helper + +SHAPE = (1, 16, 64, 64) +MUL = 1 +SHIFT = 3 +SEED = 0xC0FFEE + +def golden(x, mul, shift): + pos = x.astype(np.int32) + neg = (mul * pos) >> shift + out = np.where(pos >= 0, pos, neg) + return np.clip(out, -128, 127).astype(np.int8) + +def build_onnx(): + in_value = helper.make_tensor_value_info('data_in', TensorProto.INT8, SHAPE) + out_value = helper.make_tensor_value_info('data_out', TensorProto.INT8, SHAPE) + node = helper.make_node('iLeakyReLU', ['data_in'], ['data_out'], + name='iLeakyReLU_0', mul=MUL, shift=SHIFT) + graph = helper.make_graph([node], 'iLeakyReLU_single_node', + [in_value], [out_value]) + model = helper.make_model(graph, producer_name='SoCDAML-PartIII') + model.opset_import[0].version = 13 + model.ir_version = 7 + return model + +def main(): + rng = np.random.default_rng(SEED) + x = rng.integers(low=-128, high=127, size=SHAPE, dtype=np.int8) + y = golden(x, MUL, SHIFT) + onnx.save(build_onnx(), 'network.onnx') + np.savez('inputs.npz', data_in=x) + np.savez('outputs.npz', data_out=y) + print(f"OK: network.onnx, inputs.npz, outputs.npz " + f"(shape={SHAPE}, mul={MUL}, shift={SHIFT})") + +if __name__ == '__main__': + main() diff --git a/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLU.c b/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLU.c new file mode 100644 index 0000000000..5f0e3b89b3 --- /dev/null +++ b/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLU.c @@ -0,0 +1,33 @@ +/* ===================================================================== + * Title: iLeakyReLU.c (SoCDAML Part III - Step 3 skeleton) + * + * Plain-C int8 LeakyReLU. The per-core chunking boilerplate is provided. + * Fill in the inner loop body marked `TODO(student)`. + * + * Goal: out[i] = (in[i] >= 0) ? in[i] : ((mul * in[i]) >> shift) + * + * Hints: + * - Cast in[i] to int32_t before the multiply to avoid 8-bit overflow. + * - Cast the final result back to int8_t before storing. + * + * Drop into: TargetLibraries/PULPOpen/src/iLeakyReLU.c + * ===================================================================== */ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "DeeployPULPMath.h" +#include "pmsis.h" + +void PULPiLeakyReLU_i8_i8(int8_t *pIn, int8_t *pOut, uint32_t size, + int32_t mul, int32_t shift) { + uint32_t cid = pi_core_id(); + uint32_t nC = NUM_CORES; + uint32_t per = (size + nC - 1) / nC; + uint32_t start = cid * per; + uint32_t end = (start + per > size) ? size : (start + per); + + for (uint32_t i = start; i < end; i++) { + // TODO(student): compute pOut[i] from pIn[i], mul, shift. + // Replace the following line: + pOut[i] = 0; + } +} diff --git a/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLU.h b/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLU.h new file mode 100644 index 0000000000..130b78a93c --- /dev/null +++ b/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLU.h @@ -0,0 +1,18 @@ +/* ===================================================================== + * Title: iLeakyReLU.h (SoCDAML Part III - Step 3, provided) + * + * Header for the iLeakyReLU PULP kernel. + * Drop into: TargetLibraries/PULPOpen/inc/kernel/iLeakyReLU.h + * and add `#include "kernel/iLeakyReLU.h"` to DeeployPULPMath.h. + * ===================================================================== */ +/* SPDX-License-Identifier: Apache-2.0 */ + +#ifndef __DEEPLOY_KERNEL_ILEAKYRELU_H_ +#define __DEEPLOY_KERNEL_ILEAKYRELU_H_ + +#include "DeeployPULPMath.h" + +void PULPiLeakyReLU_i8_i8(int8_t *pIn, int8_t *pOut, uint32_t size, + int32_t mul, int32_t shift); + +#endif // __DEEPLOY_KERNEL_ILEAKYRELU_H_ diff --git a/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLUParser.py b/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLUParser.py new file mode 100644 index 0000000000..78651b3873 --- /dev/null +++ b/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLUParser.py @@ -0,0 +1,36 @@ +# ---------------------------------------------------------------------- +# File: iLeakyReLUParser.py (SoCDAML Part III - Step 2 skeleton) +# +# Paste this class into: +# Deeploy/Targets/Generic/Parsers.py +# +# Imports already present in that file (math, numpy as np, +# onnx_graphsurgeon as gs, NodeParser, NetworkContext). +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 + + +class iLeakyReLUParser(NodeParser): + + def __init__(self): + super().__init__() + + def parseNode(self, node: gs.Node) -> bool: + # TODO(student): return False if the node doesn't have exactly + # one input, exactly one output, and both 'mul' and 'shift' + # attributes. On success, store them into + # self.operatorRepresentation as ints and return True. + return False + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True): + # TODO(student): look up the input and output tensors from ctxt + # using node.inputs[0].name / node.outputs[0].name, and populate + # self.operatorRepresentation with: + # 'data_in' -> input tensor name + # 'data_out' -> output tensor name + # 'size' -> int(np.prod(input_shape)) + # Return (ctxt, True). + return ctxt, False diff --git a/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLUTemplate.py b/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLUTemplate.py new file mode 100644 index 0000000000..82440daed5 --- /dev/null +++ b/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLUTemplate.py @@ -0,0 +1,28 @@ +# ---------------------------------------------------------------------- +# File: iLeakyReLUTemplate.py (SoCDAML Part III - Step 4 skeleton) +# +# Drop this file into: +# Deeploy/Targets/PULPOpen/Templates/iLeakyReLUTemplate.py +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 + +from Deeploy.DeeployTypes import NodeTemplate + + +class _iLeakyReLUTemplate(NodeTemplate): + + def __init__(self, templateStr): + super().__init__(templateStr) + + +# TODO(student): fill in the Mako template body so it emits a single +# call to your C kernel: +# +# PULPiLeakyReLU_i8_i8(, , , , ); +# +# All five `${...}` substitutions correspond to keys you populated in +# the parser (or that Deeploy fills automatically for tensor names). +referenceTemplate = _iLeakyReLUTemplate(""" +// iLeakyReLU (Name: ${nodeName}, Op: ${nodeOp}) +// TODO(student): emit the kernel call here. +""") diff --git a/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLUTileConstraint.py b/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLUTileConstraint.py new file mode 100644 index 0000000000..3094f37939 --- /dev/null +++ b/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLUTileConstraint.py @@ -0,0 +1,33 @@ +# ---------------------------------------------------------------------- +# File: iLeakyReLUTileConstraint.py (SoCDAML Part III - Step 5+7a skeleton) +# +# Drop this file into: +# Deeploy/Targets/PULPOpen/TileConstraints/iLeakyReLUTileConstraint.py +# +# UnaryTileConstraint already implements the geometry and serializer +# you need for an elementwise op. You only have to subclass it. In +# Step 6a you'll add a performance constraint on top. +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict + +from Deeploy.DeeployTypes import NetworkContext +from Deeploy.Targets.Generic.TileConstraints.UnaryTileConstraint import UnaryTileConstraint +from Deeploy.TilingExtension.TilerModel import TilerModel + + +class iLeakyReLUTileConstraint(UnaryTileConstraint): + + @staticmethod + def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + tilerModel = UnaryTileConstraint.addGeometricalConstraint(tilerModel, parseDict, ctxt) + + # TODO(student, Step 6a): add a performance constraint so the + # innermost tile dim is a multiple of 16. Helpful API: + # tilerModel.addMinTileSizeConstraint(parseDict, name, + # tensorDimVar, modulo) + # See: Deeploy/Targets/Generic/TileConstraints/ConvTileConstraint.py + # for a usage example. + + return tilerModel diff --git a/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLU_simd.c b/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLU_simd.c new file mode 100644 index 0000000000..728e7bbcb7 --- /dev/null +++ b/Tutorials/PartIII_skeletons/iLeakyReLU/iLeakyReLU_simd.c @@ -0,0 +1,49 @@ +/* ===================================================================== + * Title: iLeakyReLU_simd.c (SoCDAML Part III - Step 6b skeleton) + * + * SIMD version of iLeakyReLU using XPULP packed 4x8b operations. + * The per-core chunking is provided. Fill in the inner SIMD body. + * + * Key identity (worth deriving on paper before reading hints below): + * LeakyReLU(x) = (x >= 0) ? x : (x >> shift) + * = max(x, x >> shift) + * because arithmetic right shift makes a negative value LESS negative + * (or zero) and doesn't change the sign of a non-negative value. + * + * Strategy hint (one path, two intrinsic-level operations per 4 lanes): + * - load v4s lane: v4s x = vIn[i]; + * - per-lane signed shift: v4s s = x >> shift; (GCC vector ext) + * - signed packed max: __builtin_pulp_max4(x, s); + * + * For the lab we assume `mul == 1` (the generator picks mul=1, shift=3). + * + * Drop into: TargetLibraries/PULPOpen/src/iLeakyReLU.c (overwrite scalar) + * ===================================================================== */ +/* SPDX-License-Identifier: Apache-2.0 */ + +#include "DeeployPULPMath.h" +#include "pmsis.h" + +void PULPiLeakyReLU_i8_i8(int8_t *pIn, int8_t *pOut, uint32_t size, + int32_t mul, int32_t shift) { + (void)mul; // SIMD path assumes mul == 1 + + uint32_t cid = pi_core_id(); + uint32_t nC = NUM_CORES; + uint32_t per = (size + nC - 1) / nC; + per &= ~0x3u; + uint32_t start = cid * per; + uint32_t end = (start + per > size) ? size : (start + per); + + v4s *vIn = (v4s *)(pIn + start); + v4s *vOut = (v4s *)(pOut + start); + uint32_t nVec = (end - start) >> 2; + + for (uint32_t i = 0; i < nVec; i++) { + v4s x = vIn[i]; + // TODO(student): one line to compute `s` from `x` and `shift`, + // one line to blend `x` and `s` with the packed + // signed max intrinsic and store it. + vOut[i] = x; // <- placeholder, replace + } +} diff --git a/Tutorials/PartIII_solution/iLeakyReLU/README.md b/Tutorials/PartIII_solution/iLeakyReLU/README.md new file mode 100644 index 0000000000..6169d9e35c --- /dev/null +++ b/Tutorials/PartIII_solution/iLeakyReLU/README.md @@ -0,0 +1,99 @@ +# SoCDAML Part III - TA reference solution for `iLeakyReLU` + +The complete working `iLeakyReLU` operator (parser, template, binding, +mapper, tile constraint, scalar kernel, SIMD kernel, ONNX + golden +artifacts, and a one-shot deploy script). Use it to demo the lab +end-to-end, and unblock students who get stuck. + +## What's in here + +| File | Purpose | +|------|---------| +| `generate.py` | Builds `network.onnx`, `inputs.npz`, `outputs.npz` for the single-node test | +| `network.onnx` | Single-node ONNX with op_type `iLeakyReLU` (`mul=1`, `shift=3`), shape `(1, 16, 64, 64)` | +| `inputs.npz` | Int64 input tensor named `input` | +| `outputs.npz` | Int64 golden output tensor named `output` | +| `iLeakyReLU.h` | Kernel header | +| `iLeakyReLU.c` | Scalar baseline kernel (Step 3) | +| `iLeakyReLU_simd.c` | XPULP SIMD kernel (Step 6b) | +| `iLeakyReLUParser.py` | Full parser class for `Deeploy/Targets/Generic/Parsers.py` | +| `iLeakyReLUTemplate.py` | Full Mako template for `Deeploy/Targets/PULPOpen/Templates/` | +| `iLeakyReLUTileConstraint.py` | Full tile + perf constraint for `Deeploy/Targets/PULPOpen/TileConstraints/` | +| `deploy.sh` | One-shot script that copies the kernel/template/constraint into the live tree AND patches `Parsers.py`, `Bindings.py`, `Tiler.py`, `Platform.py`, `DeeployPULPMath.h` to wire everything up | + +## Quick start (TA workflow) + +From this directory, inside the Singularity shell: + +```bash +# 1) (Re)generate the test artifacts +python generate.py + +# 2) Apply the SCALAR solution into the live source tree +./deploy.sh + +# 3) Verify (all four runs should report 0 errors and the cycle counts +# in the table below). See "Verification" section for the commands. + +# 4) Swap to the SIMD kernel for Step 6b +./deploy.sh simd + +# 5) Roll back the file copies if you ever need to clean up +./deploy.sh undo +# (note: the script-applied patches into Parsers.py / Bindings.py / +# Platform.py / Tiler.py / DeeployPULPMath.h are NOT auto-reverted; +# use `git checkout -- ` for those if needed) +``` + +`deploy.sh` is idempotent, i.e. running it a second time is a no-op for +the source patches. Re-running `./deploy.sh` after `./deploy.sh simd` +will overwrite the kernel back to scalar (and vice versa), so you can +flip between the two with one command. + +## Verification + +Reproduce every number in the lab's "Stacked speedup" table from +`DeeployTest/`: + +```bash +cd /app/Deeploy/Tutorials/PartIII_solution/iLeakyReLU +./deploy.sh +cd /app/Deeploy/DeeployTest + +echo "=== Baseline (1 core, scalar, untiled) ==="; python testRunner_siracusa.py -t Tests/iLeakyReLU --cores=1 2>&1 | grep -E "Runtime|Errors" +echo "=== Step 4 (8 cores, scalar, untiled) ==="; python testRunner_siracusa.py -t Tests/iLeakyReLU --cores=8 2>&1 | grep -E "Runtime|Errors" +echo "=== Step 5 (8 cores, scalar, tiled) ==="; python testRunner_tiled_siracusa.py -t Tests/iLeakyReLU --cores=8 --l1=32768 --defaultMemLevel=L2 2>&1 | grep -E "Runtime|Errors" + +cd /app/Deeploy/Tutorials/PartIII_solution/iLeakyReLU +./deploy.sh simd +cd /app/Deeploy/DeeployTest + +echo "=== Step 6 (8 cores, SIMD, tiled) ==="; python testRunner_tiled_siracusa.py -t Tests/iLeakyReLU --cores=8 --l1=32768 --defaultMemLevel=L2 2>&1 | grep -E "Runtime|Errors" +``` + +### Expected output + +Every run reports `Errors: 0 out of 65536`. Cycle counts: + +| Step | Configuration | Cycles | vs baseline | +|------|---|---|---| +| baseline | 1 core, scalar, untiled | **2 492 970** | 1.00× | +| Step 4 | 8 cores, scalar, untiled | **313 541** | 7.95× | +| Step 5 | 8 cores, scalar, tiled (`--l1=32768`) | **108 090** | 23.06× | +| Step 6 | 8 cores, SIMD, tiled (`--l1=32768`) | **43 005** | 57.97× | + +If any count drifts by more than a few percent or a run reports any +errors, something in the deploy is off. Try `./deploy.sh undo` plus +`git checkout --` on the patched source files, then re-deploy from +scratch. + +## Files NOT in this directory (live-tree edits applied by deploy.sh) + +`deploy.sh` modifies these files in the live tree. They are NOT +duplicated here, i.e. `deploy.sh` is the source of truth. + +- `Deeploy/Targets/Generic/Parsers.py`: appends `iLeakyReLUParser` +- `Deeploy/Targets/PULPOpen/Bindings.py`: appends `PULPiLeakyReLUBindings` +- `Deeploy/Targets/PULPOpen/Tiler.py`: appends `PULPiLeakyReLUTilingReadyBindings` +- `Deeploy/Targets/PULPOpen/Platform.py`: adds parser/layer imports, mapper, and `PULPMapping` entry +- `TargetLibraries/PULPOpen/inc/DeeployPULPMath.h`: adds the kernel include diff --git a/Tutorials/PartIII_solution/iLeakyReLU/deploy.sh b/Tutorials/PartIII_solution/iLeakyReLU/deploy.sh new file mode 100755 index 0000000000..668b3eb032 --- /dev/null +++ b/Tutorials/PartIII_solution/iLeakyReLU/deploy.sh @@ -0,0 +1,217 @@ +#!/bin/bash +# ---------------------------------------------------------------------- +# deploy.sh - apply the TA solution into the live Deeploy source tree. +# +# Run from this directory (.../Deeploy/Tutorials/PartIII_solution/iLeakyReLU/). +# Idempotent for the file copies; the source patches use grep guards so +# they only apply once. +# +# Usage: +# ./deploy.sh # apply scalar kernel (Step 3) +# ./deploy.sh simd # apply SIMD kernel (Step 6b) on top +# ./deploy.sh undo # remove the additions (best-effort) +# ---------------------------------------------------------------------- +set -euo pipefail +HERE="$(cd "$(dirname "$0")" && pwd)" +ROOT="$(cd "$HERE/../../.." && pwd)" + +MODE="${1:-scalar}" + +PARSERS="$ROOT/Deeploy/Targets/Generic/Parsers.py" +BINDINGS="$ROOT/Deeploy/Targets/PULPOpen/Bindings.py" +PLATFORM="$ROOT/Deeploy/Targets/PULPOpen/Platform.py" +TILER="$ROOT/Deeploy/Targets/PULPOpen/Tiler.py" +PULPMATH_H="$ROOT/TargetLibraries/PULPOpen/inc/DeeployPULPMath.h" +TEMPLATES_DIR="$ROOT/Deeploy/Targets/PULPOpen/Templates" +TILECONSTR_DIR="$ROOT/Deeploy/Targets/PULPOpen/TileConstraints" +KERNEL_SRC_DIR="$ROOT/TargetLibraries/PULPOpen/src" +KERNEL_INC_DIR="$ROOT/TargetLibraries/PULPOpen/inc/kernel" +TESTS_DIR="$ROOT/DeeployTest/Tests/iLeakyReLU" + +case "$MODE" in + undo) + echo "Undoing iLeakyReLU additions (file copies only)..." + rm -rf "$TESTS_DIR" + rm -f "$KERNEL_SRC_DIR/iLeakyReLU.c" + rm -f "$KERNEL_INC_DIR/iLeakyReLU.h" + rm -f "$TEMPLATES_DIR/iLeakyReLUTemplate.py" + rm -f "$TILECONSTR_DIR/iLeakyReLUTileConstraint.py" + echo "Note: hand-patches in Parsers.py / Bindings.py / Platform.py / Tiler.py / DeeployPULPMath.h were NOT removed." + echo "If you need a fully clean tree, use: git checkout -- Deeploy/Targets TargetLibraries/PULPOpen/inc/DeeployPULPMath.h" + exit 0 + ;; + scalar|simd) ;; + *) echo "Unknown mode '$MODE'. Try: scalar | simd | undo"; exit 1;; +esac + +echo "[1/6] Copy test artifacts -> $TESTS_DIR" +mkdir -p "$TESTS_DIR" +for f in network.onnx inputs.npz outputs.npz; do + if [ ! -f "$HERE/$f" ]; then + echo " ERROR: $f not found in $HERE - run 'python generate.py' first." >&2 + exit 1 + fi + cp "$HERE/$f" "$TESTS_DIR/" +done + +echo "[2/6] Copy kernel header -> $KERNEL_INC_DIR/iLeakyReLU.h" +cp "$HERE/iLeakyReLU.h" "$KERNEL_INC_DIR/iLeakyReLU.h" + +echo "[3/6] Copy kernel source ($MODE) -> $KERNEL_SRC_DIR/iLeakyReLU.c" +if [ "$MODE" = "simd" ]; then + cp "$HERE/iLeakyReLU_simd.c" "$KERNEL_SRC_DIR/iLeakyReLU.c" +else + cp "$HERE/iLeakyReLU.c" "$KERNEL_SRC_DIR/iLeakyReLU.c" +fi + +echo "[4/6] Copy template + tile constraint" +cp "$HERE/iLeakyReLUTemplate.py" "$TEMPLATES_DIR/iLeakyReLUTemplate.py" +cp "$HERE/iLeakyReLUTileConstraint.py" "$TILECONSTR_DIR/iLeakyReLUTileConstraint.py" + +echo "[5/6] Patch DeeployPULPMath.h (idempotent)" +if ! grep -q 'kernel/iLeakyReLU.h' "$PULPMATH_H"; then + # Insert before the final #endif + awk ' + /^#endif/ && !done { print "#include \"kernel/iLeakyReLU.h\""; done=1 } + { print } + ' "$PULPMATH_H" > "$PULPMATH_H.tmp" && mv "$PULPMATH_H.tmp" "$PULPMATH_H" +fi + +echo "[6/6] Patch Parsers.py / Bindings.py / Tiler.py / Platform.py (idempotent)" + +# --- Generic/Parsers.py: append iLeakyReLUParser if absent +if ! grep -q 'class iLeakyReLUParser' "$PARSERS"; then + cat >> "$PARSERS" <<'PARSER_EOF' + + +class iLeakyReLUParser(NodeParser): + + def __init__(self): + super().__init__() + + def parseNode(self, node: gs.Node) -> bool: + wellFormed = all([ + len(node.inputs) == 1, + len(node.outputs) == 1, + 'mul' in node.attrs, + 'shift' in node.attrs, + ]) + if not wellFormed: + return False + self.operatorRepresentation['mul'] = int(node.attrs['mul']) + self.operatorRepresentation['shift'] = int(node.attrs['shift']) + return True + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True): + data_in = ctxt.lookup(node.inputs[0].name) + data_out = ctxt.lookup(node.outputs[0].name) + self.operatorRepresentation['data_in'] = data_in.name + self.operatorRepresentation['data_out'] = data_out.name + self.operatorRepresentation['size'] = int(np.prod(data_in.shape)) + return ctxt, True +PARSER_EOF +fi + +# --- PULPOpen/Bindings.py: append PULPiLeakyReLUBindings if absent +if ! grep -q 'PULPiLeakyReLUBindings' "$BINDINGS"; then + # Add template import next to other PULPOpen template imports. + python3 - <> "$BINDINGS" <<'BIND_EOF' + + +PULPiLeakyReLUBindings = [ + NodeBinding( + ReluChecker([PointerClass(int8_t)], [PointerClass(int8_t)]), + iLeakyReLUTemplate.referenceTemplate, + ForkTransformer) +] +BIND_EOF +fi + +# --- PULPOpen/Tiler.py: append TilingReady bindings + import +if ! grep -q 'PULPiLeakyReLUTilingReadyBindings' "$TILER"; then + python3 - <= 0 +# (mul*x) >> shift otherwise +# With mul=1, shift=3 this approximates alpha = 0.125. +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import onnx +from onnx import TensorProto, helper + +SHAPE = (1, 16, 64, 64) # NCHW; 65 536 elements -> big enough that + # double-buffering's DMA/kernel overlap dominates + # per-tile bookkeeping, so DB visibly beats SB. +MUL = 1 +SHIFT = 3 +SEED = 0xC0FFEE + +def golden(x, mul, shift): + """Reference int8 LeakyReLU. Arithmetic right shift on negative ints + matches the C `>>` operator on signed integers on most platforms, + so we cast to int32, shift, then clip to int8.""" + pos = x.astype(np.int32) + neg = (mul * pos) >> shift + out = np.where(pos >= 0, pos, neg) + return np.clip(out, -128, 127).astype(np.int8) + +def build_onnx(): + in_value = helper.make_tensor_value_info('data_in', TensorProto.INT8, SHAPE) + out_value = helper.make_tensor_value_info('data_out', TensorProto.INT8, SHAPE) + + node = helper.make_node( + op_type = 'iLeakyReLU', + inputs = ['data_in'], + outputs = ['data_out'], + name = 'iLeakyReLU_0', + mul = MUL, + shift = SHIFT, + ) + + graph = helper.make_graph( + nodes = [node], + name = 'iLeakyReLU_single_node', + inputs = [in_value], + outputs = [out_value], + ) + + model = helper.make_model(graph, producer_name='SoCDAML-PartIII') + model.opset_import[0].version = 13 + model.ir_version = 7 + return model + +def main(): + rng = np.random.default_rng(SEED) + x = rng.integers(low=-128, high=127, size=SHAPE, dtype=np.int8) + y = golden(x, MUL, SHIFT) + + model = build_onnx() + onnx.save(model, 'network.onnx') + # Deeploy convention: npz tensors saved as int64 (the test harness + # casts to float64 then to the ONNX dtype). Storing int8 directly + # confuses the buffer-population path. + np.savez('inputs.npz', input=x.astype(np.int64)) + np.savez('outputs.npz', output=y.astype(np.int64)) + + print(f"Wrote network.onnx (shape={SHAPE}, mul={MUL}, shift={SHIFT})") + print(f"Wrote inputs.npz : keys=['input'] shape={x.shape} int64, int8-range=[{x.min()}, {x.max()}]") + print(f"Wrote outputs.npz : keys=['output'] shape={y.shape} int64, int8-range=[{y.min()}, {y.max()}]") + +if __name__ == '__main__': + main() diff --git a/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLU.c b/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLU.c new file mode 100644 index 0000000000..70b4926186 --- /dev/null +++ b/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLU.c @@ -0,0 +1,26 @@ +/* ===================================================================== + * Title: iLeakyReLU.c (scalar baseline) + * Description: int8 quantization-friendly LeakyReLU, plain C. + * SoCDAML Part III - TA reference solution, Step 3. + * ===================================================================== */ +/* Copyright (C) 2026 ETH Zurich and University of Bologna. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "DeeployPULPMath.h" +#include "pmsis.h" + +void PULPiLeakyReLU_i8_i8(int8_t *pIn, int8_t *pOut, uint32_t size, + int32_t mul, int32_t shift) { + uint32_t cid = pi_core_id(); + uint32_t nC = NUM_CORES; + uint32_t per = (size + nC - 1) / nC; + uint32_t start = cid * per; + uint32_t end = (start + per > size) ? size : (start + per); + + for (uint32_t i = start; i < end; i++) { + int32_t x = (int32_t)pIn[i]; + int32_t lo = (mul * x) >> shift; + pOut[i] = (int8_t)((x >= 0) ? x : lo); + } +} diff --git a/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLU.h b/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLU.h new file mode 100644 index 0000000000..844317464a --- /dev/null +++ b/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLU.h @@ -0,0 +1,20 @@ +/* ===================================================================== + * Title: iLeakyReLU.h + * Description: int8 quantization-friendly LeakyReLU. + * SoCDAML Part III - TA reference solution. + * + * out[i] = (in[i] >= 0) ? in[i] : ((mul * in[i]) >> shift) + * ===================================================================== */ +/* Copyright (C) 2026 ETH Zurich and University of Bologna. + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef __DEEPLOY_KERNEL_ILEAKYRELU_H_ +#define __DEEPLOY_KERNEL_ILEAKYRELU_H_ + +#include "DeeployPULPMath.h" + +void PULPiLeakyReLU_i8_i8(int8_t *pIn, int8_t *pOut, uint32_t size, + int32_t mul, int32_t shift); + +#endif // __DEEPLOY_KERNEL_ILEAKYRELU_H_ diff --git a/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLUParser.py b/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLUParser.py new file mode 100644 index 0000000000..e286da259a --- /dev/null +++ b/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLUParser.py @@ -0,0 +1,44 @@ +# ---------------------------------------------------------------------- +# File: iLeakyReLUParser.py +# +# SoCDAML Part III - TA reference solution. +# Parser for the iLeakyReLU op. Appended to: +# Deeploy/Targets/Generic/Parsers.py +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 + +# (in Generic/Parsers.py the following imports already exist: +# import math; import numpy as np; import onnx_graphsurgeon as gs; +# from Deeploy.DeeployTypes import NodeParser, NetworkContext) + + +class iLeakyReLUParser(NodeParser): + + def __init__(self): + super().__init__() + + def parseNode(self, node: gs.Node) -> bool: + wellFormed = all([ + len(node.inputs) == 1, + len(node.outputs) == 1, + 'mul' in node.attrs, + 'shift' in node.attrs, + ]) + + if not wellFormed: + return False + + self.operatorRepresentation['mul'] = int(node.attrs['mul']) + self.operatorRepresentation['shift'] = int(node.attrs['shift']) + return True + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True): + data_in = ctxt.lookup(node.inputs[0].name) + data_out = ctxt.lookup(node.outputs[0].name) + self.operatorRepresentation['data_in'] = data_in.name + self.operatorRepresentation['data_out'] = data_out.name + self.operatorRepresentation['size'] = int(np.prod(data_in.shape)) + return ctxt, True diff --git a/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLUTemplate.py b/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLUTemplate.py new file mode 100644 index 0000000000..031d08f635 --- /dev/null +++ b/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLUTemplate.py @@ -0,0 +1,24 @@ +# ---------------------------------------------------------------------- +# File: iLeakyReLUTemplate.py +# +# SoCDAML Part III - TA reference solution. +# Mako template that emits the call to the PULP iLeakyReLU C kernel. +# +# Drop this file into: +# Deeploy/Targets/PULPOpen/Templates/iLeakyReLUTemplate.py +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 + +from Deeploy.DeeployTypes import NodeTemplate + + +class _iLeakyReLUTemplate(NodeTemplate): + + def __init__(self, templateStr): + super().__init__(templateStr) + + +referenceTemplate = _iLeakyReLUTemplate(""" +// iLeakyReLU (Name: ${nodeName}, Op: ${nodeOp}) +PULPiLeakyReLU_i8_i8(${data_in}, ${data_out}, ${size}, ${mul}, ${shift}); +""") diff --git a/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLUTileConstraint.py b/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLUTileConstraint.py new file mode 100644 index 0000000000..3d0534a491 --- /dev/null +++ b/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLUTileConstraint.py @@ -0,0 +1,47 @@ +# ---------------------------------------------------------------------- +# File: iLeakyReLUTileConstraint.py +# +# SoCDAML Part III - TA reference solution. +# Tiling + performance constraint for the iLeakyReLU op. +# +# Drop this file into: +# Deeploy/Targets/PULPOpen/TileConstraints/iLeakyReLUTileConstraint.py +# ---------------------------------------------------------------------- +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict + +from Deeploy.DeeployTypes import NetworkContext +from Deeploy.Targets.Generic.TileConstraints.UnaryTileConstraint import UnaryTileConstraint +from Deeploy.TilingExtension.TilerModel import TilerModel + + +class iLeakyReLUTileConstraint(UnaryTileConstraint): + """ + Geometry is inherited from UnaryTileConstraint (input shape == output + shape per axis; one shared cube per output tile). On top of that we + add the Step 6a performance constraint: the innermost (last) tile + dim must be a multiple of 16 so the 4-byte SIMD kernel can vectorize + the per-core chunk without a tail iteration. + """ + + @staticmethod + def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + tilerModel = UnaryTileConstraint.addGeometricalConstraint(tilerModel, parseDict, ctxt) + + inputBufferName = parseDict['data_in'] + inputShape = ctxt.lookup(inputBufferName).shape + lastDim = len(inputShape) - 1 + lastDimVar = tilerModel.getTensorDimVar(tensorName=inputBufferName, dimIdx=lastDim) + + # Force the tiled inner dimension to be a multiple of 16. This + # ensures (per-core chunk) is a multiple of 4 once split across + # 8 cores -> the v4s SIMD inner loop is always tail-free. + # addMinTileSizeConstraint reads parseDict[varName] as the + # original axis size, so inject it here. + if inputShape[lastDim] >= 16: + dimKey = f'dim_{lastDim}' + parseDict[dimKey] = int(inputShape[lastDim]) + tilerModel.addMinTileSizeConstraint(parseDict, dimKey, lastDimVar, 16) + + return tilerModel diff --git a/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLU_simd.c b/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLU_simd.c new file mode 100644 index 0000000000..0682c634fb --- /dev/null +++ b/Tutorials/PartIII_solution/iLeakyReLU/iLeakyReLU_simd.c @@ -0,0 +1,59 @@ +/* ===================================================================== + * Title: iLeakyReLU_simd.c (XPULP SIMD) + * Description: int8 LeakyReLU optimized with packed 4x8b PULP intrinsics. + * SoCDAML Part III - TA reference solution, Step 6. + * + * Key identity used: + * LeakyReLU_shift(x) = (x >= 0) ? x : (x >> shift) + * = max(x, x >> shift) + * because: + * - x >= 0 => x >= x >> shift (shift toward zero of positive number) + * - x < 0 => x <= x >> shift (arith shift makes negative LESS negative) + * + * We use the GCC vector extension: `v4s s = x >> shift;` is a packed + * per-lane arithmetic right shift, and __builtin_pulp_max4 is a single + * XPULP signed packed-byte max. So the entire inner loop is: + * load v4s -> packed shift -> packed max -> store v4s + * + * Note: This SIMD path ignores the `mul` parameter (assumes mul == 1). + * Our generator script picks mul=1, shift=3 (alpha ~= 0.125), so this + * is identical to the scalar formula. To support arbitrary mul you would + * need a packed multiply, which loses the clean 4x speedup. + * ===================================================================== */ +/* Copyright (C) 2026 ETH Zurich and University of Bologna. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "DeeployPULPMath.h" +#include "pmsis.h" + +void PULPiLeakyReLU_i8_i8(int8_t *pIn, int8_t *pOut, uint32_t size, + int32_t mul, int32_t shift) { + (void)mul; // SIMD path assumes mul == 1 + + uint32_t cid = pi_core_id(); + uint32_t nC = NUM_CORES; + uint32_t per = (size + nC - 1) / nC; + // Round per-core chunk down to a multiple of 4 so the SIMD loop is + // tail-free. The Step 6a tile-size constraint already guarantees that + // size is a multiple of 16. + per &= ~0x3u; + uint32_t start = cid * per; + uint32_t end = (start + per > size) ? size : (start + per); + + v4s *vIn = (v4s *)(pIn + start); + v4s *vOut = (v4s *)(pOut + start); + uint32_t nVec = (end - start) >> 2; + + for (uint32_t i = 0; i < nVec; i++) { + v4s x = vIn[i]; + v4s s = x >> shift; // packed per-lane arith shift + vOut[i] = __builtin_pulp_max4(x, s); // max(x, x>>shift) = LeakyReLU + } + + // Scalar tail (only when perf constraint is not installed) + for (uint32_t i = start + (nVec << 2); i < end; i++) { + int32_t xs = (int32_t)pIn[i]; + pOut[i] = (int8_t)((xs >= 0) ? xs : (xs >> shift)); + } +} diff --git a/Tutorials/PartIII_solution/iLeakyReLU/inputs.npz b/Tutorials/PartIII_solution/iLeakyReLU/inputs.npz new file mode 100644 index 0000000000..908c7855cf Binary files /dev/null and b/Tutorials/PartIII_solution/iLeakyReLU/inputs.npz differ diff --git a/Tutorials/PartIII_solution/iLeakyReLU/network.onnx b/Tutorials/PartIII_solution/iLeakyReLU/network.onnx new file mode 100644 index 0000000000..0d3ce91ab7 --- /dev/null +++ b/Tutorials/PartIII_solution/iLeakyReLU/network.onnx @@ -0,0 +1,19 @@ +SoCDAML-PartIII: +G +data_indata_out iLeakyReLU_0" +iLeakyReLU* + +mul* +shiftiLeakyReLU_single_nodeZ! +data_in + + + +@ +@b" +data_out + + + +@ +@B \ No newline at end of file diff --git a/Tutorials/PartIII_solution/iLeakyReLU/outputs.npz b/Tutorials/PartIII_solution/iLeakyReLU/outputs.npz new file mode 100644 index 0000000000..c7045f77fc Binary files /dev/null and b/Tutorials/PartIII_solution/iLeakyReLU/outputs.npz differ diff --git a/Tutorials/SoCDAML.md b/Tutorials/SoCDAML.md new file mode 100644 index 0000000000..54e42aa5b5 --- /dev/null +++ b/Tutorials/SoCDAML.md @@ -0,0 +1,489 @@ +
+ Image +
+

Institut für Integrierte Systeme
+ Integrated Systems Laboratory

+
+
+ + +# SoCDAML: Neural Network Deeployment on the PULP Platform +Author: *Victor J.B Jung*
+ *Viviane Potocnik* +Date: 28th May 2026 + +## Installation + +A Singularity container file (extension .sif) with Deeploy and its dependencies has been installed on the system; Build the sandbox container with `singularity build --sandbox /scratch/$USER/DeeployContainer/ /home/soc_042fs25/deeploy-container-socdaml.sif` + +Then you can find Deeploy's source code in `/scratch/$USER/DeeployContainer/app/Deeploy`. To spawn a shell from the container, from your home, run `singularity shell --writable --cleanenv --contain /scratch/$USER/DeeployContainer/`. Then you can navigate to the `DeeployTest` folder with `cd /app/Deeploy/DeeployTest`. + +From the `DeeployTest` folder, you can use the `testRunner` to compile ONNXs and execute the output code using the appropriate simulators. + +To validate your installation, you can run a simple Add node on each platform: +``` +python testRunner_generic.py -t Tests/Adder +python testRunner_cortexm.py -t Tests/Adder +python testRunner_mempool.py -t Tests/Adder +python testRunner_snitch.py -t Tests/Adder/ +python testRunner_siracusa.py -t Tests/Adder --cores=8 +``` +Once all these basic tests are passed, we can jump into the basics of Deeploy. + +## I : Deeploy 101 + +Deeploy is a compiler that transforms static computational graph (represented with the [ONNX format](https://onnx.ai/onnx/operators/)) into bare-metal and (hopefully) optimized [C](https://www.c-language.org/). More specifically, it generates an application that can be deployed on the desired platform. + +Hence, Deeploy's inputs are: +- An ONNX file describing your neural network. +- Input tensors. +- Expected output tensors generated with your favorite framework (ONNXRuntime or Torch, for instance). + +Deeploy is shipped with a comprehensive testing framework conveniently named DeeployTest. This testing framework contains Test Runners for end-to-end testing of your network on a given platform. More specifically, a Test Runner compiles a given ONNX file, builds the project, feeds the inputs into the compiled neural network, and compares the output with the golden values to ensure correctness. + +If you followed this tutorial correctly, you already used Test Runners (e.g., `testRunner_siracusa.py`) to validate the Deeploy installation! We will dive into the details of the Test Runners CLI very soon, but first, let's look at the tools and libraries used downstream in Deeploy. + +The figure below gives an overview of the deployment stack. As you can see, there are several steps to take before actually running the application. For the build system (*e.g.,* the tool to organize compilation and linking), we use [CMake](https://cmake.org/). The default C compiler shipped with Deeploy is [LLVM 15](https://llvm.org/), but it supports GCC, given that you provide a local installation. To generate the Application Binary, we link the Network Code with the necessary Kernel Libraries and a Standard C Library (here [Picolibc](https://github.com/picolibc/picolibc)). Then, we feed this Application Binary to the appropriate simulator; from there, you can verify the correctness and benchmark the application. + +

+ Description +

+ +You can visualize the ONNX graphs using [Netron](https://netron.app/). Either use the web interface or install the python package with `pip install netron`. If you are using _VSCode_ you can install the [`vscode-netron`](https://marketplace.visualstudio.com/items?itemName=vincent-templier.vscode-netron) plugin. + +> ✅ **Task:** Visualize the ONNX graph of the `Adder`, `MobileNetv2`, and `Transformer` + +The ONNX graphs are in `DeeployTest/Tests//network.onnx`. The networks are increasing in complexity, `Adder` is a single node network for unit testing, while `MobileNetv2` is a simple sequential network mostly made of convolutions. Finally, the `Transformer` network showcases a typical transformer block used in Encoder and Decoder networks. If you want to peek at a complex network, you can visualize `microLlama/microLlama128`. + +Now that we understand Deeploy's input, let's check the output-generated code! + +> ✅ **Task:** Take a look at the code generated by Deeploy for the Generic platform. + +The generated code is located in the following directory: `DeeployTest/TEST_/Tests`, and the `Network.c` file is the interesting one. + +The generated code is trivial for the `Adder` graph; we simply use the template for the `Add` node of the Generic platform. You can find the template declaration in `Deeploy/Targets/Generic/Templates/AddTemplate.py`. + +Now, if you want to look at something a bit more complex, run `python testRunner_generic.py -t ./Tests/miniMobileNetv2` (from `DeeployTest`) and look at the generated code. There are two interesting points you can notice: +- We hoist the constants at the top of the file. +- In the `RunNetwork` function, we sequentially have node templates to execute the operands and malloc/free to manage the memory. You can open the ONNX graph of `miniMobileNetv2` on the side to try to match the nodes of the graph with their generated code. + +> ✅ **Task:** Visualize the effect of passes on the ONNX graph for the Siracusa platform. + +Deeploy applies passes on the ONNX graph to transform its topology and optimize its execution. Let's visualize the effect of the passes used in the Siracusa Platform. First, let's execute our `miniMobileNetv2` on Siracusa with `python testRunner_siracusa.py -t ./Tests/miniMobileNetv2`. You can find the original ONNX graph at `DeeployTest/Tests/miniMobileNetv2/network.onnx`, and the transformed ONNX graph at `DeeployTest/TEST_SIRACUSA/Tests/miniMobileNetv2/deeployStates/backend_post_binding.onnx`. Open both ONNX graphs side by side to compare them. + +You can notice the effect of two passes on the graph: +- One pass fuses the `Conv` and `RequantShift` nodes. This is a common technique named [Operator Fusion](https://medium.com/data-science/how-pytorch-2-0-accelerates-deep-learning-with-operator-fusion-and-cpu-gpu-code-generation-35132a85bd26) and used in many DNN compilers. +- Another pass is adding a `Transpose` node before the `RequantizedConv` in order to align the tensor layout from CHW to HWC (where C = Channels, H = Height, and W = Width). The HWC tensor layout is required to use optimized Convolution kernels (to learn more, check out [this blog post](https://www.intel.com/content/www/us/en/developer/articles/technical/pytorch-vision-models-with-channels-last-on-cpu.html)). + +Now that you understand the basics of Deeploy let's jump into the optimized deployment of a small language model on the Siracusa SoC. + +## II : Micro Llama on Siracusa + +### Transformers 101 + +In this section, we will study the optimization of the deployment of a small language model. To fully understand this section, you need some basic understanding of Transformer's architecture and Language Model inference mode. If you need a refresher on Transformer's architecture, check out the *Transformer Basics* section of [Lilian Weng's blog post](https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/#transformer-basics). + +Now, Language Models have two inference modes: +- The **Parallel Mode** (AKA *Prefill Mode*) is used to process the tokens of the prompts in parallel and generate the KV cache of the prompt and the first token of the Language Model's "reply". This mode contains mostly GEMMs. +- The **Autoregressive Mode** generates the rest of the Language Model's reply. It uses the KV cache from the previous step, generates a new KV cache entry, and predicts the next token. This mode contains mostly GEMVs. + +To summarize, to generate a Language Model reply of $N$ tokens, there is: +- One **Parallel Mode** inference to process the prompt and generate the first token. +- $N-1$ **Autoregressive Mode** inferences to generate the rest of the tokens. + +The slide below visually represents the **Parallel Mode** and **Autoregressive Mode**. + +

+ Description +

+ +### The Siracusa Platform + +Let's also quickly refresh our knowledge of the Siracusa platform to understand what kind of hardware we must deploy on. Below is the high-level block diagram of Siracusa, compute-wise we will mainly use: +- The cluster of RV32 cores, they are modified to be great at crunching numbers. They feature [SIMD](), hardware loops (see the [RI5CY user manual](https://www.pulp-platform.org/docs/ri5cy_user_manual.pdf), p17), and the [XPULP](https://pulp-platform.org/docs/hipeac/acaces2021/04_PULP_Accelerators.pdf) ISA extensions. +- The [NEUREKA](https://github.com/pulp-platform/neureka) NPU, an accelerator targeting integer convolutions. + +In terms of memories, we have: +- L3: An off-chip RAM (not shown on the block diagram) of 16MB capacity. The L3 has its own DMA that can transfer data to L2. +- Neural Memory Subsystem (NMS): An SRAM/MRAM-based *Weight Memory* to store constants with a direct link to the NPU. +- L2: An on-chip SRAM-based L2 memory of 2MB. +- L1: A TCDM memory of size 256KB. + +The on-chip DMA indicated on the block diagram can transfer data between the Weight Memory, the L2, and the L1. + +

+ Description +

+ +Now that you understand the hardware and the kind of workload we want to execute. Let's deploy using various optimizations to study their impact. The first parameter we can play with is the number of cores from the RV32 cluster to use. + +> ✅ **Task:** Measure and compare the runtime of the `microLlama128` model using 1 and 8 cores. Compute the speedup ratio; why is it not 8? + +*Hint:* `python testRunner_siracusa.py --help` will list and explain the available flags. + +
+ Solution + + > If you run `python testRunner_siracusa.py -t Tests/microLlama/microLlama128 --cores=1` and then `python testRunner_siracusa.py -t Tests/microLlama/microLlama128 --cores=8`, you should measure a runtime of ~16,1M cycles for 1 core and 3.1M cycles for 8 cores. + > + > The speedup ratio is obtained via $\frac{\text{Runtime 1 cores}}{\text{Runtime 8 cores}} = 5.2$. Hence, using 8 cores instead of 1 leads to a 5.2 times speedup. + > + > So why is the speedup ratio below 8? Mostly because all data movement is not overlapped with computation. Additionally, some kernels are probably not optimally parallelized for this specific network. +
+ +### Tiling Basics + +It's due time to talk about data movement now! We use all 8 cluster cores, which is great, but where do these cores fetch the data from? By default, when using `testRunner_siracusa.py`, all data is in L2; there is no tiling, and cores read and write data directly to/from L2. As the L2 memory is "further away" from the cluster, load/store takes several cycles, which is non-optimal. + +What we really want is to use the L1 memory, which provides 1 cycle latency load/store! But as the capacity is relatively small (256KB), we need to **tile our layers**. Tiling operands for an accelerator featuring only scratchpad memories is not trivial (unlike in architectures with data caches). For each layer, the compiler has to decide on tile size, a tiling schedule, a buffering strategy (single buffer, double buffer, etc...), and a memory allocation strategy. Then, the compiler must generate the code to configure and launch each transfer and place barriers accordingly to maximize concurrency. + +The good news is that Deeploy can already do that! So, let's generate and run some tiled code to see the impact of tiling on the runtime. + +> ✅ **Task:** Get familiar with the CLI arguments of `testRunner_tiled_siracusa.py`, then run `microLlama64_parallel` with different configurations. Find one "bad" and one "good" configuration, and explain why. + +*Hint:* Use the `--help` flag to list and explain the available flags. + +
+ Solution + + > Bad configuration: `python testRunner_tiled_siracusa.py -t Tests/microLlama/microLlama64_parallel --cores=8 --l1 8000 --defaultMemLevel=L2` -> Runtime: 47.5 MCycles + > + > Good configuration `python testRunner_tiled_siracusa.py -t Tests/microLlama/microLlama64_parallel --cores=8 --l1 64000 --defaultMemLevel=L2`: -> Runtime: 35.3 MCycles + > + > Justification: As the size of the L1 memory gets smaller, tiles also get smaller and smaller. Smaller tiles usually mean that it's harder to keep the core properly utilized. + +
+ +### Profiling the Execution + +To measure the effect of some optimizations in more detail, you can use the `--profileTiling=L2` flag. This flag will enable a code transformation that will insert print displaying the runtime of several critical code sections. For instance, profiling an *Integer Layer Normalization* layer from L2 with two tiles will return the print the following: +``` +[INTEGER_RMSNORM L2][SB][0 ops][Tile 0] Input DMA took 489 cycles +[INTEGER_RMSNORM L2][SB][0 ops][Tile 0] Kernel took 43305 cycles +[INTEGER_RMSNORM L2][SB][0 ops][Tile 0] Output DMA took 534 cycles +[INTEGER_RMSNORM L2][SB][0 ops][Tile 1] Input DMA took 82 cycles +[INTEGER_RMSNORM L2][SB][0 ops][Tile 1] Kernel took 3254 cycles +[INTEGER_RMSNORM L2][SB][0 ops][Tile 1] Output DMA took 49 cycles +``` +With this profiling trace, you can clearly measure the overhead of DMA transfers. When the profiling is turned ON, the total runtime of the application will encompass the prints. + +### Using the NPU and the Neural Memory Subsystem (NMS) + +To use the NPU, you can use the `testRunner_tiled_siracusa_w_neureka.py`. The Linear layers will automatically be executed by the NPU. To enable the NMS, use the `--neureka-wmem` flag. When the NMS is enabled, the constant tensors used by the accelerator will be placed in the Weight Memory. + +> ✅ **Task:** Execute Micro Llama in parallel and autoregressive mode using the NPU, derive the speedup at the model level and at the layer level compared to execution without NPU. + +*Hint:* Save the profiling traces somewhere to reason about them later on. + +> ✅ **Task:** Why does the NPU bring more speedup in parallel mode than in autoregressive mode? + +
+ Solution + + > The runtime in parallel mode with NPU is obtained with: + > + >` + python testRunner_tiled_siracusa_w_neureka.py -t Tests/microLlama/microLlama64_parallel --cores=8 --l1 64000 --defaultMemLevel=L2 + ` + > + > And returns 28.6 MCycles of runtime. The runtime without NPU was measured above and is 35.3 MCycles. Hence, the speedup is ~1.23 times. + > + > We apply the same methodology on `microLlama64` and get a speedup of ~1.04 times. + > + > Now, why is the speedup lesser in autoregressive mode compared to parallel mode? This is because the parallel mode is composed mainly of GEMM, while the autoregressive mode uses GEMV. With GEMV, the accelerator is underutilized as the [operational intensity](https://spcl.inf.ethz.ch/Teaching/2013-dphpc/lecture9-6up.pdf) of GEMV is very low, especially compared to GEMM. + > + > Additionally, in autoregressive mode (unlike in parallel mode), you have to load the KV cache, which requires lots of data movement not accelerated by the NPU. + +
+
+ +> ✅ **Task:** Benchmark the effect of the NMS on the model runtime and at the layer level. Do you notice any speedup? If yes, where does it come from? + +
+ Solution + + > Using the NMS brings the runtime from 857 to 780 KCycles for the autoregressive mode and from 28.6 to 28.3 MCycles for the parallel mode. By inspecting the trace, you can notice that the NMS drastically reduces the time spent on input DMA transfers for the layers offloaded to the NPU. + > + > This is the profiling trace for a layer without using the NMS: + ``` + [RequantizedPwConv_L2][SB][32771 ops][Tile 0] Input DMA took 2037 cycles + [RequantizedPwConv_L2][SB][32771 ops][Tile 0] Kernel took 2649 cycles + [RequantizedPwConv_L2][SB][32771 ops][Tile 0] Output DMA took 50 cycles + ``` + > And this is with the NMS activated: + ``` + [RequantizedPwConv_L2][SB][32771 ops][Tile 0] Input DMA took 125 cycles + [RequantizedPwConv_L2][SB][32771 ops][Tile 0] Kernel took 2595 cycles + [RequantizedPwConv_L2][SB][32771 ops][Tile 0] Output DMA took 56 cycles + ``` +
+
+ +> ✅ **Task:** Why does the autoregressive mode benefit more from the NMS than the parallel mode? + +
+ Solution + + > Using the NMS relaxes the memory boundness of the NPU. In the GEMM, we are not in a memory-bound regime, and the DMA transfer overhead is negligible with regard to the total runtime. In the autoregressive mode, we spend a lot of time on DMA transfers; hence, providing more bandwidth to the accelerator is very beneficial. + +
+
+ +## III : Adding a New Operator + +So far you've used Deeploy as a black box: you fed in ONNX graphs and looked at the C it spat out. In this last hour you'll open the box and add your own operator from scratch, which will be the an int8 LeakyReLU. You will be walking through every stage of the compiler that Parts I and II merely showed you in passing. By the end you'll have written a parser, a C kernel, a Mako template, a tiling constraint and (if you're quick) an XPULP SIMD intrinsic version. We stay on the Siracusa platform throughout — same target as Part II — so every `testRunner_*` command below uses the Siracusa runner. + +> 💡 **Recommended background:** the internal Deeploy training guide (Parts 1-2) covers the main classes (Parser / Mapper / Binding / Template / TypeChecker / TileConstraint) you're about to touch. Reference PRs to skim: [#25](https://github.com/pulp-platform/Deeploy/pull/25) (basic op on Generic), [#26](https://github.com/pulp-platform/Deeploy/pull/26) (adding tiling + PULP), [#29](https://github.com/pulp-platform/Deeploy/pull/29) (multi-op for a real model). + +### The operator + +`iLeakyReLU` is an elementwise unary that approximates the standard LeakyReLU using only integer arithmetic: + +$$ +\text{out}[i] = \begin{cases} \text{in}[i] & \text{if } \text{in}[i] \ge 0 \\ \lfloor (\text{mul} \cdot \text{in}[i]) / 2^{\text{shift}} \rfloor & \text{otherwise} \end{cases} +$$ + +With `mul=1, shift=3` you get a slope of $\alpha \approx 0.125$, which is close enough to the standard 0.01 that quantized networks tolerate well. + +### What we provide + +A starting kit lives under `Deeploy/Tutorials/PartIII_skeletons/iLeakyReLU/`. Each file contains the surrounding boilerplate plus `TODO(student)` markers. You'll fill the blanks **in place** (no need to copy them anywhere yet). In the steps below, each file then gets *installed* into a specific location in the live source tree (every skeleton's header comment names that destination). If you get stuck, the full reference is in `Deeploy/Tutorials/PartIII_solution/iLeakyReLU/`. We rely on your independence, and that only peek **after** you've tried. Otherwise you won't have any learning effect. + +> ✅ **Task:** Open every file in `Deeploy/Tutorials/PartIII_skeletons/iLeakyReLU/` and read its header comment. Note where each one will eventually be installed (e.g. parser → `Deeploy/Targets/Generic/Parsers.py`, kernel → `TargetLibraries/PULPOpen/src/`). Don't edit anything yet. Just get an idea of how operators are structured in Deeploy. + +### Step 1: Generate the ONNX graph + golden values + +The script `generate.py` (already complete) builds a single-node ONNX with the `op_type` `iLeakyReLU` plus matching `inputs.npz` / `outputs.npz`. Run it once and check the produced files: + +``` +cd Deeploy/Tutorials/PartIII_skeletons/iLeakyReLU +python generate.py +mkdir -p ../../../DeeployTest/Tests/iLeakyReLU +cp network.onnx inputs.npz outputs.npz ../../../DeeployTest/Tests/iLeakyReLU/ +``` + +> ✅ **Task:** Open `network.onnx` in Netron and check that the node has op_type `iLeakyReLU` and `mul`/`shift` attributes. + +### Step 2: Write the parser + +Open `iLeakyReLUParser.py` and fill in `parseNode` (validate attrs + inputs) and `parseNodeCtxt` (extract input/output tensor names and `size`). Paste the finished class into `Deeploy/Targets/Generic/Parsers.py`. + +Test in *verbose* mode: +``` +cd DeeployTest +python testRunner_siracusa.py -t Tests/iLeakyReLU --cores=8 -vv +``` + +This first run will fail later in the pipeline (no template/binding/kernel yet) but you should see your parser fire and accept the node. Use `-vvv` if you want even more diagnostics from the build system and simulator. + +
+ Hint + + > Pattern to copy: `iHardswishParser` in `Deeploy/Targets/Generic/Parsers.py` (~line 787). Its only attrs are `one_over_six / three / six` — same shape as your `mul / shift`. The `iRMSNormParser` higher up in the same file is also useful. + +
+ +### Step 3: Write the C kernel (plain C) + +In `iLeakyReLU.c` the per-core chunking is given. Fill the inner loop: +```c +int32_t x = (int32_t)pIn[i]; +int32_t lo = (mul * x) >> shift; +pOut[i] = (int8_t)((x >= 0) ? x : lo); +``` + +Drop the finished `.c` into `TargetLibraries/PULPOpen/src/`. Drop the header (`iLeakyReLU.h`, already complete) into `TargetLibraries/PULPOpen/inc/kernel/`. Then add **one line** to `TargetLibraries/PULPOpen/inc/DeeployPULPMath.h`: +```c +#include "kernel/iLeakyReLU.h" +``` + +> ⚠️ The PULPOpen CMakeLists auto-globs `src/**`, so you don't need to touch it. You **do** need that aggregator include in `DeeployPULPMath.h` though. + +### Step 4: Template, binding, mapper + +Three small pieces wire the parser to the kernel. + +**1. Template.** Fill in the Mako body of `iLeakyReLUTemplate.py` so it emits a single call to your C kernel. Drop the finished file into `Deeploy/Targets/PULPOpen/Templates/`. Pattern to copy: `Deeploy/Targets/PULPOpen/Templates/iSoftmaxTemplate.py`. + +
+ Solution + + > ```python + > referenceTemplate = _iLeakyReLUTemplate(""" + > // iLeakyReLU (Name: ${nodeName}, Op: ${nodeOp}) + > PULPiLeakyReLU_i8_i8(${data_in}, ${data_out}, ${size}, ${mul}, ${shift}); + > """) + > ``` + > Mako `${...}` substitutions come straight from `self.operatorRepresentation` (populated by your parser). `nodeName` / `nodeOp` are auto-filled by Deeploy. + +
+ +**2. Binding.** In `Deeploy/Targets/PULPOpen/Bindings.py`, define a `PULPiLeakyReLUBindings` list. A binding is a 3-tuple of *(TypeChecker, Template, CodeTransformation)*. For our `int8 → int8` op, reuse `ReluChecker` (same `int8 → int8` signature) and `ForkTransformer` (forks the kernel call across the 8 cluster cores). Also add the matching import for your template. + +
+ Solution + + > Near the other `from Deeploy.Targets.PULPOpen.Templates import` line, add: + > ```python + > from Deeploy.Targets.PULPOpen.Templates import iLeakyReLUTemplate + > ``` + > Then append the binding list: + > ```python + > PULPiLeakyReLUBindings = [ + > NodeBinding( + > ReluChecker([PointerClass(int8_t)], [PointerClass(int8_t)]), + > iLeakyReLUTemplate.referenceTemplate, + > ForkTransformer) + > ] + > ``` + > **Why `ReluChecker`?** It's the simplest existing checker that accepts an `int8` input and produces an `int8` output. You could write your own `iLeakyReLUChecker`, but the signature would be identical, so reusing the existing one keeps the lab focused on the kernel side. **Why `ForkTransformer`?** It wraps the emitted kernel call into `pi_cl_team_fork(NUM_CORES, ...)`, which is exactly what our multi-core kernel expects. + +
+ +**3. Mapper.** In `Deeploy/Targets/PULPOpen/Platform.py`, define `iLeakyReLUMapper` (a `NodeMapper` that pairs your parser with the binding list) and register the ONNX op name in `PULPMapping`. Reuse `iHardswishLayer` (a trivial `ONNXLayer` that does no extra shape/cost work, i.e. same shape as ours). + +
+ Solution + + > Imports near the existing Hardswish ones: + > ```python + > from Deeploy.Targets.Generic.Parsers import iLeakyReLUParser # add to the list + > from Deeploy.Targets.Generic.Layers import iHardswishLayer # already imported + > ``` + > Mapper definition (next to `iHardswishMapper`): + > ```python + > iLeakyReLUMapper = NodeMapper(iLeakyReLUParser(), PULPiLeakyReLUBindings) + > ``` + > `PULPMapping` entry (next to `'iHardswish'`): + > ```python + > 'iLeakyReLU': iHardswishLayer([iLeakyReLUMapper]), + > ``` + +
+ +Test untiled execution on Siracusa: +``` +python testRunner_siracusa.py -t Tests/iLeakyReLU --cores=8 +``` +Do you observe any mismatches? How many cycles does the execution take? + +### Step 5: Tiling constraint + +Open `iLeakyReLUTileConstraint.py`. It already subclasses `UnaryTileConstraint`, so the geometry (input dim == output dim per axis) and the schedule serializer come for free. Leave the body empty for now (the performance constraint comes in Step 6a). + +Drop the file into `Deeploy/Targets/PULPOpen/TileConstraints/`. Then **register the tiling-ready binding** in `Deeploy/Targets/PULPOpen/Tiler.py`: wrap your binding list with `TilingReadyNodeBindings(...)` so Deeploy knows which constraint to apply, and finally update the mapper in `Platform.py` to use the tiling-ready variant. + +
+ Solution + + > In `Tiler.py`, add the imports near the other tile-constraint imports: + > ```python + > from Deeploy.Targets.PULPOpen.TileConstraints.iLeakyReLUTileConstraint \ + > import iLeakyReLUTileConstraint + > from Deeploy.Targets.PULPOpen.Bindings import PULPiLeakyReLUBindings + > ``` + > Then append the binding bundle: + > ```python + > PULPiLeakyReLUTilingReadyBindings = TilingReadyNodeBindings( + > nodeBindings = PULPiLeakyReLUBindings, + > tileConstraint = iLeakyReLUTileConstraint()) + > ``` + > In `Platform.py`, change the mapper to use the tiling-ready bindings: + > ```python + > iLeakyReLUMapper = NodeMapper(iLeakyReLUParser(), PULPiLeakyReLUTilingReadyBindings) + > ``` + > Reference pattern: `PULPiHardswishTilingReadyBindings` in the same file. + +
+ +Run the tiled flow: +``` +python testRunner_tiled_siracusa.py -t Tests/iLeakyReLU --cores=8 --l1=32768 --defaultMemLevel=L2 +``` + +
+ Hint on the constraint itself + + > If you want a worked example of a unary quantized op, see `Deeploy/Targets/Generic/TileConstraints/iHardswishTileConstraint.py`. + +
+ +How long does the execution take, i.e. how many cycles? What do you observe? Did you expect this result? + +### Step 6: Add a performance constraint, then go SIMD + +In this final step you'll add a tile-size constraint that aligns work with the SIMD width, then swap the plain-C kernel for a PULP-intrinsics version. + +**(a) Performance constraint.** Go back to `iLeakyReLUTileConstraint.py` and add the multiple-of-16 constraint. `addMinTileSizeConstraint` looks up `parseDict[varName]` as the original axis size, so the parser must expose it. The easiest is to inject it from inside the constraint: + +```python +inputShape = ctxt.lookup(parseDict['data_in']).shape +lastDim = len(inputShape) - 1 +lastDimVar = tilerModel.getTensorDimVar(tensorName=parseDict['data_in'], dimIdx=lastDim) +if inputShape[lastDim] >= 16: + dimKey = f'dim_{lastDim}' + parseDict[dimKey] = int(inputShape[lastDim]) + tilerModel.addMinTileSizeConstraint(parseDict, dimKey, lastDimVar, 16) +``` + +Re-run with `--profileTiling`. The tile shape on the innermost dim now snaps to a multiple of 16; the per-core chunk is therefore a multiple of 4, i.e. exactly what the SIMD kernel needs. + +**(b) PULP SIMD intrinsics.** Replace the scalar kernel with `iLeakyReLU_simd.c`. The trick: LeakyReLU has a closed-form identity that fits the XPULP intrinsic set perfectly. Because arithmetic right shift makes a negative value *less* negative (or zero) and doesn't change the sign of a non-negative value: + +$$\text{LeakyReLU}(x) = \max(x,\; x \gg \text{shift})$$ + +So if you compute `x >> shift` on a packed `v4s` and feed both into `__builtin_pulp_max4`, you get LeakyReLU branch-free in just two packed operations per 4 lanes: load → packed shift → packed max → store: + +```c +v4s x = vIn[i]; +v4s s = x >> shift; // GCC vector ext: per-lane shift +vOut[i] = __builtin_pulp_max4(x, s); // single packed signed max +``` + +The SIMD kernel ignores `mul` (assumes `mul == 1`); the generator picks `mul=1, shift=3` so the formula is identical. + +Re-run with `--profileTiling`. Compare per-tile kernel cycles to your scalar baseline. + +> ✅ **Task:** Quantify the speedup vs the scalar kernel. Why isn't it exactly 4×? + +
+ Solution + + > In our reference run (`--l1=32768`, shape `(1,16,64,64)`) the end-to-end runtime drops from **108 090 cycles (scalar)** to **43 005 cycles (SIMD)** — a clean **2.51×**. Why not exactly 4×? Two reasons. **(1)** The XPULP V2 toolchain doesn't expose a packed-byte *arithmetic* right shift builtin, so `v4s s = x >> shift` is lowered by the compiler to four scalar lane shifts. The real SIMD wins come from packed `v4s` loads/stores (1 instruction vs 4) and `__builtin_pulp_max4` (1 instruction vs 4 compare+select). **(2)** Even if the shift were packed, end-to-end time also includes DMA traffic and per-tile bookkeeping, which don't shrink with SIMD. To approach 4× you'd need either a hardware packed-byte shift (the `pv.sra.sci.b` instruction *exists* in XPULP V2 but isn't exposed as a builtin in this toolchain), inline assembly against it, or a different formulation (e.g. constant-shift lookup, or a kernel using `__builtin_pulp_avgu4` restricted to `shift=1`). The full intrinsics inventory lives in `TargetLibraries/PULPOpen/third_party/pulp-nn-mixed/XpulpV2/32bit/include/pulp_nn_utils.h`. + +
+ +### Stacked speedup + +To wrap up, measure your own cycle counts at each step and compute the speedups vs the single-core untiled baseline and step-to-step. Grab the missing baseline numbers with: + +``` +python testRunner_siracusa.py -t Tests/iLeakyReLU --cores=1 # baseline +python testRunner_siracusa.py -t Tests/iLeakyReLU --cores=8 # Step 4 +python testRunner_tiled_siracusa.py -t Tests/iLeakyReLU --cores=8 --l1=32768 --defaultMemLevel=L2 # Step 5 (scalar) +python testRunner_tiled_siracusa.py -t Tests/iLeakyReLU --cores=8 --l1=32768 --defaultMemLevel=L2 # Step 6 (after deploying SIMD kernel) +``` + +> ✅ **Task:** Build a table comparing each step's cycle count to the baseline and to the previous step. Which transformation contributes the most? Is SIMD or parallelism the bigger lever for this op? + +
+ Solution + + > Our reference run on shape `(1, 16, 64, 64)` = 65 536 elements with `--l1=32768`: + > + > | Step | Configuration | Cycles | vs baseline | vs previous step | + > |------|---|---|---|---| + > | baseline | 1 core, scalar, untiled | 2 492 970 | 1.00× | — | + > | Step 4 | 8 cores, scalar, untiled | 313 541 | **7.95×** | 7.95× | + > | Step 5 | 8 cores, scalar, tiled | 108 090 | **23.06×** | 2.90× | + > | Step 6 | 8 cores, SIMD, tiled | 43 005 | **57.97×** | 2.51× | + > + > Most of the win comes from parallelizing across cores (Step 4) and moving the working set into L1 (Step 5). SIMD is the last lever to pull and contributes ~2.5× on top. The takeaway: for memory-bound elementwise ops, **getting data close to the compute (Step 5)** and **using all the cores (Step 4)** dwarf the SIMD win. Always choose your optimization order accordingly when you tackle a new operator. + +
+ +--- + +Congratulations! You just added a brand-new operator to Deeploy and traced it from ONNX all the way to optimized SIMD-accelerated C on the Siracusa cluster. The same workflow scales to any new ONNX operator you'd want to deploy. + +Et voilà, this is the end of the tutorial. Thank you for following it until the end. If you are interested in learning more about Deeploy or the SoCs we develop at the [PULP Platform](https://pulp-platform.org/), please reach out! diff --git a/docs/img/DeeployLogoGreen.svg b/docs/img/DeeployLogoGreen.svg new file mode 100644 index 0000000000..63bb29b067 --- /dev/null +++ b/docs/img/DeeployLogoGreen.svg @@ -0,0 +1,22 @@ + + + + + + + Deeploy + \ No newline at end of file