Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def _run_diffusion(
)
context.logger.info(
f"DyPE enabled: resolution={self.width}x{self.height}, preset={self.dype_preset}, "
f"method={dype_config.method}, scale={dype_config.dype_scale:.2f}, "
f"scale={dype_config.dype_scale:.2f}, "
f"exponent={dype_config.dype_exponent:.2f}, start_sigma={dype_config.dype_start_sigma:.2f}, "
f"base_resolution={dype_config.base_resolution}"
)
Expand Down
13 changes: 7 additions & 6 deletions invokeai/backend/flux/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,18 @@ def denoise(
timestep = scheduler.timesteps[step_index]
# Convert scheduler timestep (0-1000) to normalized (0-1) for the model
t_curr = timestep.item() / scheduler.config.num_train_timesteps
dype_sigma = DyPEExtension.resolve_step_sigma(
fallback_sigma=t_curr,
step_index=step_index,
scheduler_sigmas=getattr(scheduler, "sigmas", None),
)
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)

# DyPE: Update step state for timestep-dependent scaling
if dype_extension is not None and dype_embedder is not None:
dype_extension.update_step_state(
embedder=dype_embedder,
timestep=t_curr,
timestep_index=user_step,
total_steps=total_steps,
sigma=dype_sigma,
)

# For Heun scheduler, track if we're in first or second order step
Expand Down Expand Up @@ -264,9 +267,7 @@ def denoise(
if dype_extension is not None and dype_embedder is not None:
dype_extension.update_step_state(
embedder=dype_embedder,
timestep=t_curr,
timestep_index=step_index,
total_steps=total_steps,
sigma=t_curr,
)

t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
Expand Down
6 changes: 3 additions & 3 deletions invokeai/backend/flux/dype/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Dynamic Position Extrapolation (DyPE) for FLUX models.

DyPE enables high-resolution image generation (4K+) with pretrained FLUX models
by dynamically scaling RoPE position embeddings during the denoising process.
DyPE enables high-resolution image generation with pretrained FLUX models by
dynamically modulating RoPE extrapolation during denoising.

Based on: https://github.com/wildminder/ComfyUI-DyPE
Based on the official DyPE project: https://github.com/guyyariv/DyPE
"""

from invokeai.backend.flux.dype.base import DyPEConfig
Expand Down
181 changes: 18 additions & 163 deletions invokeai/backend/flux/dype/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""DyPE base configuration and utilities."""
"""DyPE base configuration and utilities for FLUX vision_yarn RoPE."""

import math
from dataclasses import dataclass
from typing import Literal

import torch
from torch import Tensor
Expand All @@ -14,72 +12,39 @@ class DyPEConfig:

enable_dype: bool = True
base_resolution: int = 1024 # Native training resolution
method: Literal["vision_yarn", "yarn", "ntk", "base"] = "vision_yarn"
dype_scale: float = 2.0 # Magnitude λs (0.0-8.0)
dype_exponent: float = 2.0 # Decay speed λt (0.0-1000.0)
dype_start_sigma: float = 1.0 # When DyPE decay starts


def get_mscale(scale: float, mscale_factor: float = 1.0) -> float:
"""Calculate magnitude scaling factor.

Args:
scale: The resolution scaling factor
mscale_factor: Adjustment factor for the scaling

Returns:
The magnitude scaling factor
"""
if scale <= 1.0:
return 1.0
return mscale_factor * math.log(scale) + 1.0


def get_timestep_mscale(
scale: float,
def get_timestep_kappa(
current_sigma: float,
dype_scale: float,
dype_exponent: float,
dype_start_sigma: float,
) -> float:
"""Calculate timestep-dependent magnitude scaling.
"""Calculate the paper-style DyPE scheduler value κ(t).

The key insight of DyPE: early steps focus on low frequencies (global structure),
late steps on high frequencies (details). This function modulates the scaling
based on the current timestep/sigma.
late steps on high frequencies (details). DyPE expresses this as a direct
timestep scheduler over the positional extrapolation strength:

κ(t) = λs * t^λt

Args:
scale: Resolution scaling factor
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
dype_scale: DyPE magnitude (λs)
dype_exponent: DyPE decay speed (λt)
dype_start_sigma: Sigma threshold to start decay

Returns:
Timestep-modulated scaling factor
Timestep scheduler value κ(t)
"""
if scale <= 1.0:
return 1.0

# Normalize sigma to [0, 1] range relative to start_sigma
if current_sigma >= dype_start_sigma:
t_normalized = 1.0
else:
t_normalized = current_sigma / dype_start_sigma

# Apply exponential decay: stronger extrapolation early, weaker late
# decay = exp(-λt * (1 - t)) where t=1 is early (high sigma), t=0 is late
decay = math.exp(-dype_exponent * (1.0 - t_normalized))

# Base mscale from resolution
base_mscale = get_mscale(scale)
if dype_scale <= 0.0 or dype_start_sigma <= 0.0:
return 0.0

# Interpolate between base_mscale and 1.0 based on decay and dype_scale
# When decay=1 (early): use scaled value
# When decay=0 (late): use base value
scaled_mscale = 1.0 + (base_mscale - 1.0) * dype_scale * decay

return scaled_mscale
t_normalized = max(0.0, min(current_sigma / dype_start_sigma, 1.0))
return dype_scale * (t_normalized**dype_exponent)


def compute_vision_yarn_freqs(
Expand Down Expand Up @@ -117,35 +82,23 @@ def compute_vision_yarn_freqs(
"""
assert dim % 2 == 0

# Use the larger scale for NTK calculation
scale = max(scale_h, scale_w)

device = pos.device
dtype = torch.float64 if device.type != "mps" else torch.float32

# NTK-aware theta scaling: extends position coverage for high-res
# Formula: theta_scaled = theta * scale^(dim/(dim-2))
# This increases the wavelength of position encodings proportionally
# DyPE applies a direct timestep scheduler to the NTK extrapolation exponent.
# Early steps keep strong extrapolation; late steps relax smoothly back
# toward the training-time RoPE.
if scale > 1.0:
ntk_alpha = scale ** (dim / (dim - 2))

# Apply timestep-dependent DyPE modulation
# mscale controls how strongly we apply the NTK extrapolation
# Early steps (high sigma): stronger extrapolation for global structure
# Late steps (low sigma): weaker extrapolation for fine details
mscale = get_timestep_mscale(
scale=scale,
ntk_exponent = dim / (dim - 2)
kappa = get_timestep_kappa(
current_sigma=current_sigma,
dype_scale=dype_config.dype_scale,
dype_exponent=dype_config.dype_exponent,
dype_start_sigma=dype_config.dype_start_sigma,
)

# Modulate NTK alpha by mscale
# When mscale > 1: interpolate towards stronger extrapolation
# When mscale = 1: use base NTK alpha
modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale
scaled_theta = theta * modulated_alpha
scaled_theta = theta * (scale ** (ntk_exponent * kappa))
else:
scaled_theta = theta

Expand All @@ -160,101 +113,3 @@ def compute_vision_yarn_freqs(
sin = torch.sin(angles)

return cos.to(pos.dtype), sin.to(pos.dtype)


def compute_yarn_freqs(
pos: Tensor,
dim: int,
theta: int,
scale: float,
current_sigma: float,
dype_config: DyPEConfig,
) -> tuple[Tensor, Tensor]:
"""Compute RoPE frequencies using YARN/NTK method.

Uses NTK-aware theta scaling for high-resolution support with
timestep-dependent DyPE modulation.

Args:
pos: Position tensor
dim: Embedding dimension
theta: RoPE base frequency
scale: Uniform scaling factor
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
dype_config: DyPE configuration

Returns:
Tuple of (cos, sin) frequency tensors
"""
assert dim % 2 == 0

device = pos.device
dtype = torch.float64 if device.type != "mps" else torch.float32

# NTK-aware theta scaling with DyPE modulation
if scale > 1.0:
ntk_alpha = scale ** (dim / (dim - 2))

# Apply timestep-dependent DyPE modulation
mscale = get_timestep_mscale(
scale=scale,
current_sigma=current_sigma,
dype_scale=dype_config.dype_scale,
dype_exponent=dype_config.dype_exponent,
dype_start_sigma=dype_config.dype_start_sigma,
)

# Modulate NTK alpha by mscale
modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale
scaled_theta = theta * modulated_alpha
else:
scaled_theta = theta

freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
freqs = 1.0 / (scaled_theta**freq_seq)

angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)

cos = torch.cos(angles)
sin = torch.sin(angles)

return cos.to(pos.dtype), sin.to(pos.dtype)


def compute_ntk_freqs(
pos: Tensor,
dim: int,
theta: int,
scale: float,
) -> tuple[Tensor, Tensor]:
"""Compute RoPE frequencies using NTK method.

Neural Tangent Kernel approach - continuous frequency scaling without
timestep dependency.

Args:
pos: Position tensor
dim: Embedding dimension
theta: RoPE base frequency
scale: Scaling factor

Returns:
Tuple of (cos, sin) frequency tensors
"""
assert dim % 2 == 0

device = pos.device
dtype = torch.float64 if device.type != "mps" else torch.float32

# NTK scaling
scaled_theta = theta * (scale ** (dim / (dim - 2)))

freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
freqs = 1.0 / (scaled_theta**freq_seq)

angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)

cos = torch.cos(angles)
sin = torch.sin(angles)

return cos.to(pos.dtype), sin.to(pos.dtype)
29 changes: 12 additions & 17 deletions invokeai/backend/flux/dype/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class DyPEPresetConfig:
"""Preset configuration values."""

base_resolution: int
method: str
dype_scale: float
dype_exponent: float
dype_start_sigma: float
Expand All @@ -41,7 +40,6 @@ class DyPEPresetConfig:
DYPE_PRESETS: dict[DyPEPreset, DyPEPresetConfig] = {
DYPE_PRESET_4K: DyPEPresetConfig(
base_resolution=1024,
method="vision_yarn",
dype_scale=2.0,
dype_exponent=2.0,
dype_start_sigma=1.0,
Expand Down Expand Up @@ -84,7 +82,6 @@ def get_dype_config_for_resolution(
return DyPEConfig(
enable_dype=True,
base_resolution=base_resolution,
method="vision_yarn",
dype_scale=dynamic_dype_scale,
dype_exponent=2.0,
dype_start_sigma=1.0,
Expand All @@ -111,24 +108,24 @@ def get_dype_config_for_area(
return None

area_ratio = area / base_area
effective_side_ratio = math.sqrt(area_ratio) # 1.0 at base, 2.0 at 2K (if base is 1K)

# Strength: 0 at base area, 8 at sat_area, clamped thereafter.
sat_area = 2027520 # Determined by experimentation where a vertical line appears
sat_side_ratio = math.sqrt(sat_area / base_area)
dynamic_dype_scale = 8.0 * (effective_side_ratio - 1.0) / (sat_side_ratio - 1.0)
effective_side_ratio = math.sqrt(area_ratio)
aspect_ratio = max(width, height) / min(width, height)
aspect_attenuation = 1.0 if aspect_ratio <= 2.0 else 2.0 / aspect_ratio

# Retune area mode to be "auto, but area-aware" instead of dramatically
# stronger than auto. This keeps it closer to the paper-style core DyPE.
dynamic_dype_scale = 2.4 * effective_side_ratio
dynamic_dype_scale *= aspect_attenuation
dynamic_dype_scale = max(0.0, min(dynamic_dype_scale, 8.0))

# Continuous exponent schedule:
# r=1 -> 0.5, r=2 -> 1.0, r=4 -> 2.0 (exact), smoothly varying in between.
x = math.log2(effective_side_ratio)
dype_exponent = 0.25 * (x**2) + 0.25 * x + 0.5
dype_exponent = max(0.5, min(dype_exponent, 2.0))
# Use a narrower, higher exponent range than the old area heuristic so the
# paper-style scheduler decays more conservatively and artifacts are reduced.
exponent_progress = max(0.0, min(effective_side_ratio - 1.0, 1.0))
dype_exponent = 1.25 + 0.75 * exponent_progress

return DyPEConfig(
enable_dype=True,
base_resolution=base_resolution,
method="vision_yarn",
dype_scale=dynamic_dype_scale,
dype_exponent=dype_exponent,
dype_start_sigma=1.0,
Expand Down Expand Up @@ -165,7 +162,6 @@ def get_dype_config_from_preset(
return DyPEConfig(
enable_dype=True,
base_resolution=1024,
method="vision_yarn",
dype_scale=custom_scale if custom_scale is not None else dynamic_dype_scale,
dype_exponent=custom_exponent if custom_exponent is not None else 2.0,
dype_start_sigma=1.0,
Expand Down Expand Up @@ -196,7 +192,6 @@ def get_dype_config_from_preset(
return DyPEConfig(
enable_dype=True,
base_resolution=preset_config.base_resolution,
method=preset_config.method,
dype_scale=preset_config.dype_scale,
dype_exponent=preset_config.dype_exponent,
dype_start_sigma=preset_config.dype_start_sigma,
Expand Down
Loading
Loading