Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,25 @@
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
add_nullable_tensors,
)
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor


class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
def _cast_tensor_for_input(self, tensor: torch.Tensor | None, input: torch.Tensor) -> torch.Tensor | None:
tensor = cast_to_device(tensor, input.device)
if (
tensor is not None
and input.is_floating_point()
and tensor.is_floating_point()
and not isinstance(tensor, GGMLTensor)
and tensor.dtype != input.dtype
):
tensor = tensor.to(dtype=input.dtype)
return tensor

def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
weight = self._cast_tensor_for_input(self.weight, input)
bias = self._cast_tensor_for_input(self.bias, input)

# Prepare the original parameters for the patch aggregation.
orig_params = {"weight": weight, "bias": bias}
Expand All @@ -25,19 +38,37 @@ def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
device=input.device,
)

weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
residual_weight = self._cast_tensor_for_input(aggregated_param_residuals.get("weight", None), input)
residual_bias = self._cast_tensor_for_input(aggregated_param_residuals.get("bias", None), input)
weight = add_nullable_tensors(weight, residual_weight)
bias = add_nullable_tensors(bias, residual_bias)
return self._conv_forward(input, weight, bias)

def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
weight = self._cast_tensor_for_input(self.weight, input)
bias = self._cast_tensor_for_input(self.bias, input)
return self._conv_forward(input, weight, bias)

def forward(self, input: torch.Tensor) -> torch.Tensor:
if len(self._patches_and_weights) > 0:
return self._autocast_forward_with_patches(input)
elif self._device_autocasting_enabled:
return self._autocast_forward(input)
elif input.is_floating_point() and (
(
self.weight.is_floating_point()
and not isinstance(self.weight, GGMLTensor)
and self.weight.dtype != input.dtype
)
or (
self.bias is not None
and self.bias.is_floating_point()
and not isinstance(self.bias, GGMLTensor)
and self.bias.dtype != input.dtype
)
):
weight = self._cast_tensor_for_input(self.weight, input)
bias = self._cast_tensor_for_input(self.bias, input)
return self._conv_forward(input, weight, bias)
else:
return super().forward(input)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor


def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
Expand Down Expand Up @@ -57,34 +58,64 @@ def autocast_linear_forward_sidecar_patches(

# Finally, apply any remaining patches.
if len(unprocessed_patches_and_weights) > 0:
weight, bias = orig_module._cast_weight_bias_for_input(input)
# Prepare the original parameters for the patch aggregation.
orig_params = {"weight": orig_module.weight, "bias": orig_module.bias}
orig_params = {"weight": weight, "bias": bias}
# Filter out None values.
orig_params = {k: v for k, v in orig_params.items() if v is not None}

aggregated_param_residuals = orig_module._aggregate_patch_parameters(
unprocessed_patches_and_weights, orig_params=orig_params, device=input.device
)
output += torch.nn.functional.linear(
input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None)
)
residual_weight = orig_module._cast_tensor_for_input(aggregated_param_residuals["weight"], input)
residual_bias = orig_module._cast_tensor_for_input(aggregated_param_residuals.get("bias", None), input)
assert residual_weight is not None
output += torch.nn.functional.linear(input, residual_weight, residual_bias)

return output


class CustomLinear(torch.nn.Linear, CustomModuleMixin):
def _cast_tensor_for_input(self, tensor: torch.Tensor | None, input: torch.Tensor) -> torch.Tensor | None:
tensor = cast_to_device(tensor, input.device)
if (
tensor is not None
and input.is_floating_point()
and tensor.is_floating_point()
and not isinstance(tensor, GGMLTensor)
and tensor.dtype != input.dtype
):
tensor = tensor.to(dtype=input.dtype)
return tensor

