diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl index 6e638a3275c..b801ddfc183 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl @@ -64,6 +64,9 @@ $else: #include "broadcasting_utils.h" #include "indexing_utils.h" +$if MASK_PADDING: + #define MASK_PADDING + layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} @@ -140,11 +143,26 @@ void main() { other_texel = other_texel.xxxx; } - write_texel_lpos( - t_out, - lpos, - VEC4_OUT_T(op(in_texel, other_texel, alpha)), - out_axis_map); + VEC4_OUT_T out_texel = VEC4_OUT_T(op(in_texel, other_texel, alpha)); + +#ifdef MASK_PADDING + // Handle padding elements in the last texel to prevent NaN propagation. + // When the packed dimension size is not a multiple of 4, the last texel + // will have padding elements. For division operations, padding elements + // (which are 0/0) can produce NaN values that propagate through reductions. + const int nspill = mod4(out_sizes[packed_dim]); + const int texels_per_batch = divup4(out_sizes[packed_dim]); + const bool is_last_texel = (lpos[packed_dim] % texels_per_batch) == (texels_per_batch - 1); + + if (is_last_texel && nspill > 0) { + // Explicitly set padding elements to 0 to avoid NaN + [[unroll]] for (int i = nspill; i < 4; i++) { + out_texel[i] = 0; + } + } +#endif + + write_texel_lpos(t_out, lpos, out_texel, out_axis_map); } #endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml index 70793628d80..ee96b5c05b4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml @@ -10,6 +10,7 @@ binary_op: NDIM: 3 DTYPE: float PACKING: C_packed + MASK_PADDING: 0 generate_variant_forall: STORAGE: - VALUE: texture3d @@ -26,10 +27,12 @@ binary_op: OPERATOR: X * Y - NAME: binary_div OPERATOR: X / Y + MASK_PADDING: 1 - NAME: binary_pow OPERATOR: pow(X, Y) - NAME: binary_floor_divide OPERATOR: floor(X / Y) + MASK_PADDING: 1 - NAME: binary_minimum OPERATOR: min(X, Y) - NAME: binary_eq_int32