Skip to content

Commit 40cb927

Browse files
authored
New model encoding version (#749)
* Add new version of RDB decoding v2, currently breaking AI.MODELGET * Let AI.MODELGET return the minbatchtimeout meta data field at the end of the result array, change tests accordingly. * PR fixes + documentation of changes in AI.MODELGET
1 parent 25c3347 commit 40cb927

File tree

19 files changed

+397
-55
lines changed

19 files changed

+397
-55
lines changed

docs/commands.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ AI.MODELGET <key> [META] [BLOB]
222222
_Arguments
223223

224224
* **key**: the model's key name
225-
* **META**: will return the model's meta information on backend, device and tag
225+
* **META**: will return the model's meta information on backend, device, tag and batching parameters
226226
* **BLOB**: will return the model's blob containing the serialized model
227227

228228
_Return_
@@ -236,6 +236,7 @@ An array of alternating key-value pairs as follows:
236236
1. **MINBATCHSIZE**: The minimum size of any batch of incoming requests.
237237
1. **INPUTS**: array reply with one or more names of the model's input nodes (applicable only for TensorFlow models)
238238
1. **OUTPUTS**: array reply with one or more names of the model's output nodes (applicable only for TensorFlow models)
239+
1. **MINBATCHTIMEOUT**: The time in milliseconds for which the engine will wait before executing a request to run the model, when the number of incoming requests is lower than `MINBATCHSIZE`. When `MINBATCHTIMEOUT` is 0, the engine will not run the model before it receives at least `MINBATCHSIZE` requests.
239240
1. **BLOB**: a blob containing the serialized model (when called with the `BLOB` argument) as a String. If the size of the serialized model exceeds `MODEL_CHUNK_SIZE` (see `AI.CONFIG` command), then an array of chunks is returned. The full serialized model can be obtained by concatenating the chunks.
240241

241242
**Examples**
@@ -259,6 +260,8 @@ redis> AI.MODELGET mymodel META
259260
2) "b"
260261
13) "outputs"
261262
14) 1) "c"
263+
15) "minbatchtimeout"
264+
16) (integer) 0
262265
```
263266

264267
You can also save it to the local file 'model.ext' with [`redis-cli`](https://redis.io/topics/cli) like so:

src/redisai.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
463463
return REDISMODULE_OK;
464464
}
465465

466-
const int outentries = blob ? 16 : 14;
466+
const int outentries = blob ? 18 : 16;
467467
RedisModule_ReplyWithArray(ctx, outentries);
468468

469469
RedisModule_ReplyWithCString(ctx, "backend");
@@ -500,6 +500,9 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
500500
RedisModule_ReplyWithCString(ctx, mto->outputs[i]);
501501
}
502502

503+
RedisModule_ReplyWithCString(ctx, "minbatchtimeout");
504+
RedisModule_ReplyWithLongLong(ctx, (long)mto->opts.minbatchtimeout);
505+
503506
if (meta && blob) {
504507
RedisModule_ReplyWithCString(ctx, "blob");
505508
RAI_ReplyWithChunks(ctx, buffer, len);
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
#include "decode_v2.h"
2+
#include "assert.h"
3+
4+
/**
5+
* In case of IO errors, the default return values are:
6+
* numbers - 0
7+
* strings - null
8+
* So only when it is necessary check for IO errors.
9+
*/
10+
11+
void *RAI_RDBLoadTensor_v2(RedisModuleIO *io) {
12+
int64_t *shape = NULL;
13+
int64_t *strides = NULL;
14+
15+
DLDevice device;
16+
device.device_type = RedisModule_LoadUnsigned(io);
17+
device.device_id = RedisModule_LoadUnsigned(io);
18+
if (RedisModule_IsIOError(io))
19+
goto cleanup;
20+
21+
// For now we only support CPU tensors (except during model and script run)
22+
assert(device.device_type == kDLCPU);
23+
assert(device.device_id == 0);
24+
25+
DLDataType dtype;
26+
dtype.bits = RedisModule_LoadUnsigned(io);
27+
dtype.code = RedisModule_LoadUnsigned(io);
28+
dtype.lanes = RedisModule_LoadUnsigned(io);
29+
30+
size_t ndims = RedisModule_LoadUnsigned(io);
31+
if (RedisModule_IsIOError(io))
32+
goto cleanup;
33+
34+
shape = RedisModule_Calloc(ndims, sizeof(*shape));
35+
for (size_t i = 0; i < ndims; ++i) {
36+
shape[i] = RedisModule_LoadUnsigned(io);
37+
}
38+
39+
strides = RedisModule_Calloc(ndims, sizeof(*strides));
40+
for (size_t i = 0; i < ndims; ++i) {
41+
strides[i] = RedisModule_LoadUnsigned(io);
42+
}
43+
44+
size_t byte_offset = RedisModule_LoadUnsigned(io);
45+
46+
size_t len;
47+
char *data = RedisModule_LoadStringBuffer(io, &len);
48+
if (RedisModule_IsIOError(io))
49+
goto cleanup;
50+
51+
RAI_Tensor *ret = RAI_TensorNew();
52+
ret->tensor = (DLManagedTensor){.dl_tensor = (DLTensor){.device = device,
53+
.data = data,
54+
.ndim = ndims,
55+
.dtype = dtype,
56+
.shape = shape,
57+
.strides = strides,
58+
.byte_offset = byte_offset},
59+
.manager_ctx = NULL,
60+
.deleter = NULL};
61+
return ret;
62+
63+
cleanup:
64+
if (shape)
65+
RedisModule_Free(shape);
66+
if (strides)
67+
RedisModule_Free(strides);
68+
RedisModule_LogIOError(io, "error", "Experienced a short read while reading a tensor from RDB");
69+
return NULL;
70+
}
71+
72+
void *RAI_RDBLoadModel_v2(RedisModuleIO *io) {
73+
74+
char *devicestr = NULL;
75+
RedisModuleString *tag = NULL;
76+
size_t ninputs = 0;
77+
const char **inputs = NULL;
78+
size_t noutputs = 0;
79+
const char **outputs = NULL;
80+
char *buffer = NULL;
81+
82+
RAI_Backend backend = RedisModule_LoadUnsigned(io);
83+
devicestr = RedisModule_LoadStringBuffer(io, NULL);
84+
tag = RedisModule_LoadString(io);
85+
86+
const size_t batchsize = RedisModule_LoadUnsigned(io);
87+
const size_t minbatchsize = RedisModule_LoadUnsigned(io);
88+
const size_t minbatchtimeout = RedisModule_LoadUnsigned(io);
89+
90+
ninputs = RedisModule_LoadUnsigned(io);
91+
if (RedisModule_IsIOError(io))
92+
goto cleanup;
93+
94+
inputs = RedisModule_Alloc(ninputs * sizeof(char *));
95+
96+
for (size_t i = 0; i < ninputs; i++) {
97+
inputs[i] = RedisModule_LoadStringBuffer(io, NULL);
98+
}
99+
100+
noutputs = RedisModule_LoadUnsigned(io);
101+
if (RedisModule_IsIOError(io))
102+
goto cleanup;
103+
104+
outputs = RedisModule_Alloc(noutputs * sizeof(char *));
105+
106+
for (size_t i = 0; i < noutputs; i++) {
107+
outputs[i] = RedisModule_LoadStringBuffer(io, NULL);
108+
}
109+
110+
RAI_ModelOpts opts = {
111+
.batchsize = batchsize,
112+
.minbatchsize = minbatchsize,
113+
.minbatchtimeout = minbatchtimeout,
114+
.backends_intra_op_parallelism = getBackendsIntraOpParallelism(),
115+
.backends_inter_op_parallelism = getBackendsInterOpParallelism(),
116+
};
117+
118+
size_t len = RedisModule_LoadUnsigned(io);
119+
if (RedisModule_IsIOError(io))
120+
goto cleanup;
121+
122+
buffer = RedisModule_Alloc(len);
123+
const size_t n_chunks = RedisModule_LoadUnsigned(io);
124+
long long chunk_offset = 0;
125+
for (size_t i = 0; i < n_chunks; i++) {
126+
size_t chunk_len;
127+
char *chunk_buffer = RedisModule_LoadStringBuffer(io, &chunk_len);
128+
if (RedisModule_IsIOError(io))
129+
goto cleanup;
130+
memcpy(buffer + chunk_offset, chunk_buffer, chunk_len);
131+
chunk_offset += chunk_len;
132+
RedisModule_Free(chunk_buffer);
133+
}
134+
135+
RAI_Error err = {0};
136+
RAI_Model *model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs,
137+
outputs, buffer, len, &err);
138+
139+
if (err.code == RAI_EBACKENDNOTLOADED) {
140+
RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io);
141+
int ret = RAI_LoadDefaultBackend(ctx, backend);
142+
if (ret == REDISMODULE_ERR) {
143+
RedisModule_Log(ctx, "warning", "Could not load default backend");
144+
RAI_ClearError(&err);
145+
goto cleanup;
146+
}
147+
RAI_ClearError(&err);
148+
model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs, outputs,
149+
buffer, len, &err);
150+
}
151+
152+
if (err.code != RAI_OK) {
153+
RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io);
154+
RedisModule_Log(ctx, "warning", "%s", err.detail);
155+
RAI_ClearError(&err);
156+
goto cleanup;
157+
}
158+
159+
RedisModuleCtx *stats_ctx = RedisModule_GetContextFromIO(io);
160+
RedisModuleString *stats_keystr =
161+
RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io));
162+
163+
model->infokey = RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_MODEL, backend, devicestr, tag);
164+
165+
for (size_t i = 0; i < ninputs; i++) {
166+
RedisModule_Free((void *)inputs[i]);
167+
}
168+
RedisModule_Free(inputs);
169+
for (size_t i = 0; i < noutputs; i++) {
170+
RedisModule_Free((void *)outputs[i]);
171+
}
172+
RedisModule_Free(outputs);
173+
RedisModule_Free(buffer);
174+
RedisModule_Free(devicestr);
175+
RedisModule_FreeString(NULL, stats_keystr);
176+
RedisModule_FreeString(NULL, tag);
177+
178+
return model;
179+
180+
cleanup:
181+
if (devicestr)
182+
RedisModule_Free(devicestr);
183+
if (tag)
184+
RedisModule_FreeString(NULL, tag);
185+
if (inputs) {
186+
for (size_t i = 0; i < ninputs; i++) {
187+
RedisModule_Free((void *)inputs[i]);
188+
}
189+
RedisModule_Free(inputs);
190+
}
191+
192+
if (outputs) {
193+
for (size_t i = 0; i < noutputs; i++) {
194+
RedisModule_Free((void *)outputs[i]);
195+
}
196+
RedisModule_Free(outputs);
197+
}
198+
199+
if (buffer)
200+
RedisModule_Free(buffer);
201+
202+
RedisModule_LogIOError(io, "error", "Experienced a short read while reading a model from RDB");
203+
return NULL;
204+
}
205+
206+
void *RAI_RDBLoadScript_v2(RedisModuleIO *io) {
207+
RedisModuleString *tag = NULL;
208+
char *devicestr = NULL;
209+
char *scriptdef = NULL;
210+
RAI_Error err = {0};
211+
212+
devicestr = RedisModule_LoadStringBuffer(io, NULL);
213+
tag = RedisModule_LoadString(io);
214+
215+
size_t len;
216+
scriptdef = RedisModule_LoadStringBuffer(io, &len);
217+
if (RedisModule_IsIOError(io))
218+
goto cleanup;
219+
220+
RAI_Script *script = RAI_ScriptCreate(devicestr, tag, scriptdef, &err);
221+
222+
if (err.code == RAI_EBACKENDNOTLOADED) {
223+
RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io);
224+
int ret = RAI_LoadDefaultBackend(ctx, RAI_BACKEND_TORCH);
225+
if (ret == REDISMODULE_ERR) {
226+
RedisModule_Log(ctx, "warning", "Could not load default TORCH backend\n");
227+
RAI_ClearError(&err);
228+
goto cleanup;
229+
}
230+
RAI_ClearError(&err);
231+
script = RAI_ScriptCreate(devicestr, tag, scriptdef, &err);
232+
}
233+
234+
if (err.code != RAI_OK) {
235+
printf("ERR: %s\n", err.detail);
236+
RAI_ClearError(&err);
237+
goto cleanup;
238+
}
239+
240+
RedisModuleCtx *stats_ctx = RedisModule_GetContextFromIO(io);
241+
RedisModuleString *stats_keystr =
242+
RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io));
243+
244+
script->infokey =
245+
RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_SCRIPT, RAI_BACKEND_TORCH, devicestr, tag);
246+
247+
RedisModule_FreeString(NULL, stats_keystr);
248+
RedisModule_FreeString(NULL, tag);
249+
RedisModule_Free(devicestr);
250+
RedisModule_Free(scriptdef);
251+
return script;
252+
cleanup:
253+
if (devicestr)
254+
RedisModule_Free(devicestr);
255+
if (scriptdef)
256+
RedisModule_Free(scriptdef);
257+
if (tag)
258+
RedisModule_FreeString(NULL, tag);
259+
return NULL;
260+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#pragma once
2+
#include "serialization/serialization_include.h"
3+
4+
void *RAI_RDBLoadTensor_v2(RedisModuleIO *io);
5+
6+
void *RAI_RDBLoadModel_v2(RedisModuleIO *io);
7+
8+
void *RAI_RDBLoadScript_v2(RedisModuleIO *io);

src/serialization/RDB/decoder/decode_previous.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
#include "decode_previous.h"
22
#include "previous/v0/decode_v0.h"
3+
#include "previous/v1/decode_v1.h"
4+
35
void *Decode_PreviousTensor(RedisModuleIO *rdb, int encver) {
46
switch (encver) {
57
case 0:
68
return RAI_RDBLoadTensor_v0(rdb);
9+
case 1:
10+
return RAI_RDBLoadTensor_v1(rdb);
711
default:
812
assert(false && "Invalid encoding version");
913
}
@@ -14,6 +18,8 @@ void *Decode_PreviousModel(RedisModuleIO *rdb, int encver) {
1418
switch (encver) {
1519
case 0:
1620
return RAI_RDBLoadModel_v0(rdb);
21+
case 1:
22+
return RAI_RDBLoadModel_v1(rdb);
1723
default:
1824
assert(false && "Invalid encoding version");
1925
}
@@ -24,6 +30,8 @@ void *Decode_PreviousScript(RedisModuleIO *rdb, int encver) {
2430
switch (encver) {
2531
case 0:
2632
return RAI_RDBLoadScript_v0(rdb);
33+
case 1:
34+
return RAI_RDBLoadScript_v1(rdb);
2735
default:
2836
assert(false && "Invalid encoding version");
2937
}
File renamed without changes.

src/serialization/RDB/decoder/current/v1/decode_v1.h renamed to src/serialization/RDB/decoder/previous/v1/decode_v1.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#pragma once
2-
#include "../../../../serialization_include.h"
2+
#include "serialization/serialization_include.h"
33

44
void *RAI_RDBLoadTensor_v1(RedisModuleIO *io);
55

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include "rai_rdb_decoder.h"
2-
#include "current/v1/decode_v1.h"
2+
#include "current/v2/decode_v2.h"
33

4-
void *RAI_RDBLoadTensor(RedisModuleIO *io) { return RAI_RDBLoadTensor_v1(io); }
4+
void *RAI_RDBLoadTensor(RedisModuleIO *io) { return RAI_RDBLoadTensor_v2(io); }
55

6-
void *RAI_RDBLoadModel(RedisModuleIO *io) { return RAI_RDBLoadModel_v1(io); }
6+
void *RAI_RDBLoadModel(RedisModuleIO *io) { return RAI_RDBLoadModel_v2(io); }
77

8-
void *RAI_RDBLoadScript(RedisModuleIO *io) { return RAI_RDBLoadScript_v1(io); }
8+
void *RAI_RDBLoadScript(RedisModuleIO *io) { return RAI_RDBLoadScript_v2(io); }
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include "rai_rdb_encode.h"
2-
#include "v1/encode_v1.h"
2+
#include "v2/encode_v2.h"
33

4-
void RAI_RDBSaveTensor(RedisModuleIO *io, void *value) { RAI_RDBSaveTensor_v1(io, value); }
4+
void RAI_RDBSaveTensor(RedisModuleIO *io, void *value) { RAI_RDBSaveTensor_v2(io, value); }
55

6-
void RAI_RDBSaveModel(RedisModuleIO *io, void *value) { RAI_RDBSaveModel_v1(io, value); }
6+
void RAI_RDBSaveModel(RedisModuleIO *io, void *value) { RAI_RDBSaveModel_v2(io, value); }
77

8-
void RAI_RDBSaveScript(RedisModuleIO *io, void *value) { RAI_RDBSaveScript_v1(io, value); }
8+
void RAI_RDBSaveScript(RedisModuleIO *io, void *value) { RAI_RDBSaveScript_v2(io, value); }

src/serialization/RDB/encoder/v1/encode_v1.h

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)