1212import math
1313import torch
1414from .._util import VideoContainer , VideoCodec , VideoComponents
15+ import logging
1516
1617
1718def container_to_output_format (container_format : str | None ) -> str | None :
@@ -238,32 +239,86 @@ def get_components_internal(self, container: InputContainer) -> VideoComponents:
238239 start_time = max (self ._get_raw_duration () + self .__start_time , 0 )
239240 else :
240241 start_time = self .__start_time
242+
241243 # Get video frames
242244 frames = []
245+ audio_frames = []
243246 alphas = None
244247 start_pts = int (start_time / video_stream .time_base )
245248 end_pts = int ((start_time + self .__duration ) / video_stream .time_base )
246- container .seek (start_pts , stream = video_stream )
249+
250+ if start_pts != 0 :
251+ container .seek (start_pts , stream = video_stream )
252+
247253 image_format = 'gbrpf32le'
248- for frame in container .decode (video_stream ):
249- if alphas is None :
250- for comp in frame .format .components :
251- if comp .is_alpha :
252- alphas = []
253- image_format = 'gbrapf32le'
254- break
254+ audio = None
255+
256+ streams = [video_stream ]
257+ has_first_audio_frame = False
258+ checked_alpha = False
255259
256- if frame .pts < start_pts :
257- continue
258- if self .__duration and frame .pts >= end_pts :
260+ # Default to False so we decode until EOF if duration is 0
261+ video_done = False
262+ audio_done = True
263+
264+ if len (container .streams .audio ):
265+ audio_stream = container .streams .audio [- 1 ]
266+ streams += [audio_stream ]
267+ resampler = av .audio .resampler .AudioResampler (format = 'fltp' )
268+ audio_done = False
269+
270+ for packet in container .demux (* streams ):
271+ if video_done and audio_done :
259272 break
260273
261- img = frame .to_ndarray (format = image_format ) # shape: (H, W, 4)
262- if alphas is None :
263- frames .append (torch .from_numpy (img ))
264- else :
265- frames .append (torch .from_numpy (img [..., :- 1 ]))
266- alphas .append (torch .from_numpy (img [..., - 1 :]))
274+ if packet .stream .type == "video" :
275+ if video_done :
276+ continue
277+ try :
278+ for frame in packet .decode ():
279+ if frame .pts < start_pts :
280+ continue
281+ if self .__duration and frame .pts >= end_pts :
282+ video_done = True
283+ break
284+
285+ if not checked_alpha :
286+ for comp in frame .format .components :
287+ if comp .is_alpha :
288+ alphas = []
289+ image_format = 'gbrapf32le'
290+ break
291+ checked_alpha = True
292+
293+ img = frame .to_ndarray (format = image_format ) # shape: (H, W, 4)
294+ if alphas is None :
295+ frames .append (torch .from_numpy (img ))
296+ else :
297+ frames .append (torch .from_numpy (img [..., :- 1 ]))
298+ alphas .append (torch .from_numpy (img [..., - 1 :]))
299+ except av .error .InvalidDataError :
300+ logging .info ("pyav decode error" )
301+
302+ elif packet .stream .type == "audio" :
303+ if audio_done :
304+ continue
305+
306+ aframes = itertools .chain .from_iterable (
307+ map (resampler .resample , packet .decode ())
308+ )
309+ for frame in aframes :
310+ if self .__duration and frame .time > start_time + self .__duration :
311+ audio_done = True
312+ break
313+
314+ if not has_first_audio_frame :
315+ offset_seconds = start_time - frame .pts * audio_stream .time_base
316+ to_skip = max (0 , int (offset_seconds * audio_stream .sample_rate ))
317+ if to_skip < frame .samples :
318+ has_first_audio_frame = True
319+ audio_frames .append (frame .to_ndarray ()[..., to_skip :])
320+ else :
321+ audio_frames .append (frame .to_ndarray ())
267322
268323 images = torch .stack (frames ) if len (frames ) > 0 else torch .zeros (0 , 0 , 0 , 3 )
269324 if alphas is not None :
@@ -272,42 +327,16 @@ def get_components_internal(self, container: InputContainer) -> VideoComponents:
272327 # Get frame rate
273328 frame_rate = Fraction (video_stream .average_rate ) if video_stream .average_rate else Fraction (1 )
274329
275- # Get audio if available
276- audio = None
277- container .seek (start_pts , stream = video_stream )
278- # Use last stream for consistency
279- if len (container .streams .audio ):
280- audio_stream = container .streams .audio [- 1 ]
281- audio_frames = []
282- resample = av .audio .resampler .AudioResampler (format = 'fltp' ).resample
283- frames = itertools .chain .from_iterable (
284- map (resample , container .decode (audio_stream ))
285- )
286-
287- has_first_frame = False
288- for frame in frames :
289- offset_seconds = start_time - frame .pts * audio_stream .time_base
290- to_skip = max (0 , int (offset_seconds * audio_stream .sample_rate ))
291- if to_skip < frame .samples :
292- has_first_frame = True
293- break
294- if has_first_frame :
295- audio_frames .append (frame .to_ndarray ()[..., to_skip :])
330+ if len (audio_frames ) > 0 :
331+ audio_data = np .concatenate (audio_frames , axis = 1 ) # shape: (channels, total_samples)
332+ if self .__duration :
333+ audio_data = audio_data [..., :int (self .__duration * audio_stream .sample_rate )]
296334
297- for frame in frames :
298- if self .__duration and frame .time > start_time + self .__duration :
299- break
300- audio_frames .append (frame .to_ndarray ()) # shape: (channels, samples)
301- if len (audio_frames ) > 0 :
302- audio_data = np .concatenate (audio_frames , axis = 1 ) # shape: (channels, total_samples)
303- if self .__duration :
304- audio_data = audio_data [..., :int (self .__duration * audio_stream .sample_rate )]
305-
306- audio_tensor = torch .from_numpy (audio_data ).unsqueeze (0 ) # shape: (1, channels, total_samples)
307- audio = AudioInput ({
308- "waveform" : audio_tensor ,
309- "sample_rate" : int (audio_stream .sample_rate ) if audio_stream .sample_rate else 1 ,
310- })
335+ audio_tensor = torch .from_numpy (audio_data ).unsqueeze (0 ) # shape: (1, channels, total_samples)
336+ audio = AudioInput ({
337+ "waveform" : audio_tensor ,
338+ "sample_rate" : int (audio_stream .sample_rate ) if audio_stream .sample_rate else 1 ,
339+ })
311340
312341 metadata = container .metadata
313342 return VideoComponents (images = images , alpha = alphas , audio = audio , frame_rate = frame_rate , metadata = metadata )
0 commit comments