Skip to content

Commit 19e1a5f

Browse files
committed
2.1 update
- retrained much better v3 models - revamp of inference and training scriptsa
1 parent 4557d6b commit 19e1a5f

10 files changed

Lines changed: 1422 additions & 613 deletions

File tree

inference.py

Lines changed: 97 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,45 +3,50 @@
33
import torch
44
import torchaudio
55
import torchaudio.functional as F_audio
6+
import soundfile as sf
67
import glob
78
import numpy as np
89
import math
910
import gc
1011

12+
from model_v5 import CGA_ResUNet
1113
from model_v3 import DSCA_ResUNet_v3
1214

13-
# --- CONFIG ---
14-
IN_DIR = "infer_input"
15-
OUT_DIR = "infer_output"
16-
CKPT_DIR = "ckpts"
17-
N_MELS = 160
15+
# ---------------------------------------------------------------------------------------------------------------------------------------------------------------
1816

19-
20-
# --- INFER AND PROCESSING CONFIG ---
17+
# INFER AND PROCESSING CONFIG
18+
MODEL_VERSION = "v3"
2119
FORCE_CPU = False # By default runs with GPU Acceleration ( CUDA )
2220
MASK_MODE = "Soft" # Available: "Soft", "Hard", "PowerMean" and "Hybrid"
2321
DEBUG_MASK_PRED = False # Set to True if you need to debug / predict the model's prediction on your samples.
2422
SAVE_EXTENSION = "wave_16" # Available: "flac", "wave_16" and "wave_32float"
2523

24+
# SMART CUTTER CONFIG ( Safe defaults. )
25+
SILENCE_TARGET_DURATION = 0.100 # The target duration for silence gaps (e.g. 500ms gap / silence --> 100ms)
26+
MIN_SEGMENT_DURATION_MS = 100 # Minimum length for detected spots to count as viable for cutting ( 100ms, safe default. )
27+
28+
# PREDICTION STABILIZATION
29+
STABILITY_NOISE = False # Injects subtle noise into pure silence to stabilize the model
30+
STABILITY_DB_LEVEL = -75.0 # The dB level of the injected noise ( -80 is minimum; Model's limitation. )
31+
STABILITY_FADE_MS = 1 # Fade duration (ms) for the injected noise edges to be softer
32+
ENABLE_BRIDGING = True # Filling of the mask/prediction gaps - Only use when and if you debug the mask output and notice gaps.
33+
34+
# PATHS
35+
IN_DIR = "infer_input"
36+
OUT_DIR = "infer_output"
37+
CKPT_DIR = "ckpts"
2638

27-
# --- SMART CUTTER CONFIG ( Safe defaults. ) ---
28-
SILENCE_TARGET_DURATION = 0.100 # The target duration for silence gaps (e.g. 500ms gap / silence --> 100ms)
29-
MIN_SEGMENT_DURATION_MS = 100 # Minimum length for detected spots to count as viable for cutting ( 100ms, safe default. )
39+
# ---------------------------------------------------------------------------------------------------------------------------------------------------------------
3040

3141

3242

33-
# Do not tweak these!
43+
# Estabilished safe params, do not tweak these unless necessary and you know what you're doing.
3444
SEARCH_WINDOW_MS = 25
3545
FADE_DURATION_MS = 10
3646
CUTTING_PROBABILITY = 0.5
3747
SAFETY_BUFFER_MS = 5
38-
TARGET_STEP_SEC = 60.0 # segmentation / chunking length
39-
MARGIN_SEC = 2.0
48+
SEGMENT_LEN = 8.0
4049

41-
# Not properly supported yet. Likely in next / more proper models' revision.
42-
STABILITY_NOISE = False # Injects subtle noise into pure silence to stabilize the model
43-
STABILITY_DB_LEVEL = -70.0 # The dB level of the injected noise
44-
STABILITY_FADE_MS = 4 # Fade duration (ms) for the injected noise edges to be softer
4550

4651

