|
| 1 | +"""FILM: Frame Interpolation for Large Motion (ECCV 2022).""" |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +import torch.nn.functional as F |
| 6 | + |
| 7 | +import comfy.ops |
| 8 | + |
| 9 | +ops = comfy.ops.disable_weight_init |
| 10 | + |
| 11 | + |
| 12 | +class FilmConv2d(nn.Module): |
| 13 | + """Conv2d with optional LeakyReLU and FILM-style padding.""" |
| 14 | + |
| 15 | + def __init__(self, in_channels, out_channels, size, activation=True, device=None, dtype=None, operations=ops): |
| 16 | + super().__init__() |
| 17 | + self.even_pad = not size % 2 |
| 18 | + self.conv = operations.Conv2d(in_channels, out_channels, kernel_size=size, padding=size // 2 if size % 2 else 0, device=device, dtype=dtype) |
| 19 | + self.activation = nn.LeakyReLU(0.2) if activation else None |
| 20 | + |
| 21 | + def forward(self, x): |
| 22 | + if self.even_pad: |
| 23 | + x = F.pad(x, (0, 1, 0, 1)) |
| 24 | + x = self.conv(x) |
| 25 | + if self.activation is not None: |
| 26 | + x = self.activation(x) |
| 27 | + return x |
| 28 | + |
| 29 | + |
| 30 | +def _warp_core(image, flow, grid_x, grid_y): |
| 31 | + dtype = image.dtype |
| 32 | + H, W = flow.shape[2], flow.shape[3] |
| 33 | + dx = flow[:, 0].float() / (W * 0.5) |
| 34 | + dy = flow[:, 1].float() / (H * 0.5) |
| 35 | + grid = torch.stack([grid_x[None, None, :] + dx, grid_y[None, :, None] + dy], dim=3) |
| 36 | + return F.grid_sample(image.float(), grid, mode="bilinear", padding_mode="border", align_corners=False).to(dtype) |
| 37 | + |
| 38 | + |
| 39 | +def build_image_pyramid(image, pyramid_levels): |
| 40 | + pyramid = [image] |
| 41 | + for _ in range(1, pyramid_levels): |
| 42 | + image = F.avg_pool2d(image, 2, 2) |
| 43 | + pyramid.append(image) |
| 44 | + return pyramid |
| 45 | + |
| 46 | + |
| 47 | +def flow_pyramid_synthesis(residual_pyramid): |
| 48 | + flow = residual_pyramid[-1] |
| 49 | + flow_pyramid = [flow] |
| 50 | + for residual_flow in residual_pyramid[:-1][::-1]: |
| 51 | + flow = F.interpolate(flow, size=residual_flow.shape[2:4], mode="bilinear", scale_factor=None).mul_(2).add_(residual_flow) |
| 52 | + flow_pyramid.append(flow) |
| 53 | + flow_pyramid.reverse() |
| 54 | + return flow_pyramid |
| 55 | + |
| 56 | + |
| 57 | +def multiply_pyramid(pyramid, scalar): |
| 58 | + return [image * scalar[:, None, None, None] for image in pyramid] |
| 59 | + |
| 60 | + |
| 61 | +def pyramid_warp(feature_pyramid, flow_pyramid, warp_fn): |
| 62 | + return [warp_fn(features, flow) for features, flow in zip(feature_pyramid, flow_pyramid)] |
| 63 | + |
| 64 | + |
| 65 | +def concatenate_pyramids(pyramid1, pyramid2): |
| 66 | + return [torch.cat([f1, f2], dim=1) for f1, f2 in zip(pyramid1, pyramid2)] |
| 67 | + |
| 68 | + |
| 69 | +class SubTreeExtractor(nn.Module): |
| 70 | + def __init__(self, in_channels=3, channels=64, n_layers=4, device=None, dtype=None, operations=ops): |
| 71 | + super().__init__() |
| 72 | + convs = [] |
| 73 | + for i in range(n_layers): |
| 74 | + out_ch = channels << i |
| 75 | + convs.append(nn.Sequential( |
| 76 | + FilmConv2d(in_channels, out_ch, 3, device=device, dtype=dtype, operations=operations), |
| 77 | + FilmConv2d(out_ch, out_ch, 3, device=device, dtype=dtype, operations=operations))) |
| 78 | + in_channels = out_ch |
| 79 | + self.convs = nn.ModuleList(convs) |
| 80 | + |
| 81 | + def forward(self, image, n): |
| 82 | + head = image |
| 83 | + pyramid = [] |
| 84 | + for i, layer in enumerate(self.convs): |
| 85 | + head = layer(head) |
| 86 | + pyramid.append(head) |
| 87 | + if i < n - 1: |
| 88 | + head = F.avg_pool2d(head, 2, 2) |
| 89 | + return pyramid |
| 90 | + |
| 91 | + |
| 92 | +class FeatureExtractor(nn.Module): |
| 93 | + def __init__(self, in_channels=3, channels=64, sub_levels=4, device=None, dtype=None, operations=ops): |
| 94 | + super().__init__() |
| 95 | + self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels, device=device, dtype=dtype, operations=operations) |
| 96 | + self.sub_levels = sub_levels |
| 97 | + |
| 98 | + def forward(self, image_pyramid): |
| 99 | + sub_pyramids = [self.extract_sublevels(image_pyramid[i], min(len(image_pyramid) - i, self.sub_levels)) |
| 100 | + for i in range(len(image_pyramid))] |
| 101 | + feature_pyramid = [] |
| 102 | + for i in range(len(image_pyramid)): |
| 103 | + features = sub_pyramids[i][0] |
| 104 | + for j in range(1, self.sub_levels): |
| 105 | + if j <= i: |
| 106 | + features = torch.cat([features, sub_pyramids[i - j][j]], dim=1) |
| 107 | + feature_pyramid.append(features) |
| 108 | + # Free sub-pyramids no longer needed by future levels |
| 109 | + if i >= self.sub_levels - 1: |
| 110 | + sub_pyramids[i - self.sub_levels + 1] = None |
| 111 | + return feature_pyramid |
| 112 | + |
| 113 | + |
| 114 | +class FlowEstimator(nn.Module): |
| 115 | + def __init__(self, in_channels, num_convs, num_filters, device=None, dtype=None, operations=ops): |
| 116 | + super().__init__() |
| 117 | + self._convs = nn.ModuleList() |
| 118 | + for _ in range(num_convs): |
| 119 | + self._convs.append(FilmConv2d(in_channels, num_filters, 3, device=device, dtype=dtype, operations=operations)) |
| 120 | + in_channels = num_filters |
| 121 | + self._convs.append(FilmConv2d(in_channels, num_filters // 2, 1, device=device, dtype=dtype, operations=operations)) |
| 122 | + self._convs.append(FilmConv2d(num_filters // 2, 2, 1, activation=False, device=device, dtype=dtype, operations=operations)) |
| 123 | + |
| 124 | + def forward(self, features_a, features_b): |
| 125 | + net = torch.cat([features_a, features_b], dim=1) |
| 126 | + for conv in self._convs: |
| 127 | + net = conv(net) |
| 128 | + return net |
| 129 | + |
| 130 | + |
| 131 | +class PyramidFlowEstimator(nn.Module): |
| 132 | + def __init__(self, filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops): |
| 133 | + super().__init__() |
| 134 | + in_channels = filters << 1 |
| 135 | + predictors = [] |
| 136 | + for i in range(len(flow_convs)): |
| 137 | + predictors.append(FlowEstimator(in_channels, flow_convs[i], flow_filters[i], device=device, dtype=dtype, operations=operations)) |
| 138 | + in_channels += filters << (i + 2) |
| 139 | + self._predictor = predictors[-1] |
| 140 | + self._predictors = nn.ModuleList(predictors[:-1][::-1]) |
| 141 | + |
| 142 | + def forward(self, feature_pyramid_a, feature_pyramid_b, warp_fn): |
| 143 | + levels = len(feature_pyramid_a) |
| 144 | + v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1]) |
| 145 | + residuals = [v] |
| 146 | + # Coarse-to-fine: shared predictor for deep levels, then specialized predictors for fine levels |
| 147 | + steps = [(i, self._predictor) for i in range(levels - 2, len(self._predictors) - 1, -1)] |
| 148 | + steps += [(len(self._predictors) - 1 - k, p) for k, p in enumerate(self._predictors)] |
| 149 | + for i, predictor in steps: |
| 150 | + v = F.interpolate(v, size=feature_pyramid_a[i].shape[2:4], mode="bilinear").mul_(2) |
| 151 | + v_residual = predictor(feature_pyramid_a[i], warp_fn(feature_pyramid_b[i], v)) |
| 152 | + residuals.append(v_residual) |
| 153 | + v = v.add_(v_residual) |
| 154 | + residuals.reverse() |
| 155 | + return residuals |
| 156 | + |
| 157 | + |
| 158 | +def _get_fusion_channels(level, filters): |
| 159 | + # Per direction: multi-scale features + RGB image (3ch) + flow (2ch), doubled for both directions |
| 160 | + return (sum(filters << i for i in range(level)) + 3 + 2) * 2 |
| 161 | + |
| 162 | + |
| 163 | +class Fusion(nn.Module): |
| 164 | + def __init__(self, n_layers=4, specialized_layers=3, filters=64, device=None, dtype=None, operations=ops): |
| 165 | + super().__init__() |
| 166 | + self.output_conv = operations.Conv2d(filters, 3, kernel_size=1, device=device, dtype=dtype) |
| 167 | + self.convs = nn.ModuleList() |
| 168 | + in_channels = _get_fusion_channels(n_layers, filters) |
| 169 | + increase = 0 |
| 170 | + for i in range(n_layers)[::-1]: |
| 171 | + num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers) |
| 172 | + self.convs.append(nn.ModuleList([ |
| 173 | + FilmConv2d(in_channels, num_filters, 2, activation=False, device=device, dtype=dtype, operations=operations), |
| 174 | + FilmConv2d(in_channels + (increase or num_filters), num_filters, 3, device=device, dtype=dtype, operations=operations), |
| 175 | + FilmConv2d(num_filters, num_filters, 3, device=device, dtype=dtype, operations=operations)])) |
| 176 | + in_channels = num_filters |
| 177 | + increase = _get_fusion_channels(i, filters) - num_filters // 2 |
| 178 | + |
| 179 | + def forward(self, pyramid): |
| 180 | + net = pyramid[-1] |
| 181 | + for k, layers in enumerate(self.convs): |
| 182 | + i = len(self.convs) - 1 - k |
| 183 | + net = layers[0](F.interpolate(net, size=pyramid[i].shape[2:4], mode="nearest")) |
| 184 | + net = layers[2](layers[1](torch.cat([pyramid[i], net], dim=1))) |
| 185 | + return self.output_conv(net) |
| 186 | + |
| 187 | + |
| 188 | +class FILMNet(nn.Module): |
| 189 | + def __init__(self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels=3, sub_levels=4, |
| 190 | + filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops): |
| 191 | + super().__init__() |
| 192 | + self.pyramid_levels = pyramid_levels |
| 193 | + self.fusion_pyramid_levels = fusion_pyramid_levels |
| 194 | + self.extract = FeatureExtractor(3, filters, sub_levels, device=device, dtype=dtype, operations=operations) |
| 195 | + self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters, device=device, dtype=dtype, operations=operations) |
| 196 | + self.fuse = Fusion(sub_levels, specialized_levels, filters, device=device, dtype=dtype, operations=operations) |
| 197 | + self._warp_grids = {} |
| 198 | + |
| 199 | + def get_dtype(self): |
| 200 | + return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype |
| 201 | + |
| 202 | + def _build_warp_grids(self, H, W, device): |
| 203 | + """Pre-compute warp grids for all pyramid levels.""" |
| 204 | + if (H, W) in self._warp_grids: |
| 205 | + return |
| 206 | + self._warp_grids = {} # clear old resolution grids to prevent memory leaks |
| 207 | + for _ in range(self.pyramid_levels): |
| 208 | + self._warp_grids[(H, W)] = ( |
| 209 | + torch.linspace(-(1 - 1 / W), 1 - 1 / W, W, dtype=torch.float32, device=device), |
| 210 | + torch.linspace(-(1 - 1 / H), 1 - 1 / H, H, dtype=torch.float32, device=device), |
| 211 | + ) |
| 212 | + H, W = H // 2, W // 2 |
| 213 | + |
| 214 | + def warp(self, image, flow): |
| 215 | + grid_x, grid_y = self._warp_grids[(flow.shape[2], flow.shape[3])] |
| 216 | + return _warp_core(image, flow, grid_x, grid_y) |
| 217 | + |
| 218 | + def extract_features(self, img): |
| 219 | + """Extract image and feature pyramids for a single frame. Can be cached across pairs.""" |
| 220 | + image_pyramid = build_image_pyramid(img, self.pyramid_levels) |
| 221 | + feature_pyramid = self.extract(image_pyramid) |
| 222 | + return image_pyramid, feature_pyramid |
| 223 | + |
| 224 | + def forward(self, img0, img1, timestep=0.5, cache=None): |
| 225 | + # FILM uses a scalar timestep per batch element (spatially-varying timesteps not supported) |
| 226 | + t = timestep.mean(dim=(1, 2, 3)).item() if isinstance(timestep, torch.Tensor) else timestep |
| 227 | + return self.forward_multi_timestep(img0, img1, [t], cache=cache) |
| 228 | + |
| 229 | + def forward_multi_timestep(self, img0, img1, timesteps, cache=None): |
| 230 | + """Compute flow once, synthesize at multiple timesteps. Expects batch=1 inputs.""" |
| 231 | + self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device) |
| 232 | + |
| 233 | + image_pyr0, feat_pyr0 = cache["img0"] if cache and "img0" in cache else self.extract_features(img0) |
| 234 | + image_pyr1, feat_pyr1 = cache["img1"] if cache and "img1" in cache else self.extract_features(img1) |
| 235 | + |
| 236 | + fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels] |
| 237 | + bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels] |
| 238 | + |
| 239 | + # Build warp targets and free full pyramids (only first fpl levels needed from here) |
| 240 | + fpl = self.fusion_pyramid_levels |
| 241 | + p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]), |
| 242 | + concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])] |
| 243 | + del image_pyr0, image_pyr1, feat_pyr0, feat_pyr1 |
| 244 | + |
| 245 | + results = [] |
| 246 | + dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype) |
| 247 | + for idx in range(len(timesteps)): |
| 248 | + batch_dt = dt_tensors[idx:idx + 1] |
| 249 | + bwd_scaled = multiply_pyramid(bwd_flow, batch_dt) |
| 250 | + fwd_scaled = multiply_pyramid(fwd_flow, 1 - batch_dt) |
| 251 | + fwd_warped = pyramid_warp(p2w[0], bwd_scaled, self.warp) |
| 252 | + bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp) |
| 253 | + aligned = [torch.cat([fw, bw, bf, ff], dim=1) |
| 254 | + for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)] |
| 255 | + del fwd_warped, bwd_warped, bwd_scaled, fwd_scaled |
| 256 | + results.append(self.fuse(aligned)) |
| 257 | + del aligned |
| 258 | + return torch.cat(results, dim=0) |
0 commit comments