def _cast_weight_bias_for_input(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
weight = self._cast_tensor_for_input(self.weight, input)
bias = self._cast_tensor_for_input(self.bias, input)
assert weight is not None
return weight, bias

def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
return autocast_linear_forward_sidecar_patches(self, input, self._patches_and_weights)

def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = cast_to_device(self.weight, input.device)
bias = cast_to_device(self.bias, input.device)
weight, bias = self._cast_weight_bias_for_input(input)
return torch.nn.functional.linear(input, weight, bias)

def forward(self, input: torch.Tensor) -> torch.Tensor:
if len(self._patches_and_weights) > 0:
return self._autocast_forward_with_patches(input)
elif self._device_autocasting_enabled:
return self._autocast_forward(input)
elif input.is_floating_point() and (
(self.weight.is_floating_point() and self.weight.dtype != input.dtype)
or (
self.bias is not None
and self.bias.is_floating_point()
and not isinstance(self.bias, GGMLTensor)
and self.bias.dtype != input.dtype
)
):
weight, bias = self._cast_weight_bias_for_input(input)
return torch.nn.functional.linear(input, weight, bias)
else:
return super().forward(input)
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def _aggregate_patch_parameters(
# parameters. But, of course, any sub-layers that need to access the actual values of the parameters will fail.
for param_name in orig_params.keys():
param = orig_params[param_name]
if type(param) is torch.nn.Parameter and type(param.data) is torch.Tensor:
if isinstance(param, torch.nn.Parameter) and type(param.data) is torch.Tensor:
pass
elif type(param) is torch.Tensor:
pass
elif type(param) is GGMLTensor:
# Move to device and dequantize here. Doing it in the patch layer can result in redundant casts /
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from collections.abc import Callable

import gguf
import pytest
Expand Down Expand Up @@ -124,6 +125,67 @@ def unwrap_single_custom_layer(layer: torch.nn.Module):
return unwrap_custom_layer(layer, orig_layer_type)


class ZeroParamPatch(BaseLayerPatch):
"""A minimal parameter patch that exercises the aggregated sidecar patch path."""

def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]:
return {name: torch.zeros_like(param) for name, param in orig_parameters.items()}

def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
return self

def calc_size(self) -> int:
return 0


def _cpu_dtype_supported(
layer_factory: Callable[[], torch.nn.Module],
input_factory: Callable[[torch.dtype], torch.Tensor],
dtype: torch.dtype,
) -> bool:
try:
layer = layer_factory().to(dtype=dtype)
input_tensor = input_factory(dtype)
with torch.no_grad():
_ = layer(input_tensor)
return True
except (RuntimeError, TypeError, NotImplementedError):
return False


def _cpu_dtype_param(
dtype: torch.dtype,
layer_factory: Callable[[], torch.nn.Module],
input_factory: Callable[[torch.dtype], torch.Tensor],
):
supported = _cpu_dtype_supported(layer_factory, input_factory, dtype)
return pytest.param(
dtype,
id=str(dtype).removeprefix("torch."),
marks=pytest.mark.skipif(not supported, reason=f"CPU {dtype} is not supported for this op"),
)


LINEAR_CPU_MIXED_DTYPE_PARAMS = [
_cpu_dtype_param(torch.bfloat16, lambda: torch.nn.Linear(8, 16), lambda dtype: torch.randn(2, 8, dtype=dtype)),
_cpu_dtype_param(torch.float16, lambda: torch.nn.Linear(8, 16), lambda dtype: torch.randn(2, 8, dtype=dtype)),
]


CONV2D_CPU_MIXED_DTYPE_PARAMS = [
_cpu_dtype_param(
torch.bfloat16,
lambda: torch.nn.Conv2d(8, 16, 3),
lambda dtype: torch.randn(2, 8, 5, 5, dtype=dtype),
),
_cpu_dtype_param(
torch.float16,
lambda: torch.nn.Conv2d(8, 16, 3),
lambda dtype: torch.randn(2, 8, 5, 5, dtype=dtype),
),
]


def test_isinstance(layer_under_test: LayerUnderTest):
"""Test that isinstance() and type() behave as expected after wrapping a layer in a custom layer."""
orig_layer, _, _ = layer_under_test
Expand Down Expand Up @@ -550,3 +612,67 @@ def test_quantized_linear_sidecar_patches_with_autocast_from_cpu_to_device(

# Assert that the outputs with and without autocasting are the same.
assert torch.allclose(expected_output, autocast_output, atol=1e-6)


@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS)
@torch.no_grad()
def test_linear_mixed_dtype_inference_without_patches(dtype: torch.dtype):
layer = wrap_single_custom_layer(torch.nn.Linear(8, 16))
input = torch.randn(2, 8, dtype=dtype)

output = layer(input)

assert output.dtype == input.dtype
assert output.shape == (2, 16)


@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS)
@torch.no_grad()
def test_linear_mixed_dtype_inference_without_patches_bias_only_mismatch(dtype: torch.dtype):
layer = torch.nn.Linear(8, 16).to(dtype=dtype)
layer.bias = torch.nn.Parameter(layer.bias.detach().to(torch.float32))
layer = wrap_single_custom_layer(layer)
input = torch.randn(2, 8, dtype=dtype)

output = layer(input)

assert output.dtype == input.dtype
assert output.shape == (2, 16)


@pytest.mark.parametrize("dtype", CONV2D_CPU_MIXED_DTYPE_PARAMS)
@torch.no_grad()
def test_conv2d_mixed_dtype_inference_without_patches(dtype: torch.dtype):
layer = wrap_single_custom_layer(torch.nn.Conv2d(8, 16, 3))
input = torch.randn(2, 8, 5, 5, dtype=dtype)

output = layer(input)

assert output.dtype == input.dtype
assert output.shape == (2, 16, 3, 3)


@pytest.mark.parametrize("dtype", LINEAR_CPU_MIXED_DTYPE_PARAMS)
@torch.no_grad()
def test_linear_mixed_dtype_sidecar_parameter_patch(dtype: torch.dtype):
layer = wrap_single_custom_layer(torch.nn.Linear(8, 16))
layer.add_patch(ZeroParamPatch(), 1.0)
input = torch.randn(2, 8, dtype=dtype)

output = layer(input)

assert output.dtype == input.dtype
assert output.shape == (2, 16)


@pytest.mark.parametrize("dtype", CONV2D_CPU_MIXED_DTYPE_PARAMS)
@torch.no_grad()
def test_conv2d_mixed_dtype_sidecar_parameter_patch(dtype: torch.dtype):
layer = wrap_single_custom_layer(torch.nn.Conv2d(8, 16, 3))
layer.add_patch(ZeroParamPatch(), 1.0)
input = torch.randn(2, 8, 5, 5, dtype=dtype)

output = layer(input)

assert output.dtype == input.dtype
assert output.shape == (2, 16, 3, 3)
Loading