Skip to content

Commit 3cbf015

Browse files
Read audio and video at the same time in video loader node. (Comfy-Org#13591)
1 parent 64b8457 commit 3cbf015

1 file changed

Lines changed: 81 additions & 52 deletions

File tree

comfy_api/latest/_input_impl/video_types.py

Lines changed: 81 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import math
1313
import torch
1414
from .._util import VideoContainer, VideoCodec, VideoComponents
15+
import logging
1516

1617

1718
def container_to_output_format(container_format: str | None) -> str | None:
@@ -238,32 +239,86 @@ def get_components_internal(self, container: InputContainer) -> VideoComponents:
238239
start_time = max(self._get_raw_duration() + self.__start_time, 0)
239240
else:
240241
start_time = self.__start_time
242+
241243
# Get video frames
242244
frames = []
245+
audio_frames = []
243246
alphas = None
244247
start_pts = int(start_time / video_stream.time_base)
245248
end_pts = int((start_time + self.__duration) / video_stream.time_base)
246-
container.seek(start_pts, stream=video_stream)
249+
250+
if start_pts != 0:
251+
container.seek(start_pts, stream=video_stream)
252+
247253
image_format = 'gbrpf32le'
248-
for frame in container.decode(video_stream):
249-
if alphas is None:
250-
for comp in frame.format.components:
251-
if comp.is_alpha:
252-
alphas = []
253-
image_format = 'gbrapf32le'
254-
break
254+
audio = None
255+
256+
streams = [video_stream]
257+
has_first_audio_frame = False
258+
checked_alpha = False
255259

256-
if frame.pts < start_pts:
257-
continue
258-
if self.__duration and frame.pts >= end_pts:
260+
# Default to False so we decode until EOF if duration is 0
261+
video_done = False
262+
audio_done = True
263+
264+
if len(container.streams.audio):
265+
audio_stream = container.streams.audio[-1]
266+
streams += [audio_stream]
267+
resampler = av.audio.resampler.AudioResampler(format='fltp')
268+
audio_done = False
269+
270+
for packet in container.demux(*streams):
271+
if video_done and audio_done:
259272
break
260273

261-
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
262-
if alphas is None:
263-
frames.append(torch.from_numpy(img))
264-
else:
265-
frames.append(torch.from_numpy(img[..., :-1]))
266-
alphas.append(torch.from_numpy(img[..., -1:]))
274+
if packet.stream.type == "video":
275+
if video_done:
276+
continue
277+
try:
278+
for frame in packet.decode():
279+
if frame.pts < start_pts:
280+
continue
281+
if self.__duration and frame.pts >= end_pts:
282+
video_done = True
283+
break
284+
285+
if not checked_alpha:
286+
for comp in frame.format.components:
287+
if comp.is_alpha:
288+
alphas = []
289+
image_format = 'gbrapf32le'
290+
break
291+
checked_alpha = True
292+
293+
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
294+
if alphas is None:
295+
frames.append(torch.from_numpy(img))
296+
else:
297+
frames.append(torch.from_numpy(img[..., :-1]))
298+
alphas.append(torch.from_numpy(img[..., -1:]))
299+
except av.error.InvalidDataError:
300+
logging.info("pyav decode error")
301+
302+
elif packet.stream.type == "audio":
303+
if audio_done:
304+
continue
305+
306+
aframes = itertools.chain.from_iterable(
307+
map(resampler.resample, packet.decode())
308+
)
309+
for frame in aframes:
310+
if self.__duration and frame.time > start_time + self.__duration:
311+
audio_done = True
312+
break
313+
314+
if not has_first_audio_frame:
315+
offset_seconds = start_time - frame.pts * audio_stream.time_base
316+
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
317+
if to_skip < frame.samples:
318+
has_first_audio_frame = True
319+
audio_frames.append(frame.to_ndarray()[..., to_skip:])
320+
else:
321+
audio_frames.append(frame.to_ndarray())
267322

268323
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 0, 0, 3)
269324
if alphas is not None:
@@ -272,42 +327,16 @@ def get_components_internal(self, container: InputContainer) -> VideoComponents:
272327
# Get frame rate
273328
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
274329

275-
# Get audio if available
276-
audio = None
277-
container.seek(start_pts, stream=video_stream)
278-
# Use last stream for consistency
279-
if len(container.streams.audio):
280-
audio_stream = container.streams.audio[-1]
281-
audio_frames = []
282-
resample = av.audio.resampler.AudioResampler(format='fltp').resample
283-
frames = itertools.chain.from_iterable(
284-
map(resample, container.decode(audio_stream))
285-
)
286-
287-
has_first_frame = False
288-
for frame in frames:
289-
offset_seconds = start_time - frame.pts * audio_stream.time_base
290-
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
291-
if to_skip < frame.samples:
292-
has_first_frame = True
293-
break
294-
if has_first_frame:
295-
audio_frames.append(frame.to_ndarray()[..., to_skip:])
330+
if len(audio_frames) > 0:
331+
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
332+
if self.__duration:
333+
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
296334

297-
for frame in frames:
298-
if self.__duration and frame.time > start_time + self.__duration:
299-
break
300-
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
301-
if len(audio_frames) > 0:
302-
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
303-
if self.__duration:
304-
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
305-
306-
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
307-
audio = AudioInput({
308-
"waveform": audio_tensor,
309-
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
310-
})
335+
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
336+
audio = AudioInput({
337+
"waveform": audio_tensor,
338+
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
339+
})
311340

312341
metadata = container.metadata
313342
return VideoComponents(images=images, alpha=alphas, audio=audio, frame_rate=frame_rate, metadata=metadata)

0 commit comments

Comments
 (0)