@@ -240,19 +240,34 @@ def get_components_internal(self, container: InputContainer) -> VideoComponents:
240240 start_time = self .__start_time
241241 # Get video frames
242242 frames = []
243+ alphas = None
243244 start_pts = int (start_time / video_stream .time_base )
244245 end_pts = int ((start_time + self .__duration ) / video_stream .time_base )
245246 container .seek (start_pts , stream = video_stream )
247+ image_format = 'gbrpf32le'
246248 for frame in container .decode (video_stream ):
249+ if alphas is None :
250+ for comp in frame .format .components :
251+ if comp .is_alpha :
252+ alphas = []
253+ image_format = 'gbrapf32le'
254+ break
255+
247256 if frame .pts < start_pts :
248257 continue
249258 if self .__duration and frame .pts >= end_pts :
250259 break
251- img = frame .to_ndarray (format = 'gbrpf32le' ) # shape: (H, W, 3)
252- img = torch .from_numpy (img )
253- frames .append (img )
254260
255- images = torch .stack (frames ) if len (frames ) > 0 else torch .zeros (0 , 3 , 0 , 0 )
261+ img = frame .to_ndarray (format = image_format ) # shape: (H, W, 4)
262+ if alphas is None :
263+ frames .append (torch .from_numpy (img ))
264+ else :
265+ frames .append (torch .from_numpy (img [..., :- 1 ]))
266+ alphas .append (torch .from_numpy (img [..., - 1 :]))
267+
268+ images = torch .stack (frames ) if len (frames ) > 0 else torch .zeros (0 , 0 , 0 , 3 )
269+ if alphas is not None :
270+ alphas = torch .stack (alphas ) if len (alphas ) > 0 else torch .zeros (0 , 0 , 0 , 1 )
256271
257272 # Get frame rate
258273 frame_rate = Fraction (video_stream .average_rate ) if video_stream .average_rate else Fraction (1 )
@@ -295,7 +310,7 @@ def get_components_internal(self, container: InputContainer) -> VideoComponents:
295310 })
296311
297312 metadata = container .metadata
298- return VideoComponents (images = images , audio = audio , frame_rate = frame_rate , metadata = metadata )
313+ return VideoComponents (images = images , alpha = alphas , audio = audio , frame_rate = frame_rate , metadata = metadata )
299314
300315 def get_components (self ) -> VideoComponents :
301316 if isinstance (self .__file , io .BytesIO ):
0 commit comments