Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions invokeai/backend/model_manager/configs/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,14 +703,25 @@ def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
- diffusion_model.layers.X.attention.to_k.lora_down.weight (DoRA format)
- diffusion_model.layers.X.attention.to_k.lora_A.weight (PEFT format)
- diffusion_model.layers.X.attention.to_k.dora_scale (DoRA scale)
- lora_unet__layers_X_attention_to_k.lora_down.weight (Kohya format)
"""
from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import (
is_state_dict_likely_z_image_kohya_lora,
)

state_dict = mod.load_state_dict()

# Check for Z-Image specific LoRA patterns
# Check for Kohya format first
if is_state_dict_likely_z_image_kohya_lora(state_dict):
return

# Check for Z-Image specific LoRA patterns (dot-notation formats)
has_z_image_lora_keys = state_dict_has_any_keys_starting_with(
state_dict,
{
"diffusion_model.layers.", # Z-Image S3-DiT layer pattern
"diffusion_model.context_refiner.",
"diffusion_model.noise_refiner.",
},
)

Expand Down Expand Up @@ -738,15 +749,26 @@ def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
Z-Image uses S3-DiT architecture with layer names like:
- diffusion_model.layers.0.attention.to_k.lora_A.weight
- diffusion_model.layers.0.feed_forward.w1.lora_A.weight
- lora_unet__layers_0_attention_to_k.lora_down.weight (Kohya format)
"""
from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import (
is_state_dict_likely_z_image_kohya_lora,
)

state_dict = mod.load_state_dict()

# Check for Z-Image transformer layer patterns
# Check for Kohya format
if is_state_dict_likely_z_image_kohya_lora(state_dict):
return BaseModelType.ZImage

# Check for Z-Image transformer layer patterns (dot-notation formats)
# Z-Image uses diffusion_model.layers.X structure (unlike Flux which uses double_blocks/single_blocks)
has_z_image_keys = state_dict_has_any_keys_starting_with(
state_dict,
{
"diffusion_model.layers.", # Z-Image S3-DiT layer pattern
"diffusion_model.context_refiner.",
"diffusion_model.noise_refiner.",
},
)

Expand Down
9 changes: 6 additions & 3 deletions invokeai/backend/model_manager/configs/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,20 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
".lora_A.weight",
".lora_B.weight",
".dora_scale",
".alpha",
)

# First pass: check if any key has LoRA suffixes - if so, this is a LoRA not a main model
for key in state_dict.keys():
if isinstance(key, int):
continue

# If we find any LoRA-specific keys, this is not a main model
if key.endswith(lora_suffixes):
return False

# Check for Z-Image specific key prefixes
# Second pass: check for Z-Image specific key parts
for key in state_dict.keys():
if isinstance(key, int):
continue
# Handle both direct keys (cap_embedder.0.weight) and
# ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight)
key_parts = key.split(".")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Z-Image LoRA conversion utilities.

Z-Image uses S3-DiT transformer architecture with Qwen3 text encoder.
LoRAs for Z-Image typically follow the diffusers PEFT format.
LoRAs for Z-Image typically follow the diffusers PEFT format or Kohya format.
"""

from typing import Dict
import re
from typing import Any, Dict

import torch

Expand All @@ -16,13 +17,39 @@
)
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw

# Regex for Kohya-format Z-Image transformer keys.
# Example keys:
# lora_unet__layers_0_attention_to_k.alpha
# lora_unet__layers_0_attention_to_k.lora_down.weight
# lora_unet__context_refiner_0_feed_forward_w1.lora_up.weight
# lora_unet__noise_refiner_1_attention_to_v.lora_down.weight
Z_IMAGE_KOHYA_TRANSFORMER_KEY_REGEX = (
r"lora_unet__(layers|context_refiner|noise_refiner)_(\d+)_(attention|feed_forward)_(to_k|to_q|to_v|w1|w2|w3)"
)


def is_state_dict_likely_z_image_kohya_lora(state_dict: dict[str | int, Any]) -> bool:
"""Checks if the provided state dict is likely a Z-Image LoRA in Kohya format.

