Skip to content

Commit 0a9d0f9

Browse files
authored
Aof rewrite fix and test (#754)
* Fix aof rewrite callback for model type, starting to test it (for TF). Currently looks like something is not working properly. Note: Running properly with slaves required changes in RLTest. * Extend tests for every backend and config.
1 parent 9c7d6db commit 0a9d0f9

File tree

2 files changed

+203
-33
lines changed

2 files changed

+203
-33
lines changed

src/serialization/AOF/rai_aof_rewrite.c

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,10 @@ void RAI_AOFRewriteModel(RedisModuleIO *aof, RedisModuleString *key, void *value
4141
return;
4242
}
4343

44-
// AI.MODELSET model_key backend device [INPUTS name1 name2 ... OUTPUTS name1 name2 ...]
45-
// model_blob
46-
47-
RedisModuleString **inputs_ = array_new(RedisModuleString *, model->ninputs);
48-
RedisModuleString **outputs_ = array_new(RedisModuleString *, model->noutputs);
49-
50-
RedisModuleCtx *ctx = RedisModule_GetContextFromIO(aof);
51-
52-
for (size_t i = 0; i < model->ninputs; i++) {
53-
inputs_ = array_append(
54-
inputs_, RedisModule_CreateString(ctx, model->inputs[i], strlen(model->inputs[i])));
55-
}
56-
57-
for (size_t i = 0; i < model->noutputs; i++) {
58-
outputs_ = array_append(
59-
outputs_, RedisModule_CreateString(ctx, model->outputs[i], strlen(model->outputs[i])));
60-
}
44+
// AI.MODELSTORE model_key backend device [TAG tag]
45+
// [BATCHSIZE n [MINBATCHSIZE m [MINBATCHTIMEOUT t]]]
46+
// [INPUTS <input_count> name1 name2 ... OUTPUTS <output_count> name1 name2 ...]
47+
// BLOB model_blob
6148

6249
long long chunk_size = getModelChunkSize();
6350
const size_t n_chunks = len / chunk_size + 1;
@@ -66,7 +53,7 @@ void RAI_AOFRewriteModel(RedisModuleIO *aof, RedisModuleString *key, void *value
6653
for (size_t i = 0; i < n_chunks; i++) {
6754
size_t chunk_len = i < n_chunks - 1 ? chunk_size : len % chunk_size;
6855
buffers_ = array_append(buffers_,
69-
RedisModule_CreateString(ctx, buffer + i * chunk_size, chunk_len));
56+
RedisModule_CreateString(NULL, buffer + i * chunk_size, chunk_len));
7057
}
7158

7259
if (buffer) {
@@ -75,29 +62,54 @@ void RAI_AOFRewriteModel(RedisModuleIO *aof, RedisModuleString *key, void *value
7562

7663
const char *backendstr = RAI_BackendName(model->backend);
7764

78-
RedisModule_EmitAOF(aof, "AI.MODELSET", "slccclclcvcvcv", key, backendstr, model->devicestr,
79-
model->tag, "BATCHSIZE", model->opts.batchsize, "MINBATCHSIZE",
80-
model->opts.minbatchsize, "INPUTS", inputs_, model->ninputs, "OUTPUTS",
81-
outputs_, model->noutputs, "BLOB", buffers_, n_chunks);
65+
if (model->backend != RAI_BACKEND_TENSORFLOW) {
66+
67+
RedisModule_EmitAOF(aof, "AI.MODELSTORE", "scccsclclclcv", key, backendstr,
68+
model->devicestr, "TAG", model->tag, "BATCHSIZE", model->opts.batchsize,
69+
"MINBATCHSIZE", model->opts.minbatchsize, "MINBATCHTIMEOUT",
70+
model->opts.minbatchtimeout, "BLOB", buffers_, n_chunks);
71+
} else {
72+
// For TF backend, the command should contain INPUTS and OUTPUTS names.
73+
// Create RedisModuleString* arrays from the char* arrays, so we can send a proper vector
74+
// to RedisModule_EmitAOF.
75+
array_new_on_stack(RedisModuleString *, 5, inputs_);
76+
array_new_on_stack(RedisModuleString *, 5, outputs_);
77+
78+
for (size_t i = 0; i < model->ninputs; i++) {
79+
inputs_ = array_append(inputs_, RedisModule_CreateString(NULL, model->inputs[i],
80+
strlen(model->inputs[i])));
81+
}
82+
for (size_t i = 0; i < model->noutputs; i++) {
83+
outputs_ = array_append(outputs_, RedisModule_CreateString(NULL, model->outputs[i],
84+
strlen(model->outputs[i])));
85+
}
8286

83-
for (size_t i = 0; i < model->ninputs; i++) {
84-
RedisModule_FreeString(ctx, inputs_[i]);
85-
}
86-
array_free(inputs_);
87+
RedisModule_EmitAOF(aof, "AI.MODELSTORE", "scccsclclclclvclvcv", key, backendstr,
88+
model->devicestr, "TAG", model->tag, "BATCHSIZE", model->opts.batchsize,
89+
"MINBATCHSIZE", model->opts.minbatchsize, "MINBATCHTIMEOUT",
90+
model->opts.minbatchtimeout, "INPUTS", model->ninputs, inputs_,
91+
model->ninputs, "OUTPUTS", model->noutputs, outputs_, model->noutputs,
92+
"BLOB", buffers_, n_chunks);
8793

88-
for (size_t i = 0; i < model->noutputs; i++) {
89-
RedisModule_FreeString(ctx, outputs_[i]);
94+
for (size_t i = 0; i < model->ninputs; i++) {
95+
RedisModule_FreeString(NULL, inputs_[i]);
96+
}
97+
array_free(inputs_);
98+
99+
for (size_t i = 0; i < model->noutputs; i++) {
100+
RedisModule_FreeString(NULL, outputs_[i]);
101+
}
102+
array_free(outputs_);
90103
}
91-
array_free(outputs_);
92104

93105
for (size_t i = 0; i < n_chunks; i++) {
94-
RedisModule_FreeString(ctx, buffers_[i]);
106+
RedisModule_FreeString(NULL, buffers_[i]);
95107
}
96108
array_free(buffers_);
97109
}
98110

99111
void RAI_AOFRewriteScript(RedisModuleIO *aof, RedisModuleString *key, void *value) {
100112
RAI_Script *script = (RAI_Script *)value;
101-
RedisModule_EmitAOF(aof, "AI.SCRIPTSET", "scccc", key, script->devicestr, script->tag, "SOURCE",
102-
script->scriptdef);
113+
RedisModule_EmitAOF(aof, "AI.SCRIPTSET", "sccscc", key, script->devicestr, "TAG", script->tag,
114+
"SOURCE", script->scriptdef);
103115
}

tests/flow/test_serializations.py

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def torch_script_run(env, script_key):
4040

4141
con.execute_command('AI.TENSORSET', 'b{1}', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
4242

43-
con.execute_command('AI.SCRIPTRUN', script_key, 'bar', 'INPUTS', 'a{1}', 'b{1}', 'OUTPUTS', 'c{1}')
43+
con.execute_command('AI.SCRIPTEXECUTE', script_key, 'bar', 'KEYS', 1, '{1}', 'INPUTS', 2, 'a{1}', 'b{1}',
44+
'OUTPUTS', 1, 'c{1}')
4445

4546
ensureSlaveSynced(con, env)
4647

@@ -216,3 +217,160 @@ def test_v2_tensor(self):
216217
self.env.assertEqual([tensor_type, tensor_shape], [b"INT32", [2, 1]])
217218
values = con.execute_command('AI.TENSORGET', key_name, 'VALUES')
218219
self.env.assertEqual(values, [1, 2])
220+
221+
222+
class TestAofRewrite:
223+
224+
def __init__(self):
225+
self.env = Env(useAof=True)
226+
227+
def test_aof_rewrite_tf_model(self):
228+
key_name = "tf_graph{1}"
229+
con = self.env.getConnection()
230+
tf_model = load_file_content("graph.pb")
231+
con.execute_command('AI.MODELSTORE', key_name, 'TF', 'CPU', 'TAG', 'TF_GRAPH', 'batchsize', 4, 'minbatchsize', 2,
232+
'minbatchtimeout', 1000, 'INPUTS', 2, 'a', 'b', 'OUTPUTS', 1, 'mul', 'BLOB', tf_model)
233+
234+
# Redis should save the stored model by calling the AOF rewrite callback and then reload from AOF.
235+
self.env.restartAndReload()
236+
_, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _ , inputs, _, outputs, _, minbatchtimeout\
237+
= con.execute_command("AI.MODELGET", key_name, "META")
238+
self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, minbatchtimeout, inputs, outputs],
239+
[b"TF", b"CPU", b"TF_GRAPH", 4, 2, 1000, [b"a", b"b"], [b"mul"]])
240+
tf_model_run(self.env, key_name)
241+
242+
# Reinsert the model (without minbatchtimeout)
243+
con.execute_command('AI.MODELSTORE', key_name, 'TF', 'CPU', 'TAG', 'TF_GRAPH1', 'batchsize', 4, 'minbatchsize', 2,
244+
'INPUTS', 2, 'a', 'b', 'OUTPUTS', 1, 'mul', 'BLOB', tf_model)
245+
# Redis should save the stored model by calling the AOF rewrite callback and then reload from AOF.
246+
self.env.restartAndReload()
247+
_, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _ , inputs, _, outputs, _, minbatchtimeout\
248+
= con.execute_command("AI.MODELGET", key_name, "META")
249+
self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, minbatchtimeout, inputs, outputs],
250+
[b"TF", b"CPU", b"TF_GRAPH1", 4, 2, 0, [b"a", b"b"], [b"mul"]])
251+
252+
# Reinsert the model (without minbatch)
253+
con.execute_command('AI.MODELSTORE', key_name, 'TF', 'CPU', 'TAG', 'TF_GRAPH2', 'batchsize', 4,
254+
'INPUTS', 2, 'a', 'b', 'OUTPUTS', 1, 'mul', 'BLOB', tf_model)
255+
# Redis should save the stored model by calling the AOF rewrite callback and then reload from AOF.
256+
self.env.restartAndReload()
257+
_, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _ , inputs, _, outputs, _, minbatchtimeout \
258+
= con.execute_command("AI.MODELGET", key_name, "META")
259+
self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, minbatchtimeout, inputs, outputs],
260+
[b"TF", b"CPU", b"TF_GRAPH2", 4, 0, 0, [b"a", b"b"], [b"mul"]])
261+
262+
# Reinsert the model (without batching)
263+
con.execute_command('AI.MODELSTORE', key_name, 'TF', 'CPU', 'TAG', 'TF_GRAPH3',
264+
'INPUTS', 2, 'a', 'b', 'OUTPUTS', 1, 'mul', 'BLOB', tf_model)
265+
# Redis should save the stored model by calling the AOF rewrite callback and then reload from AOF.
266+
self.env.restartAndReload()
267+
_, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _ , inputs, _, outputs, _, minbatchtimeout \
268+
= con.execute_command("AI.MODELGET", key_name, "META")
269+
self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, minbatchtimeout, inputs, outputs],
270+
[b"TF", b"CPU", b"TF_GRAPH3", 0, 0, 0, [b"a", b"b"], [b"mul"]])
271+
272+
def test_aof_rewrite_torch_model(self):
273+
key_name = "pt-minimal{1}"
274+
con = self.env.getConnection()
275+
torch_model = load_file_content("pt-minimal.pt")
276+
con.execute_command('AI.MODELSTORE', key_name, 'TORCH', 'CPU', 'TAG', 'PT_MINIMAL', 'batchsize', 4, 'minbatchsize', 2,
277+
'minbatchtimeout', 1000, 'BLOB', torch_model)
278+
279+
# Redis should save the stored model by calling the AOF rewrite callback and then reload from AOF.
280+
self.env.restartAndReload()
281+
_, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _ , inputs, _, outputs, _, minbatchtimeout\
282+
= con.execute_command("AI.MODELGET", key_name, "META")
283+
self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, minbatchtimeout, inputs, outputs],
284+
[b"TORCH", b"CPU", b"PT_MINIMAL", 4, 2, 1000, [b"a", b"b"], [b'']])
285+
torch_model_run(self.env, key_name)
286+
287+
# Reinsert the model (without minbatchtimeout)
288+
con.execute_command('AI.MODELSTORE', key_name, 'TORCH', 'CPU', 'TAG', 'PT_MINIMAL1', 'batchsize', 4, 'minbatchsize', 2,
289+
'BLOB', torch_model)
290+
self.env.restartAndReload()
291+
_, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _ , inputs, _, outputs, _, minbatchtimeout \
292+
= con.execute_command("AI.MODELGET", key_name, "META")
293+
self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, minbatchtimeout, inputs, outputs],
294+
[b"TORCH", b"CPU", b"PT_MINIMAL1", 4, 2, 0, [b"a", b"b"], [b'']])
295+
296+
# Reinsert the model (without minbatch)
297+
con.execute_command('AI.MODELSTORE', key_name, 'TORCH', 'CPU', 'TAG', 'PT_MINIMAL2', 'batchsize', 4,
298+
'BLOB', torch_model)
299+
self.env.restartAndReload()
300+
_, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _ , inputs, _, outputs, _, minbatchtimeout \
301+
= con.execute_command("AI.MODELGET", key_name, "META")
302+
self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, minbatchtimeout, inputs, outputs],
303+
[b"TORCH", b"CPU", b"PT_MINIMAL2", 4, 0, 0, [b"a", b"b"], [b'']])
304+
305+
# Reinsert the model (without batching)
306+
con.execute_command('AI.MODELSTORE', key_name, 'TORCH', 'CPU', 'TAG', 'PT_MINIMAL3',
307+
'BLOB', torch_model)
308+
self.env.restartAndReload()
309+
_, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _ , inputs, _, outputs, _, minbatchtimeout \
310+
= con.execute_command("AI.MODELGET", key_name, "META")
311+
self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, minbatchtimeout, inputs, outputs],
312+
[b"TORCH", b"CPU", b"PT_MINIMAL3", 0, 0, 0, [b"a", b"b"], [b'']])
313+
314+
def test_aof_rewrite_troch_script(self):
315+
key_name = "torch_script{1}"
316+
con = self.env.getConnection()
317+
torch_script = load_file_content("script.txt")
318+
con.execute_command('AI.SCRIPTSET', key_name, 'CPU', 'TAG', 'TORCH_SCRIPT', 'SOURCE', torch_script)
319+
320+
# Redis should save the stored script by calling the AOF rewrite callback and then reload from AOF.
321+
self.env.restartAndReload()
322+
_, device, _, tag = con.execute_command("AI.SCRIPTGET", key_name, "META")
323+
self.env.assertEqual([device, tag], [b"CPU", b"TORCH_SCRIPT"])
324+
torch_script_run(self.env, key_name)
325+
326+
def test_aof_rewrite_onnx_model(self):
327+
key_name = "linear_iris{1}"
328+
con = self.env.getConnection()
329+
onnx_model = load_file_content("linear_iris.onnx")
330+
con.execute_command('AI.MODELSTORE', key_name, 'ONNX', 'CPU', 'TAG', 'ONNX_LINEAR_IRIS', 'batchsize', 4, 'minbatchsize', 2,
331+
'minbatchtimeout', 1000, 'BLOB', onnx_model)
332+
# Redis should save the stored model by calling the AOF rewrite callback and then reload from AOF.
333+
self.env.restartAndReload()
334+
_, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _ , inputs, _, outputs, _, minbatchtimeout\
335+
= con.execute_command("AI.MODELGET", key_name, "META")
336+
self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, minbatchtimeout, inputs, outputs],
337+
[b"ONNX", b"CPU", b"ONNX_LINEAR_IRIS", 4, 2, 1000, [b'float_input'], [b'variable']])
338+
onnx_model_run(self.env, key_name)
339+
340+
# Reinsert the model (without minbatchtimeout)
341+
con.execute_command('AI.MODELSTORE', key_name, 'ONNX', 'CPU', 'TAG', 'ONNX_LINEAR_IRIS1', 'batchsize', 4,
342+
'minbatchsize', 2, 'BLOB', onnx_model)
343+
self.env.restartAndReload()
344+
_, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _ , inputs, _, outputs, _, minbatchtimeout \
345+
= con.execute_command("AI.MODELGET", key_name, "META")
346+
self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, minbatchtimeout, inputs, outputs],
347+
[b"ONNX", b"CPU", b"ONNX_LINEAR_IRIS1", 4, 2, 0, [b'float_input'], [b'variable']])
348+
349+
# Reinsert the model (without minbatch)
350+
con.execute_command('AI.MODELSTORE', key_name, 'ONNX', 'CPU', 'TAG', 'ONNX_LINEAR_IRIS2', 'batchsize', 4,
351+
'BLOB', onnx_model)
352+
self.env.restartAndReload()
353+
_, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _ , inputs, _, outputs, _, minbatchtimeout \
354+
= con.execute_command("AI.MODELGET", key_name, "META")
355+
self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, minbatchtimeout, inputs, outputs],
356+
[b"ONNX", b"CPU", b"ONNX_LINEAR_IRIS2", 4, 0, 0, [b'float_input'], [b'variable']])
357+
358+
# Reinsert the model (without batching)
359+
con.execute_command('AI.MODELSTORE', key_name, 'ONNX', 'CPU', 'TAG', 'ONNX_LINEAR_IRIS3',
360+
'BLOB', onnx_model)
361+
self.env.restartAndReload()
362+
_, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _ , inputs, _, outputs, _, minbatchtimeout \
363+
= con.execute_command("AI.MODELGET", key_name, "META")
364+
self.env.assertEqual([backend, device, tag, batchsize, minbatchsize, minbatchtimeout, inputs, outputs],
365+
[b"ONNX", b"CPU", b"ONNX_LINEAR_IRIS3", 0, 0, 0, [b'float_input'], [b'variable']])
366+
367+
def test_aof_rewrite_tensor(self):
368+
key_name = "tensor{1}"
369+
con = self.env.getConnection()
370+
con.execute_command('AI.TENSORSET', key_name, 'INT32', 2, 1, 'VALUES', 1, 2)
371+
# Redis should save the stored tensor by calling the AOF rewrite callback and then reload from AOF.
372+
self.env.restartAndReload()
373+
_, tensor_type, _, tensor_shape = con.execute_command('AI.TENSORGET', key_name, 'META')
374+
self.env.assertEqual([tensor_type, tensor_shape], [b"INT32", [2, 1]])
375+
values = con.execute_command('AI.TENSORGET', key_name, 'VALUES')
376+
self.env.assertEqual(values, [1, 2])

0 commit comments

Comments
 (0)