Skip to content

Commit 9c8ac4a

Browse files
committed
Merge branch 'master' into Fix_unsafe_dag_release_upon_error
2 parents 2f0ec89 + cb7f664 commit 9c8ac4a

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

tests/flow/tests_tensorflow.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -638,12 +638,12 @@ def run(name=model_name, output_name='output{1}'):
638638
'INPUTS', 'input{1}', 'OUTPUTS', output_name)
639639

640640
# Running thrice since minbatchsize = 2
641-
p1 = mp.Process(target=run)
642-
p1.start()
643-
p2 = mp.Process(target=run)
644-
p2.start()
645-
p3 = mp.Process(target=run)
646-
p3.start()
641+
# The third process will hang until termintation or until a new process will execute the model with the same properties.
642+
processes = []
643+
for i in range(3):
644+
p = mp.Process(target=run)
645+
p.start()
646+
processes.append(p)
647647

648648
time.sleep(3)
649649

@@ -655,9 +655,10 @@ def run(name=model_name, output_name='output{1}'):
655655

656656
p1b = mp.Process(target=run, args=(another_model_name, 'final1{1}'))
657657
p1b.start()
658-
659658
run(another_model_name, 'final2{1}')
660659

660+
p1b.join()
661+
661662
_, dtype, _, shape, _, data = con.execute_command('AI.TENSORGET', 'final1{1}', 'META', 'BLOB')
662663
dtype_map = {b'FLOAT': np.float32}
663664
tensor = np.frombuffer(data, dtype=dtype_map[dtype]).reshape(shape)
@@ -667,7 +668,8 @@ def run(name=model_name, output_name='output{1}'):
667668

668669
env.assertEqual(label, 'giant_panda')
669670

670-
p3.terminate()
671+
for p in processes:
672+
p.terminate()
671673

672674

673675
@skip_if_no_TF

0 commit comments

Comments
 (0)