4752
def get_cosine_fade(length, device):
@@ -70,47 +75,41 @@ def apply_fade(waveform, fade_samples, mode="both"):
7075

7176
def inject_stability_noise(wav, sr, device):
7277
"""
73-
Finds exact digital silence (0.0) and injects soft-enveloped noise.
78+
Injects steady, neutral colored noise
7479
"""
75-
# 1. Create Noise Tensor
7680
noise_amp = 10 ** (STABILITY_DB_LEVEL / 20.0)
77-
78-
# Identify silence regions (1 for silence, 0 for audio)
79-
# wav is [1, T]
81+
8082
silence_mask = (wav.squeeze(0) == 0.0).float()
81-
82-
# Find edges: 1 (start of silence), -1 (end of silence)
8383
diff = torch.diff(silence_mask, prepend=torch.tensor([0.0], device=device), append=torch.tensor([0.0], device=device))
84-
8584
starts = torch.where(diff == 1)[0]
8685
ends = torch.where(diff == -1)[0]
87-
86+
8887
if len(starts) == 0:
8988
return wav
9089

91-
# Generate a master noise tensor for efficiency
92-
full_noise = torch.randn_like(wav) * noise_amp
93-
90+
raw_noise = torch.randn_like(wav)
91+
92+
alpha = 0.85
93+
neutral_noise = torchaudio.functional.lfilter(
94+
raw_noise,
95+
torch.tensor([1.0, 0.0], device=device),
96+
torch.tensor([1.0, -alpha], device=device)
97+
)
98+
neutral_noise *= noise_amp
9499
fade_samples = int(sr * (STABILITY_FADE_MS / 1000.0))
95-
96-
# Process each silence gap
100+
97101
for start, end in zip(starts, ends):
98102
length = end - start
99103
if length <= 0: continue
100-
101-
# Extract noise chunk
102-
noise_chunk = full_noise[:, start:end].clone()
103-
104-
# Apply fade to noise chunk so it doesn't have hard edges
105-
# If the gap is tiny, the fades might overlap, apply_fade handles that gracefully-ish or we skip.
104+
105+
noise_chunk = neutral_noise[:, start:end].clone()
106+
106107
if length > fade_samples * 2:
107108
noise_chunk = apply_fade(noise_chunk, fade_samples, mode="both")
108109
else:
109-
# For tiny gaps, just window it completely
110110
window = torch.hann_window(length, device=device)
111111
noise_chunk *= window
112-
113-
# Inject
112+
114113
wav[:, start:end] = noise_chunk
115114

116115
return wav
@@ -212,12 +211,20 @@ def SmartCutter(waveform, mask, sr=48000):
212211
waveform = waveform.cpu()
213212
mask = mask.cpu()
214213

