Skip to content

Commit 9498574

Browse files
authored
Merge branch 'master' into DAGRUN_command_refactor
2 parents afcf7ef + 6a389c0 commit 9498574

29 files changed

+1172
-497
lines changed

src/CMakeLists.txt

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@ if (CMAKE_BUILD_TYPE STREQUAL Debug)
22
SET(DEBUG_SRC "${CMAKE_CURRENT_SOURCE_DIR}/../opt/readies/cetara/diag/gdb.c")
33
endif()
44

5+
file (GLOB_RECURSE SERIALIZATION_SRC
6+
tensor.c
7+
model.c
8+
script.c
9+
backends.c
10+
stats.c
11+
config.c
12+
serialization/*.c)
13+
14+
file (GLOB BACKEND_COMMON_SRC
15+
backends/util.c
16+
err.c
17+
util/dict.c
18+
tensor.c
19+
serialization/ai_datatypes.c)
20+
521
ADD_LIBRARY(redisai_obj OBJECT
622
util/dict.c
723
util/queue.c
@@ -27,42 +43,38 @@ ADD_LIBRARY(redisai_obj OBJECT
2743
rmutil/heap.c
2844
rmutil/priority_queue.c
2945
rmutil/vector.c run_info.c
46+
redis_ai_types/model_type.c
47+
redis_ai_types/tensor_type.c
48+
redis_ai_types/script_type.c
49+
${SERIALIZATION_SRC}
3050
${DEBUG_SRC})
3151

3252
IF(BUILD_TF)
3353
ADD_LIBRARY(redisai_tensorflow_obj OBJECT
3454
backends/tensorflow.c
35-
backends/util.c
36-
err.c
37-
util/dict.c
38-
tensor.c)
55+
${BACKEND_COMMON_SRC}
56+
)
3957
ENDIF()
4058

4159
IF(BUILD_TFLITE)
4260
ADD_LIBRARY(redisai_tflite_obj OBJECT
4361
backends/tflite.c
44-
backends/util.c
45-
err.c
46-
util/dict.c
47-
tensor.c)
62+
${BACKEND_COMMON_SRC}
63+
)
4864
ENDIF()
4965

5066
IF(BUILD_TORCH)
5167
ADD_LIBRARY(redisai_torch_obj OBJECT
5268
backends/torch.c
53-
backends/util.c
54-
err.c
55-
util/dict.c
56-
tensor.c)
69+
${BACKEND_COMMON_SRC}
70+
)
5771
ENDIF()
5872

5973
IF(BUILD_ORT)
6074
ADD_LIBRARY(redisai_onnxruntime_obj OBJECT
6175
backends/onnxruntime.c
62-
backends/util.c
63-
err.c
64-
util/dict.c
65-
tensor.c)
76+
${BACKEND_COMMON_SRC}
77+
)
6678
ENDIF()
6779

6880
INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR})

src/model.c

Lines changed: 0 additions & 251 deletions
Original file line numberDiff line numberDiff line change
@@ -21,234 +21,6 @@
2121
#include <pthread.h>
2222
#include "DAG/dag.h"
2323

24-
RedisModuleType *RedisAI_ModelType = NULL;
25-
26-
static void *RAI_Model_RdbLoad(struct RedisModuleIO *io, int encver) {
27-
// if (encver != RAI_ENC_VER) {
28-
// /* We should actually log an error here, or try to implement
29-
// the ability to load older versions of our data structure. */
30-
// return NULL;
31-
// }
32-
33-
RAI_Backend backend = RedisModule_LoadUnsigned(io);
34-
const char *devicestr = RedisModule_LoadStringBuffer(io, NULL);
35-
36-
RedisModuleString *tag = RedisModule_LoadString(io);
37-
38-
const size_t batchsize = RedisModule_LoadUnsigned(io);
39-
const size_t minbatchsize = RedisModule_LoadUnsigned(io);
40-
41-
const size_t ninputs = RedisModule_LoadUnsigned(io);
42-
const char **inputs = RedisModule_Alloc(ninputs * sizeof(char *));
43-
44-
for (size_t i = 0; i < ninputs; i++) {
45-
inputs[i] = RedisModule_LoadStringBuffer(io, NULL);
46-
}
47-
48-
const size_t noutputs = RedisModule_LoadUnsigned(io);
49-
50-
const char **outputs = RedisModule_Alloc(ninputs * sizeof(char *));
51-
52-
for (size_t i = 0; i < noutputs; i++) {
53-
outputs[i] = RedisModule_LoadStringBuffer(io, NULL);
54-
}
55-
56-
RAI_ModelOpts opts = {
57-
.batchsize = batchsize,
58-
.minbatchsize = minbatchsize,
59-
.backends_intra_op_parallelism = getBackendsIntraOpParallelism(),
60-
.backends_inter_op_parallelism = getBackendsInterOpParallelism(),
61-
};
62-
63-
size_t len;
64-
char *buffer = NULL;
65-
66-
if (encver <= 100) {
67-
buffer = RedisModule_LoadStringBuffer(io, &len);
68-
} else {
69-
len = RedisModule_LoadUnsigned(io);
70-
buffer = RedisModule_Alloc(len);
71-
const size_t n_chunks = RedisModule_LoadUnsigned(io);
72-
long long chunk_offset = 0;
73-
for (size_t i = 0; i < n_chunks; i++) {
74-
size_t chunk_len;
75-
char *chunk_buffer = RedisModule_LoadStringBuffer(io, &chunk_len);
76-
memcpy(buffer + chunk_offset, chunk_buffer, chunk_len);
77-
chunk_offset += chunk_len;
78-
RedisModule_Free(chunk_buffer);
79-
}
80-
}
81-
82-
RAI_Error err = {0};
83-
84-
RAI_Model *model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs,
85-
outputs, buffer, len, &err);
86-
87-
if (err.code == RAI_EBACKENDNOTLOADED) {
88-
RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io);
89-
int ret = RAI_LoadDefaultBackend(ctx, backend);
90-
if (ret == REDISMODULE_ERR) {
91-
RedisModule_Log(ctx, "error", "Could not load default backend");
92-
RAI_ClearError(&err);
93-
return NULL;
94-
}
95-
RAI_ClearError(&err);
96-
model = RAI_ModelCreate(backend, devicestr, tag, opts, ninputs, inputs, noutputs, outputs,
97-
buffer, len, &err);
98-
}
99-
100-
if (err.code != RAI_OK) {
101-
RedisModuleCtx *ctx = RedisModule_GetContextFromIO(io);
102-
RedisModule_Log(ctx, "error", "%s", err.detail);
103-
RAI_ClearError(&err);
104-
if (buffer) {
105-
RedisModule_Free(buffer);
106-
}
107-
return NULL;
108-
}
109-
110-
for (size_t i = 0; i < ninputs; i++) {
111-
RedisModule_Free(inputs[i]);
112-
}
113-
for (size_t i = 0; i < noutputs; i++) {
114-
RedisModule_Free(outputs[i]);
115-
}
116-
RedisModule_Free(inputs);
117-
RedisModule_Free(outputs);
118-
RedisModule_Free(buffer);
119-
120-
RedisModuleCtx *stats_ctx = RedisModule_GetContextFromIO(io);
121-
RedisModuleString *stats_keystr =
122-
RedisModule_CreateStringFromString(stats_ctx, RedisModule_GetKeyNameFromIO(io));
123-
124-
model->infokey = RAI_AddStatsEntry(stats_ctx, stats_keystr, RAI_MODEL, backend, devicestr, tag);
125-
126-
RedisModule_FreeString(NULL, tag);
127-
RedisModule_Free(devicestr);
128-
RedisModule_FreeString(NULL, stats_keystr);
129-
130-
return model;
131-
}
132-
133-
static void RAI_Model_RdbSave(RedisModuleIO *io, void *value) {
134-
RAI_Model *model = (RAI_Model *)value;
135-
char *buffer = NULL;
136-
size_t len = 0;
137-
RAI_Error err = {0};
138-
139-
int ret = RAI_ModelSerialize(model, &buffer, &len, &err);
140-
141-
if (err.code != RAI_OK) {
142-
RedisModuleCtx *stats_ctx = RedisModule_GetContextFromIO(io);
143-
printf("ERR: %s\n", err.detail);
144-
RAI_ClearError(&err);
145-
if (buffer) {
146-
RedisModule_Free(buffer);
147-
}
148-
return;
149-
}
150-
151-
RedisModule_SaveUnsigned(io, model->backend);
152-
RedisModule_SaveStringBuffer(io, model->devicestr, strlen(model->devicestr) + 1);
153-
RedisModule_SaveString(io, model->tag);
154-
RedisModule_SaveUnsigned(io, model->opts.batchsize);
155-
RedisModule_SaveUnsigned(io, model->opts.minbatchsize);
156-
RedisModule_SaveUnsigned(io, model->ninputs);
157-
for (size_t i = 0; i < model->ninputs; i++) {
158-
RedisModule_SaveStringBuffer(io, model->inputs[i], strlen(model->inputs[i]) + 1);
159-
}
160-
RedisModule_SaveUnsigned(io, model->noutputs);
161-
for (size_t i = 0; i < model->noutputs; i++) {
162-
RedisModule_SaveStringBuffer(io, model->outputs[i], strlen(model->outputs[i]) + 1);
163-
}
164-
long long chunk_size = getModelChunkSize();
165-
const size_t n_chunks = len / chunk_size + 1;
166-
RedisModule_SaveUnsigned(io, len);
167-
RedisModule_SaveUnsigned(io, n_chunks);
168-
for (size_t i = 0; i < n_chunks; i++) {
169-
size_t chunk_len = i < n_chunks - 1 ? chunk_size : len % chunk_size;
170-
RedisModule_SaveStringBuffer(io, buffer + i * chunk_size, chunk_len);
171-
}
172-
173-
if (buffer) {
174-
RedisModule_Free(buffer);
175-
}
176-
}
177-
178-
static void RAI_Model_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, void *value) {
179-
RAI_Model *model = (RAI_Model *)value;
180-
181-
char *buffer = NULL;
182-
size_t len = 0;
183-
RAI_Error err = {0};
184-
185-
int ret = RAI_ModelSerialize(model, &buffer, &len, &err);
186-
187-
if (err.code != RAI_OK) {
188-
189-
printf("ERR: %s\n", err.detail);
190-
RAI_ClearError(&err);
191-
if (buffer) {
192-
RedisModule_Free(buffer);
193-
}
194-
return;
195-
}
196-
197-
// AI.MODELSET model_key backend device [INPUTS name1 name2 ... OUTPUTS name1
198-
// name2 ...] model_blob
199-
200-
RedisModuleString **inputs_ = array_new(RedisModuleString *, model->ninputs);
201-
RedisModuleString **outputs_ = array_new(RedisModuleString *, model->noutputs);
202-
203-
RedisModuleCtx *ctx = RedisModule_GetContextFromIO(aof);
204-
205-
for (size_t i = 0; i < model->ninputs; i++) {
206-
inputs_ = array_append(
207-
inputs_, RedisModule_CreateString(ctx, model->inputs[i], strlen(model->inputs[i])));
208-
}
209-
210-
for (size_t i = 0; i < model->noutputs; i++) {
211-
outputs_ = array_append(
212-
outputs_, RedisModule_CreateString(ctx, model->outputs[i], strlen(model->outputs[i])));
213-
}
214-
215-
long long chunk_size = getModelChunkSize();
216-
const size_t n_chunks = len / chunk_size + 1;
217-
RedisModuleString **buffers_ = array_new(RedisModuleString *, n_chunks);
218-
219-
for (size_t i = 0; i < n_chunks; i++) {
220-
size_t chunk_len = i < n_chunks - 1 ? chunk_size : len % chunk_size;
221-
buffers_ = array_append(buffers_,
222-
RedisModule_CreateString(ctx, buffer + i * chunk_size, chunk_len));
223-
}
224-
225-
if (buffer) {
226-
RedisModule_Free(buffer);
227-
}
228-
229-
const char *backendstr = RAI_BackendName(model->backend);
230-
231-
RedisModule_EmitAOF(aof, "AI.MODELSET", "sccsclclcvcvcv", key, backendstr, model->devicestr,
232-
model->tag, "BATCHSIZE", model->opts.batchsize, "MINBATCHSIZE",
233-
model->opts.minbatchsize, "INPUTS", inputs_, model->ninputs, "OUTPUTS",
234-
outputs_, model->noutputs, "BLOB", buffers_, n_chunks);
235-
236-
for (size_t i = 0; i < model->ninputs; i++) {
237-
RedisModule_FreeString(ctx, inputs_[i]);
238-
}
239-
array_free(inputs_);
240-
241-
for (size_t i = 0; i < model->noutputs; i++) {
242-
RedisModule_FreeString(ctx, outputs_[i]);
243-
}
244-
array_free(outputs_);
245-
246-
for (size_t i = 0; i < n_chunks; i++) {
247-
RedisModule_FreeString(ctx, buffers_[i]);
248-
}
249-
array_free(buffers_);
250-
}
251-
25224
/* Return REDISMODULE_ERR if there was an error getting the Model.
25325
* Return REDISMODULE_OK if the model value stored at key was correctly
25426
* returned and available at *model variable. */
@@ -270,29 +42,6 @@ int RAI_GetModelFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, Re
27042
return REDISMODULE_OK;
27143
}
27244

273-
// TODO: pass err in?
274-
static void RAI_Model_DTFree(void *value) {
275-
RAI_Error err = {0};
276-
RAI_ModelFree(value, &err);
277-
if (err.code != RAI_OK) {
278-
printf("ERR: %s\n", err.detail);
279-
RAI_ClearError(&err);
280-
}
281-
}
282-
283-
int RAI_ModelInit(RedisModuleCtx *ctx) {
284-
RedisModuleTypeMethods tmModel = {.version = REDISMODULE_TYPE_METHOD_VERSION,
285-
.rdb_load = RAI_Model_RdbLoad,
286-
.rdb_save = RAI_Model_RdbSave,
287-
.aof_rewrite = RAI_Model_AofRewrite,
288-
.mem_usage = NULL,
289-
.free = RAI_Model_DTFree,
290-
.digest = NULL};
291-
292-
RedisAI_ModelType = RedisModule_CreateDataType(ctx, "AI__MODEL", RAI_ENC_VER_MM, &tmModel);
293-
return RedisAI_ModelType != NULL;
294-
}
295-
29645
RAI_Model *RAI_ModelCreate(RAI_Backend backend, const char *devicestr, RedisModuleString *tag,
29746
RAI_ModelOpts opts, size_t ninputs, const char **inputs, size_t noutputs,
29847
const char **outputs, const char *modeldef, size_t modellen,

src/redis_ai_types/model_type.c

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "model_type.h"
2+
#include "../model.h"
3+
#include "../serialization/AOF/rai_aof_rewrite.h"
4+
#include "../serialization/RDB/encoder/rai_rdb_encode.h"
5+
#include "../serialization/RDB/decoder/rai_rdb_decoder.h"
6+
#include "../serialization/RDB/decoder/decode_previous.h"
7+
8+
extern RedisModuleType *RedisAI_ModelType;
9+
10+
static void *RAI_Model_RdbLoad(struct RedisModuleIO *io, int encver) {
11+
if (encver > REDISAI_ENC_VER) {
12+
RedisModule_LogIOError(
13+
io, "error", "Failed loading model, RedisAI version (%d) is not forward compatible.\n",
14+
REDISAI_MODULE_VERSION);
15+
return NULL;
16+
} else if (encver < REDISAI_ENC_VER) {
17+
return Decode_PreviousModel(io, encver);
18+
} else {
19+
return RAI_RDBLoadModel(io);
20+
}
21+
}
22+
23+
static void RAI_Model_RdbSave(RedisModuleIO *io, void *value) { RAI_RDBSaveModel(io, value); }
24+
25+
static void RAI_Model_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, void *value) {
26+
RAI_AOFRewriteModel(aof, key, value);
27+
}
28+
29+
static void RAI_Model_DTFree(void *value) {
30+
RAI_Error err = {0};
31+
RAI_ModelFree(value, &err);
32+
if (err.code != RAI_OK) {
33+
printf("ERR: %s\n", err.detail);
34+
RAI_ClearError(&err);
35+
}
36+
}
37+
38+
int ModelType_Register(RedisModuleCtx *ctx) {
39+
RedisModuleTypeMethods tmModel = {.version = REDISMODULE_TYPE_METHOD_VERSION,
40+
.rdb_load = RAI_Model_RdbLoad,
41+
.rdb_save = RAI_Model_RdbSave,
42+
.aof_rewrite = RAI_Model_AofRewrite,
43+
.mem_usage = NULL,
44+
.free = RAI_Model_DTFree,
45+
.digest = NULL};
46+
47+
RedisAI_ModelType = RedisModule_CreateDataType(ctx, "AI__MODEL", REDISAI_ENC_VER, &tmModel);
48+
return RedisAI_ModelType != NULL;
49+
}

src/redis_ai_types/model_type.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
3+
#include "redismodule.h"
4+
5+
int ModelType_Register(RedisModuleCtx *ctx);

0 commit comments

Comments
 (0)