@@ -46,14 +46,18 @@ static size_t Tensor_DataTypeSize(DLDataType dtype) {
4646 return dtype .bits / 8 ;
4747}
4848
49- static void Tensor_DataTypeStr (DLDataType dtype , char * * dtypestr ) {
49+ void Tensor_DataTypeStr (DLDataType dtype , char * * dtypestr ) {
5050 * dtypestr = RedisModule_Calloc (8 , sizeof (char ));
5151 if (dtype .code == kDLFloat ) {
5252 if (dtype .bits == 32 ) {
53- strcpy (* dtypestr , "FLOAT32 " );
53+ strcpy (* dtypestr , "FLOAT " );
5454 }
5555 else if (dtype .bits == 64 ) {
56- strcpy (* dtypestr , "FLOAT64" );
56+ strcpy (* dtypestr , "DOUBLE" );
57+ }
58+ else {
59+ RedisModule_Free (* dtypestr );
60+ * dtypestr = NULL ;
5761 }
5862 }
5963 else if (dtype .code == kDLInt ) {
@@ -69,6 +73,10 @@ static void Tensor_DataTypeStr(DLDataType dtype, char **dtypestr) {
6973 else if (dtype .bits == 64 ) {
7074 strcpy (* dtypestr , "INT64" );
7175 }
76+ else {
77+ RedisModule_Free (* dtypestr );
78+ * dtypestr = NULL ;
79+ }
7280 }
7381 else if (dtype .code == kDLUInt ) {
7482 if (dtype .bits == 8 ) {
@@ -77,6 +85,10 @@ static void Tensor_DataTypeStr(DLDataType dtype, char **dtypestr) {
7785 else if (dtype .bits == 16 ) {
7886 strcpy (* dtypestr , "UINT16" );
7987 }
88+ else {
89+ RedisModule_Free (* dtypestr );
90+ * dtypestr = NULL ;
91+ }
8092 }
8193}
8294
@@ -175,51 +187,21 @@ static void RAI_Tensor_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, vo
175187 RAI_Tensor * tensor = (RAI_Tensor * )value ;
176188
177189 char * dtypestr = NULL ;
178-
179190 Tensor_DataTypeStr (RAI_TensorDataType (tensor ), & dtypestr );
180191
181- int64_t * shape = tensor -> tensor .dl_tensor .shape ;
182- char * data = RAI_TensorData (tensor );
183- size_t size = RAI_TensorByteSize (tensor );
192+ char * data = RAI_TensorData (tensor );
193+ long long size = RAI_TensorByteSize (tensor );
194+
195+ long long ndims = RAI_TensorNumDims (tensor );
196+
197+ RedisModuleString * dims [ndims ];
184198
185- // We switch over the dimensions of the tensor up to 7
186- // The reason is that we don't have a way to pass a vector of long long to RedisModule_EmitAOF,
187- // there's no format for it. Vector of strings is supported (format 'v').
188- // This might change in the future, but it needs to change in redis/src/module.c
189-
190- switch (RAI_TensorNumDims (tensor )) {
191- case 1 :
192- RedisModule_EmitAOF (aof , "AI.TENSORSET" , "sllcb" ,
193- key , dtypestr , RAI_SPLICE_SHAPE_1 (shape ), "BLOB" , data , size );
194- break ;
195- case 2 :
196- RedisModule_EmitAOF (aof , "AI.TENSORSET" , "slllcb" ,
197- key , dtypestr , RAI_SPLICE_SHAPE_2 (shape ), "BLOB" , data , size );
198- break ;
199- case 3 :
200- RedisModule_EmitAOF (aof , "AI.TENSORSET" , "sllllcb" ,
201- key , dtypestr , RAI_SPLICE_SHAPE_3 (shape ), "BLOB" , data , size );
202- break ;
203- case 4 :
204- RedisModule_EmitAOF (aof , "AI.TENSORSET" , "slllllcb" ,
205- key , dtypestr , RAI_SPLICE_SHAPE_4 (shape ), "BLOB" , data , size );
206- break ;
207- case 5 :
208- RedisModule_EmitAOF (aof , "AI.TENSORSET" , "sllllllcb" ,
209- key , dtypestr , RAI_SPLICE_SHAPE_5 (shape ), "BLOB" , data , size );
210- break ;
211- case 6 :
212- RedisModule_EmitAOF (aof , "AI.TENSORSET" , "slllllllcb" ,
213- key , dtypestr , RAI_SPLICE_SHAPE_6 (shape ), "BLOB" , data , size );
214- break ;
215- case 7 :
216- RedisModule_EmitAOF (aof , "AI.TENSORSET" , "sllllllllcb" ,
217- key , dtypestr , RAI_SPLICE_SHAPE_7 (shape ), "BLOB" , data , size );
218- break ;
219- default :
220- printf ("ERR: AOF serialization supports tensors of dimension up to 7\n" );
199+ for (long long i = 0 ; i < ndims ; i ++ ) {
200+ dims [i ] = RedisModule_CreateStringFromLongLong (RedisModule_GetContextFromIO (aof ), RAI_TensorDim (tensor , i ));
221201 }
222202
203+ RedisModule_EmitAOF (aof , "AI.TENSORSET" , "scvcb" , key , dtypestr , dims , ndims , "BLOB" , data , size );
204+
223205 RedisModule_Free (dtypestr );
224206}
225207
0 commit comments