215-
target_size = waveform.shape[1]
214+
if ENABLE_BRIDGING:
215+
# Gap bridge operation
216+
bridge_frames = 5 # At 100fps, 50ms is ~5 frames
217+
mask = mask.view(1, 1, -1)
218+
219+
mask = torch.nn.functional.max_pool1d(mask, bridge_frames, 1, bridge_frames//2) # Dilation
220+
mask = -torch.nn.functional.max_pool1d(-mask, bridge_frames, 1, bridge_frames//2) # Erosion
216221

217222
# Interpolate the low-res mask up to the full audio resolution.
218223
if mask.dim() == 1: mask = mask.view(1, 1, -1)
219224
elif mask.dim() == 2: mask = mask.unsqueeze(1)
220225

226+
target_size = waveform.shape[1]
227+
221228
mask_full = torch.nn.functional.interpolate(
222229
mask, size=target_size, mode='linear', align_corners=True
223230
).squeeze()
@@ -275,7 +282,7 @@ def process_grid_aligned(model, transform, waveform, sr, hop_length, device, sta
275282
# Processes audio in overlapping chunks and averages the results.
276283
total_samples = waveform.shape[1]
277284

278-
CHUNK_SEC = TARGET_STEP_SEC
285+
CHUNK_SEC = SEGMENT_LEN
279286
OVERLAP_SEC = CHUNK_SEC / 2
280287

281288
chunk_samples = int(CHUNK_SEC * sr)
@@ -329,10 +336,10 @@ def process_grid_aligned(model, transform, waveform, sr, hop_length, device, sta
329336

330337
# Ensure we don't go out of bounds
331338
if start_frame + frames_per_chunk > mask_accumulator.shape[1]:
332-
# Expand CPU buffer dynamically if needed
333-
extra = (start_frame + frames_per_chunk) - mask_accumulator.shape[1]
334-
mask_accumulator = torch.nn.functional.pad(mask_accumulator, (0, extra))
335-
weight_accumulator = torch.nn.functional.pad(weight_accumulator, (0, extra))
339+
# Expand CPU buffer dynamically if needed
340+
extra = (start_frame + frames_per_chunk) - mask_accumulator.shape[1]
341+
mask_accumulator = torch.nn.functional.pad(mask_accumulator, (0, extra))
342+
weight_accumulator = torch.nn.functional.pad(weight_accumulator, (0, extra))
336343

337344
# Accumulate on CPU
338345
current_pred_cpu = raw_mask.cpu() # Move pred to CPU
@@ -375,7 +382,8 @@ def _run_inference(model, mel_transform, wav_chunk, device, input_buffer):
375382
mask_2d = model(input_buffer[:, :, :, :current_frames])
376383

377384
# Collapse 2D output (freq/time) to 1D (time) based on strategy.
378-
if MASK_MODE == "Soft": mask_pred = torch.mean(mask_2d, dim=2)
385+
if MASK_MODE == "Soft":
386+
mask_pred = torch.mean(mask_2d, dim=2)
379387
elif MASK_MODE == "Hybrid":
380388
soft_mask = torch.mean(mask_2d, dim=2)
381389
hard_mask = torch.max(mask_2d, dim=2)[0]
@@ -430,15 +438,23 @@ def processing():
430438

431439
# Select that channel and keep dimensions as [1, Time]
432440
wav = wav[best_ch_idx].unsqueeze(0)
433-
434441
print(f" -> Converted Stereo to Mono (Selected Ch {best_ch_idx}, DC: {dc_offsets[best_ch_idx]:.6f})")
435442

443+
wav_for_inference = wav.clone()
444+
445+
# Safety norm on input
446+
input_peak = torch.abs(wav_for_inference).max()
447+
if input_peak > 0:
448+
target_peak = 0.9
449+
wav_for_inference = wav_for_inference * (target_peak / input_peak)
450+
436451
if STABILITY_NOISE:
437-
wav = inject_stability_noise(wav, sr, wav.device)
452+
wav_for_inference = inject_stability_noise(wav_for_inference, sr, wav.device)
438453

439454
# Dynamic model loading based on Sample Rate.
440455
current_hop = sr // 100
441456
if sr not in loaded_models:
457+
442458
# Release previous model if switching SR
443459
if len(loaded_models) > 0:
444460
print("Unloading previous model to free VRAM...")
@@ -447,32 +463,48 @@ def processing():
447463
if device.type == 'cuda':
448464
torch.cuda.empty_cache()
449465

450-
model_path = os.path.join(CKPT_DIR, f"model_{sr}.pth")
466+
model_path = os.path.join(CKPT_DIR, f"{MODEL_VERSION}_model_{sr}.pth")
451467
if not os.path.exists(model_path):
452-
print(f"Skipping {fname}: No model for {sr}Hz")
468+
print(f"Skipping {fname}: No {MODEL_VERSION} model for {sr}Hz")
453469
continue
454470

455-
print(f"Loading {sr}Hz model...")
456-
model = DSCA_ResUNet_v3(n_channels=2).to(device)
471+
print(f"Loading {sr}Hz {MODEL_VERSION} model ...")
472+
473+
if MODEL_VERSION == "v3":
474+
model = DSCA_ResUNet_v3(n_channels=2, n_classes=1).to(device) # v3
475+
elif MODEL_VERSION == "v5":
476+
model = CGA_ResUNet(n_channels=2, n_classes=1).to(device) # v5
477+
else:
478+
print(f"'{MODEL_VERSION}' is not a valid model version choice. Exiting.")
479+
sys.exit(1)
480+
457481
model.load_state_dict(torch.load(model_path, map_location=device))
458482
model.eval()
459483

484+
# mel transform config
485+
if sr in [48000, 40000]:
486+
N_FFT = 2048
487+
N_MELS = 160
488+
else:
489+
N_FFT = 1024 # for 32khz model variant.
490+
N_MELS = 128
491+
460492
mel_transform = torchaudio.transforms.MelSpectrogram(
461-
sample_rate=sr, n_mels=N_MELS, n_fft=2048, hop_length=current_hop
493+
sample_rate=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=current_hop
462494
).to(device)
463495

464496
# we're pre-allocating a static buffer for the model input
465497
# Shape: [Batch, Channels, Mel_Bins, Frames_per_60s_chunk]
466498
# Channels = 2 (Mel + Delta)
467-
dummy_frames = int(math.ceil((TARGET_STEP_SEC * sr) / current_hop)) + 5
499+
dummy_frames = int(math.ceil((SEGMENT_LEN * sr) / current_hop)) + 5
468500
static_buffer = torch.zeros((1, 2, N_MELS, dummy_frames), device=device)
469501

470502
loaded_models[sr] = (model, mel_transform, static_buffer)
471503

472504
curr_model, curr_mel_transform, curr_buffer = loaded_models[sr]
473505

474506
# Inference (GPU-accelerated, CPU accumulation)
475-
mel_mask = process_grid_aligned(curr_model, curr_mel_transform, wav, sr, current_hop, device, curr_buffer)
507+
mel_mask = process_grid_aligned(curr_model, curr_mel_transform, wav_for_inference, sr, current_hop, device, curr_buffer)
476508

477509
if device.type == 'cuda':
478510
torch.cuda.empty_cache()
@@ -492,17 +524,23 @@ def processing():
492524
debug_wav = wav + (debug_noise * (binary_mask_interpolated > CUTTING_PROBABILITY).float())
493525
torchaudio.save(os.path.join(OUT_DIR, "debug_" + os.path.basename(f_path)), debug_wav, sr)
494526

495-
out_path = os.path.join(OUT_DIR, fname)
496527

497528
# Volume Normalization
498529
peak = torch.abs(cleaned).max()
499530
if peak >= 0.95:
500531
scale_factor = 0.95 / peak.item()
501532
cleaned = cleaned * scale_factor
502533

534+
# Output path construction
535+
file_stem = os.path.splitext(fname)[0]
536+
if SAVE_EXTENSION == "flac":
537+
out_path = os.path.join(OUT_DIR, file_stem + ".flac")
538+
elif "wave" in SAVE_EXTENSION:
539+
out_path = os.path.join(OUT_DIR, file_stem + ".wav")
540+
503541
# Saving
504542
if SAVE_EXTENSION == "flac":
505-
torchaudio.save(out_path, cleaned, sr, bits_per_sample=16)
543+
torchaudio.save(out_path, cleaned, sr, format="flac", backend="soundfile")
506544
elif SAVE_EXTENSION == "wave_16":
507545
torchaudio.save(out_path, cleaned, sr, encoding="PCM_S", bits_per_sample=16)
508546
elif SAVE_EXTENSION == "wave_32float":
@@ -518,6 +556,7 @@ def processing():
518556
# inputs and outputs
519557
if 'wav' in locals(): del wav
520558
if 'cleaned' in locals(): del cleaned
559+
if 'wav_for_inference' in locals(): del wav_for_inference
521560

522561
# masks and intermediate tensors
523562
if 'binary_mask' in locals(): del binary_mask

0 commit comments

Comments
 (0)