Skip to content

Commit a5b5b3e

Browse files
committed
Use bitmap for tensor format (blob/values/meta).
1 parent b762e61 commit a5b5b3e

File tree

5 files changed

+32
-52
lines changed

5 files changed

+32
-52
lines changed

src/DAG/dag_parser.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ int _ParseDAGOps(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
316316
RAI_HoldString(NULL, currentOp->argv[1]);
317317
currentOp->inkeys = array_append(currentOp->inkeys, currentOp->argv[1]);
318318
currentOp->fmt = ParseTensorGetArgs(ctx, currentOp->argv, currentOp->argc);
319-
if (currentOp->fmt == REDISAI_TENSOR_NONE)
319+
if (currentOp->fmt == TENSOR_NONE)
320320
return REDISMODULE_ERR;
321321
continue;
322322
}

src/redisai.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ int RedisAI_TensorGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
139139
return REDISMODULE_ERR;
140140
}
141141

142-
RedisAI_TensorFmt fmt = ParseTensorGetArgs(ctx, argv, argc);
143-
if (fmt == REDISAI_TENSOR_NONE) {
142+
uint fmt = ParseTensorGetArgs(ctx, argv, argc);
143+
if (fmt == TENSOR_NONE) {
144144
// This means that args are invalid.
145145
return REDISMODULE_ERR;
146146
}

src/run_info.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ typedef struct RAI_DagOp {
3636
RAI_Tensor *outTensor; // The tensor to upload in TENSORSET op.
3737
RAI_ModelRunCtx *mctx;
3838
RAI_ScriptRunCtx *sctx;
39-
RedisAI_TensorFmt fmt; // This is relevant for TENSORGET op.
39+
uint fmt; // This is relevant for TENSORGET op.
4040
char *devicestr;
4141
int result; // REDISMODULE_OK or REDISMODULE_ERR
4242
long long duration_us;

src/tensor.c

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
657657
return -1;
658658
}
659659
const char *fmtstr;
660-
int datafmt = REDISAI_TENSOR_NONE;
660+
int datafmt = TENSOR_NONE;
661661
int tensorAllocMode = TENSORALLOC_CALLOC;
662662
size_t ndims = 0;
663663
long long len = 1;
@@ -671,7 +671,7 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
671671
const char *opt = RedisModule_StringPtrLen(argv[argpos], NULL);
672672
remaining_args = argc - 1 - argpos;
673673
if (!strcasecmp(opt, "BLOB")) {
674-
datafmt = REDISAI_TENSOR_BLOB;
674+
datafmt = TENSOR_BLOB;
675675
tensorAllocMode = TENSORALLOC_CALLOC;
676676
// if we've found the dataformat there are no more dimensions
677677
// check right away if the arity is correct
@@ -688,7 +688,7 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
688688
argpos++;
689689
break;
690690
} else if (!strcasecmp(opt, "VALUES")) {
691-
datafmt = REDISAI_TENSOR_VALUES;
691+
datafmt = TENSOR_VALUES;
692692
tensorAllocMode = TENSORALLOC_CALLOC;
693693
// if we've found the dataformat there are no more dimensions
694694
// check right away if the arity is correct
@@ -728,7 +728,7 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
728728
size_t datalen;
729729
const char *data;
730730
DLDataType datatype = RAI_TensorDataTypeFromString(typestr);
731-
if (datafmt == REDISAI_TENSOR_BLOB) {
731+
if (datafmt == TENSOR_BLOB) {
732732
RedisModuleString *rstr = argv[argpos];
733733
RedisModule_RetainString(NULL, rstr);
734734
*t = RAI_TensorCreateWithDLDataTypeAndRString(datatype, dims, ndims, rstr);
@@ -746,7 +746,7 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
746746
return -1;
747747
}
748748
long i = 0;
749-
if (datafmt == REDISAI_TENSOR_VALUES) {
749+
if (datafmt == TENSOR_VALUES) {
750750
for (; (argpos <= argc - 1) && (i < len); argpos++) {
751751
if (datatype.code == kDLFloat) {
752752
double val;
@@ -845,59 +845,42 @@ int RAI_TensorReplyWithValues(RedisModuleCtx *ctx, RAI_Tensor *t) {
845845
return 0;
846846
}
847847

848-
RedisAI_TensorFmt ParseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
849-
RedisAI_TensorFmt fmt = REDISAI_TENSOR_NONE;
848+
uint ParseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
849+
uint fmt = TENSOR_NONE;
850850
if (argc < 2 || argc > 4) {
851851
RedisModule_WrongArity(ctx);
852852
return fmt;
853853
}
854-
855-
bool meta_arg = false;
856-
bool blob_arg = false;
857-
bool values_arg = false;
858-
bool fmt_error = false;
859854
for (int i = 2; i < argc; i++) {
860855
const char *fmtstr = RedisModule_StringPtrLen(argv[i], NULL);
861856
if (!strcasecmp(fmtstr, "BLOB")) {
862-
blob_arg = true;
857+
fmt |= TENSOR_BLOB;
863858
} else if (!strcasecmp(fmtstr, "VALUES")) {
864-
values_arg = true;
859+
fmt |= TENSOR_VALUES;
865860
} else if (!strcasecmp(fmtstr, "META")) {
866-
meta_arg = true;
861+
fmt |= TENSOR_META;
867862
} else {
868-
fmt_error = true;
863+
RedisModule_ReplyWithError(ctx, "ERR unsupported data format");
864+
return TENSOR_NONE;
869865
}
870866
}
871-
if (fmt_error) {
872-
RedisModule_ReplyWithError(ctx, "ERR unsupported data format");
873-
return fmt;
874-
}
875-
if (blob_arg && values_arg) {
867+
868+
if (fmt == TENSOR_ILLEGAL_VALUES_BLOB) {
876869
RedisModule_ReplyWithError(ctx, "ERR both BLOB and VALUES specified");
877-
return fmt;
870+
return TENSOR_NONE;
878871
}
879-
if (blob_arg && !meta_arg)
880-
return REDISAI_TENSOR_BLOB;
881-
if (values_arg && !meta_arg)
882-
return REDISAI_TENSOR_VALUES;
883-
if (blob_arg && meta_arg)
884-
return REDISAI_TENSOR_BLOB_WITH_META;
885-
if (values_arg && meta_arg)
886-
return REDISAI_TENSOR_VALUES_WITH_META;
887-
if (!blob_arg && !values_arg && meta_arg)
888-
return REDISAI_TENSOR_META;
889872
return fmt;
890873
}
891874

892-
int ReplyWithTensor(RedisModuleCtx *ctx, RedisAI_TensorFmt fmt, RAI_Tensor *t) {
875+
int ReplyWithTensor(RedisModuleCtx *ctx, uint fmt, RAI_Tensor *t) {
893876

894-
if (fmt == REDISAI_TENSOR_BLOB) {
877+
if ((fmt & TENSOR_BLOB) && !(fmt & TENSOR_META)) {
895878
long long size = RAI_TensorByteSize(t);
896879
char *data = RAI_TensorData(t);
897880
RedisModule_ReplyWithStringBuffer(ctx, data, size);
898881
return REDISMODULE_OK;
899882
}
900-
if (fmt == REDISAI_TENSOR_VALUES) {
883+
if ((fmt & TENSOR_VALUES) && !(fmt & TENSOR_META)) {
901884
int ret = RAI_TensorReplyWithValues(ctx, t);
902885
if (ret == -1) {
903886
return REDISMODULE_ERR;
@@ -906,7 +889,7 @@ int ReplyWithTensor(RedisModuleCtx *ctx, RedisAI_TensorFmt fmt, RAI_Tensor *t) {
906889
}
907890

908891
long long resplen = 4;
909-
if (fmt == REDISAI_TENSOR_BLOB_WITH_META || fmt == REDISAI_TENSOR_VALUES_WITH_META)
892+
if ((fmt & TENSOR_BLOB) || (fmt & TENSOR_VALUES))
910893
resplen += 2;
911894

912895
const long long ndims = RAI_TensorNumDims(t);
@@ -929,13 +912,13 @@ int ReplyWithTensor(RedisModuleCtx *ctx, RedisAI_TensorFmt fmt, RAI_Tensor *t) {
929912
RedisModule_ReplyWithLongLong(ctx, dim);
930913
}
931914

932-
if (fmt == REDISAI_TENSOR_BLOB_WITH_META) {
915+
if (fmt & TENSOR_BLOB) {
933916
long long size = RAI_TensorByteSize(t);
934917
char *data = RAI_TensorData(t);
935918
RedisModule_ReplyWithCString(ctx, "blob");
936919
RedisModule_ReplyWithStringBuffer(ctx, data, size);
937920

938-
} else if (fmt == REDISAI_TENSOR_VALUES_WITH_META) {
921+
} else if (fmt & TENSOR_VALUES) {
939922
RedisModule_ReplyWithCString(ctx, "values");
940923
int ret = RAI_TensorReplyWithValues(ctx, t);
941924
if (ret != REDISMODULE_OK) {

src/tensor.h

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,11 @@ static const char *RAI_DATATYPE_STR_INT64 = "INT64";
3333
static const char *RAI_DATATYPE_STR_UINT8 = "UINT8";
3434
static const char *RAI_DATATYPE_STR_UINT16 = "UINT16";
3535

36-
typedef enum RedisAI_TensorFmt {
37-
REDISAI_TENSOR_NONE = 0,
38-
REDISAI_TENSOR_VALUES,
39-
REDISAI_TENSOR_META,
40-
REDISAI_TENSOR_BLOB_WITH_META,
41-
REDISAI_TENSOR_VALUES_WITH_META,
42-
REDISAI_TENSOR_BLOB
43-
} RedisAI_TensorFmt;
36+
#define TENSOR_NONE 0
37+
#define TENSOR_VALUES (1 << 0)
38+
#define TENSOR_META (1 << 1)
39+
#define TENSOR_BLOB (1 << 2)
40+
#define TENSOR_ILLEGAL_VALUES_BLOB (TENSOR_VALUES | TENSOR_BLOB)
4441

4542
extern RedisModuleType *RedisAI_TensorType;
4643

@@ -387,7 +384,7 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
387384
* @return The format in which tensor is returned.
388385
*/
389386

390-
RedisAI_TensorFmt ParseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc);
387+
uint ParseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc);
391388

392389
/**
393390
* Helper method to return a tensor to the client in a response to AI.TENSORGET
@@ -399,7 +396,7 @@ RedisAI_TensorFmt ParseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **ar
399396
* @return REDISMODULE_OK in case of success, REDISMODULE_ERR otherwise.
400397
*/
401398

402-
int ReplyWithTensor(RedisModuleCtx *ctx, RedisAI_TensorFmt fmt, RAI_Tensor *t);
399+
int ReplyWithTensor(RedisModuleCtx *ctx, uint fmt, RAI_Tensor *t);
403400

404401
/**
405402
* @brief Returns the redis module type representing a tensor.

0 commit comments

Comments
 (0)