Skip to content

Commit df22bcd

Browse files
Support loading the alpha channel of videos. (Comfy-Org#13564)
Not exposed in nodes yet.
1 parent 5e3f15a commit df22bcd

2 files changed

Lines changed: 22 additions & 8 deletions

File tree

comfy_api/latest/_input_impl/video_types.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,19 +240,34 @@ def get_components_internal(self, container: InputContainer) -> VideoComponents:
240240
start_time = self.__start_time
241241
# Get video frames
242242
frames = []
243+
alphas = None
243244
start_pts = int(start_time / video_stream.time_base)
244245
end_pts = int((start_time + self.__duration) / video_stream.time_base)
245246
container.seek(start_pts, stream=video_stream)
247+
image_format = 'gbrpf32le'
246248
for frame in container.decode(video_stream):
249+
if alphas is None:
250+
for comp in frame.format.components:
251+
if comp.is_alpha:
252+
alphas = []
253+
image_format = 'gbrapf32le'
254+
break
255+
247256
if frame.pts < start_pts:
248257
continue
249258
if self.__duration and frame.pts >= end_pts:
250259
break
251-
img = frame.to_ndarray(format='gbrpf32le') # shape: (H, W, 3)
252-
img = torch.from_numpy(img)
253-
frames.append(img)
254260

255-
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
261+
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
262+
if alphas is None:
263+
frames.append(torch.from_numpy(img))
264+
else:
265+
frames.append(torch.from_numpy(img[..., :-1]))
266+
alphas.append(torch.from_numpy(img[..., -1:]))
267+
268+
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 0, 0, 3)
269+
if alphas is not None:
270+
alphas = torch.stack(alphas) if len(alphas) > 0 else torch.zeros(0, 0, 0, 1)
256271

257272
# Get frame rate
258273
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
@@ -295,7 +310,7 @@ def get_components_internal(self, container: InputContainer) -> VideoComponents:
295310
})
296311

297312
metadata = container.metadata
298-
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
313+
return VideoComponents(images=images, alpha=alphas, audio=audio, frame_rate=frame_rate, metadata=metadata)
299314

300315
def get_components(self) -> VideoComponents:
301316
if isinstance(self.__file, io.BytesIO):

comfy_api/latest/_util/video_types.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from enum import Enum
44
from fractions import Fraction
55
from typing import Optional
6-
from .._input import ImageInput, AudioInput
6+
from .._input import ImageInput, AudioInput, MaskInput
77

88
class VideoCodec(str, Enum):
99
AUTO = "auto"
@@ -48,5 +48,4 @@ class VideoComponents:
4848
frame_rate: Fraction
4949
audio: Optional[AudioInput] = None
5050
metadata: Optional[dict] = None
51-
52-
51+
alpha: Optional[MaskInput] = None

0 commit comments

Comments
 (0)