Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ All key knobs are exposed via YAML in the `opensr_srgan/configs` folder:

* **Model**: `in_channels`, `n_channels`, `n_blocks`, `scale`, ESRGAN knobs (`growth_channels`, `res_scale`, `out_channels`), `block_type ∈ {SRResNet, res, rcab, rrdb, lka}`
* **Losses**: `l1_weight`, `sam_weight`, `perceptual_weight`, `tv_weight`, `adv_loss_beta`
* **Training**: `pretrain_g_only`, `g_pretrain_steps`, `adv_loss_ramp_steps`, `label_smoothing`, generator LR warmup (`Schedulers.g_warmup_steps`, `Schedulers.g_warmup_type`), discriminator cadence controls
* **Training**: `pretrain_g_only`, `g_pretrain_steps` (`-1` keeps generator-only pretraining active indefinitely), `adv_loss_ramp_steps`, `label_smoothing`, generator LR warmup (`Schedulers.g_warmup_steps`, `Schedulers.g_warmup_type`), discriminator cadence controls
* **Adversarial mode**: `Training.Losses.adv_loss_type` (`bce`/`wasserstein`) and optional `Training.Losses.relativistic_average_d` for BCE-based relativistic-average GAN updates
* **Data**: band order, normalization stats, crop sizes, augmentations

