@@ -289,6 +289,60 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in
289289 return ret ;
290290}
291291
292+ void RAI_RStringDataTensorDeleter (DLManagedTensor * arg ) {
293+ if (arg -> dl_tensor .shape ) {
294+ RedisModule_Free (arg -> dl_tensor .shape );
295+ }
296+ if (arg -> dl_tensor .strides ) {
297+ RedisModule_Free (arg -> dl_tensor .strides );
298+ }
299+ if (arg -> manager_ctx ) {
300+ RedisModuleString * rstr = (RedisModuleString * )arg -> manager_ctx ;
301+ RedisModule_FreeString (NULL , rstr );
302+ }
303+
304+ RedisModule_Free (arg );
305+ }
306+
307+ RAI_Tensor * RAI_TensorCreateWithDLDataTypeAndRString (DLDataType dtype , long long * dims , int ndims ,
308+ RedisModuleString * rstr ) {
309+ const size_t dtypeSize = Tensor_DataTypeSize (dtype );
310+ if (dtypeSize == 0 ) {
311+ return NULL ;
312+ }
313+
314+ RAI_Tensor * ret = RedisModule_Alloc (sizeof (* ret ));
315+ int64_t * shape = RedisModule_Alloc (ndims * sizeof (* shape ));
316+ int64_t * strides = RedisModule_Alloc (ndims * sizeof (* strides ));
317+
318+ size_t len = 1 ;
319+ for (int64_t i = 0 ; i < ndims ; ++ i ) {
320+ shape [i ] = dims [i ];
321+ strides [i ] = 1 ;
322+ len *= dims [i ];
323+ }
324+ for (int64_t i = ndims - 2 ; i >= 0 ; -- i ) {
325+ strides [i ] *= strides [i + 1 ] * shape [i + 1 ];
326+ }
327+
328+ DLContext ctx = (DLContext ){.device_type = kDLCPU , .device_id = 0 };
329+
330+ char * data = (char * )RedisModule_StringPtrLen (rstr , NULL );
331+
332+ ret -> tensor = (DLManagedTensor ){.dl_tensor = (DLTensor ){.ctx = ctx ,
333+ .data = data ,
334+ .ndim = ndims ,
335+ .dtype = dtype ,
336+ .shape = shape ,
337+ .strides = strides ,
338+ .byte_offset = 0 },
339+ .manager_ctx = rstr ,
340+ .deleter = RAI_RStringDataTensorDeleter };
341+
342+ ret -> refCount = 1 ;
343+ return ret ;
344+ }
345+
292346RAI_Tensor * RAI_TensorCreate (const char * dataType , long long * dims , int ndims , int hasdata ) {
293347 DLDataType dtype = RAI_TensorDataTypeFromString (dataType );
294348 return RAI_TensorCreateWithDLDataType (dtype , dims , ndims , TENSORALLOC_ALLOC );
@@ -815,7 +869,14 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
815869 size_t datalen ;
816870 const char * data ;
817871 DLDataType datatype = RAI_TensorDataTypeFromString (typestr );
818- * t = RAI_TensorCreateWithDLDataType (datatype , dims , ndims , tensorAllocMode );
872+ if (datafmt == REDISAI_DATA_BLOB ) {
873+ RedisModuleString * rstr = argv [argpos ];
874+ RedisModule_RetainString (NULL , rstr );
875+ * t = RAI_TensorCreateWithDLDataTypeAndRString (datatype , dims , ndims , rstr );
876+ } else {
877+ * t = RAI_TensorCreateWithDLDataType (datatype , dims , ndims , tensorAllocMode );
878+ }
879+
819880 if (!t ) {
820881 array_free (dims );
821882 if (ctx == NULL ) {
@@ -826,24 +887,7 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
826887 return -1 ;
827888 }
828889 long i = 0 ;
829- switch (datafmt ) {
830- case REDISAI_DATA_BLOB : {
831- const char * blob = RedisModule_StringPtrLen (argv [argpos ], & datalen );
832- if (datalen != nbytes ) {
833- RAI_TensorFree (* t );
834- array_free (dims );
835- if (ctx == NULL ) {
836- RAI_SetError (error , RAI_ETENSORSET ,
837- "ERR data length does not match tensor shape and type" );
838- } else {
839- RedisModule_ReplyWithError (ctx ,
840- "ERR data length does not match tensor shape and type" );
841- }
842- return -1 ;
843- }
844- RAI_TensorSetData (* t , blob , datalen );
845- } break ;
846- case REDISAI_DATA_VALUES :
890+ if (datafmt == REDISAI_DATA_VALUES ) {
847891 for (; (argpos <= argc - 1 ) && (i < len ); argpos ++ ) {
848892 if (datatype .code == kDLFloat ) {
849893 double val ;
@@ -900,10 +944,6 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
900944 }
901945 i ++ ;
902946 }
903- break ;
904- default :
905- // default does not require tensor data setting since calloc setted it to 0
906- break ;
907947 }
908948 array_free (dims );
909949 return argpos ;
0 commit comments