Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions AudioLoader/Music.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,3 +697,105 @@ def available_groups(self, group):





class Nsynth(Dataset):
"""Dataset class for Nsynth dataset.
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.

split (str):
Choose different dataset splits such as ``"train"``,``"valid"`` or ``"test"``c. (default: ``"train"``).

download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).

"""


def __init__(self, root, split = "train", download = False):

url = f"http://download.magenta.tensorflow.org/datasets/nsynth/nsynth-{split}.jsonwav.tar.gz"
# Getting audio path
archive_name = f'nsynth-{split}.jsonwav.tar.gz'
folder_name = 'Nsynth'
download_path = os.path.join(root, folder_name)
assert split.upper()=="TRAIN" or split.upper()=="TEST" or split.upper()=="VALID" , f"split={split} is not present in this dataset"
self._path = os.path.join(download_path, f"nsynth-{split}", "audio")

checksum_dict = {"test": "5e6f8719bf7e16ad0a00d518b78af77d", "train":"fde6665a93865503ba598b9fac388660", "valid":"87e94a00a19b6dbc99cf6d4c0c0cae87"}

if download:
#file exists and extracted
if os.path.isfile(os.path.join(download_path,archive_name)) and os.path.exists(self._path):
print(f"Dataset archive exists, all files are extracted. Using all file from {self._path} ")
#file exists but not extracted
if os.path.isfile(os.path.join(download_path,archive_name)) and not os.path.exists(self._path):
print(f"Dataset archive exists, extracting archive:{os.path.join(download_path,archive_name)}")
extract_archive(os.path.join(download_path, archive_name))
print(f"Using all file from {self._path} ")
#file not exist
elif not os.path.isfile(os.path.join(download_path,archive_name)):
if not os.path.exists(download_path):
os.makedirs(download_path)
try:
download_url(url, download_path, hash_value= checksum_dict[split], hash_type='md5')
extract_archive(os.path.join(download_path, archive_name))
print(f"All files are extracted. Using all file from {self._path} ")

except:
raise Exception('Auto download fails. '+
'You may want to download it manually from:\n'+
url+ '\n' +
f'Then, put it inside {download_path}')
else:
#archive is downloaded and extracted
if os.path.isfile(os.path.join(download_path, archive_name)) and not os.path.exists(self._path):
print(f"Dataset archive exists, all files are extracted. Using all file from {self._path} ")
#archive is downloaded but not extracted
elif os.path.isfile(os.path.join(download_path, archive_name)) and not os.path.exists(self._path):
print(f'archive:{os.path.join(download_path, archive_name)} exists, extracting...')
extract_archive(os.path.join(download_path, archive_name))
print(f"Using all file from {self._path} ")
else:
raise FileNotFoundError(f"Dataset not found at {self._path}, please specify the correct location or set `download=True`")

print(f'Using all data at {self._path}')
self._walker = glob.glob(f"{self._path}/*.wav")

#load the label file
label_path = os.path.join(download_path, f"nsynth-{split}/examples.json")
self.labels = json.load(open(label_path,"r"))

def __getitem__(self, n):
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
tuple: ``(path, wavform, sr, note, note_str, instr., instr._str, pitch, vel, qualities, qual._str, instr._fam, instr._fam_str, instr._src, instr._src_str )``
"""
file_path = self._walker[n]
waveform, sample_rate = torchaudio.load(file_path)
feature_dict = self.labels[os.path.basename(file_path).split(".")[0]]

batch = {'path': file_path,
'waveform': waveform,
'sample_rate': sample_rate,
'note': feature_dict['note'],
"note_str":feature_dict['note_str'],
"instrument":feature_dict['instrument'],
"instrument_str":feature_dict['note_str'],
"pitch":feature_dict['pitch'],
"velocity":feature_dict['velocity'],
"qualities":feature_dict['qualities'],
"qualities_str":feature_dict['qualities_str'],
"instrument_family":feature_dict['instrument_family'],
"instrument_family_str":feature_dict['instrument_family_str'],
"instrument_source":feature_dict['instrument_source'],
"instrument_source_str":feature_dict['instrument_source_str']
}

return batch

def __len__(self) -> int:
return len(self._walker)