33import torch
44import torchaudio
55import torchaudio .functional as F_audio
6+ import soundfile as sf
67import glob
78import numpy as np
89import math
910import gc
1011
12+ from model_v5 import CGA_ResUNet
1113from model_v3 import DSCA_ResUNet_v3
1214
13- # --- CONFIG ---
14- IN_DIR = "infer_input"
15- OUT_DIR = "infer_output"
16- CKPT_DIR = "ckpts"
17- N_MELS = 160
15+ # ---------------------------------------------------------------------------------------------------------------------------------------------------------------
1816
19-
20- # --- INFER AND PROCESSING CONFIG ---
17+ # INFER AND PROCESSING CONFIG
18+ MODEL_VERSION = "v3"
2119FORCE_CPU = False # By default runs with GPU Acceleration ( CUDA )
2220MASK_MODE = "Soft" # Available: "Soft", "Hard", "PowerMean" and "Hybrid"
2321DEBUG_MASK_PRED = False # Set to True if you need to debug / predict the model's prediction on your samples.
2422SAVE_EXTENSION = "wave_16" # Available: "flac", "wave_16" and "wave_32float"
2523
24+ # SMART CUTTER CONFIG ( Safe defaults. )
25+ SILENCE_TARGET_DURATION = 0.100 # The target duration for silence gaps (e.g. 500ms gap / silence --> 100ms)
26+ MIN_SEGMENT_DURATION_MS = 100 # Minimum length for detected spots to count as viable for cutting ( 100ms, safe default. )
27+
28+ # PREDICTION STABILIZATION
29+ STABILITY_NOISE = False # Injects subtle noise into pure silence to stabilize the model
30+ STABILITY_DB_LEVEL = - 75.0 # The dB level of the injected noise ( -80 is minimum; Model's limitation. )
31+ STABILITY_FADE_MS = 1 # Fade duration (ms) for the injected noise edges to be softer
32+ ENABLE_BRIDGING = True # Filling of the mask/prediction gaps - Only use when and if you debug the mask output and notice gaps.
33+
34+ # PATHS
35+ IN_DIR = "infer_input"
36+ OUT_DIR = "infer_output"
37+ CKPT_DIR = "ckpts"
2638
27- # --- SMART CUTTER CONFIG ( Safe defaults. ) ---
28- SILENCE_TARGET_DURATION = 0.100 # The target duration for silence gaps (e.g. 500ms gap / silence --> 100ms)
29- MIN_SEGMENT_DURATION_MS = 100 # Minimum length for detected spots to count as viable for cutting ( 100ms, safe default. )
39+ # ---------------------------------------------------------------------------------------------------------------------------------------------------------------
3040
3141
3242
33- # Do not tweak these!
43+ # Estabilished safe params, do not tweak these unless necessary and you know what you're doing.
3444SEARCH_WINDOW_MS = 25
3545FADE_DURATION_MS = 10
3646CUTTING_PROBABILITY = 0.5
3747SAFETY_BUFFER_MS = 5
38- TARGET_STEP_SEC = 60.0 # segmentation / chunking length
39- MARGIN_SEC = 2.0
48+ SEGMENT_LEN = 8.0
4049
41- # Not properly supported yet. Likely in next / more proper models' revision.
42- STABILITY_NOISE = False # Injects subtle noise into pure silence to stabilize the model
43- STABILITY_DB_LEVEL = - 70.0 # The dB level of the injected noise
44- STABILITY_FADE_MS = 4 # Fade duration (ms) for the injected noise edges to be softer
4550
4651
4752def get_cosine_fade (length , device ):
@@ -70,47 +75,41 @@ def apply_fade(waveform, fade_samples, mode="both"):
7075
7176def inject_stability_noise (wav , sr , device ):
7277 """
73- Finds exact digital silence (0.0) and injects soft-enveloped noise.
78+ Injects steady, neutral colored noise
7479 """
75- # 1. Create Noise Tensor
7680 noise_amp = 10 ** (STABILITY_DB_LEVEL / 20.0 )
77-
78- # Identify silence regions (1 for silence, 0 for audio)
79- # wav is [1, T]
81+
8082 silence_mask = (wav .squeeze (0 ) == 0.0 ).float ()
81-
82- # Find edges: 1 (start of silence), -1 (end of silence)
8383 diff = torch .diff (silence_mask , prepend = torch .tensor ([0.0 ], device = device ), append = torch .tensor ([0.0 ], device = device ))
84-
8584 starts = torch .where (diff == 1 )[0 ]
8685 ends = torch .where (diff == - 1 )[0 ]
87-
86+
8887 if len (starts ) == 0 :
8988 return wav
9089
91- # Generate a master noise tensor for efficiency
92- full_noise = torch .randn_like (wav ) * noise_amp
93-
90+ raw_noise = torch .randn_like (wav )
91+
92+ alpha = 0.85
93+ neutral_noise = torchaudio .functional .lfilter (
94+ raw_noise ,
95+ torch .tensor ([1.0 , 0.0 ], device = device ),
96+ torch .tensor ([1.0 , - alpha ], device = device )
97+ )
98+ neutral_noise *= noise_amp
9499 fade_samples = int (sr * (STABILITY_FADE_MS / 1000.0 ))
95-
96- # Process each silence gap
100+
97101 for start , end in zip (starts , ends ):
98102 length = end - start
99103 if length <= 0 : continue
100-
101- # Extract noise chunk
102- noise_chunk = full_noise [:, start :end ].clone ()
103-
104- # Apply fade to noise chunk so it doesn't have hard edges
105- # If the gap is tiny, the fades might overlap, apply_fade handles that gracefully-ish or we skip.
104+
105+ noise_chunk = neutral_noise [:, start :end ].clone ()
106+
106107 if length > fade_samples * 2 :
107108 noise_chunk = apply_fade (noise_chunk , fade_samples , mode = "both" )
108109 else :
109- # For tiny gaps, just window it completely
110110 window = torch .hann_window (length , device = device )
111111 noise_chunk *= window
112-
113- # Inject
112+
114113 wav [:, start :end ] = noise_chunk
115114
116115 return wav
@@ -212,12 +211,20 @@ def SmartCutter(waveform, mask, sr=48000):
212211 waveform = waveform .cpu ()
213212 mask = mask .cpu ()
214213
215- target_size = waveform .shape [1 ]
214+ if ENABLE_BRIDGING :
215+ # Gap bridge operation
216+ bridge_frames = 5 # At 100fps, 50ms is ~5 frames
217+ mask = mask .view (1 , 1 , - 1 )
218+
219+ mask = torch .nn .functional .max_pool1d (mask , bridge_frames , 1 , bridge_frames // 2 ) # Dilation
220+ mask = - torch .nn .functional .max_pool1d (- mask , bridge_frames , 1 , bridge_frames // 2 ) # Erosion
216221
217222 # Interpolate the low-res mask up to the full audio resolution.
218223 if mask .dim () == 1 : mask = mask .view (1 , 1 , - 1 )
219224 elif mask .dim () == 2 : mask = mask .unsqueeze (1 )
220225
226+ target_size = waveform .shape [1 ]
227+
221228 mask_full = torch .nn .functional .interpolate (
222229 mask , size = target_size , mode = 'linear' , align_corners = True
223230 ).squeeze ()
@@ -275,7 +282,7 @@ def process_grid_aligned(model, transform, waveform, sr, hop_length, device, sta
275282 # Processes audio in overlapping chunks and averages the results.
276283 total_samples = waveform .shape [1 ]
277284
278- CHUNK_SEC = TARGET_STEP_SEC
285+ CHUNK_SEC = SEGMENT_LEN
279286 OVERLAP_SEC = CHUNK_SEC / 2
280287
281288 chunk_samples = int (CHUNK_SEC * sr )
@@ -329,10 +336,10 @@ def process_grid_aligned(model, transform, waveform, sr, hop_length, device, sta
329336
330337 # Ensure we don't go out of bounds
331338 if start_frame + frames_per_chunk > mask_accumulator .shape [1 ]:
332- # Expand CPU buffer dynamically if needed
333- extra = (start_frame + frames_per_chunk ) - mask_accumulator .shape [1 ]
334- mask_accumulator = torch .nn .functional .pad (mask_accumulator , (0 , extra ))
335- weight_accumulator = torch .nn .functional .pad (weight_accumulator , (0 , extra ))
339+ # Expand CPU buffer dynamically if needed
340+ extra = (start_frame + frames_per_chunk ) - mask_accumulator .shape [1 ]
341+ mask_accumulator = torch .nn .functional .pad (mask_accumulator , (0 , extra ))
342+ weight_accumulator = torch .nn .functional .pad (weight_accumulator , (0 , extra ))
336343
337344 # Accumulate on CPU
338345 current_pred_cpu = raw_mask .cpu () # Move pred to CPU
@@ -375,7 +382,8 @@ def _run_inference(model, mel_transform, wav_chunk, device, input_buffer):
375382 mask_2d = model (input_buffer [:, :, :, :current_frames ])
376383
377384 # Collapse 2D output (freq/time) to 1D (time) based on strategy.
378- if MASK_MODE == "Soft" : mask_pred = torch .mean (mask_2d , dim = 2 )
385+ if MASK_MODE == "Soft" :
386+ mask_pred = torch .mean (mask_2d , dim = 2 )
379387 elif MASK_MODE == "Hybrid" :
380388 soft_mask = torch .mean (mask_2d , dim = 2 )
381389 hard_mask = torch .max (mask_2d , dim = 2 )[0 ]
@@ -430,15 +438,23 @@ def processing():
430438
431439 # Select that channel and keep dimensions as [1, Time]
432440 wav = wav [best_ch_idx ].unsqueeze (0 )
433-
434441 print (f" -> Converted Stereo to Mono (Selected Ch { best_ch_idx } , DC: { dc_offsets [best_ch_idx ]:.6f} )" )
435442
443+ wav_for_inference = wav .clone ()
444+
445+ # Safety norm on input
446+ input_peak = torch .abs (wav_for_inference ).max ()
447+ if input_peak > 0 :
448+ target_peak = 0.9
449+ wav_for_inference = wav_for_inference * (target_peak / input_peak )
450+
436451 if STABILITY_NOISE :
437- wav = inject_stability_noise (wav , sr , wav .device )
452+ wav_for_inference = inject_stability_noise (wav_for_inference , sr , wav .device )
438453
439454 # Dynamic model loading based on Sample Rate.
440455 current_hop = sr // 100
441456 if sr not in loaded_models :
457+
442458 # Release previous model if switching SR
443459 if len (loaded_models ) > 0 :
444460 print ("Unloading previous model to free VRAM..." )
@@ -447,32 +463,48 @@ def processing():
447463 if device .type == 'cuda' :
448464 torch .cuda .empty_cache ()
449465
450- model_path = os .path .join (CKPT_DIR , f"model_ { sr } .pth" )
466+ model_path = os .path .join (CKPT_DIR , f"{ MODEL_VERSION } _model_ { sr } .pth" )
451467 if not os .path .exists (model_path ):
452- print (f"Skipping { fname } : No model for { sr } Hz" )
468+ print (f"Skipping { fname } : No { MODEL_VERSION } model for { sr } Hz" )
453469 continue
454470
455- print (f"Loading { sr } Hz model..." )
456- model = DSCA_ResUNet_v3 (n_channels = 2 ).to (device )
471+ print (f"Loading { sr } Hz { MODEL_VERSION } model ..." )
472+
473+ if MODEL_VERSION == "v3" :
474+ model = DSCA_ResUNet_v3 (n_channels = 2 , n_classes = 1 ).to (device ) # v3
475+ elif MODEL_VERSION == "v5" :
476+ model = CGA_ResUNet (n_channels = 2 , n_classes = 1 ).to (device ) # v5
477+ else :
478+ print (f"'{ MODEL_VERSION } ' is not a valid model version choice. Exiting." )
479+ sys .exit (1 )
480+
457481 model .load_state_dict (torch .load (model_path , map_location = device ))
458482 model .eval ()
459483
484+ # mel transform config
485+ if sr in [48000 , 40000 ]:
486+ N_FFT = 2048
487+ N_MELS = 160
488+ else :
489+ N_FFT = 1024 # for 32khz model variant.
490+ N_MELS = 128
491+
460492 mel_transform = torchaudio .transforms .MelSpectrogram (
461- sample_rate = sr , n_mels = N_MELS , n_fft = 2048 , hop_length = current_hop
493+ sample_rate = sr , n_mels = N_MELS , n_fft = N_FFT , hop_length = current_hop
462494 ).to (device )
463495
464496 # we're pre-allocating a static buffer for the model input
465497 # Shape: [Batch, Channels, Mel_Bins, Frames_per_60s_chunk]
466498 # Channels = 2 (Mel + Delta)
467- dummy_frames = int (math .ceil ((TARGET_STEP_SEC * sr ) / current_hop )) + 5
499+ dummy_frames = int (math .ceil ((SEGMENT_LEN * sr ) / current_hop )) + 5
468500 static_buffer = torch .zeros ((1 , 2 , N_MELS , dummy_frames ), device = device )
469501
470502 loaded_models [sr ] = (model , mel_transform , static_buffer )
471503
472504 curr_model , curr_mel_transform , curr_buffer = loaded_models [sr ]
473505
474506 # Inference (GPU-accelerated, CPU accumulation)
475- mel_mask = process_grid_aligned (curr_model , curr_mel_transform , wav , sr , current_hop , device , curr_buffer )
507+ mel_mask = process_grid_aligned (curr_model , curr_mel_transform , wav_for_inference , sr , current_hop , device , curr_buffer )
476508
477509 if device .type == 'cuda' :
478510 torch .cuda .empty_cache ()
@@ -492,17 +524,23 @@ def processing():
492524 debug_wav = wav + (debug_noise * (binary_mask_interpolated > CUTTING_PROBABILITY ).float ())
493525 torchaudio .save (os .path .join (OUT_DIR , "debug_" + os .path .basename (f_path )), debug_wav , sr )
494526
495- out_path = os .path .join (OUT_DIR , fname )
496527
497528 # Volume Normalization
498529 peak = torch .abs (cleaned ).max ()
499530 if peak >= 0.95 :
500531 scale_factor = 0.95 / peak .item ()
501532 cleaned = cleaned * scale_factor
502533
534+ # Output path construction
535+ file_stem = os .path .splitext (fname )[0 ]
536+ if SAVE_EXTENSION == "flac" :
537+ out_path = os .path .join (OUT_DIR , file_stem + ".flac" )
538+ elif "wave" in SAVE_EXTENSION :
539+ out_path = os .path .join (OUT_DIR , file_stem + ".wav" )
540+
503541 # Saving
504542 if SAVE_EXTENSION == "flac" :
505- torchaudio .save (out_path , cleaned , sr , bits_per_sample = 16 )
543+ torchaudio .save (out_path , cleaned , sr , format = "flac" , backend = "soundfile" )
506544 elif SAVE_EXTENSION == "wave_16" :
507545 torchaudio .save (out_path , cleaned , sr , encoding = "PCM_S" , bits_per_sample = 16 )
508546 elif SAVE_EXTENSION == "wave_32float" :
@@ -518,6 +556,7 @@ def processing():
518556 # inputs and outputs
519557 if 'wav' in locals (): del wav
520558 if 'cleaned' in locals (): del cleaned
559+ if 'wav_for_inference' in locals (): del wav_for_inference
521560
522561 # masks and intermediate tensors
523562 if 'binary_mask' in locals (): del binary_mask
0 commit comments