@@ -657,7 +657,7 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
657657 return -1 ;
658658 }
659659 const char * fmtstr ;
660- int datafmt = REDISAI_TENSOR_NONE ;
660+ int datafmt = TENSOR_NONE ;
661661 int tensorAllocMode = TENSORALLOC_CALLOC ;
662662 size_t ndims = 0 ;
663663 long long len = 1 ;
@@ -671,7 +671,7 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
671671 const char * opt = RedisModule_StringPtrLen (argv [argpos ], NULL );
672672 remaining_args = argc - 1 - argpos ;
673673 if (!strcasecmp (opt , "BLOB" )) {
674- datafmt = REDISAI_TENSOR_BLOB ;
674+ datafmt = TENSOR_BLOB ;
675675 tensorAllocMode = TENSORALLOC_CALLOC ;
676676 // if we've found the dataformat there are no more dimensions
677677 // check right away if the arity is correct
@@ -688,7 +688,7 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
688688 argpos ++ ;
689689 break ;
690690 } else if (!strcasecmp (opt , "VALUES" )) {
691- datafmt = REDISAI_TENSOR_VALUES ;
691+ datafmt = TENSOR_VALUES ;
692692 tensorAllocMode = TENSORALLOC_CALLOC ;
693693 // if we've found the dataformat there are no more dimensions
694694 // check right away if the arity is correct
@@ -728,7 +728,7 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
728728 size_t datalen ;
729729 const char * data ;
730730 DLDataType datatype = RAI_TensorDataTypeFromString (typestr );
731- if (datafmt == REDISAI_TENSOR_BLOB ) {
731+ if (datafmt == TENSOR_BLOB ) {
732732 RedisModuleString * rstr = argv [argpos ];
733733 RedisModule_RetainString (NULL , rstr );
734734 * t = RAI_TensorCreateWithDLDataTypeAndRString (datatype , dims , ndims , rstr );
@@ -746,7 +746,7 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
746746 return -1 ;
747747 }
748748 long i = 0 ;
749- if (datafmt == REDISAI_TENSOR_VALUES ) {
749+ if (datafmt == TENSOR_VALUES ) {
750750 for (; (argpos <= argc - 1 ) && (i < len ); argpos ++ ) {
751751 if (datatype .code == kDLFloat ) {
752752 double val ;
@@ -845,59 +845,42 @@ int RAI_TensorReplyWithValues(RedisModuleCtx *ctx, RAI_Tensor *t) {
845845 return 0 ;
846846}
847847
848- RedisAI_TensorFmt ParseTensorGetArgs (RedisModuleCtx * ctx , RedisModuleString * * argv , int argc ) {
849- RedisAI_TensorFmt fmt = REDISAI_TENSOR_NONE ;
848+ uint ParseTensorGetArgs (RedisModuleCtx * ctx , RedisModuleString * * argv , int argc ) {
849+ uint fmt = TENSOR_NONE ;
850850 if (argc < 2 || argc > 4 ) {
851851 RedisModule_WrongArity (ctx );
852852 return fmt ;
853853 }
854-
855- bool meta_arg = false;
856- bool blob_arg = false;
857- bool values_arg = false;
858- bool fmt_error = false;
859854 for (int i = 2 ; i < argc ; i ++ ) {
860855 const char * fmtstr = RedisModule_StringPtrLen (argv [i ], NULL );
861856 if (!strcasecmp (fmtstr , "BLOB" )) {
862- blob_arg = true ;
857+ fmt |= TENSOR_BLOB ;
863858 } else if (!strcasecmp (fmtstr , "VALUES" )) {
864- values_arg = true ;
859+ fmt |= TENSOR_VALUES ;
865860 } else if (!strcasecmp (fmtstr , "META" )) {
866- meta_arg = true ;
861+ fmt |= TENSOR_META ;
867862 } else {
868- fmt_error = true;
863+ RedisModule_ReplyWithError (ctx , "ERR unsupported data format" );
864+ return TENSOR_NONE ;
869865 }
870866 }
871- if (fmt_error ) {
872- RedisModule_ReplyWithError (ctx , "ERR unsupported data format" );
873- return fmt ;
874- }
875- if (blob_arg && values_arg ) {
867+
868+ if (fmt == TENSOR_ILLEGAL_VALUES_BLOB ) {
876869 RedisModule_ReplyWithError (ctx , "ERR both BLOB and VALUES specified" );
877- return fmt ;
870+ return TENSOR_NONE ;
878871 }
879- if (blob_arg && !meta_arg )
880- return REDISAI_TENSOR_BLOB ;
881- if (values_arg && !meta_arg )
882- return REDISAI_TENSOR_VALUES ;
883- if (blob_arg && meta_arg )
884- return REDISAI_TENSOR_BLOB_WITH_META ;
885- if (values_arg && meta_arg )
886- return REDISAI_TENSOR_VALUES_WITH_META ;
887- if (!blob_arg && !values_arg && meta_arg )
888- return REDISAI_TENSOR_META ;
889872 return fmt ;
890873}
891874
892- int ReplyWithTensor (RedisModuleCtx * ctx , RedisAI_TensorFmt fmt , RAI_Tensor * t ) {
875+ int ReplyWithTensor (RedisModuleCtx * ctx , uint fmt , RAI_Tensor * t ) {
893876
894- if (fmt == REDISAI_TENSOR_BLOB ) {
877+ if (( fmt & TENSOR_BLOB ) && !( fmt & TENSOR_META ) ) {
895878 long long size = RAI_TensorByteSize (t );
896879 char * data = RAI_TensorData (t );
897880 RedisModule_ReplyWithStringBuffer (ctx , data , size );
898881 return REDISMODULE_OK ;
899882 }
900- if (fmt == REDISAI_TENSOR_VALUES ) {
883+ if (( fmt & TENSOR_VALUES ) && !( fmt & TENSOR_META ) ) {
901884 int ret = RAI_TensorReplyWithValues (ctx , t );
902885 if (ret == -1 ) {
903886 return REDISMODULE_ERR ;
@@ -906,7 +889,7 @@ int ReplyWithTensor(RedisModuleCtx *ctx, RedisAI_TensorFmt fmt, RAI_Tensor *t) {
906889 }
907890
908891 long long resplen = 4 ;
909- if (fmt == REDISAI_TENSOR_BLOB_WITH_META || fmt == REDISAI_TENSOR_VALUES_WITH_META )
892+ if (( fmt & TENSOR_BLOB ) || ( fmt & TENSOR_VALUES ) )
910893 resplen += 2 ;
911894
912895 const long long ndims = RAI_TensorNumDims (t );
@@ -929,13 +912,13 @@ int ReplyWithTensor(RedisModuleCtx *ctx, RedisAI_TensorFmt fmt, RAI_Tensor *t) {
929912 RedisModule_ReplyWithLongLong (ctx , dim );
930913 }
931914
932- if (fmt == REDISAI_TENSOR_BLOB_WITH_META ) {
915+ if (fmt & TENSOR_BLOB ) {
933916 long long size = RAI_TensorByteSize (t );
934917 char * data = RAI_TensorData (t );
935918 RedisModule_ReplyWithCString (ctx , "blob" );
936919 RedisModule_ReplyWithStringBuffer (ctx , data , size );
937920
938- } else if (fmt == REDISAI_TENSOR_VALUES_WITH_META ) {
921+ } else if (fmt & TENSOR_VALUES ) {
939922 RedisModule_ReplyWithCString (ctx , "values" );
940923 int ret = RAI_TensorReplyWithValues (ctx , t );
941924 if (ret != REDISMODULE_OK ) {
0 commit comments