From 5f9717c81bbf37333e9ce9301297759c1e37ff05 Mon Sep 17 00:00:00 2001 From: Andrew Gibiansky Date: Fri, 10 Apr 2020 14:33:52 -0700 Subject: [PATCH 1/2] Do not load model if it's already loaded Not only is this slow, this will break on GPUs because TensorFlow does not release its allocated memory. So if you try to run a catalog with multiple files on a GPU, the first file will succeed, and the second file will give you an OOM error. --- align/align.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/align/align.py b/align/align.py index d53f7c6..56507c1 100644 --- a/align/align.py +++ b/align/align.py @@ -63,9 +63,10 @@ def read_script(script_path): def init_stt(output_graph_path, lm_path, trie_path): global model - model = deepspeech.Model(output_graph_path, BEAM_WIDTH) - model.enableDecoderWithLM(lm_path, trie_path, LM_ALPHA, LM_BETA) - logging.debug('Process {}: Loaded models'.format(os.getpid())) + if model is None: + model = deepspeech.Model(output_graph_path, BEAM_WIDTH) + model.enableDecoderWithLM(lm_path, trie_path, LM_ALPHA, LM_BETA) + logging.debug('Process {}: Loaded models'.format(os.getpid())) def stt(sample): From a6400b6692608eeb8b77f78670c043a1c11b1eaf Mon Sep 17 00:00:00 2001 From: Andrew Gibiansky Date: Fri, 10 Apr 2020 14:46:08 -0700 Subject: [PATCH 2/2] Wait for pool to close to deallocate resources Otherwise using this on GPU fails, because TensorFlow constantly allocates the entire heap and doesn't let go of it --- align/align.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/align/align.py b/align/align.py index 56507c1..8949b77 100644 --- a/align/align.py +++ b/align/align.py @@ -499,10 +499,10 @@ def pre_filter(): samples = list(progress(pre_filter(), desc='VAD splitting')) - pool = multiprocessing.Pool(initializer=init_stt, + with multiprocessing.Pool(initializer=init_stt, initargs=(output_graph_path, lm_path, trie_path), - processes=args.stt_workers) - transcripts = list(progress(pool.imap(stt, samples), desc='Transcribing', total=len(samples))) + processes=args.stt_workers) as pool: + transcripts = list(progress(pool.imap(stt, samples), desc='Transcribing', total=len(samples))) fragments = [] for time_start, time_end, segment_transcript in transcripts: