Skip to content

Commit 054ddae

Browse files
committed
Add manual device placement for graphs
1 parent adc1987 commit 054ddae

File tree

11 files changed

+62
-17
lines changed

11 files changed

+62
-17
lines changed

examples/js/mobilenet.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ async function run(filenames) {
4949
const buffer = fs.readFileSync(graph_filename, {'flag': 'r'});
5050

5151
console.log("Setting graph");
52-
redis.call('AI.SET', 'GRAPH', 'mobilenet', 'TF', buffer);
52+
redis.call('AI.SET', 'GRAPH', 'mobilenet', 'TF', 'GPU', buffer);
5353

5454
const image_height = 224;
5555
const image_width = 224;

examples/models/load_model.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
REDIS_CLI=../../deps/redis/src/redis-cli
22

33
echo "SET GRAPH"
4-
$REDIS_CLI -x AI.SET GRAPH foo TF < graph.pb
4+
$REDIS_CLI -x AI.SET GRAPH foo TF GPU < graph.pb
55

66
echo "SET TENSORS"
77
$REDIS_CLI AI.SET TENSOR a FLOAT 1 2 VALUES 2 3

examples/models/load_yolo.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ IMAGE_WIDTH=224
1313
IMAGE_HEIGHT=224
1414

1515
echo "SET GRAPH"
16-
$REDIS_CLI -x AI.SET GRAPH $GRAPH_KEY TF < $GRAPH_FILE
16+
$REDIS_CLI -x AI.SET GRAPH $GRAPH_KEY TF GPU < $GRAPH_FILE
1717

1818
# TODO: cast tensor, change shape of tensor (NHWC, NCHW)
1919
# instead of casting, we could specify the type of data provided in the blob

src/backends/tensorflow.c

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ TF_Tensor* RAI_TFTensorFromTensor(RAI_Tensor* t){
136136
}
137137

138138

139-
RAI_Graph *RAI_GraphCreateTF(const char *prefix, RAI_Backend backend,
139+
RAI_Graph *RAI_GraphCreateTF(const char *prefix,
140+
RAI_Backend backend, RAI_Device device,
140141
const char *graphdef, size_t graphlen) {
141142
TF_Graph* graph = TF_NewGraph();
142143

@@ -160,9 +161,31 @@ RAI_Graph *RAI_GraphCreateTF(const char *prefix, RAI_Backend backend,
160161
TF_DeleteBuffer(buffer);
161162
TF_DeleteStatus(status);
162163

163-
TF_Status *sessionStatus = TF_NewStatus();
164+
TF_Status *optionsStatus = TF_NewStatus();
164165

165166
TF_SessionOptions *sessionOptions = TF_NewSessionOptions();
167+
168+
// For setting config options in session from the C API see:
169+
// https://github.com/tensorflow/tensorflow/issues/13853
170+
// import tensorflow as tf
171+
// config = tf.ConfigProto()
172+
// config.intra_op_parallelism_threads = 1
173+
// serialized = config.SerializeToString()
174+
// result = list(map(hex, serialized))
175+
176+
// TODO: complain if device is GPU and GPU not available?
177+
if (device == RAI_DEVICE_CPU) {
178+
uint8_t config[9] = {0x0a, 0x07, 0x0a, 0x03, 0x47, 0x50, 0x55, 0x10, 0x00};
179+
TF_SetConfig(sessionOptions, (void *)config, 9, status);
180+
}
181+
182+
if (TF_GetCode(optionsStatus) != TF_OK) {
183+
// TODO: free memory
184+
return NULL;
185+
}
186+
TF_DeleteStatus(optionsStatus);
187+
188+
TF_Status *sessionStatus = TF_NewStatus();
166189
TF_Session *session = TF_NewSession(graph, sessionOptions, sessionStatus);
167190

168191
if (TF_GetCode(sessionStatus) != TF_OK) {

src/backends/tensorflow.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor);
1111

1212
TF_Tensor* RAI_TFTensorFromTensor(RAI_Tensor* t);
1313

14-
RAI_Graph *RAI_GraphCreateTF(const char *prefix, RAI_Backend backend,
14+
RAI_Graph *RAI_GraphCreateTF(const char *prefix,
15+
RAI_Backend backend, RAI_Device device,
1516
const char *graphdef, size_t graphlen);
1617

1718
void RAI_GraphFreeTF(RAI_Graph* graph);

src/config.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ typedef enum {
1111
RAI_BACKEND_ONNXRUNTIME,
1212
} RAI_Backend;
1313

14+
typedef enum {
15+
RAI_DEVICE_CPU = 0,
16+
// TODO: multi GPU
17+
RAI_DEVICE_GPU,
18+
} RAI_Device;
19+
1420
//#define RAI_COPY_RUN_INPUT
1521
#define RAI_COPY_RUN_OUTPUT
1622

src/graph.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,11 @@ int RAI_GraphInit(RedisModuleCtx* ctx) {
3737
return RedisAI_GraphType != NULL;
3838
}
3939

40-
RAI_Graph *RAI_GraphCreate(const char *prefix, RAI_Backend backend,
40+
RAI_Graph *RAI_GraphCreate(const char *prefix,
41+
RAI_Backend backend, RAI_Device device,
4142
const char *graphdef, size_t graphlen) {
4243
if (backend == RAI_BACKEND_TENSORFLOW) {
43-
return RAI_GraphCreateTF(prefix, backend, graphdef, graphlen);
44+
return RAI_GraphCreateTF(prefix, backend, device, graphdef, graphlen);
4445
}
4546

4647
return NULL;

src/graph.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
extern RedisModuleType *RedisAI_GraphType;
1717

1818
int RAI_GraphInit(RedisModuleCtx* ctx);
19-
RAI_Graph* RAI_GraphCreate(const char* prefix, RAI_Backend backend, const char* graphdef, size_t graphlen);
19+
RAI_Graph* RAI_GraphCreate(const char* prefix, RAI_Backend backend, RAI_Device device, const char* graphdef, size_t graphlen);
2020
void RAI_GraphFree(RAI_Graph* graph);
2121
RAI_GraphRunCtx* RAI_RunCtxCreate(RAI_Graph* graph);
2222
int RAI_RunCtxAddInput(RAI_GraphRunCtx* gctx, const char* inputName, RAI_Tensor* inputTensor);

src/redisai.c

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ int RedisAI_Get_Tensor_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg
392392
int RedisAI_Set_Graph_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
393393
RedisModule_AutoMemory(ctx);
394394

395-
if ((argc != 4) && (argc != 5)) return RedisModule_WrongArity(ctx);
395+
if ((argc != 5) && (argc != 6)) return RedisModule_WrongArity(ctx);
396396

397397
const char* bckstr;
398398
int backend;
@@ -410,16 +410,29 @@ int RedisAI_Set_Graph_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
410410
return RedisModule_ReplyWithError(ctx, "ERR unsupported backend");
411411
}
412412

413+
const char* devicestr;
414+
int device;
415+
devicestr = RedisModule_StringPtrLen(argv[3], NULL);
416+
if (strcasecmp(devicestr, "CPU") == 0) {
417+
device = RAI_DEVICE_CPU;
418+
}
419+
else if (strcasecmp(devicestr, "GPU") == 0) {
420+
device = RAI_DEVICE_GPU;
421+
}
422+
else {
423+
return RedisModule_ReplyWithError(ctx, "ERR unsupported device");
424+
}
425+
413426
RAI_Graph *graph = NULL;
414427

415428
size_t graphlen;
416-
const char *graphdef = RedisModule_StringPtrLen(argv[3], &graphlen);
429+
const char *graphdef = RedisModule_StringPtrLen(argv[4], &graphlen);
417430
const char *prefix = "";
418-
if (argc == 5) {
419-
const char *prefix = RedisModule_StringPtrLen(argv[4], NULL);
431+
if (argc == 6) {
432+
const char *prefix = RedisModule_StringPtrLen(argv[5], NULL);
420433
}
421434

422-
graph = RAI_GraphCreate(prefix, backend, graphdef, graphlen);
435+
graph = RAI_GraphCreate(prefix, backend, device, graphdef, graphlen);
423436

424437
if(graph == NULL){
425438
return RedisModule_ReplyWithError(ctx, "ERR failed creating the graph");

src/redisai.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ typedef struct RAI_Graph RAI_Graph;
1313
typedef struct RAI_GraphRunCtx RAI_GraphRunCtx;
1414

1515
typedef enum RAI_Backend RAI_Backend;
16+
typedef enum RAI_Device RAI_Device;
1617

1718
RAI_Tensor* MODULE_API_FUNC(RedisAI_TensorCreate)(const char* dataTypeStr, long long* dims, int ndims);
1819
size_t MODULE_API_FUNC(RedisAI_TensorLength)(RAI_Tensor* t);
@@ -30,7 +31,7 @@ long long MODULE_API_FUNC(RedisAI_TensorDim)(RAI_Tensor* t, int dim);
3031
size_t MODULE_API_FUNC(RedisAI_TensorByteSize)(RAI_Tensor* t);
3132
char* MODULE_API_FUNC(RedisAI_TensorData)(RAI_Tensor* t);
3233

33-
RAI_Graph* MODULE_API_FUNC(RedisAI_GraphCreate)(const char* prefix, RAI_Backend backend, const char* graphdef, size_t graphlen);
34+
RAI_Graph* MODULE_API_FUNC(RedisAI_GraphCreate)(const char* prefix, RAI_Backend backend, RAI_Device device, const char* graphdef, size_t graphlen);
3435
void MODULE_API_FUNC(RedisAI_GraphFree)(RAI_Graph* graph);
3536
RAI_GraphRunCtx* MODULE_API_FUNC(RedisAI_RunCtxCreate)(RAI_Graph* graph);
3637
int MODULE_API_FUNC(RedisAI_RunCtxAddInput)(RAI_GraphRunCtx* gctx, const char* inputName, RAI_Tensor* inputTensor);

0 commit comments

Comments
 (0)