-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcompute_sigma_max.py
More file actions
executable file
·24 lines (20 loc) · 1009 Bytes
/
compute_sigma_max.py
File metadata and controls
executable file
·24 lines (20 loc) · 1009 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from pathlib import Path
import sys; sys.path.append(str(Path(__file__).parent.parent.resolve()))
from diffusha.algo.ddpm_base import make_beta_schedule
# timesteps = [10, 20, 30, 50, 100]
timesteps = [10, 20, 50]
beta_mins = [1e-4]
beta_maxs = [0.05]
for beta_min in beta_mins:
for beta_max in beta_maxs:
for T in timesteps:
betas = make_beta_schedule(schedule='sigmoid', n_timesteps=T, start=1e-4, end=beta_max)
alphas = 1 - betas
alphas_prod = torch.cumprod(alphas, 0)
alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]], 0)
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)
sigmas = (one_minus_alphas_bar_sqrt / alphas_bar_sqrt)
print(f'timesteps: {T}\tbeta_min: {beta_min}\tbeta_max: {beta_max}\tsigma_min: {sigmas[0]}\tsigma_max: {sigmas[-1]}')