diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index feba4f6f072..cbabadcfc80 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -793,17 +793,30 @@ def register_ported_op_all_packed_dims(): # Ported ops that support their own prepacking. -@update_features( - [ - exir_ops.edge.aten.embedding.default, - exir_ops.edge.aten._native_batch_norm_legit_no_training.default, - ] -) -def register_ported_ops_with_prepacking(): +@update_features(exir_ops.edge.aten.embedding.default) +def register_embedding_op(): + return OpFeatures( + inputs_storage=utils.CHANNELS_PACKED_TEXTURE, + supports_prepacking=True, + supports_resize=True, + ) + + +@update_features(exir_ops.edge.aten._native_batch_norm_legit_no_training.default) +def register_batch_norm_op(): + def check_batch_norm_node(node: torch.fx.Node) -> bool: + x = node.args[0] + if not isinstance(x, torch.fx.Node): + return False + x_shape = x.meta["val"].size() + # Only support 4-D input tensors + return len(x_shape) == 4 + return OpFeatures( inputs_storage=utils.CHANNELS_PACKED_TEXTURE, supports_prepacking=True, supports_resize=True, + are_node_inputs_supported_fn=check_batch_norm_node, )