Skip to content

Commit db85cf0

Browse files
authored
feat: RIFE and FILM frame interpolation model support (CORE-29) (Comfy-Org#13258)
* initial RIFE support * Also support FILM * Better RAM usage, reduce FILM VRAM peak * Add model folder placeholder * Fix oom fallback frame loss * Remove torch.compile for now * Rename model input * Shorter input type name ---------
1 parent 91e1f45 commit db85cf0

6 files changed

Lines changed: 601 additions & 1 deletion

File tree

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
"""FILM: Frame Interpolation for Large Motion (ECCV 2022)."""
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
7+
import comfy.ops
8+
9+
ops = comfy.ops.disable_weight_init
10+
11+
12+
class FilmConv2d(nn.Module):
13+
"""Conv2d with optional LeakyReLU and FILM-style padding."""
14+
15+
def __init__(self, in_channels, out_channels, size, activation=True, device=None, dtype=None, operations=ops):
16+
super().__init__()
17+
self.even_pad = not size % 2
18+
self.conv = operations.Conv2d(in_channels, out_channels, kernel_size=size, padding=size // 2 if size % 2 else 0, device=device, dtype=dtype)
19+
self.activation = nn.LeakyReLU(0.2) if activation else None
20+
21+
def forward(self, x):
22+
if self.even_pad:
23+
x = F.pad(x, (0, 1, 0, 1))
24+
x = self.conv(x)
25+
if self.activation is not None:
26+
x = self.activation(x)
27+
return x
28+
29+
30+
def _warp_core(image, flow, grid_x, grid_y):
31+
dtype = image.dtype
32+
H, W = flow.shape[2], flow.shape[3]
33+
dx = flow[:, 0].float() / (W * 0.5)
34+
dy = flow[:, 1].float() / (H * 0.5)
35+
grid = torch.stack([grid_x[None, None, :] + dx, grid_y[None, :, None] + dy], dim=3)
36+
return F.grid_sample(image.float(), grid, mode="bilinear", padding_mode="border", align_corners=False).to(dtype)
37+
38+
39+
def build_image_pyramid(image, pyramid_levels):
40+
pyramid = [image]
41+
for _ in range(1, pyramid_levels):
42+
image = F.avg_pool2d(image, 2, 2)
43+
pyramid.append(image)
44+
return pyramid
45+
46+
47+
def flow_pyramid_synthesis(residual_pyramid):
48+
flow = residual_pyramid[-1]
49+
flow_pyramid = [flow]
50+
for residual_flow in residual_pyramid[:-1][::-1]:
51+
flow = F.interpolate(flow, size=residual_flow.shape[2:4], mode="bilinear", scale_factor=None).mul_(2).add_(residual_flow)
52+
flow_pyramid.append(flow)
53+
flow_pyramid.reverse()
54+
return flow_pyramid
55+
56+
57+
def multiply_pyramid(pyramid, scalar):
58+
return [image * scalar[:, None, None, None] for image in pyramid]
59+
60+
61+
def pyramid_warp(feature_pyramid, flow_pyramid, warp_fn):
62+
return [warp_fn(features, flow) for features, flow in zip(feature_pyramid, flow_pyramid)]
63+
64+
65+
def concatenate_pyramids(pyramid1, pyramid2):
66+
return [torch.cat([f1, f2], dim=1) for f1, f2 in zip(pyramid1, pyramid2)]
67+
68+
69+
class SubTreeExtractor(nn.Module):
70+
def __init__(self, in_channels=3, channels=64, n_layers=4, device=None, dtype=None, operations=ops):
71+
super().__init__()
72+
convs = []
73+
for i in range(n_layers):
74+
out_ch = channels << i
75+
convs.append(nn.Sequential(
76+
FilmConv2d(in_channels, out_ch, 3, device=device, dtype=dtype, operations=operations),
77+
FilmConv2d(out_ch, out_ch, 3, device=device, dtype=dtype, operations=operations)))
78+
in_channels = out_ch
79+
self.convs = nn.ModuleList(convs)
80+
81+
def forward(self, image, n):
82+
head = image
83+
pyramid = []
84+
for i, layer in enumerate(self.convs):
85+
head = layer(head)
86+
pyramid.append(head)
87+
if i < n - 1:
88+
head = F.avg_pool2d(head, 2, 2)
89+
return pyramid
90+
91+
92+
class FeatureExtractor(nn.Module):
93+
def __init__(self, in_channels=3, channels=64, sub_levels=4, device=None, dtype=None, operations=ops):
94+
super().__init__()
95+
self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels, device=device, dtype=dtype, operations=operations)
96+
self.sub_levels = sub_levels
97+
98+
def forward(self, image_pyramid):
99+
sub_pyramids = [self.extract_sublevels(image_pyramid[i], min(len(image_pyramid) - i, self.sub_levels))
100+
for i in range(len(image_pyramid))]
101+
feature_pyramid = []
102+
for i in range(len(image_pyramid)):
103+
features = sub_pyramids[i][0]
104+
for j in range(1, self.sub_levels):
105+
if j <= i:
106+
features = torch.cat([features, sub_pyramids[i - j][j]], dim=1)
107+
feature_pyramid.append(features)
108+
# Free sub-pyramids no longer needed by future levels
109+
if i >= self.sub_levels - 1:
110+
sub_pyramids[i - self.sub_levels + 1] = None
111+
return feature_pyramid
112+
113+
114+
class FlowEstimator(nn.Module):
115+
def __init__(self, in_channels, num_convs, num_filters, device=None, dtype=None, operations=ops):
116+
super().__init__()
117+
self._convs = nn.ModuleList()
118+
for _ in range(num_convs):
119+
self._convs.append(FilmConv2d(in_channels, num_filters, 3, device=device, dtype=dtype, operations=operations))
120+
in_channels = num_filters
121+
self._convs.append(FilmConv2d(in_channels, num_filters // 2, 1, device=device, dtype=dtype, operations=operations))
122+
self._convs.append(FilmConv2d(num_filters // 2, 2, 1, activation=False, device=device, dtype=dtype, operations=operations))
123+
124+
def forward(self, features_a, features_b):
125+
net = torch.cat([features_a, features_b], dim=1)
126+
for conv in self._convs:
127+
net = conv(net)
128+
return net
129+
130+
131+
class PyramidFlowEstimator(nn.Module):
132+
def __init__(self, filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops):
133+
super().__init__()
134+
in_channels = filters << 1
135+
predictors = []
136+
for i in range(len(flow_convs)):
137+
predictors.append(FlowEstimator(in_channels, flow_convs[i], flow_filters[i], device=device, dtype=dtype, operations=operations))
138+
in_channels += filters << (i + 2)
139+
self._predictor = predictors[-1]
140+
self._predictors = nn.ModuleList(predictors[:-1][::-1])
141+
142+
def forward(self, feature_pyramid_a, feature_pyramid_b, warp_fn):
143+
levels = len(feature_pyramid_a)
144+
v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1])
145+
residuals = [v]
146+
# Coarse-to-fine: shared predictor for deep levels, then specialized predictors for fine levels
147+
steps = [(i, self._predictor) for i in range(levels - 2, len(self._predictors) - 1, -1)]
148+
steps += [(len(self._predictors) - 1 - k, p) for k, p in enumerate(self._predictors)]
149+
for i, predictor in steps:
150+
v = F.interpolate(v, size=feature_pyramid_a[i].shape[2:4], mode="bilinear").mul_(2)
151+
v_residual = predictor(feature_pyramid_a[i], warp_fn(feature_pyramid_b[i], v))
152+
residuals.append(v_residual)
153+
v = v.add_(v_residual)
154+
residuals.reverse()
155+
return residuals
156+
157+
158+
def _get_fusion_channels(level, filters):
159+
# Per direction: multi-scale features + RGB image (3ch) + flow (2ch), doubled for both directions
160+
return (sum(filters << i for i in range(level)) + 3 + 2) * 2
161+
162+
163+
class Fusion(nn.Module):
164+
def __init__(self, n_layers=4, specialized_layers=3, filters=64, device=None, dtype=None, operations=ops):
165+
super().__init__()
166+
self.output_conv = operations.Conv2d(filters, 3, kernel_size=1, device=device, dtype=dtype)
167+
self.convs = nn.ModuleList()
168+
in_channels = _get_fusion_channels(n_layers, filters)
169+
increase = 0
170+
for i in range(n_layers)[::-1]:
171+
num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers)
172+
self.convs.append(nn.ModuleList([
173+
FilmConv2d(in_channels, num_filters, 2, activation=False, device=device, dtype=dtype, operations=operations),
174+
FilmConv2d(in_channels + (increase or num_filters), num_filters, 3, device=device, dtype=dtype, operations=operations),
175+
FilmConv2d(num_filters, num_filters, 3, device=device, dtype=dtype, operations=operations)]))
176+
in_channels = num_filters
177+
increase = _get_fusion_channels(i, filters) - num_filters // 2
178+
179+
def forward(self, pyramid):
180+
net = pyramid[-1]
181+
for k, layers in enumerate(self.convs):
182+
i = len(self.convs) - 1 - k
183+
net = layers[0](F.interpolate(net, size=pyramid[i].shape[2:4], mode="nearest"))
184+
net = layers[2](layers[1](torch.cat([pyramid[i], net], dim=1)))
185+
return self.output_conv(net)
186+
187+
188+
class FILMNet(nn.Module):
189+
def __init__(self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels=3, sub_levels=4,
190+
filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops):
191+
super().__init__()
192+
self.pyramid_levels = pyramid_levels
193+
self.fusion_pyramid_levels = fusion_pyramid_levels
194+
self.extract = FeatureExtractor(3, filters, sub_levels, device=device, dtype=dtype, operations=operations)
195+
self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters, device=device, dtype=dtype, operations=operations)
196+
self.fuse = Fusion(sub_levels, specialized_levels, filters, device=device, dtype=dtype, operations=operations)
197+
self._warp_grids = {}
198+
199+
def get_dtype(self):
200+
return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype
201+
202+
def _build_warp_grids(self, H, W, device):
203+
"""Pre-compute warp grids for all pyramid levels."""
204+
if (H, W) in self._warp_grids:
205+
return
206+
self._warp_grids = {} # clear old resolution grids to prevent memory leaks
207+
for _ in range(self.pyramid_levels):
208+
self._warp_grids[(H, W)] = (
209+
torch.linspace(-(1 - 1 / W), 1 - 1 / W, W, dtype=torch.float32, device=device),
210+
torch.linspace(-(1 - 1 / H), 1 - 1 / H, H, dtype=torch.float32, device=device),
211+
)
212+
H, W = H // 2, W // 2
213+
214+
def warp(self, image, flow):
215+
grid_x, grid_y = self._warp_grids[(flow.shape[2], flow.shape[3])]
216+
return _warp_core(image, flow, grid_x, grid_y)
217+
218+
def extract_features(self, img):
219+
"""Extract image and feature pyramids for a single frame. Can be cached across pairs."""
220+
image_pyramid = build_image_pyramid(img, self.pyramid_levels)
221+
feature_pyramid = self.extract(image_pyramid)
222+
return image_pyramid, feature_pyramid
223+
224+
def forward(self, img0, img1, timestep=0.5, cache=None):
225+
# FILM uses a scalar timestep per batch element (spatially-varying timesteps not supported)
226+
t = timestep.mean(dim=(1, 2, 3)).item() if isinstance(timestep, torch.Tensor) else timestep
227+
return self.forward_multi_timestep(img0, img1, [t], cache=cache)
228+
229+
def forward_multi_timestep(self, img0, img1, timesteps, cache=None):
230+
"""Compute flow once, synthesize at multiple timesteps. Expects batch=1 inputs."""
231+
self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device)
232+
233+
image_pyr0, feat_pyr0 = cache["img0"] if cache and "img0" in cache else self.extract_features(img0)
234+
image_pyr1, feat_pyr1 = cache["img1"] if cache and "img1" in cache else self.extract_features(img1)
235+
236+
fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels]
237+
bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels]
238+
239+
# Build warp targets and free full pyramids (only first fpl levels needed from here)
240+
fpl = self.fusion_pyramid_levels
241+
p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]),
242+
concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])]
243+
del image_pyr0, image_pyr1, feat_pyr0, feat_pyr1
244+
245+
results = []
246+
dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype)
247+
for idx in range(len(timesteps)):
248+
batch_dt = dt_tensors[idx:idx + 1]
249+
bwd_scaled = multiply_pyramid(bwd_flow, batch_dt)
250+
fwd_scaled = multiply_pyramid(fwd_flow, 1 - batch_dt)
251+
fwd_warped = pyramid_warp(p2w[0], bwd_scaled, self.warp)
252+
bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp)
253+
aligned = [torch.cat([fw, bw, bf, ff], dim=1)
254+
for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)]
255+
del fwd_warped, bwd_warped, bwd_scaled, fwd_scaled
256+
results.append(self.fuse(aligned))
257+
del aligned
258+
return torch.cat(results, dim=0)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
import comfy.ops
6+
7+
ops = comfy.ops.disable_weight_init
8+
9+
10+
def _warp(img, flow, warp_grids):
11+
B, _, H, W = img.shape
12+
base_grid, flow_div = warp_grids[(H, W)]
13+
flow_norm = torch.cat([flow[:, 0:1] / flow_div[0], flow[:, 1:2] / flow_div[1]], 1).float()
14+
grid = (base_grid.expand(B, -1, -1, -1) + flow_norm).permute(0, 2, 3, 1)
15+
return F.grid_sample(img.float(), grid, mode="bilinear", padding_mode="border", align_corners=True).to(img.dtype)
16+
17+
18+
class Head(nn.Module):
19+
def __init__(self, out_ch=4, device=None, dtype=None, operations=ops):
20+
super().__init__()
21+
self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype)
22+
self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype)
23+
self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype)
24+
self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype)
25+
self.relu = nn.LeakyReLU(0.2, True)
26+
27+
def forward(self, x):
28+
x = self.relu(self.cnn0(x))
29+
x = self.relu(self.cnn1(x))
30+
x = self.relu(self.cnn2(x))
31+
return self.cnn3(x)
32+
33+
34+
class ResConv(nn.Module):
35+
def __init__(self, c, device=None, dtype=None, operations=ops):
36+
super().__init__()
37+
self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype)
38+
self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype))
39+
self.relu = nn.LeakyReLU(0.2, True)
40+
41+
def forward(self, x):
42+
return self.relu(torch.addcmul(x, self.conv(x), self.beta))
43+
44+
45+
class IFBlock(nn.Module):
46+
def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops):
47+
super().__init__()
48+
self.conv0 = nn.Sequential(
49+
nn.Sequential(operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)),
50+
nn.Sequential(operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)))
51+
self.convblock = nn.Sequential(*(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8)))
52+
self.lastconv = nn.Sequential(operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype), nn.PixelShuffle(2))
53+
54+
def forward(self, x, flow=None, scale=1):
55+
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear")
56+
if flow is not None:
57+
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale)
58+
x = torch.cat((x, flow), 1)
59+
feat = self.convblock(self.conv0(x))
60+
tmp = F.interpolate(self.lastconv(feat), scale_factor=scale, mode="bilinear")
61+
return tmp[:, :4] * scale, tmp[:, 4:5], tmp[:, 5:]
62+
63+
64+
class IFNet(nn.Module):
65+
def __init__(self, head_ch=4, channels=(192, 128, 96, 64, 32), device=None, dtype=None, operations=ops):
66+
super().__init__()
67+
self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations)
68+
block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4
69+
self.blocks = nn.ModuleList([IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations) for i in range(5)])
70+
self.scale_list = [16, 8, 4, 2, 1]
71+
self.pad_align = 64
72+
self._warp_grids = {}
73+
74+
def get_dtype(self):
75+
return self.encode.cnn0.weight.dtype
76+
77+
def _build_warp_grids(self, H, W, device):
78+
if (H, W) in self._warp_grids:
79+
return
80+
self._warp_grids = {} # clear old resolution grids to prevent memory leaks
81+
grid_y, grid_x = torch.meshgrid(
82+
torch.linspace(-1.0, 1.0, H, device=device, dtype=torch.float32),
83+
torch.linspace(-1.0, 1.0, W, device=device, dtype=torch.float32), indexing="ij")
84+
self._warp_grids[(H, W)] = (
85+
torch.stack((grid_x, grid_y), dim=0).unsqueeze(0),
86+
torch.tensor([(W - 1.0) / 2.0, (H - 1.0) / 2.0], dtype=torch.float32, device=device))
87+
88+
def warp(self, img, flow):
89+
return _warp(img, flow, self._warp_grids)
90+
91+
def extract_features(self, img):
92+
"""Extract head features for a single frame. Can be cached across pairs."""
93+
return self.encode(img)
94+
95+
def forward(self, img0, img1, timestep=0.5, cache=None):
96+
if not isinstance(timestep, torch.Tensor):
97+
timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]), timestep, device=img0.device, dtype=img0.dtype)
98+
99+
self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device)
100+
101+
B = img0.shape[0]
102+
f0 = cache["img0"].expand(B, -1, -1, -1) if cache and "img0" in cache else self.encode(img0)
103+
f1 = cache["img1"].expand(B, -1, -1, -1) if cache and "img1" in cache else self.encode(img1)
104+
flow = mask = feat = None
105+
warped_img0, warped_img1 = img0, img1
106+
for i, block in enumerate(self.blocks):
107+
if flow is None:
108+
flow, mask, feat = block(torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i])
109+
else:
110+
fd, mask, feat = block(
111+
torch.cat((warped_img0, warped_img1, self.warp(f0, flow[:, :2]), self.warp(f1, flow[:, 2:4]), timestep, mask, feat), 1),
112+
flow, scale=self.scale_list[i])
113+
flow = flow.add_(fd)
114+
warped_img0 = self.warp(img0, flow[:, :2])
115+
warped_img1 = self.warp(img1, flow[:, 2:4])
116+
return torch.lerp(warped_img1, warped_img0, torch.sigmoid(mask))
117+
118+
119+
def detect_rife_config(state_dict):
120+
head_ch = state_dict["encode.cnn3.weight"].shape[1] # ConvTranspose2d: (in_ch, out_ch, kH, kW)
121+
channels = []
122+
for i in range(5):
123+
key = f"blocks.{i}.conv0.1.0.weight"
124+
if key in state_dict:
125+
channels.append(state_dict[key].shape[0])
126+
if len(channels) != 5:
127+
raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}")
128+
return head_ch, channels

0 commit comments

Comments
 (0)