From fdc2889ea9acf430dfd118327c27b296d245c93c Mon Sep 17 00:00:00 2001 From: JPPhoto Date: Sat, 21 Mar 2026 13:36:01 -0500 Subject: [PATCH] Align DyPE with paper --- invokeai/app/invocations/flux_denoise.py | 2 +- invokeai/backend/flux/denoise.py | 13 +- invokeai/backend/flux/dype/__init__.py | 6 +- invokeai/backend/flux/dype/base.py | 181 ++----------- invokeai/backend/flux/dype/presets.py | 29 +- invokeai/backend/flux/dype/rope.py | 42 +-- .../backend/flux/extensions/dype_extension.py | 38 ++- tests/backend/flux/dype/test_dype.py | 205 ++++++++++---- tests/backend/flux/test_denoise.py | 249 ++++++++++++++++++ 9 files changed, 478 insertions(+), 287 deletions(-) create mode 100644 tests/backend/flux/test_denoise.py diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index d6102b105b3..ab7970fbd08 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -477,7 +477,7 @@ def _run_diffusion( ) context.logger.info( f"DyPE enabled: resolution={self.width}x{self.height}, preset={self.dype_preset}, " - f"method={dype_config.method}, scale={dype_config.dype_scale:.2f}, " + f"scale={dype_config.dype_scale:.2f}, " f"exponent={dype_config.dype_exponent:.2f}, start_sigma={dype_config.dype_start_sigma:.2f}, " f"base_resolution={dype_config.base_resolution}" ) diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index 30d075a5270..a0ae4cfb3f3 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -96,15 +96,18 @@ def denoise( timestep = scheduler.timesteps[step_index] # Convert scheduler timestep (0-1000) to normalized (0-1) for the model t_curr = timestep.item() / scheduler.config.num_train_timesteps + dype_sigma = DyPEExtension.resolve_step_sigma( + fallback_sigma=t_curr, + step_index=step_index, + scheduler_sigmas=getattr(scheduler, "sigmas", None), + ) t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) # DyPE: Update step state for timestep-dependent scaling if dype_extension is not None and dype_embedder is not None: dype_extension.update_step_state( embedder=dype_embedder, - timestep=t_curr, - timestep_index=user_step, - total_steps=total_steps, + sigma=dype_sigma, ) # For Heun scheduler, track if we're in first or second order step @@ -264,9 +267,7 @@ def denoise( if dype_extension is not None and dype_embedder is not None: dype_extension.update_step_state( embedder=dype_embedder, - timestep=t_curr, - timestep_index=step_index, - total_steps=total_steps, + sigma=t_curr, ) t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) diff --git a/invokeai/backend/flux/dype/__init__.py b/invokeai/backend/flux/dype/__init__.py index eebcfc45df1..7af50625dd7 100644 --- a/invokeai/backend/flux/dype/__init__.py +++ b/invokeai/backend/flux/dype/__init__.py @@ -1,9 +1,9 @@ """Dynamic Position Extrapolation (DyPE) for FLUX models. -DyPE enables high-resolution image generation (4K+) with pretrained FLUX models -by dynamically scaling RoPE position embeddings during the denoising process. +DyPE enables high-resolution image generation with pretrained FLUX models by +dynamically modulating RoPE extrapolation during denoising. -Based on: https://github.com/wildminder/ComfyUI-DyPE +Based on the official DyPE project: https://github.com/guyyariv/DyPE """ from invokeai.backend.flux.dype.base import DyPEConfig diff --git a/invokeai/backend/flux/dype/base.py b/invokeai/backend/flux/dype/base.py index 7b25a7f71f3..6c3fc42fa2c 100644 --- a/invokeai/backend/flux/dype/base.py +++ b/invokeai/backend/flux/dype/base.py @@ -1,8 +1,6 @@ -"""DyPE base configuration and utilities.""" +"""DyPE base configuration and utilities for FLUX vision_yarn RoPE.""" -import math from dataclasses import dataclass -from typing import Literal import torch from torch import Tensor @@ -14,72 +12,39 @@ class DyPEConfig: enable_dype: bool = True base_resolution: int = 1024 # Native training resolution - method: Literal["vision_yarn", "yarn", "ntk", "base"] = "vision_yarn" dype_scale: float = 2.0 # Magnitude λs (0.0-8.0) dype_exponent: float = 2.0 # Decay speed λt (0.0-1000.0) dype_start_sigma: float = 1.0 # When DyPE decay starts -def get_mscale(scale: float, mscale_factor: float = 1.0) -> float: - """Calculate magnitude scaling factor. - - Args: - scale: The resolution scaling factor - mscale_factor: Adjustment factor for the scaling - - Returns: - The magnitude scaling factor - """ - if scale <= 1.0: - return 1.0 - return mscale_factor * math.log(scale) + 1.0 - - -def get_timestep_mscale( - scale: float, +def get_timestep_kappa( current_sigma: float, dype_scale: float, dype_exponent: float, dype_start_sigma: float, ) -> float: - """Calculate timestep-dependent magnitude scaling. + """Calculate the paper-style DyPE scheduler value κ(t). The key insight of DyPE: early steps focus on low frequencies (global structure), - late steps on high frequencies (details). This function modulates the scaling - based on the current timestep/sigma. + late steps on high frequencies (details). DyPE expresses this as a direct + timestep scheduler over the positional extrapolation strength: + + κ(t) = λs * t^λt Args: - scale: Resolution scaling factor current_sigma: Current noise level (1.0 = full noise, 0.0 = clean) dype_scale: DyPE magnitude (λs) dype_exponent: DyPE decay speed (λt) dype_start_sigma: Sigma threshold to start decay Returns: - Timestep-modulated scaling factor + Timestep scheduler value κ(t) """ - if scale <= 1.0: - return 1.0 - - # Normalize sigma to [0, 1] range relative to start_sigma - if current_sigma >= dype_start_sigma: - t_normalized = 1.0 - else: - t_normalized = current_sigma / dype_start_sigma - - # Apply exponential decay: stronger extrapolation early, weaker late - # decay = exp(-λt * (1 - t)) where t=1 is early (high sigma), t=0 is late - decay = math.exp(-dype_exponent * (1.0 - t_normalized)) - - # Base mscale from resolution - base_mscale = get_mscale(scale) + if dype_scale <= 0.0 or dype_start_sigma <= 0.0: + return 0.0 - # Interpolate between base_mscale and 1.0 based on decay and dype_scale - # When decay=1 (early): use scaled value - # When decay=0 (late): use base value - scaled_mscale = 1.0 + (base_mscale - 1.0) * dype_scale * decay - - return scaled_mscale + t_normalized = max(0.0, min(current_sigma / dype_start_sigma, 1.0)) + return dype_scale * (t_normalized**dype_exponent) def compute_vision_yarn_freqs( @@ -117,35 +82,23 @@ def compute_vision_yarn_freqs( """ assert dim % 2 == 0 - # Use the larger scale for NTK calculation scale = max(scale_h, scale_w) device = pos.device dtype = torch.float64 if device.type != "mps" else torch.float32 - # NTK-aware theta scaling: extends position coverage for high-res - # Formula: theta_scaled = theta * scale^(dim/(dim-2)) - # This increases the wavelength of position encodings proportionally + # DyPE applies a direct timestep scheduler to the NTK extrapolation exponent. + # Early steps keep strong extrapolation; late steps relax smoothly back + # toward the training-time RoPE. if scale > 1.0: - ntk_alpha = scale ** (dim / (dim - 2)) - - # Apply timestep-dependent DyPE modulation - # mscale controls how strongly we apply the NTK extrapolation - # Early steps (high sigma): stronger extrapolation for global structure - # Late steps (low sigma): weaker extrapolation for fine details - mscale = get_timestep_mscale( - scale=scale, + ntk_exponent = dim / (dim - 2) + kappa = get_timestep_kappa( current_sigma=current_sigma, dype_scale=dype_config.dype_scale, dype_exponent=dype_config.dype_exponent, dype_start_sigma=dype_config.dype_start_sigma, ) - - # Modulate NTK alpha by mscale - # When mscale > 1: interpolate towards stronger extrapolation - # When mscale = 1: use base NTK alpha - modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale - scaled_theta = theta * modulated_alpha + scaled_theta = theta * (scale ** (ntk_exponent * kappa)) else: scaled_theta = theta @@ -160,101 +113,3 @@ def compute_vision_yarn_freqs( sin = torch.sin(angles) return cos.to(pos.dtype), sin.to(pos.dtype) - - -def compute_yarn_freqs( - pos: Tensor, - dim: int, - theta: int, - scale: float, - current_sigma: float, - dype_config: DyPEConfig, -) -> tuple[Tensor, Tensor]: - """Compute RoPE frequencies using YARN/NTK method. - - Uses NTK-aware theta scaling for high-resolution support with - timestep-dependent DyPE modulation. - - Args: - pos: Position tensor - dim: Embedding dimension - theta: RoPE base frequency - scale: Uniform scaling factor - current_sigma: Current noise level (1.0 = full noise, 0.0 = clean) - dype_config: DyPE configuration - - Returns: - Tuple of (cos, sin) frequency tensors - """ - assert dim % 2 == 0 - - device = pos.device - dtype = torch.float64 if device.type != "mps" else torch.float32 - - # NTK-aware theta scaling with DyPE modulation - if scale > 1.0: - ntk_alpha = scale ** (dim / (dim - 2)) - - # Apply timestep-dependent DyPE modulation - mscale = get_timestep_mscale( - scale=scale, - current_sigma=current_sigma, - dype_scale=dype_config.dype_scale, - dype_exponent=dype_config.dype_exponent, - dype_start_sigma=dype_config.dype_start_sigma, - ) - - # Modulate NTK alpha by mscale - modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale - scaled_theta = theta * modulated_alpha - else: - scaled_theta = theta - - freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim - freqs = 1.0 / (scaled_theta**freq_seq) - - angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) - - cos = torch.cos(angles) - sin = torch.sin(angles) - - return cos.to(pos.dtype), sin.to(pos.dtype) - - -def compute_ntk_freqs( - pos: Tensor, - dim: int, - theta: int, - scale: float, -) -> tuple[Tensor, Tensor]: - """Compute RoPE frequencies using NTK method. - - Neural Tangent Kernel approach - continuous frequency scaling without - timestep dependency. - - Args: - pos: Position tensor - dim: Embedding dimension - theta: RoPE base frequency - scale: Scaling factor - - Returns: - Tuple of (cos, sin) frequency tensors - """ - assert dim % 2 == 0 - - device = pos.device - dtype = torch.float64 if device.type != "mps" else torch.float32 - - # NTK scaling - scaled_theta = theta * (scale ** (dim / (dim - 2))) - - freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim - freqs = 1.0 / (scaled_theta**freq_seq) - - angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) - - cos = torch.cos(angles) - sin = torch.sin(angles) - - return cos.to(pos.dtype), sin.to(pos.dtype) diff --git a/invokeai/backend/flux/dype/presets.py b/invokeai/backend/flux/dype/presets.py index 7805f4364d4..48a714b007a 100644 --- a/invokeai/backend/flux/dype/presets.py +++ b/invokeai/backend/flux/dype/presets.py @@ -31,7 +31,6 @@ class DyPEPresetConfig: """Preset configuration values.""" base_resolution: int - method: str dype_scale: float dype_exponent: float dype_start_sigma: float @@ -41,7 +40,6 @@ class DyPEPresetConfig: DYPE_PRESETS: dict[DyPEPreset, DyPEPresetConfig] = { DYPE_PRESET_4K: DyPEPresetConfig( base_resolution=1024, - method="vision_yarn", dype_scale=2.0, dype_exponent=2.0, dype_start_sigma=1.0, @@ -84,7 +82,6 @@ def get_dype_config_for_resolution( return DyPEConfig( enable_dype=True, base_resolution=base_resolution, - method="vision_yarn", dype_scale=dynamic_dype_scale, dype_exponent=2.0, dype_start_sigma=1.0, @@ -111,24 +108,24 @@ def get_dype_config_for_area( return None area_ratio = area / base_area - effective_side_ratio = math.sqrt(area_ratio) # 1.0 at base, 2.0 at 2K (if base is 1K) - - # Strength: 0 at base area, 8 at sat_area, clamped thereafter. - sat_area = 2027520 # Determined by experimentation where a vertical line appears - sat_side_ratio = math.sqrt(sat_area / base_area) - dynamic_dype_scale = 8.0 * (effective_side_ratio - 1.0) / (sat_side_ratio - 1.0) + effective_side_ratio = math.sqrt(area_ratio) + aspect_ratio = max(width, height) / min(width, height) + aspect_attenuation = 1.0 if aspect_ratio <= 2.0 else 2.0 / aspect_ratio + + # Retune area mode to be "auto, but area-aware" instead of dramatically + # stronger than auto. This keeps it closer to the paper-style core DyPE. + dynamic_dype_scale = 2.4 * effective_side_ratio + dynamic_dype_scale *= aspect_attenuation dynamic_dype_scale = max(0.0, min(dynamic_dype_scale, 8.0)) - # Continuous exponent schedule: - # r=1 -> 0.5, r=2 -> 1.0, r=4 -> 2.0 (exact), smoothly varying in between. - x = math.log2(effective_side_ratio) - dype_exponent = 0.25 * (x**2) + 0.25 * x + 0.5 - dype_exponent = max(0.5, min(dype_exponent, 2.0)) + # Use a narrower, higher exponent range than the old area heuristic so the + # paper-style scheduler decays more conservatively and artifacts are reduced. + exponent_progress = max(0.0, min(effective_side_ratio - 1.0, 1.0)) + dype_exponent = 1.25 + 0.75 * exponent_progress return DyPEConfig( enable_dype=True, base_resolution=base_resolution, - method="vision_yarn", dype_scale=dynamic_dype_scale, dype_exponent=dype_exponent, dype_start_sigma=1.0, @@ -165,7 +162,6 @@ def get_dype_config_from_preset( return DyPEConfig( enable_dype=True, base_resolution=1024, - method="vision_yarn", dype_scale=custom_scale if custom_scale is not None else dynamic_dype_scale, dype_exponent=custom_exponent if custom_exponent is not None else 2.0, dype_start_sigma=1.0, @@ -196,7 +192,6 @@ def get_dype_config_from_preset( return DyPEConfig( enable_dype=True, base_resolution=preset_config.base_resolution, - method=preset_config.method, dype_scale=preset_config.dype_scale, dype_exponent=preset_config.dype_exponent, dype_start_sigma=preset_config.dype_start_sigma, diff --git a/invokeai/backend/flux/dype/rope.py b/invokeai/backend/flux/dype/rope.py index f6a1594f6be..980b768cbc0 100644 --- a/invokeai/backend/flux/dype/rope.py +++ b/invokeai/backend/flux/dype/rope.py @@ -6,9 +6,7 @@ from invokeai.backend.flux.dype.base import ( DyPEConfig, - compute_ntk_freqs, compute_vision_yarn_freqs, - compute_yarn_freqs, ) @@ -50,37 +48,15 @@ def rope_dype( if not dype_config.enable_dype or scale <= 1.0: return _rope_base(pos, dim, theta) - # Select method and compute frequencies - method = dype_config.method - - if method == "vision_yarn": - cos, sin = compute_vision_yarn_freqs( - pos=pos, - dim=dim, - theta=theta, - scale_h=scale_h, - scale_w=scale_w, - current_sigma=current_sigma, - dype_config=dype_config, - ) - elif method == "yarn": - cos, sin = compute_yarn_freqs( - pos=pos, - dim=dim, - theta=theta, - scale=scale, - current_sigma=current_sigma, - dype_config=dype_config, - ) - elif method == "ntk": - cos, sin = compute_ntk_freqs( - pos=pos, - dim=dim, - theta=theta, - scale=scale, - ) - else: # "base" - return _rope_base(pos, dim, theta) + cos, sin = compute_vision_yarn_freqs( + pos=pos, + dim=dim, + theta=theta, + scale_h=scale_h, + scale_w=scale_w, + current_sigma=current_sigma, + dype_config=dype_config, + ) # Construct rotation matrix from cos/sin # Output shape: (batch, seq_len, dim/2, 2, 2) diff --git a/invokeai/backend/flux/extensions/dype_extension.py b/invokeai/backend/flux/extensions/dype_extension.py index db27c053dd3..af01a305b7b 100644 --- a/invokeai/backend/flux/extensions/dype_extension.py +++ b/invokeai/backend/flux/extensions/dype_extension.py @@ -1,7 +1,9 @@ """DyPE extension for FLUX denoising pipeline.""" from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Sequence + +import torch from invokeai.backend.flux.dype.base import DyPEConfig from invokeai.backend.flux.dype.embed import DyPEEmbedND @@ -59,9 +61,7 @@ def patch_model(self, model: "Flux") -> tuple[DyPEEmbedND, object]: def update_step_state( self, embedder: DyPEEmbedND, - timestep: float, - timestep_index: int, - total_steps: int, + sigma: float, ) -> None: """Update the step state in the DyPE embedder. @@ -70,16 +70,38 @@ def update_step_state( Args: embedder: The DyPE embedder to update - timestep: Current timestep value (sigma/noise level) - timestep_index: Current step index (0-based) - total_steps: Total number of denoising steps + sigma: Current noise level for the active denoising step """ embedder.set_step_state( - sigma=timestep, + sigma=sigma, height=self.target_height, width=self.target_width, ) + @staticmethod + def resolve_step_sigma( + fallback_sigma: float, + step_index: int, + scheduler_sigmas: Sequence[float] | torch.Tensor | None, + ) -> float: + """Resolve the actual sigma for the current denoising step. + + Diffusers schedulers may expose both normalized timesteps and the underlying + sigma sequence. DyPE should follow the noise schedule, so prefer + ``scheduler.sigmas`` when available and fall back to the provided value + otherwise. + """ + if scheduler_sigmas is None: + return fallback_sigma + + if step_index >= len(scheduler_sigmas): + return fallback_sigma + + sigma = scheduler_sigmas[step_index] + if isinstance(sigma, torch.Tensor): + return float(sigma.item()) + return float(sigma) + @staticmethod def restore_model(model: "Flux", original_embedder: object) -> None: """Restore the original position embedder. diff --git a/tests/backend/flux/dype/test_dype.py b/tests/backend/flux/dype/test_dype.py index cc0b99011cd..c66cc3b72a5 100644 --- a/tests/backend/flux/dype/test_dype.py +++ b/tests/backend/flux/dype/test_dype.py @@ -4,11 +4,8 @@ from invokeai.backend.flux.dype.base import ( DyPEConfig, - compute_ntk_freqs, compute_vision_yarn_freqs, - compute_yarn_freqs, - get_mscale, - get_timestep_mscale, + get_timestep_kappa, ) from invokeai.backend.flux.dype.embed import DyPEEmbedND from invokeai.backend.flux.dype.presets import ( @@ -23,6 +20,7 @@ get_dype_config_from_preset, ) from invokeai.backend.flux.dype.rope import rope_dype +from invokeai.backend.flux.extensions.dype_extension import DyPEExtension class TestDyPEConfig: @@ -32,7 +30,6 @@ def test_default_values(self): config = DyPEConfig() assert config.enable_dype is True assert config.base_resolution == 1024 - assert config.method == "vision_yarn" assert config.dype_scale == 2.0 assert config.dype_exponent == 2.0 assert config.dype_start_sigma == 1.0 @@ -41,63 +38,74 @@ def test_custom_values(self): config = DyPEConfig( enable_dype=False, base_resolution=512, - method="yarn", dype_scale=4.0, dype_exponent=3.0, dype_start_sigma=0.5, ) assert config.enable_dype is False assert config.base_resolution == 512 - assert config.method == "yarn" assert config.dype_scale == 4.0 -class TestMscale: - """Tests for mscale calculation functions.""" +class TestDyPEExtension: + """Tests for DyPE extension helpers.""" - def test_get_mscale_no_scaling(self): - """When scale <= 1.0, mscale should be 1.0.""" - assert get_mscale(1.0) == 1.0 - assert get_mscale(0.5) == 1.0 + def test_resolve_step_sigma_prefers_scheduler_sigmas_tensor(self): + sigma = DyPEExtension.resolve_step_sigma( + fallback_sigma=0.42, + step_index=1, + scheduler_sigmas=torch.tensor([1.0, 0.75, 0.5]), + ) + assert sigma == 0.75 - def test_get_mscale_with_scaling(self): - """When scale > 1.0, mscale should increase.""" - mscale_2x = get_mscale(2.0) - mscale_4x = get_mscale(4.0) + def test_resolve_step_sigma_falls_back_without_scheduler_sigmas(self): + sigma = DyPEExtension.resolve_step_sigma( + fallback_sigma=0.42, + step_index=1, + scheduler_sigmas=None, + ) + assert sigma == 0.42 - assert mscale_2x > 1.0 - assert mscale_4x > mscale_2x - def test_get_timestep_mscale_no_scaling(self): - """When scale <= 1.0, timestep_mscale should be 1.0.""" - result = get_timestep_mscale( - scale=1.0, - current_sigma=0.5, +class TestKappa: + """Tests for the DyPE timestep scheduler.""" + + def test_get_timestep_kappa_clamps_to_zero_without_scale(self): + assert ( + get_timestep_kappa( + current_sigma=0.5, + dype_scale=0.0, + dype_exponent=2.0, + dype_start_sigma=1.0, + ) + == 0.0 + ) + + def test_get_timestep_kappa_is_stronger_early(self): + early_kappa = get_timestep_kappa( + current_sigma=1.0, dype_scale=2.0, dype_exponent=2.0, dype_start_sigma=1.0, ) - assert result == 1.0 - - def test_get_timestep_mscale_high_sigma(self): - """Early steps (high sigma) should have stronger scaling.""" - early_mscale = get_timestep_mscale( - scale=2.0, - current_sigma=1.0, # Early step + late_kappa = get_timestep_kappa( + current_sigma=0.1, dype_scale=2.0, dype_exponent=2.0, dype_start_sigma=1.0, ) - late_mscale = get_timestep_mscale( - scale=2.0, - current_sigma=0.1, # Late step + + assert early_kappa == 2.0 + assert late_kappa < early_kappa + + def test_get_timestep_kappa_clamps_above_start_sigma(self): + kappa = get_timestep_kappa( + current_sigma=2.0, dype_scale=2.0, dype_exponent=2.0, dype_start_sigma=1.0, ) - - # Early steps should have larger mscale than late steps - assert early_mscale >= late_mscale + assert kappa == 2.0 class TestRopeDype: @@ -156,6 +164,47 @@ def test_rope_dype_no_scaling(self): # Results should be different when scaling is applied assert not torch.allclose(result_no_scale, result_with_scale) + def test_rope_dype_late_stage_moves_toward_base_rope(self): + """Late-stage DyPE should be closer to base RoPE than early-stage DyPE.""" + pos = torch.arange(16).unsqueeze(0).float() + dim = 32 + theta = 10000 + + config = DyPEConfig(base_resolution=1024) + + base_result = rope_dype( + pos=pos, + dim=dim, + theta=theta, + current_sigma=1.0, + target_height=1024, + target_width=1024, + dype_config=config, + ) + early_result = rope_dype( + pos=pos, + dim=dim, + theta=theta, + current_sigma=1.0, + target_height=2048, + target_width=2048, + dype_config=config, + ) + late_result = rope_dype( + pos=pos, + dim=dim, + theta=theta, + current_sigma=0.05, + target_height=2048, + target_width=2048, + dype_config=config, + ) + + early_delta = torch.mean(torch.abs(early_result - base_result)) + late_delta = torch.mean(torch.abs(late_result - base_result)) + + assert late_delta < early_delta + class TestDyPEEmbedND: """Tests for DyPEEmbedND module.""" @@ -245,7 +294,6 @@ def test_get_dype_config_for_resolution_above_threshold(self): ) assert config is not None assert config.enable_dype is True - assert config.method == "vision_yarn" def test_get_dype_config_for_resolution_dynamic_scale(self): """Higher resolution should result in higher dype_scale.""" @@ -283,7 +331,57 @@ def test_get_dype_config_for_area_above_threshold(self): ) assert config is not None assert config.enable_dype is True - assert config.method == "vision_yarn" + + def test_get_dype_config_for_area_penalizes_extreme_aspect_ratios(self): + balanced_extreme = get_dype_config_for_area( + width=2304, + height=1152, + base_resolution=1024, + ) + extreme = get_dype_config_for_area( + width=2304, + height=960, + base_resolution=1024, + ) + balanced_same_area = get_dype_config_for_area( + width=2048, + height=1080, + base_resolution=1024, + ) + + assert balanced_extreme is not None + assert extreme is not None + assert balanced_same_area is not None + assert extreme.dype_scale < balanced_extreme.dype_scale + assert extreme.dype_scale < balanced_same_area.dype_scale + + def test_get_dype_config_for_area_is_closer_to_auto_strength(self): + area = get_dype_config_for_area( + width=1728, + height=1152, + base_resolution=1024, + ) + auto = get_dype_config_for_resolution( + width=1728, + height=1152, + base_resolution=1024, + activation_threshold=1536, + ) + + assert area is not None + assert auto is not None + assert area.dype_scale > auto.dype_scale * 0.9 + assert area.dype_scale < auto.dype_scale * 1.1 + + def test_get_dype_config_for_area_uses_higher_exponent_than_old_curve(self): + config = get_dype_config_for_area( + width=1536, + height=1024, + base_resolution=1024, + ) + + assert config is not None + assert 1.25 <= config.dype_exponent <= 2.0 def test_get_dype_config_from_preset_area(self): """Preset AREA should use area-based config.""" @@ -374,33 +472,28 @@ def test_compute_vision_yarn_freqs_shape(self): assert cos.shape[0] == 1 # batch assert cos.shape[1] == 16 # seq_len - def test_compute_yarn_freqs_shape(self): - """Test yarn frequency computation shape.""" + def test_compute_vision_yarn_freqs_reverts_to_base_rope_at_zero_sigma(self): pos = torch.arange(16).unsqueeze(0).float() config = DyPEConfig() - cos, sin = compute_yarn_freqs( + dy_cos, dy_sin = compute_vision_yarn_freqs( pos=pos, dim=32, theta=10000, - scale=2.0, - current_sigma=0.5, + scale_h=2.0, + scale_w=2.0, + current_sigma=0.0, dype_config=config, ) - - assert cos.shape == sin.shape - assert cos.shape[0] == 1 - - def test_compute_ntk_freqs_shape(self): - """Test ntk frequency computation shape.""" - pos = torch.arange(16).unsqueeze(0).float() - - cos, sin = compute_ntk_freqs( + base_cos, base_sin = compute_vision_yarn_freqs( pos=pos, dim=32, theta=10000, - scale=2.0, + scale_h=1.0, + scale_w=1.0, + current_sigma=0.0, + dype_config=config, ) - assert cos.shape == sin.shape - assert cos.shape[0] == 1 + assert torch.allclose(dy_cos, base_cos) + assert torch.allclose(dy_sin, base_sin) diff --git a/tests/backend/flux/test_denoise.py b/tests/backend/flux/test_denoise.py new file mode 100644 index 00000000000..47065cc21f3 --- /dev/null +++ b/tests/backend/flux/test_denoise.py @@ -0,0 +1,249 @@ +from types import SimpleNamespace + +import pytest +import torch + +from invokeai.backend.flux.denoise import denoise +from invokeai.backend.flux.schedulers import FLUX_SCHEDULER_MAP + + +class _FakeFluxModel: + def __call__( + self, + img: torch.Tensor, + img_ids: torch.Tensor, + txt: torch.Tensor, + txt_ids: torch.Tensor, + y: torch.Tensor, + timesteps: torch.Tensor, + guidance: torch.Tensor, + timestep_index: int, + total_num_timesteps: int, + controlnet_double_block_residuals: list[torch.Tensor] | None, + controlnet_single_block_residuals: list[torch.Tensor] | None, + ip_adapter_extensions: list[object], + regional_prompting_extension: object, + ) -> torch.Tensor: + return torch.zeros_like(img) + + +class _FakeDyPEExtension: + def __init__(self) -> None: + self.sigmas: list[float] = [] + + def patch_model(self, model: object) -> tuple[object, None]: + return object(), None + + def update_step_state(self, embedder: object, sigma: float) -> None: + self.sigmas.append(sigma) + + +class _FakeScheduler: + def __init__(self) -> None: + self.config = SimpleNamespace(num_train_timesteps=1000) + self.timesteps = torch.tensor([], dtype=torch.float32) + self.sigmas = torch.tensor([], dtype=torch.float32) + + def set_timesteps(self, sigmas: list[float], device: torch.device) -> None: + del device + self.sigmas = torch.tensor(sigmas, dtype=torch.float32) + self.timesteps = torch.tensor([900.0, 400.0], dtype=torch.float32) + + def step(self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor) -> SimpleNamespace: + del model_output, timestep + return SimpleNamespace(prev_sample=sample) + + +class _FakeHeunScheduler: + def __init__(self) -> None: + self.config = SimpleNamespace(num_train_timesteps=1000) + self.timesteps = torch.tensor([], dtype=torch.float32) + self.sigmas = torch.tensor([], dtype=torch.float32) + self.state_in_first_order = True + self._step_index = 0 + + def set_timesteps(self, sigmas: list[float], device: torch.device) -> None: + del device + # Duplicate each user-facing step to mimic a second-order scheduler. + self.sigmas = torch.tensor([1.0, 1.0, 0.25, 0.25, 0.0], dtype=torch.float32) + self.timesteps = torch.tensor([900.0, 850.0, 400.0, 350.0], dtype=torch.float32) + self._step_index = 0 + self.state_in_first_order = True + + def step(self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor) -> SimpleNamespace: + del model_output, timestep + self._step_index += 1 + self.state_in_first_order = self._step_index % 2 == 0 + return SimpleNamespace(prev_sample=sample) + + +class _FakePbar: + def update(self, value: int) -> None: + del value + + def close(self) -> None: + return None + + +def _fake_tqdm(iterable=None, **kwargs): + del kwargs + if iterable is None: + return _FakePbar() + return iterable + + +def _build_regional_prompting_extension(batch_size: int) -> SimpleNamespace: + return SimpleNamespace( + regional_text_conditioning=SimpleNamespace( + t5_embeddings=torch.zeros(batch_size, 1, 4), + t5_txt_ids=torch.zeros(batch_size, 1, 3), + clip_embeddings=torch.zeros(batch_size, 4), + ) + ) + + +def test_denoise_euler_path_updates_dype_with_sigma(monkeypatch): + monkeypatch.setattr("invokeai.backend.flux.denoise.tqdm", _fake_tqdm) + + model = _FakeFluxModel() + dype_extension = _FakeDyPEExtension() + img = torch.zeros(1, 2, 4) + img_ids = torch.zeros(1, 2, 3) + regional_prompting_extension = _build_regional_prompting_extension(batch_size=1) + callback_steps: list[int] = [] + + result = denoise( + model=model, + img=img, + img_ids=img_ids, + pos_regional_prompting_extension=regional_prompting_extension, + neg_regional_prompting_extension=None, + timesteps=[1.0, 0.5, 0.0], + step_callback=lambda state: callback_steps.append(state.step), + guidance=1.0, + cfg_scale=[1.0, 1.0], + inpaint_extension=None, + controlnet_extensions=[], + pos_ip_adapter_extensions=[], + neg_ip_adapter_extensions=[], + img_cond=None, + img_cond_seq=None, + img_cond_seq_ids=None, + dype_extension=dype_extension, + scheduler=None, + ) + + assert torch.equal(result, img) + assert dype_extension.sigmas == [1.0, 0.5] + assert callback_steps == [1, 2] + + +def test_denoise_scheduler_path_prefers_scheduler_sigmas_for_dype(monkeypatch): + monkeypatch.setattr("invokeai.backend.flux.denoise.tqdm", _fake_tqdm) + + model = _FakeFluxModel() + scheduler = _FakeScheduler() + dype_extension = _FakeDyPEExtension() + img = torch.zeros(1, 2, 4) + img_ids = torch.zeros(1, 2, 3) + regional_prompting_extension = _build_regional_prompting_extension(batch_size=1) + + denoise( + model=model, + img=img, + img_ids=img_ids, + pos_regional_prompting_extension=regional_prompting_extension, + neg_regional_prompting_extension=None, + timesteps=[1.0, 0.25, 0.0], + step_callback=lambda state: None, + guidance=1.0, + cfg_scale=[1.0, 1.0], + inpaint_extension=None, + controlnet_extensions=[], + pos_ip_adapter_extensions=[], + neg_ip_adapter_extensions=[], + img_cond=None, + img_cond_seq=None, + img_cond_seq_ids=None, + dype_extension=dype_extension, + scheduler=scheduler, + ) + + # Scheduler timesteps normalize to [0.9, 0.4], so this asserts the scheduler + # sigma sequence is what DyPE actually consumes. + assert dype_extension.sigmas == [1.0, 0.25] + + +def test_denoise_heun_scheduler_path_uses_internal_scheduler_sigmas(monkeypatch): + monkeypatch.setattr("invokeai.backend.flux.denoise.tqdm", _fake_tqdm) + + model = _FakeFluxModel() + scheduler = _FakeHeunScheduler() + dype_extension = _FakeDyPEExtension() + img = torch.zeros(1, 2, 4) + img_ids = torch.zeros(1, 2, 3) + regional_prompting_extension = _build_regional_prompting_extension(batch_size=1) + callback_steps: list[int] = [] + + denoise( + model=model, + img=img, + img_ids=img_ids, + pos_regional_prompting_extension=regional_prompting_extension, + neg_regional_prompting_extension=None, + timesteps=[1.0, 0.25, 0.0], + step_callback=lambda state: callback_steps.append(state.step), + guidance=1.0, + cfg_scale=[1.0, 1.0], + inpaint_extension=None, + controlnet_extensions=[], + pos_ip_adapter_extensions=[], + neg_ip_adapter_extensions=[], + img_cond=None, + img_cond_seq=None, + img_cond_seq_ids=None, + dype_extension=dype_extension, + scheduler=scheduler, + ) + + assert dype_extension.sigmas == [1.0, 1.0, 0.25, 0.25] + assert callback_steps == [1, 2] + + +@pytest.mark.parametrize("scheduler_name", sorted(FLUX_SCHEDULER_MAP)) +def test_denoise_real_flux_schedulers_update_dype_from_internal_sigma_schedule(monkeypatch, scheduler_name): + monkeypatch.setattr("invokeai.backend.flux.denoise.tqdm", _fake_tqdm) + + model = _FakeFluxModel() + scheduler = FLUX_SCHEDULER_MAP[scheduler_name](num_train_timesteps=1000) + dype_extension = _FakeDyPEExtension() + img = torch.zeros(1, 2, 4) + img_ids = torch.zeros(1, 2, 3) + regional_prompting_extension = _build_regional_prompting_extension(batch_size=1) + callback_steps: list[int] = [] + + denoise( + model=model, + img=img, + img_ids=img_ids, + pos_regional_prompting_extension=regional_prompting_extension, + neg_regional_prompting_extension=None, + timesteps=[1.0, 0.25, 0.0], + step_callback=lambda state: callback_steps.append(state.step), + guidance=1.0, + cfg_scale=[1.0, 1.0], + inpaint_extension=None, + controlnet_extensions=[], + pos_ip_adapter_extensions=[], + neg_ip_adapter_extensions=[], + img_cond=None, + img_cond_seq=None, + img_cond_seq_ids=None, + dype_extension=dype_extension, + scheduler=scheduler, + ) + + assert dype_extension.sigmas + expected_sigmas = [float(sigma) for sigma in scheduler.sigmas[: len(dype_extension.sigmas)]] + assert dype_extension.sigmas == expected_sigmas + assert callback_steps