Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 55 additions & 16 deletions backends/cadence/aot/tests/test_quantizer_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,15 @@
# Test case definitions for quantizer annotation tests.
# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
# Adding a new quantizer test only requires adding a tuple to this list.
# Note: Use None in expected_input_qspecs to skip comparison for that input (e.g., for DerivedQuantizationSpec).
QUANTIZER_ANNOTATION_TEST_CASES: list[
tuple[
str,
GraphBuilderFn,
CadenceQuantizer,
OpOverload,
QuantizationSpec,
list[QuantizationSpec],
list[QuantizationSpec | None],
]
] = [
(
Expand Down Expand Up @@ -192,6 +193,16 @@
# For relu: only input_activation
[qconfig_A8W8.input_activation],
),
(
"default_addmm_A8W8",
lambda self: self._build_addmm_graph(),
CadenceDefaultQuantizer(),
torch.ops.aten.addmm.default,
qconfig_A8W8.output_activation,
# For addmm: [bias (DerivedQuantizationSpec), mat1, mat2]
# Use None to skip comparison for bias since it's a DerivedQuantizationSpec
[None, qconfig_A8W8.input_activation, qconfig_A8W8.weight],
),
]

# Derive the set of tested quantizer classes from the test cases.
Expand Down Expand Up @@ -408,6 +419,31 @@ def _build_relu_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")
return gm, relu_nodes[0]

def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with an addmm operation."""
builder = GraphBuilder()
# addmm: bias + (mat1 @ mat2)
# args: (bias, mat1, mat2)
bias = builder.placeholder("bias", torch.randn(5))
mat1 = builder.placeholder("mat1", torch.randn(1, 10))
mat2 = builder.placeholder("mat2", torch.randn(10, 5))
addmm = builder.call_operator(
op=torch.ops.aten.addmm.default,
args=(bias, mat1, mat2),
meta=NodeMetadata(
{"source_fn_stack": [("addmm", torch.ops.aten.addmm.default)]}
),
)
builder.output([addmm])
gm = builder.get_graph_module()

addmm_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.addmm.default,
)
self.assertEqual(len(addmm_nodes), 1, "Should find exactly one addmm node")
return gm, addmm_nodes[0]

@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
def test_quantizer_annotation(
self,
Expand All @@ -416,7 +452,7 @@ def test_quantizer_annotation(
quantizer: CadenceQuantizer,
target: OpOverload,
expected_output_qspec: QuantizationSpec,
expected_input_qspecs: list[QuantizationSpec],
expected_input_qspecs: list[QuantizationSpec | None],
) -> None:
"""Parameterized test for quantizer annotations."""
gm, op_node = graph_builder_fn(self)
Expand All @@ -431,21 +467,24 @@ def test_quantizer_annotation(

# Verify input annotations
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The length validation assumes that the number of entries in input_qspec_map equals the length of expected_input_qspecs. However, since the new iteration logic uses argument indices to access expected_input_qspecs, there could be an IndexError if an operation has tensor arguments at non-consecutive positions (e.g., if stride or padding scalars appear between tensor arguments). While this is unlikely given typical PyTorch API design, consider documenting that expected_input_qspecs must have length equal to the maximum argument index of quantized inputs plus one, or handle sparse argument indices more explicitly.

Copilot uses AI. Check for mistakes.
for i, (input_node, input_qspec) in enumerate(
annotation.input_qspec_map.items()
):
expected_arg = op_node.args[i]
assert isinstance(expected_arg, torch.fx.Node)
self.assertEqual(
input_node,
expected_arg,
f"Input node mismatch at index {i}",
)
self.assertEqual(
input_qspec,
expected_input_qspecs[i],
f"Input qspec mismatch at index {i}",
for input_node, input_qspec in annotation.input_qspec_map.items():
# Find the index of this input node in the op's args
arg_index = None
for i, arg in enumerate(op_node.args):
if arg is input_node:
arg_index = i
break
self.assertIsNotNone(
arg_index,
f"Input node {input_node} not found in op_node.args",
)
# Skip comparison if expected qspec is None (e.g., for DerivedQuantizationSpec)
if expected_input_qspecs[arg_index] is not None:
self.assertEqual(
input_qspec,
expected_input_qspecs[arg_index],
f"Input qspec mismatch at arg index {arg_index}",
)

def test_all_quantizers_have_annotation_tests(self) -> None:
"""Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""
Expand Down
Loading