---
Expand All @@ -48,6 +49,7 @@ All key knobs are exposed via YAML in the `opensr_srgan/configs` folder:
* **EMA smoothing:** Enable `Training.EMA.enabled` to keep a shadow copy of the generator. Decay values in the 0.995–0.9999 range balance responsiveness with stability and are swapped in automatically for validation/inference.
* **Spectral normalization:** Optional for the SRGAN discriminator via `Discriminator.use_spectral_norm` to better control its Lipschitz constant and stabilize adversarial updates. [Miyato et al., 2018](https://arxiv.org/abs/1802.05957)
* **Wasserstein critic + R1 penalty:** Switch `Training.Losses.adv_loss_type: wasserstein` to enable a critic objective and pair it with the configurable `Training.Losses.r1_gamma` gradient penalty on real images for smoother discriminator updates. [Arjovsky et al., 2017](https://arxiv.org/abs/1701.07875); [Mescheder et al., 2018](https://arxiv.org/abs/1801.04406)
* **Relativistic average GAN (BCE):** Set `Training.Losses.relativistic_average_d: true` to train D/G on relative real-vs-fake logits instead of absolute logits. This is supported in both Lightning training paths (PL1 and PL2).
The schedule and ramp make training **easier, safer, and more reproducible**.

---
Expand All @@ -59,7 +61,7 @@ The schedule and ramp make training **easier, safer, and more reproducible**.
| **Generators** | `SRResNet`, `res`, `rcab`, `rrdb`, `lka`, `esrgan`, `stochastic_gan` | `Generator.model_type`, depth via `Generator.n_blocks`, width via `Generator.n_channels`, kernels/scale plus ESRGAN-specific `growth_channels`, `res_scale`, `out_channels`. |
| **Discriminators** | `standard` `SRGAN`, `CNN`, `patchgan`, `esrgan` | `Discriminator.model_type`, granularity with `Discriminator.n_blocks`, spectral norm toggle via `Discriminator.use_spectral_norm`, ESRGAN-specific `base_channels`, `linear_size`. |
| **Content losses** | L1, Spectral Angle Mapper, VGG19/LPIPS perceptual metrics, Total Variation | Weighted by `Training.Losses.*` (e.g. `l1_weight`, `sam_weight`, `perceptual_weight`, `perceptual_metric`, `tv_weight`). |
| **Adversarial loss** | BCE‑with‑logits on real/fake logits | Warmup via `Training.pretrain_g_only`, ramped by `adv_loss_ramp_steps`, capped at `adv_loss_beta`, optional label smoothing. |
| **Adversarial loss** | BCE‑with‑logits or Wasserstein critic | Controlled by `Training.Losses.adv_loss_type`, warmup via `Training.pretrain_g_only`, ramped by `adv_loss_ramp_steps`, capped at `adv_loss_beta`, optional label smoothing. For BCE, enable `Training.Losses.relativistic_average_d` for RaGAN-style relative logits. |

The YAML keeps the SRGAN flexible: swap architectures or rebalance perceptual vs. spectral fidelity without touching the code.

Expand Down
3 changes: 2 additions & 1 deletion docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ If you need to reuse the same function for both directions (for example
| Key | Default | Description |
| --- | --- | --- |
| `pretrain_g_only` | `True` | Enable generator-only warm-up before adversarial updates. |
| `g_pretrain_steps` | `10000` | Number of optimiser steps spent in the warm-up phase. |
| `g_pretrain_steps` | `10000` | Number of optimiser steps spent in the warm-up phase (`-1` keeps generator-only pretraining active indefinitely when `pretrain_g_only: true`). |
| `adv_loss_ramp_steps` | `5000` | Duration of the adversarial weight ramp after the warm-up. |
| `label_smoothing` | `True` | Replaces target value 1.0 with 0.9 for real examples to stabilise discriminator training. |

Expand All @@ -109,6 +109,7 @@ stable validation imagery. The EMA is fully optional and controlled through the
| `adv_loss_beta` | `1e-3` | Target weight applied to the adversarial term after ramp-up. |
| `adv_loss_schedule` | `cosine` | Ramp shape (`linear` or `cosine`). |
| `adv_loss_type` | `bce` | Adversarial objective (`bce` for classic SRGAN logits, `wasserstein` for a non-saturating critic-style loss). |
| `relativistic_average_d` | `False` | BCE-only switch for relativistic-average GAN training (real/fake logits are compared against each other's batch mean). Supported in both PL1 and PL2 training-step implementations. |
| `r1_gamma` | `0.0` | Strength of the R1 gradient penalty applied to real images (useful with Wasserstein critics). |
| `l1_weight` | `1.0` | Weight of the pixelwise L1 loss. |
| `sam_weight` | `0.05` | Weight of the spectral angle mapper loss. |
Expand Down
8 changes: 4 additions & 4 deletions docs/trainer-details.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ if pretrain_phase:
return dummy
```

* `_pretrain_check()` compares `self.global_step` against `Training.g_pretrain_steps` to decide whether the generator-only warm-up is active. 【F:opensr_srgan/model/training_step_PL.py†L10-L46】
* `_pretrain_check()` compares `self.global_step` against `Training.g_pretrain_steps` to decide whether the generator-only warm-up is active (`g_pretrain_steps: -1` keeps this phase active indefinitely). 【F:opensr_srgan/model/training_step_PL.py†L10-L46】
* The pretraining branch logs the instantaneous adversarial weight even though it stays unused until GAN training begins. This keeps dashboards continuous when you review historical runs.
* The discriminator receives a zero-valued tensor with `requires_grad=True` so Lightning's closure executes without mutating weights. Dummy logs (`discriminator/D(y)_prob`, `discriminator/D(G(x))_prob`) remain pinned to zero for clarity.

Once `_pretrain_check()` flips to `False`, the function splits into discriminator and generator updates:

* **Discriminator (`optimizer_idx == 0`).** Real and fake logits are compared against smoothed targets, and the resulting BCE components are summed into `discriminator/adversarial_loss`. The helper logs running opinions (`discriminator/D(y)_prob`, `discriminator/D(G(x))_prob`) so you can diagnose mode collapse early. 【F:opensr_srgan/model/training_step_PL.py†L135-L195】
* **Generator (`optimizer_idx == 1`).** The generator measures content metrics once, reuses them for logging, queries the adversarial signal (`adversarial_loss_criterion(sr_discriminated, ones)`), and multiplies it with `_adv_loss_weight()` before combining both parts into `generator/total_loss`. 【F:opensr_srgan/model/training_step_PL.py†L203-L247】
* **Discriminator (`optimizer_idx == 0`).** Real and fake logits are compared against smoothed targets, and the resulting BCE components are summed into `discriminator/adversarial_loss`. If `Training.Losses.relativistic_average_d: true` (BCE mode), both terms are computed on relativistic logits (`D(real)-mean(D(fake))`, `D(fake)-mean(D(real))`) and additional relativistic confidence logs are emitted. 【F:opensr_srgan/model/training_step_PL.py†L135-L195】
* **Generator (`optimizer_idx == 1`).** The generator measures content metrics once, reuses them for logging, queries the adversarial signal, and multiplies it with `_adv_loss_weight()` before combining both parts into `generator/total_loss`. In BCE + relativistic mode, generator adversarial loss is averaged from `BCE(D(fake)-mean(D(real)), 1)` and `BCE(D(real)-mean(D(fake)), 0)`. 【F:opensr_srgan/model/training_step_PL.py†L203-L247】

With `Training.Losses.adv_loss_type: wasserstein`, the same branches apply but swap the BCE terms for a critic objective: the discriminator minimises `mean(fake) - mean(real)` (plus any configured R1 penalty), and the generator minimises `-mean(D(G(x)))`. Logged probabilities remain sigmoid-squashed critic scores to keep dashboards comparable. Configure `Training.Losses.r1_gamma` to activate the real-image R1 gradient penalty popularised by Mescheder et al. for stabilising Wasserstein critics, and toggle `Discriminator.use_spectral_norm` when you want Miyato et al.'s spectral normalisation to enforce a tighter Lipschitz bound on SRGAN discriminators. 【F:opensr_srgan/model/training_step_PL.py†L129-L247】

Expand All @@ -67,7 +67,7 @@ if pretrain_phase:
return content_loss
```

The adversarial branch toggles each optimiser in turn, accumulates identical logs to the PL1.x path, and performs the EMA update after every generator step. 【F:opensr_srgan/model/training_step_PL.py†L336-L458】
The adversarial branch toggles each optimiser in turn, accumulates identical logs to the PL1.x path (including optional relativistic BCE metrics via `Training.Losses.relativistic_average_d`), and performs the EMA update after every generator step. 【F:opensr_srgan/model/training_step_PL.py†L336-L458】

## Adversarial weight schedule

Expand Down
3 changes: 2 additions & 1 deletion opensr_srgan/configs/config_10m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Training:

# --- Pretraining and adversarial setup ---
pretrain_g_only: True # Train generator only for initial phase
g_pretrain_steps: 20000 # Number of generator-only warmup steps
g_pretrain_steps: 20000 # Number of generator-only warmup steps (-1 keeps pretraining active)
adv_loss_ramp_steps: 5000 # Gradual adversarial weight ramp steps
label_smoothing: True # Discriminator target smoothing (1.0 → 0.9)

Expand All @@ -57,6 +57,7 @@ Training:
Losses:
# --- GAN term ---
adv_loss_type: 'bce' # Adversarial objective: ['bce', 'wasserstein']
relativistic_average_d: False # BCE-only: use relativistic-average GAN losses/logs
adv_loss_beta: 0.001 # Final adversarial loss weight after ramp-up - original 0.001
adv_loss_schedule: 'cosine' # Adversarial weight ramp type: ['linear', 'cosine']
r1_gamma: 0.0 # R1 gradient penalty strength on real images (0 disables)
Expand Down
3 changes: 2 additions & 1 deletion opensr_srgan/configs/config_20m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Training:

# --- Pretraining and adversarial setup ---
pretrain_g_only: True # Train generator only for initial phase
g_pretrain_steps: 15000 # Number of generator-only warmup steps
g_pretrain_steps: 15000 # Number of generator-only warmup steps (-1 keeps pretraining active)
adv_loss_ramp_steps: 2500 # Gradual adversarial weight ramp steps
label_smoothing: True # Discriminator target smoothing (1.0 → 0.9)

Expand All @@ -57,6 +57,7 @@ Training:
Losses:
# --- GAN term ---
adv_loss_type: 'bce' # Adversarial objective: ['bce', 'wasserstein']
relativistic_average_d: False # BCE-only: use relativistic-average GAN losses/logs
adv_loss_beta: 1e-3 # Final adversarial loss weight after ramp-up
adv_loss_schedule: 'cosine' # Adversarial weight ramp type: ['linear', 'cosine']
r1_gamma: 0.0 # R1 gradient penalty strength on real images (0 disables)
Expand Down
3 changes: 2 additions & 1 deletion opensr_srgan/configs/config_playgound.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Training:

# --- Pretraining and adversarial setup ---
pretrain_g_only: False # Train generator only for initial phase
g_pretrain_steps: 100000 # Number of generator-only warmup steps
g_pretrain_steps: 100000 # Number of generator-only warmup steps (-1 keeps pretraining active)
adv_loss_ramp_steps: 25000 # Gradual adversarial weight ramp steps
label_smoothing: True # Discriminator target smoothing (1.0 → 0.9)

Expand All @@ -57,6 +57,7 @@ Training:
Losses:
# --- GAN term ---
adv_loss_type: 'bce' # Adversarial objective: ['bce', 'wasserstein']
relativistic_average_d: False # BCE-only: use relativistic-average GAN losses/logs
adv_loss_beta: 1e-3 # Final adversarial loss weight after ramp-up
adv_loss_schedule: 'cosine' # Adversarial weight ramp type: ['linear', 'cosine']
r1_gamma: 0.0 # R1 gradient penalty strength on real images (0 disables)
Expand Down
5 changes: 3 additions & 2 deletions opensr_srgan/configs/config_training_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Training:

# --- Pretraining and adversarial setup ---
pretrain_g_only: True # Train generator only for initial phase
g_pretrain_steps: 1000 # Number of generator-only warmup steps
g_pretrain_steps: 1000 # Number of generator-only warmup steps (-1 keeps pretraining active)
adv_loss_ramp_steps: 500 # Gradual adversarial weight ramp steps
label_smoothing: True # Discriminator target smoothing (1.0 → 0.9)

Expand All @@ -57,6 +57,7 @@ Training:
Losses:
# --- GAN term ---
adv_loss_type: 'bce' # Adversarial objective: ['bce', 'wasserstein']
relativistic_average_d: False # BCE-only: use relativistic-average GAN losses/logs
adv_loss_beta: 0.001 # Final adversarial loss weight after ramp-up - original 0.001
adv_loss_schedule: 'cosine' # Adversarial weight ramp type: ['linear', 'cosine']
r1_gamma: 0.0 # R1 gradient penalty strength on real images (0 disables)
Expand Down Expand Up @@ -130,4 +131,4 @@ Logging:
wandb:
enabled: False # Toggle Weights & Biases logging on/off
entity: "opensr" # W&B entity or team name
project: "SRGAN_10m" # W&B project name
project: "SRGAN_10m" # W&B project name
7 changes: 5 additions & 2 deletions opensr_srgan/model/SRGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def __init__(self, config="config.yaml", mode="train"):
self.adv_loss_type = str(
getattr(self.config.Training.Losses, "adv_loss_type", "bce")
).lower()
self.relativistic_average_d = bool(
getattr(self.config.Training.Losses, "relativistic_average_d", False)
)
if self.adv_loss_type not in {"bce", "wasserstein"}:
raise ValueError(
"Training.Losses.adv_loss_type must be either 'bce' or 'wasserstein'"
Expand Down Expand Up @@ -1091,14 +1094,14 @@ def _pretrain_check(self) -> bool:

Returns:
bool: True if the generator-only pretraining phase is active
(i.e., `global_step` < `g_pretrain_steps`), otherwise False.
(i.e., `global_step` < `g_pretrain_steps` or `g_pretrain_steps == -1`), otherwise False.

Notes:
- During pretraining, the discriminator is frozen and only the
generator is updated.
"""
if (
self.pretrain_g_only and self.global_step < self.g_pretrain_steps
self.pretrain_g_only and (self.global_step < self.g_pretrain_steps or self.g_pretrain_steps == -1)
): # true if pretraining active
return True
else:
Expand Down
Loading