Skip to content

Commit 142fa77

Browse files
authored
Onnx benchmark (#856)
* Add script for running onnx benchmark * Add documentation of onnx benchmark to performance section * Add support for using ctrl-C to stop the running and exit gracefully.
1 parent 862fdae commit 142fa77

File tree

4 files changed

+219
-4
lines changed

4 files changed

+219
-4
lines changed

docs/performance.md

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
To get an early sense of what RedisAI is capable of, you can test it with:
44
- [`redis-benchmark`](https://redis.io/topics/benchmarks): Redis includes the redis-benchmark utility that simulates running commands done by N clients at the same time sending M total queries (it is similar to the Apache's ab utility).
55

6-
- [`memtier_benchmark`](https://github.com/RedisLabs/memtier_benchmark): from [Redis Labs](https://redislabs.com/) is a NoSQL Redis and Memcache traffic generation and benchmarking tool.
6+
- [`memtier_benchmark`](https://github.com/RedisLabs/memtier_benchmark): from [Redis](https://redislabs.com/) is a NoSQL Redis and Memcache traffic generation and benchmarking tool.
77

8-
- [`aibench`](https://github.com/RedisAI/aibench): a collection of Go programs that are used to generate datasets and then benchmark the inference performance of various Model Servers.
8+
- `onnx_benchmark`: a quick tool that benchmarks the inference performance of ONNXRuntime backend for different model sizes.
9+
10+
- [`aibench`](https://github.com/RedisAI/aibench): a collection of Go programs that are used to generate datasets and then benchmark the inference performance of various Model Servers.
911

1012

1113
This page is intended to provide clarity on how to obtain the benchmark numbers and links to the most recent results. We encourage developers, data scientists, and architects to run these benchmarks for themselves on their particular hardware, datasets, and Model Servers and pull request this documentation with links for the actual numbers.
@@ -65,6 +67,27 @@ The following example will:
6567
memtier_benchmark --clients 50 --threads 4 --requests 10000 --pipeline 1 --json-out-file results.json --command "AI.MODELEXECUTE model_key INPUTS input_count input1 ... OUTPUTS output_count output1 ..." --command "AI.SCRIPTEXECUTE script_key entry_point INPUTS input_count input1 ... OUTPUTS output_count output1 ..."
6668
```
6769

70+
## Using onnx_benchmark
71+
72+
`onnx_benchmark` is a simple python script that is used for loading and benchmarking RedisAI+ONNXRuntime performance on CPU, using a single shard. It uses the following 3 renowned models:
73+
1. “small" model - [mnist](https://en.wikipedia.org/wiki/MNIST_database) (26.5 KB)
74+
2. "medium" model - [inception v2](https://towardsdatascience.com/a-simple-guide-to-the-versions-of-the-inception-network-7fc52b863202) (45 MB)
75+
3. "large" model - [bert-base-cased](https://huggingface.co/bert-base-cased) (433 MB)
76+
77+
To simulate a situation where the memory consumption is high from the beginning, the script is loading mnist model under 50 different keys, inception model under 20 different keys and bert model (once).
78+
Then, it will execute parallel and sequential inference sessions of all 3 models, and will print the performance results to the screen.
79+
80+
The script can receive the following arguments as inputs:
81+
- `--num_threads` The number of RedisAI working threads that can execute sessions in parallel. Default value: 1.
82+
- `--num_parallel_clients` The number of parallel clients that send consecutive run requests per model. Default value: 20.
83+
- `--num_runs_mnist` The number of requests per client that is running mnist run sessions. Default value: 500
84+
- `--num_runs_inception` The number of requests per client that is running inception run sessions. Default value: 50
85+
- `--num_runs_bert` The number of requests per client that is running bert run sessions. Default value: 5
86+
87+
To run the benchmark, first you should build RedisAI for CPU as described in the [quick start](quickstart.md) section. The following command will run `onnx_benchmark` from RedisAI root directory (using the default arguments):
88+
89+
```python3 tests/flow/onnx_benchmark.py --num_threads 1 --num_parallel_clients 20 --num_runs_mnist 500 --num_runs_inception 50 --num_runs_bert 5```
90+
6891
## Using aibench
6992

7093
_AIBench_ is a collection of Go programs that are used to generate datasets and then benchmark the inference performance of various Model Servers. The intent is to make the AIBench extensible so that a variety of use cases and Model Servers can be included and benchmarked.

tests/flow/includes.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@
2525
TEST_ONNX = os.environ.get("TEST_ONNX") != "0" and os.environ.get("WITH_ORT") != "0"
2626
COV = os.environ.get("COV") != "0" and os.environ.get("COV") != "0"
2727
DEVICE = os.environ.get('DEVICE', 'CPU').upper().encode('utf-8', 'ignore').decode('utf-8')
28+
print(f'\nRunning inference sessions on {DEVICE}\n')
2829
VALGRIND = os.environ.get("VALGRIND") == "1"
29-
print("Running tests on {}\n".format(DEVICE))
30-
print("Using a max of {} iterations per test\n".format(MAX_ITERATIONS))
3130
# change this to make inference tests longer
3231
MAX_TRANSACTIONS=100
3332

tests/flow/onnx_benchmark.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import os
2+
3+
from RLTest import Env
4+
from includes import *
5+
import shutil
6+
import argparse
7+
import signal
8+
from redis import RedisError
9+
10+
terminate_flag = 0
11+
parent_pid = os.getpid()
12+
13+
14+
# this should capture user SIGINT signals (such as keyboard ctrl-c). Since we are using multi-processing,
15+
# this handler will be inherited by all the running processes. Note that every process will get the signal,
16+
# as all of them are at the same group.
17+
def handler(signum, frame):
18+
global terminate_flag
19+
terminate_flag = 1
20+
global parent_pid
21+
if os.getpid() == parent_pid: # print it only once
22+
print("\nReceived user interrupt. Shutting down...")
23+
24+
25+
def _exit():
26+
# remove the logs that were auto generated by redis
27+
shutil.rmtree('logs', ignore_errors=True)
28+
print("from exit\n")
29+
sys.exit(1)
30+
31+
32+
def run_benchmark(env, num_runs_mnist, num_runs_inception, num_runs_bert, num_parallel_clients):
33+
global terminate_flag
34+
con = get_connection(env, '{1}')
35+
36+
print("Loading ONNX models...")
37+
model_pb = load_file_content('mnist.onnx')
38+
sample_raw = load_file_content('one.raw')
39+
inception_pb = load_file_content('inception-v2-9.onnx')
40+
_, _, _, _, img = load_mobilenet_v2_test_data()
41+
bert_pb = load_file_content('bert-base-cased.onnx')
42+
bert_in_data = np.random.randint(-2, 1, size=(10, 100), dtype=np.int64)
43+
44+
for i in range(50):
45+
if terminate_flag == 1:
46+
_exit()
47+
ret = con.execute_command('AI.MODELSTORE', 'mnist{1}'+str(i), 'ONNX', DEVICE, 'BLOB', model_pb)
48+
env.assertEqual(ret, b'OK')
49+
con.execute_command('AI.TENSORSET', 'mnist_in{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw)
50+
51+
for i in range(20):
52+
if terminate_flag == 1:
53+
_exit()
54+
ret = con.execute_command('AI.MODELSTORE', 'inception{1}'+str(i), 'ONNX', DEVICE, 'BLOB', inception_pb)
55+
env.assertEqual(ret, b'OK')
56+
57+
backends_info = get_info_section(con, 'backends_info')
58+
print(f'Done. ONNX memory consumption is: {backends_info["ai_onnxruntime_memory"]} bytes')
59+
60+
ret = con.execute_command('AI.TENSORSET', 'inception_in{1}', 'FLOAT', 1, 3, 224, 224, 'BLOB', img.tobytes())
61+
env.assertEqual(ret, b'OK')
62+
ret = con.execute_command('AI.MODELSTORE', 'bert{1}', 'ONNX', DEVICE, 'BLOB', bert_pb)
63+
env.assertEqual(ret, b'OK')
64+
ret = con.execute_command('AI.TENSORSET', 'bert_in{1}', 'INT64', 10, 100, 'BLOB', bert_in_data.tobytes())
65+
env.assertEqual(ret, b'OK')
66+
67+
def run_parallel_onnx_sessions(con, model, input, num_runs):
68+
for _ in range(num_runs):
69+
if terminate_flag == 1:
70+
return
71+
# If the user is terminating the benchmark, redis-server will receive a termination signal as well, and
72+
# RedisError exception will thrown (and caught)
73+
try:
74+
if model == 'bert{1}':
75+
ret = con.execute_command('AI.MODELEXECUTE', model, 'INPUTS', 3, input, input, input,
76+
'OUTPUTS', 2, 'res{1}', 'res2{1}')
77+
else:
78+
ret = con.execute_command('AI.MODELEXECUTE', model, 'INPUTS', 1, input, 'OUTPUTS', 1, 'res{1}')
79+
env.assertEqual(ret, b'OK')
80+
except RedisError:
81+
return
82+
83+
def run_mnist():
84+
run_test_multiproc(env, '{1}', num_parallel_clients, run_parallel_onnx_sessions,
85+
('mnist{1}0', 'mnist_in{1}', num_runs_mnist))
86+
87+
def run_bert():
88+
run_test_multiproc(env, '{1}', num_parallel_clients, run_parallel_onnx_sessions,
89+
('bert{1}', 'bert_in{1}', num_runs_bert))
90+
91+
# run only mnist
92+
mnist_total_requests_count = num_runs_mnist*num_parallel_clients
93+
print(f'\nRunning {num_runs_mnist} consecutive executions of mnist from {num_parallel_clients} parallel clients...')
94+
start_time = time.time()
95+
run_test_multiproc(env, '{1}', num_parallel_clients, run_parallel_onnx_sessions,
96+
('mnist{1}0', 'mnist_in{1}', num_runs_mnist))
97+
if terminate_flag == 1:
98+
_exit()
99+
print(f'Done. Total execution time for {mnist_total_requests_count} requests: {time.time()-start_time} seconds')
100+
mnist_time = con.execute_command('AI.INFO', 'mnist{1}0')[11]
101+
print("Average serving time per mnist run session is: {} seconds"
102+
.format(float(mnist_time)/1000000/mnist_total_requests_count))
103+
104+
# run only inception
105+
inception_total_requests_count = num_runs_inception*num_parallel_clients
106+
print(f'\nRunning {num_runs_inception} consecutive executions of inception from {num_parallel_clients} parallel clients...')
107+
start_time = time.time()
108+
run_test_multiproc(env, '{1}', num_parallel_clients, run_parallel_onnx_sessions,
109+
('inception{1}0', 'inception_in{1}', num_runs_inception))
110+
if terminate_flag == 1:
111+
_exit()
112+
print(f'Done. Total execution time for {inception_total_requests_count} requests: {time.time()-start_time} seconds')
113+
inception_time = con.execute_command('AI.INFO', 'inception{1}0')[11]
114+
print("Average serving time per inception run session is: {} seconds"
115+
.format(float(inception_time)/1000000/inception_total_requests_count))
116+
117+
# run only bert
118+
bert_total_requests_count = num_runs_bert*num_parallel_clients
119+
print(f'\nRunning {num_runs_bert} consecutive executions of bert from {num_parallel_clients} parallel clients...')
120+
start_time = time.time()
121+
run_test_multiproc(env, '{1}', num_parallel_clients, run_parallel_onnx_sessions, ('bert{1}', 'bert_in{1}', num_runs_bert))
122+
if terminate_flag == 1:
123+
_exit()
124+
print(f'Done. Total execution time for {bert_total_requests_count} requests: {time.time()-start_time} seconds')
125+
bert_time = con.execute_command('AI.INFO', 'bert{1}')[11]
126+
print("Average server time per bert run session is: {} seconds"
127+
.format(float(bert_time)/1000000/bert_total_requests_count))
128+
129+
con.execute_command('AI.INFO', 'mnist{1}0', 'RESETSTAT')
130+
con.execute_command('AI.INFO', 'inception{1}0', 'RESETSTAT')
131+
con.execute_command('AI.INFO', 'bert{1}', 'RESETSTAT')
132+
133+
# run all 3 models in parallel
134+
total_requests_count = mnist_total_requests_count+inception_total_requests_count+bert_total_requests_count
135+
print(f'\nRunning requests for all 3 models from {3*num_parallel_clients} parallel clients...')
136+
start_time = time.time()
137+
t = threading.Thread(target=run_mnist)
138+
t.start()
139+
t2 = threading.Thread(target=run_bert)
140+
t2.start()
141+
run_test_multiproc(env, '{1}', num_parallel_clients, run_parallel_onnx_sessions,
142+
('inception{1}0', 'inception_in{1}', num_runs_inception))
143+
t.join()
144+
t2.join()
145+
if terminate_flag == 1:
146+
_exit()
147+
print(f'Done. Total execution time for {total_requests_count} requests: {time.time()-start_time} seconds')
148+
mnist_info = con.execute_command('AI.INFO', 'mnist{1}0')[11]
149+
inception_info = con.execute_command('AI.INFO', 'inception{1}0')[11]
150+
bert_info = con.execute_command('AI.INFO', 'bert{1}')[11]
151+
total_time = mnist_info+inception_info+bert_info
152+
print("Average serving time per run session is: {} seconds"
153+
.format(float(total_time)/1000000/total_requests_count))
154+
155+
156+
if __name__ == '__main__':
157+
158+
# set a handler for user interrupt signal
159+
signal.signal(signal.SIGINT, handler)
160+
161+
# parse command line arguments
162+
parser = argparse.ArgumentParser()
163+
parser.add_argument("--num_threads", default='1',
164+
help='The number of RedisAI working threads that can execute sessions in parallel')
165+
parser.add_argument("--num_runs_mnist", type=int, default=500,
166+
help='The number of requests per client that is running mnist run sessions')
167+
parser.add_argument("--num_runs_inception", type=int, default=50,
168+
help='The number of requests per client that is running inception run sessions')
169+
parser.add_argument("--num_runs_bert", type=int, default=5,
170+
help='The number of requests per client that is running bert run sessions')
171+
parser.add_argument("--num_parallel_clients", type=int, default=20,
172+
help='The number of parallel clients that send consecutive run requests per model')
173+
args = parser.parse_args()
174+
175+
terminate_flag = 0
176+
print(f'Running ONNX benchmark on RedisAI, using {args.num_threads} working threads')
177+
env = Env(module='install-cpu/redisai.so',
178+
moduleArgs='MODEL_EXECUTION_TIMEOUT 50000 THREADS_PER_QUEUE '+args.num_threads, logDir='logs')
179+
180+
# If the user is terminating the benchmark, redis-server will receive a termination signal as well, and
181+
# RedisError exception will thrown (and caught)
182+
try:
183+
run_benchmark(env, num_runs_mnist=args.num_runs_mnist, num_runs_inception=args.num_runs_inception,
184+
num_runs_bert=args.num_runs_bert, num_parallel_clients=args.num_parallel_clients)
185+
env.stop()
186+
except RedisError as e:
187+
pass
188+
finally:
189+
# remove the logs that were auto generated by redis
190+
shutil.rmtree('logs', ignore_errors=True)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:c8b29b06415e08f3d0de97e47ec94ccd6ce6ed52cef0cc1202bb13b3cdff4d45
3+
size 433311846

0 commit comments

Comments
 (0)