-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsample.py
More file actions
106 lines (90 loc) · 3.31 KB
/
sample.py
File metadata and controls
106 lines (90 loc) · 3.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import argparse
import os
import torch
import torch.distributed as dist
import yaml
from torchvision.utils import make_grid, save_image
from ema_pytorch import EMA
from model.models import get_models_class
from utils import Config, init_seeds, gather_tensor, print0
def get_default_steps(model_type, steps):
if steps is not None:
return steps
else:
return {'EDM': 18}[model_type]
# ===== sampling =====
def sample(opt):
yaml_path = opt.config
local_rank = opt.local_rank
use_amp = opt.use_amp
steps = opt.steps
eta = opt.eta
batches = opt.batches
ep = opt.epoch
with open(yaml_path, 'r') as f:
opt = yaml.full_load(f)
print0(opt)
opt = Config(opt)
if ep == -1:
ep = opt.n_epoch - 1
device = "cuda:%d" % local_rank
steps = get_default_steps(opt.model_type, steps)
DIFFUSION, NETWORK = get_models_class(opt.model_type, opt.net_type)
diff = DIFFUSION(nn_model=NETWORK(**opt.network),
**opt.diffusion,
device=device,
)
diff.to(device)
target = os.path.join(opt.save_dir, "ckpts", f"model_{ep}.pth")
print0("loading model at", target)
checkpoint = torch.load(target, map_location=device)
ema = EMA(diff, beta=opt.ema, update_after_step=0, update_every=1)
ema.to(device)
ema.load_state_dict(checkpoint['EMA'])
model = ema.ema_model
model.eval()
if local_rank == 0:
if opt.model_type == 'EDM':
gen_dir = os.path.join(opt.save_dir, f"EMAgenerated_ep{ep}_edm_steps{steps}_eta{eta}")
else:
raise NotImplementedError
os.makedirs(gen_dir)
gen_dir_png = os.path.join(gen_dir, "pngs")
os.makedirs(gen_dir_png)
res = []
for batch in range(batches):
with torch.no_grad():
assert 400 % dist.get_world_size() == 0
samples_per_process = 400 // dist.get_world_size()
args = dict(n_sample=samples_per_process, size=opt.network['image_shape'], notqdm=(local_rank != 0), use_amp=use_amp)
if opt.model_type == 'EDM':
x_gen = model.edm_sample(**args, steps=steps, eta=eta)
else:
raise NotImplementedError
dist.barrier()
x_gen = gather_tensor(x_gen).cpu()
if local_rank == 0:
res.append(x_gen)
grid = make_grid(x_gen, nrow=20)
png_path = os.path.join(gen_dir, f"grid_{batch}.png")
save_image(grid, png_path)
if local_rank == 0:
res = torch.cat(res)
for no, img in enumerate(res):
png_path = os.path.join(gen_dir_png, f"{no}.png")
save_image(img, png_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str)
parser.add_argument("--use_amp", action='store_true', default=False)
parser.add_argument("--steps", type=int, default=None)
parser.add_argument("--eta", type=float, default=0.0)
parser.add_argument("--batches", type=int, default=125)
parser.add_argument("--epoch", type=int, default=-1)
opt = parser.parse_args()
opt.local_rank = int(os.environ['LOCAL_RANK'])
print0(opt)
init_seeds(no=opt.local_rank)
dist.init_process_group(backend='nccl')
torch.cuda.set_device(opt.local_rank)
sample(opt)