@@ -59,47 +59,38 @@ DLDataType RAI_TensorDataTypeFromString(const char *typestr) {
5959
6060static size_t Tensor_DataTypeSize (DLDataType dtype ) { return dtype .bits / 8 ; }
6161
62- int Tensor_DataTypeStr (DLDataType dtype , char * * dtypestr ) {
62+ int Tensor_DataTypeStr (DLDataType dtype , char * dtypestr ) {
6363 int result = REDISMODULE_ERR ;
64- * dtypestr = RedisModule_Calloc ( 8 , sizeof ( char ));
64+
6565 if (dtype .code == kDLFloat ) {
6666 if (dtype .bits == 32 ) {
67- strcpy (* dtypestr , RAI_DATATYPE_STR_FLOAT );
67+ strcpy (dtypestr , RAI_DATATYPE_STR_FLOAT );
6868 result = REDISMODULE_OK ;
6969 } else if (dtype .bits == 64 ) {
70- strcpy (* dtypestr , RAI_DATATYPE_STR_DOUBLE );
70+ strcpy (dtypestr , RAI_DATATYPE_STR_DOUBLE );
7171 result = REDISMODULE_OK ;
72- } else {
73- RedisModule_Free (* dtypestr );
74- * dtypestr = NULL ;
7572 }
7673 } else if (dtype .code == kDLInt ) {
7774 if (dtype .bits == 8 ) {
78- strcpy (* dtypestr , RAI_DATATYPE_STR_INT8 );
75+ strcpy (dtypestr , RAI_DATATYPE_STR_INT8 );
7976 result = REDISMODULE_OK ;
8077 } else if (dtype .bits == 16 ) {
81- strcpy (* dtypestr , RAI_DATATYPE_STR_INT16 );
78+ strcpy (dtypestr , RAI_DATATYPE_STR_INT16 );
8279 result = REDISMODULE_OK ;
8380 } else if (dtype .bits == 32 ) {
84- strcpy (* dtypestr , RAI_DATATYPE_STR_INT32 );
81+ strcpy (dtypestr , RAI_DATATYPE_STR_INT32 );
8582 result = REDISMODULE_OK ;
8683 } else if (dtype .bits == 64 ) {
87- strcpy (* dtypestr , RAI_DATATYPE_STR_INT64 );
84+ strcpy (dtypestr , RAI_DATATYPE_STR_INT64 );
8885 result = REDISMODULE_OK ;
89- } else {
90- RedisModule_Free (* dtypestr );
91- * dtypestr = NULL ;
9286 }
9387 } else if (dtype .code == kDLUInt ) {
9488 if (dtype .bits == 8 ) {
95- strcpy (* dtypestr , RAI_DATATYPE_STR_UINT8 );
89+ strcpy (dtypestr , RAI_DATATYPE_STR_UINT8 );
9690 result = REDISMODULE_OK ;
9791 } else if (dtype .bits == 16 ) {
98- strcpy (* dtypestr , RAI_DATATYPE_STR_UINT16 );
92+ strcpy (dtypestr , RAI_DATATYPE_STR_UINT16 );
9993 result = REDISMODULE_OK ;
100- } else {
101- RedisModule_Free (* dtypestr );
102- * dtypestr = NULL ;
10394 }
10495 }
10596 return result ;
@@ -195,8 +186,9 @@ static void RAI_Tensor_RdbSave(RedisModuleIO *io, void *value) {
195186static void RAI_Tensor_AofRewrite (RedisModuleIO * aof , RedisModuleString * key , void * value ) {
196187 RAI_Tensor * tensor = (RAI_Tensor * )value ;
197188
198- char * dtypestr = NULL ;
199- Tensor_DataTypeStr (RAI_TensorDataType (tensor ), & dtypestr );
189+ char dtypestr [8 ];
190+ const int status = Tensor_DataTypeStr (RAI_TensorDataType (tensor ), dtypestr );
191+ RedisModule_Assert (status == REDISMODULE_OK );
200192
201193 char * data = RAI_TensorData (tensor );
202194 long long size = RAI_TensorByteSize (tensor );
@@ -212,8 +204,6 @@ static void RAI_Tensor_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, vo
212204
213205 RedisModule_EmitAOF (aof , "AI.TENSORSET" , "scvcb" , key , dtypestr , dims , ndims , "BLOB" , data ,
214206 size );
215-
216- RedisModule_Free (dtypestr );
217207}
218208
219209static void RAI_Tensor_DTFree (void * value ) { RAI_TensorFree (value ); }
@@ -756,10 +746,9 @@ int RAI_getTensorFromLocalContext(RedisModuleCtx *ctx, AI_dict *localContextDict
756746void RedisAI_ReplicateTensorSet (RedisModuleCtx * ctx , RedisModuleString * key , RAI_Tensor * t ) {
757747 long long ndims = RAI_TensorNumDims (t );
758748
759- char * dtypestr = NULL ;
760- Tensor_DataTypeStr (RAI_TensorDataType (t ), & dtypestr );
761-
762- assert (dtypestr );
749+ char dtypestr [8 ];
750+ const int status = Tensor_DataTypeStr (RAI_TensorDataType (t ), dtypestr );
751+ RedisModule_Assert (status == REDISMODULE_OK );
763752
764753 char * data = RAI_TensorData (t );
765754 long long size = RAI_TensorByteSize (t );
@@ -776,8 +765,6 @@ void RedisAI_ReplicateTensorSet(RedisModuleCtx *ctx, RedisModuleString *key, RAI
776765 for (long long i = 0 ; i < ndims ; i ++ ) {
777766 RedisModule_FreeString (ctx , dims [i ]);
778767 }
779-
780- RedisModule_Free (dtypestr );
781768}
782769
783770int RAI_parseTensorSetArgs (RedisModuleCtx * ctx , RedisModuleString * * argv , int argc , RAI_Tensor * * t ,
@@ -1054,19 +1041,18 @@ int RAI_parseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
10541041
10551042 const long long ndims = RAI_TensorNumDims (t );
10561043
1057- char * dtypestr = NULL ;
1058- const int dtypestr_result = Tensor_DataTypeStr (RAI_TensorDataType (t ), & dtypestr );
1044+ char dtypestr [ 8 ] ;
1045+ const int dtypestr_result = Tensor_DataTypeStr (RAI_TensorDataType (t ), dtypestr );
10591046 if (dtypestr_result == REDISMODULE_ERR ) {
10601047 RedisModule_ReplyWithError (ctx , "ERR unsupported dtype" );
10611048 return -1 ;
10621049 }
10631050
10641051 RedisModule_ReplyWithArray (ctx , resplen );
1065-
10661052 RedisModule_ReplyWithCString (ctx , "dtype" );
10671053 RedisModule_ReplyWithCString (ctx , dtypestr );
1068-
10691054 RedisModule_ReplyWithCString (ctx , "shape" );
1055+
10701056 RedisModule_ReplyWithArray (ctx , ndims );
10711057 for (long long i = 0 ; i < ndims ; i ++ ) {
10721058 const long long dim = RAI_TensorDim (t , i );
0 commit comments