diff --git a/tests/unit/aggregation/test_values.py b/tests/unit/aggregation/test_values.py index 860f313d..42faca91 100644 --- a/tests/unit/aggregation/test_values.py +++ b/tests/unit/aggregation/test_values.py @@ -51,7 +51,7 @@ ], ) -AGGREGATOR_PARAMETRIZATIONS = [ +AGGREGATOR_PARAMETRIZATIONS: list[tuple] = [ (AlignedMTL(), J_base, tensor([0.2133, 0.9673, 0.9673])), (ConFIG(), J_base, tensor([0.1588, 2.0706, 2.0706])), (Constant(tensor([1.0, 2.0])), J_base, tensor([8.0, 3.0, 3.0])), @@ -71,7 +71,7 @@ G_base = J_base @ J_base.T G_Krum = J_Krum @ J_Krum.T -WEIGHTING_PARAMETRIZATIONS = [ +WEIGHTING_PARAMETRIZATIONS: list[tuple] = [ (AlignedMTLWeighting(), G_base, tensor([0.5591, 0.4083])), (ConstantWeighting(tensor([1.0, 2.0])), G_base, tensor([1.0, 2.0])), (DualProjWeighting(), G_base, tensor([0.6109, 0.5000])), diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index 2db0b7f3..91278f07 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -191,7 +191,7 @@ def dummy_backward_pre_hook(_module, _grad_output) -> Tensor: assert not _has_forward_hook(module) -_PARAMETRIZATIONS = [ +_PARAMETRIZATIONS: list[tuple] = [ (AlignedMTL(), True), (DualProj(), True), (IMTLG(), True),