Kohya Z-Image LoRAs have keys like:
- lora_unet__layers_0_attention_to_k.lora_down.weight
- lora_unet__context_refiner_0_feed_forward_w1.alpha
- lora_unet__noise_refiner_1_attention_to_v.lora_up.weight
"""
return any(
isinstance(k, str) and re.match(Z_IMAGE_KOHYA_TRANSFORMER_KEY_REGEX, k.split(".")[0]) for k in state_dict.keys()
)


def is_state_dict_likely_z_image_lora(state_dict: dict[str | int, torch.Tensor]) -> bool:
"""Checks if the provided state dict is likely a Z-Image LoRA.

Z-Image LoRAs can have keys for transformer and/or Qwen3 text encoder.
They may use various prefixes depending on the training framework.
"""
if is_state_dict_likely_z_image_kohya_lora(state_dict):
return True

str_keys = [k for k in state_dict.keys() if isinstance(k, str)]

# Check for Z-Image transformer keys (S3-DiT architecture)
Expand Down Expand Up @@ -57,6 +84,7 @@ def lora_model_from_z_image_state_dict(
- "transformer." or "base_model.model.transformer." for diffusers PEFT format
- "diffusion_model." for some training frameworks
- "text_encoder." or "base_model.model.text_encoder." for Qwen3 encoder
- "lora_unet__" for Kohya format (underscores instead of dots)

Args:
state_dict: The LoRA state dict
Expand All @@ -65,6 +93,10 @@ def lora_model_from_z_image_state_dict(
Returns:
A ModelPatchRaw containing the LoRA layers
"""
# If Kohya format, convert keys first then process normally
if is_state_dict_likely_z_image_kohya_lora(state_dict):
state_dict = _convert_z_image_kohya_state_dict(state_dict)

layers: dict[str, BaseLayerPatch] = {}

# Group keys by layer
Expand Down Expand Up @@ -120,6 +152,45 @@ def lora_model_from_z_image_state_dict(
return ModelPatchRaw(layers=layers)


def _convert_z_image_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Converts a Kohya-format Z-Image LoRA state dict to diffusion_model dot-notation.

Example key conversions:
- lora_unet__layers_0_attention_to_k.lora_down.weight -> diffusion_model.layers.0.attention.to_k.lora_down.weight
- lora_unet__context_refiner_0_feed_forward_w1.alpha -> diffusion_model.context_refiner.0.feed_forward.w1.alpha
- lora_unet__noise_refiner_1_attention_to_v.lora_up.weight -> diffusion_model.noise_refiner.1.attention.to_v.lora_up.weight
"""
converted: Dict[str, torch.Tensor] = {}
for key, value in state_dict.items():
if not isinstance(key, str) or not key.startswith("lora_unet__"):
converted[key] = value
continue

# Split into layer name and param suffix (e.g. "lora_down.weight", "alpha")
layer_name, _, param_suffix = key.partition(".")

# Strip lora_unet__ prefix
remainder = layer_name[len("lora_unet__") :]

# Convert Kohya underscore format to dot-notation using the known structure
match = re.match(
r"(layers|context_refiner|noise_refiner)_(\d+)_(attention|feed_forward)_(to_k|to_q|to_v|w1|w2|w3)$",
remainder,
)
if match:
block, idx, submodule, param = match.groups()
new_layer = f"diffusion_model.{block}.{idx}.{submodule}.{param}"
else:
# Fallback: keep original key for unrecognized patterns
converted[key] = value
continue

new_key = f"{new_layer}.{param_suffix}" if param_suffix else new_layer
converted[new_key] = value

return converted


def _get_lora_layer_values(layer_dict: dict[str, torch.Tensor], alpha: float | None) -> dict[str, torch.Tensor]:
"""Convert layer dict keys from PEFT format to internal format."""
if "lora_A.weight" in layer_dict:
Expand Down
Loading