@@ -106,12 +106,12 @@ def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput:
106106 if bypass :
107107 return (latent ,)
108108
109- samples = latent ["samples" ]
109+ samples = latent ["samples" ]. clone ()
110110 _ , height_scale_factor , width_scale_factor = (
111111 vae .downscale_index_formula
112112 )
113113
114- batch , _ , latent_frames , latent_height , latent_width = samples .shape
114+ _ , _ , _ , latent_height , latent_width = samples .shape
115115 width = latent_width * width_scale_factor
116116 height = latent_height * height_scale_factor
117117
@@ -124,11 +124,7 @@ def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput:
124124
125125 samples [:, :, :t .shape [2 ]] = t
126126
127- conditioning_latent_frames_mask = torch .ones (
128- (batch , 1 , latent_frames , 1 , 1 ),
129- dtype = torch .float32 ,
130- device = samples .device ,
131- )
127+ conditioning_latent_frames_mask = get_noise_mask (latent )
132128 conditioning_latent_frames_mask [:, :, :t .shape [2 ]] = 1.0 - strength
133129
134130 return io .NodeOutput ({"samples" : samples , "noise_mask" : conditioning_latent_frames_mask })
0 commit comments