forked from Ghadjeres/DeepBach
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmusescore_flask_server.py
More file actions
328 lines (265 loc) · 10.4 KB
/
musescore_flask_server.py
File metadata and controls
328 lines (265 loc) · 10.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
import math
import os
import pickle
import click
import tempfile
from glob import glob
import subprocess
import music21
import numpy as np
from flask import Flask, request, make_response, jsonify
from music21 import musicxml, converter
from tqdm import tqdm
import torch
import logging
from logging import handlers as logging_handlers
from DatasetManager.chorale_dataset import ChoraleDataset
from DatasetManager.dataset_manager import DatasetManager
from DatasetManager.metadata import FermataMetadata, TickMetadata, KeyMetadata
from DeepBach.model_manager import DeepBach
UPLOAD_FOLDER = '/tmp'
ALLOWED_EXTENSIONS = {'xml', 'mxl', 'mid', 'midi'}
app = Flask(__name__)
deepbach = None
_tensor_metadata = None
_num_iterations = None
_sequence_length_ticks = None
_ticks_per_quarter = None
_tensor_sheet = None
# TODO use this parameter or extract it from the metadata somehow
timesignature = music21.meter.TimeSignature('4/4')
# generation parameters
# todo put in click?
batch_size_per_voice = 8
metadatas = [
FermataMetadata(),
TickMetadata(subdivision=_ticks_per_quarter),
KeyMetadata()
]
response_headers = {"Content-Type": "text/html",
"charset": "utf-8"
}
@click.command()
@click.option('--note_embedding_dim', default=20,
help='size of the note embeddings')
@click.option('--meta_embedding_dim', default=20,
help='size of the metadata embeddings')
@click.option('--num_layers', default=2,
help='number of layers of the LSTMs')
@click.option('--lstm_hidden_size', default=256,
help='hidden size of the LSTMs')
@click.option('--dropout_lstm', default=0.5,
help='amount of dropout between LSTM layers')
@click.option('--dropout_lstm', default=0.5,
help='amount of dropout between LSTM layers')
@click.option('--linear_hidden_size', default=256,
help='hidden size of the Linear layers')
@click.option('--num_iterations', default=100,
help='number of parallel pseudo-Gibbs sampling iterations (for a single update)')
@click.option('--sequence_length_ticks', default=64,
help='length of the generated chorale (in ticks)')
@click.option('--ticks_per_quarter', default=4,
help='number of ticks per quarter note')
@click.option('--port', default=5000,
help='port to serve on')
def init_app(note_embedding_dim,
meta_embedding_dim,
num_layers,
lstm_hidden_size,
dropout_lstm,
linear_hidden_size,
num_iterations,
sequence_length_ticks,
ticks_per_quarter,
port
):
global metadatas
global _sequence_length_ticks
global _num_iterations
global _ticks_per_quarter
global bach_chorales_dataset
_ticks_per_quarter = ticks_per_quarter
_sequence_length_ticks = sequence_length_ticks
_num_iterations = num_iterations
dataset_manager = DatasetManager()
chorale_dataset_kwargs = {
'voice_ids': [0, 1, 2, 3],
'metadatas': metadatas,
'sequences_size': 8,
'subdivision': 4
}
_bach_chorales_dataset: ChoraleDataset = dataset_manager.get_dataset(
name='bach_chorales',
**chorale_dataset_kwargs
)
bach_chorales_dataset = _bach_chorales_dataset
assert sequence_length_ticks % bach_chorales_dataset.subdivision == 0
global deepbach
deepbach = DeepBach(
dataset=bach_chorales_dataset,
note_embedding_dim=note_embedding_dim,
meta_embedding_dim=meta_embedding_dim,
num_layers=num_layers,
lstm_hidden_size=lstm_hidden_size,
dropout_lstm=dropout_lstm,
linear_hidden_size=linear_hidden_size
)
deepbach.load()
deepbach.cuda()
# launch the script
# use threaded=True to fix Chrome/Chromium engine hanging on requests
# [https://stackoverflow.com/a/30670626]
local_only = False
if local_only:
# accessible only locally:
app.run(threaded=True)
else:
# accessible from outside:
app.run(host='0.0.0.0', port=port, threaded=True)
def get_fermatas_tensor(metadata_tensor: torch.Tensor) -> torch.Tensor:
"""
Extract the fermatas tensor from a metadata tensor
"""
fermatas_index = [m.__class__ for m in metadatas].index(
FermataMetadata().__class__)
# fermatas are shared across all voices so we only consider the first voice
soprano_voice_metadata = metadata_tensor[0]
# `soprano_voice_metadata` has shape
# `(sequence_duration, len(metadatas + 1))` (accouting for the voice
# index metadata)
# Extract fermatas for all steps
return soprano_voice_metadata[:, fermatas_index]
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def compose_from_scratch():
"""
Return a new, generated sheet
Usage:
- Request: empty, generation is done in an unconstrained fashion
- Response: a sheet, MusicXML
"""
global deepbach
global _sequence_length_ticks
global _num_iterations
global _tensor_sheet
global _tensor_metadata
# Use more iterations for the initial generation step
# FIXME hardcoded 4/4 time-signature
num_measures_generation = math.floor(_sequence_length_ticks /
deepbach.dataset.subdivision)
initial_num_iterations = math.floor(_num_iterations * num_measures_generation
/ 3) # HACK hardcoded reduction
(generated_sheet, _tensor_sheet, _tensor_metadata) = (
deepbach.generation(num_iterations=initial_num_iterations,
sequence_length_ticks=_sequence_length_ticks)
)
return generated_sheet
@app.route('/compose', methods=['POST'])
def compose():
global deepbach
global _num_iterations
global _sequence_length_ticks
global _tensor_sheet
global _tensor_metadata
global bach_chorales_dataset
# global models
NUM_MIDI_TICKS_IN_SIXTEENTH_NOTE = 120
start_tick_selection = int(float(
request.form['start_tick']) / NUM_MIDI_TICKS_IN_SIXTEENTH_NOTE)
end_tick_selection = int(
float(request.form['end_tick']) / NUM_MIDI_TICKS_IN_SIXTEENTH_NOTE)
file_path = request.form['file_path']
root, ext = os.path.splitext(file_path)
dir = os.path.dirname(file_path)
assert ext == '.mxl'
xml_file = f'{root}.xml'
# if no selection REGENERATE and set chorale length
if start_tick_selection == 0 and end_tick_selection == 0:
generated_sheet = compose_from_scratch()
generated_sheet.write('xml', xml_file)
return sheet_to_response(generated_sheet)
else:
# --- Parse request---
# Old method: does not work because the MuseScore plugin does not export to xml but only to compressed .mxl
# with tempfile.NamedTemporaryFile(mode='wb', suffix='.xml') as file:
# print(file.name)
# xml_string = request.form['xml_string']
# file.write(xml_string)
# music21_parsed_chorale = converter.parse(file.name)
# file_path points to an mxl file: we extract it
subprocess.run(f'unzip -o {file_path} -d {dir}', shell=True)
music21_parsed_chorale = converter.parse(xml_file)
_tensor_sheet, _tensor_metadata = bach_chorales_dataset.transposed_score_and_metadata_tensors(music21_parsed_chorale, semi_tone=0)
start_voice_index = int(request.form['start_staff'])
end_voice_index = int(request.form['end_staff']) + 1
time_index_range_ticks = [start_tick_selection, end_tick_selection]
region_length = end_tick_selection - start_tick_selection
# compute batch_size_per_voice:
if region_length <= 8:
batch_size_per_voice = 2
elif region_length <= 16:
batch_size_per_voice = 4
else:
batch_size_per_voice = 8
num_total_iterations = int(_num_iterations * region_length / batch_size_per_voice)
fermatas_tensor = get_fermatas_tensor(_tensor_metadata)
# --- Generate---
(output_sheet,
_tensor_sheet,
_tensor_metadata) = deepbach.generation(
tensor_chorale=_tensor_sheet,
tensor_metadata=_tensor_metadata,
temperature=1.,
batch_size_per_voice=batch_size_per_voice,
num_iterations=num_total_iterations,
sequence_length_ticks=_sequence_length_ticks,
time_index_range_ticks=time_index_range_ticks,
fermatas=fermatas_tensor,
voice_index_range=[start_voice_index, end_voice_index],
random_init=True
)
output_sheet.write('xml', xml_file)
response = sheet_to_response(sheet=output_sheet)
return response
def get_fermatas_tensor(metadata_tensor: torch.Tensor) -> torch.Tensor:
"""
Extract the fermatas tensor from a metadata tensor
"""
fermatas_index = [m.__class__ for m in metadatas].index(
FermataMetadata().__class__)
# fermatas are shared across all voices so we only consider the first voice
soprano_voice_metadata = metadata_tensor[0]
# `soprano_voice_metadata` has shape
# `(sequence_duration, len(metadatas + 1))` (accouting for the voice
# index metadata)
# Extract fermatas for all steps
return soprano_voice_metadata[:, fermatas_index]
def sheet_to_response(sheet):
# convert sheet to xml
goe = musicxml.m21ToXml.GeneralObjectExporter(sheet)
xml_chorale_string = goe.parse()
response = make_response((xml_chorale_string, response_headers))
return response
@app.route('/test', methods=['POST', 'GET'])
def test_generation():
response = make_response(('TEST', response_headers))
if request.method == 'POST':
print(request)
return response
@app.route('/models', methods=['GET'])
def get_models():
models_list = ['Deepbach']
return jsonify(models_list)
@app.route('/current_model', methods=['POST', 'PUT'])
def current_model_update():
return 'Model is only loaded once'
@app.route('/current_model', methods=['GET'])
def current_model_get():
return 'DeepBach'
if __name__ == '__main__':
file_handler = logging_handlers.RotatingFileHandler(
'app.log', maxBytes=10000, backupCount=5)
app.logger.addHandler(file_handler)
app.logger.setLevel(logging.INFO)
init_app()