diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py index 91f08fb96be..eac3549b5ab 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py @@ -7,12 +7,25 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import ( add_nullable_tensors, ) +from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin): + def _cast_tensor_for_input(self, tensor: torch.Tensor | None, input: torch.Tensor) -> torch.Tensor | None: + tensor = cast_to_device(tensor, input.device) + if ( + tensor is not None + and input.is_floating_point() + and tensor.is_floating_point() + and not isinstance(tensor, GGMLTensor) + and tensor.dtype != input.dtype + ): + tensor = tensor.to(dtype=input.dtype) + return tensor + def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: - weight = cast_to_device(self.weight, input.device) - bias = cast_to_device(self.bias, input.device) + weight = self._cast_tensor_for_input(self.weight, input) + bias = self._cast_tensor_for_input(self.bias, input) # Prepare the original parameters for the patch aggregation. orig_params = {"weight": weight, "bias": bias} @@ -25,13 +38,15 @@ def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: device=input.device, ) - weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None)) - bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None)) + residual_weight = self._cast_tensor_for_input(aggregated_param_residuals.get("weight", None), input) + residual_bias = self._cast_tensor_for_input(aggregated_param_residuals.get("bias", None), input) + weight = add_nullable_tensors(weight, residual_weight) + bias = add_nullable_tensors(bias, residual_bias) return self._conv_forward(input, weight, bias) def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: - weight = cast_to_device(self.weight, input.device) - bias = cast_to_device(self.bias, input.device) + weight = self._cast_tensor_for_input(self.weight, input) + bias = self._cast_tensor_for_input(self.bias, input) return self._conv_forward(input, weight, bias) def forward(self, input: torch.Tensor) -> torch.Tensor: @@ -39,5 +54,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return self._autocast_forward_with_patches(input) elif self._device_autocasting_enabled: return self._autocast_forward(input) + elif input.is_floating_point() and ( + ( + self.weight.is_floating_point() + and not isinstance(self.weight, GGMLTensor) + and self.weight.dtype != input.dtype + ) + or ( + self.bias is not None + and self.bias.is_floating_point() + and not isinstance(self.bias, GGMLTensor) + and self.bias.dtype != input.dtype + ) + ): + weight = self._cast_tensor_for_input(self.weight, input) + bias = self._cast_tensor_for_input(self.bias, input) + return self._conv_forward(input, weight, bias) else: return super().forward(input) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py index 99f70646db2..77227583cd9 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py @@ -9,6 +9,7 @@ from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer from invokeai.backend.patches.layers.lora_layer import LoRALayer +from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor: @@ -57,28 +58,47 @@ def autocast_linear_forward_sidecar_patches( # Finally, apply any remaining patches. if len(unprocessed_patches_and_weights) > 0: + weight, bias = orig_module._cast_weight_bias_for_input(input) # Prepare the original parameters for the patch aggregation. - orig_params = {"weight": orig_module.weight, "bias": orig_module.bias} + orig_params = {"weight": weight, "bias": bias} # Filter out None values. orig_params = {k: v for k, v in orig_params.items() if v is not None} aggregated_param_residuals = orig_module._aggregate_patch_parameters( unprocessed_patches_and_weights, orig_params=orig_params, device=input.device ) - output += torch.nn.functional.linear( - input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) - ) + residual_weight = orig_module._cast_tensor_for_input(aggregated_param_residuals["weight"], input) + residual_bias = orig_module._cast_tensor_for_input(aggregated_param_residuals.get("bias", None), input) + assert residual_weight is not None + output += torch.nn.functional.linear(input, residual_weight, residual_bias) return output class CustomLinear(torch.nn.Linear, CustomModuleMixin): + def _cast_tensor_for_input(self, tensor: torch.Tensor | None, input: torch.Tensor) -> torch.Tensor | None: + tensor = cast_to_device(tensor, input.device) + if ( + tensor is not None + and input.is_floating_point() + and tensor.is_floating_point() + and not isinstance(tensor, GGMLTensor) + and tensor.dtype != input.dtype + ): + tensor = tensor.to(dtype=input.dtype) + return tensor + + def _cast_weight_bias_for_input(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: + weight = self._cast_tensor_for_input(self.weight, input) + bias = self._cast_tensor_for_input(self.bias, input) + assert weight is not None + return weight, bias + def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: return autocast_linear_forward_sidecar_patches(self, input, self._patches_and_weights) def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: - weight = cast_to_device(self.weight, input.device) - bias = cast_to_device(self.bias, input.device) + weight, bias = self._cast_weight_bias_for_input(input) return torch.nn.functional.linear(input, weight, bias) def forward(self, input: torch.Tensor) -> torch.Tensor: @@ -86,5 +106,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return self._autocast_forward_with_patches(input) elif self._device_autocasting_enabled: return self._autocast_forward(input) + elif input.is_floating_point() and ( + (self.weight.is_floating_point() and self.weight.dtype != input.dtype) + or ( + self.bias is not None + and self.bias.is_floating_point() + and not isinstance(self.bias, GGMLTensor) + and self.bias.dtype != input.dtype + ) + ): + weight, bias = self._cast_weight_bias_for_input(input) + return torch.nn.functional.linear(input, weight, bias) else: return super().forward(input) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py index 0563f3cb366..1a5b8585473 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py @@ -49,7 +49,9 @@ def _aggregate_patch_parameters( # parameters. But, of course, any sub-layers that need to access the actual values of the parameters will fail. for param_name in orig_params.keys(): param = orig_params[param_name] - if type(param) is torch.nn.Parameter and type(param.data) is torch.Tensor: + if isinstance(param, torch.nn.Parameter) and type(param.data) is torch.Tensor: + pass + elif type(param) is torch.Tensor: pass elif type(param) is GGMLTensor: # Move to device and dequantize here. Doing it in the patch layer can result in redundant casts / diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index c1e77c333bb..15d2ba61ef4 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -1,4 +1,5 @@ import copy +from collections.abc import Callable import gguf import pytest @@ -124,6 +125,67 @@ def unwrap_single_custom_layer(layer: torch.nn.Module): return unwrap_custom_layer(layer, orig_layer_type) +class ZeroParamPatch(BaseLayerPatch): + """A minimal parameter patch that exercises the aggregated sidecar patch path.""" + + def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]: + return {name: torch.zeros_like(param) for name, param in orig_parameters.items()} + + def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): + return self + + def calc_size(self) -> int: + return 0 + + +def _cpu_dtype_supported( + layer_factory: Callable[[], torch.nn.Module], + input_factory: Callable[[torch.dtype], torch.Tensor], + dtype: torch.dtype, +) -> bool: + try: + layer = layer_factory().to(dtype=dtype) + input_tensor = input_factory(dtype) + with torch.no_grad(): + _ = layer(input_tensor) + return True + except (RuntimeError, TypeError, NotImplementedError): + return False + + +def _cpu_dtype_param( + dtype: torch.dtype, + layer_factory: Callable[[], torch.nn.Module], + input_factory: Callable[[torch.dtype], torch.Tensor], +): + supported = _cpu_dtype_supported(layer_factory, input_factory, dtype) + return pytest.param( + dtype, + id=str(dtype).removeprefix("torch."), + marks=pytest.mark.skipif(not supported, reason=f"CPU {dtype} is not supported for this op"), + ) + + +LINEAR_CPU_MIXED_DTYPE_PARAMS = [ + _cpu_dtype_param(torch.bfloat16, lambda: torch.nn.Linear(8, 16), lambda dtype: torch.randn(2, 8, dtype=dtype)), + _cpu_dtype_param(torch.float16, lambda: torch.nn.Linear(8, 16), lambda dtype: torch.randn(2, 8, dtype=dtype)), +] + + +CONV2D_CPU_MIXED_DTYPE_PARAMS = [ + _cpu_dtype_param( + torch.bfloat16, + lambda: torch.nn.Conv2d(8, 16, 3), + lambda dtype: torch.randn(2, 8, 5, 5, dtype=dtype), + ), + _cpu_dtype_param( + torch.float16, + lambda: torch.nn.Conv2d(8, 16, 3), + lambda dtype: torch.randn(2, 8, 5, 5, dtype=dtype), + ), +] + + def test_isinstance(layer_under_test: LayerUnderTest): """Test that isinstance() and type() behave as expected after wrapping a layer in a custom layer.""" orig_layer, _, _ = layer_under_test @@ -550,3 +612,67 @@ def test_quantized_linear_sidecar_patches_with_autocast_from_cpu_to_device( # Assert that the outputs with and without autocasting are the same. assert torch.allclose(expected_output, autocast_output, atol=1e-6) + + +@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS) +@torch.no_grad() +def test_linear_mixed_dtype_inference_without_patches(dtype: torch.dtype): + layer = wrap_single_custom_layer(torch.nn.Linear(8, 16)) + input = torch.randn(2, 8, dtype=dtype) + + output = layer(input) + + assert output.dtype == input.dtype + assert output.shape == (2, 16) + + +@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS) +@torch.no_grad() +def test_linear_mixed_dtype_inference_without_patches_bias_only_mismatch(dtype: torch.dtype): + layer = torch.nn.Linear(8, 16).to(dtype=dtype) + layer.bias = torch.nn.Parameter(layer.bias.detach().to(torch.float32)) + layer = wrap_single_custom_layer(layer) + input = torch.randn(2, 8, dtype=dtype) + + output = layer(input) + + assert output.dtype == input.dtype + assert output.shape == (2, 16) + + +@pytest.mark.parametrize("dtype", CONV2D_CPU_MIXED_DTYPE_PARAMS) +@torch.no_grad() +def test_conv2d_mixed_dtype_inference_without_patches(dtype: torch.dtype): + layer = wrap_single_custom_layer(torch.nn.Conv2d(8, 16, 3)) + input = torch.randn(2, 8, 5, 5, dtype=dtype) + + output = layer(input) + + assert output.dtype == input.dtype + assert output.shape == (2, 16, 3, 3) + + +@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS) +@torch.no_grad() +def test_linear_mixed_dtype_sidecar_parameter_patch(dtype: torch.dtype): + layer = wrap_single_custom_layer(torch.nn.Linear(8, 16)) + layer.add_patch(ZeroParamPatch(), 1.0) + input = torch.randn(2, 8, dtype=dtype) + + output = layer(input) + + assert output.dtype == input.dtype + assert output.shape == (2, 16) + + +@pytest.mark.parametrize("dtype", CONV2D_CPU_MIXED_DTYPE_PARAMS) +@torch.no_grad() +def test_conv2d_mixed_dtype_sidecar_parameter_patch(dtype: torch.dtype): + layer = wrap_single_custom_layer(torch.nn.Conv2d(8, 16, 3)) + layer.add_patch(ZeroParamPatch(), 1.0) + input = torch.randn(2, 8, 5, 5, dtype=dtype) + + output = layer(input) + + assert output.dtype == input.dtype + assert output.shape == (2, 16, 3, 3)