@@ -24,54 +24,57 @@ PytorchSerDes::PytorchSerDes() : BaseSerDes(BaseSerDes::Kind::Pytorch) {
2424
2525void PytorchSerDes::setFeature (const std::string &Name, const int Value) {
2626 auto tensor = torch::tensor ({Value}, torch::kInt32 );
27- static_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
27+ reinterpret_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
2828}
2929
3030void PytorchSerDes::setFeature (const std::string &Name, const long Value) {
3131 auto tensor = torch::tensor ({Value}, torch::kInt64 );
32- static_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
32+ reinterpret_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
3333}
3434
3535void PytorchSerDes::setFeature (const std::string &Name, const float Value) {
3636 auto tensor = torch::tensor ({Value}, torch::kFloat32 );
37- static_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
37+ reinterpret_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
3838}
3939
4040void PytorchSerDes::setFeature (const std::string &Name, const double Value) {
4141 auto tensor = torch::tensor ({Value}, torch::kFloat64 );
42- static_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
42+ reinterpret_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
4343}
4444
4545void PytorchSerDes::setFeature (const std::string &Name, const std::string Value) {
4646 std::vector<int8_t > encoded_str (Value.begin (), Value.end ());
4747 auto tensor = torch::tensor (encoded_str, torch::kInt8 );
48- static_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
48+ reinterpret_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
4949}
5050
5151void PytorchSerDes::setFeature (const std::string &Name, const bool Value) {
5252 auto tensor = torch::tensor ({Value}, torch::kBool );
53- static_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
53+ reinterpret_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
5454}
5555
5656void PytorchSerDes::setFeature (const std::string &Name, const std::vector<int > &Value) {
5757 auto tensor = torch::tensor (Value, torch::kInt32 );
58- static_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
58+ reinterpret_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
5959}
6060
6161void PytorchSerDes::setFeature (const std::string &Name, const std::vector<long > &Value) {
6262 auto tensor = torch::tensor (Value, torch::kInt64 );
63- static_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
63+ reinterpret_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
6464}
6565
6666void PytorchSerDes::setFeature (const std::string &Name, const std::vector<float > &Value) {
6767 auto tensor = torch::tensor (Value, torch::kFloat32 );
68+ llvm::errs () << Value.size () << " [Vec Size]\n " ;
6869 tensor = tensor.reshape ({1 , Value.size ()});
69- static_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
70+ llvm::errs () << tensor.sizes ()[1 ] << " [Tensor Size]\n " ;
71+ reinterpret_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
72+ llvm::errs () << reinterpret_cast <TensorVec*>(this ->RequestVoid )->size () << " [In serdes, len of req, (TensorVec)]\n " ;
7073}
7174
7275void PytorchSerDes::setFeature (const std::string &Name, const std::vector<double > &Value) {
7376 auto tensor = torch::tensor (Value, torch::kFloat64 );
74- static_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
77+ reinterpret_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
7578}
7679
7780void PytorchSerDes::setFeature (const std::string &Name, const std::vector<std::string> &Value) {
@@ -81,21 +84,21 @@ void PytorchSerDes::setFeature(const std::string &Name, const std::vector<std::s
8184 flat_vec.push_back (' \0 ' ); // Null-terminate each string
8285 }
8386 auto tensor = torch::tensor (flat_vec, torch::kInt8 );
84- static_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
87+ reinterpret_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
8588}
8689
8790void PytorchSerDes::setFeature (const std::string &Name, const std::vector<bool > &Value) {
8891 std::vector<uint8_t > bool_vec (Value.begin (), Value.end ());
8992 auto tensor = torch::tensor (bool_vec, torch::kUInt8 );
90- static_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
93+ reinterpret_cast <TensorVec*>(this ->RequestVoid )->push_back (tensor.clone ());
9194}
9295
9396// void PytorchSerDes::setRequest(void *Request) {
94- // CompiledModel = static_cast <torch::inductor::AOTIModelContainerRunnerCpu *>(Request);
97+ // CompiledModel = reinterpret_cast <torch::inductor::AOTIModelContainerRunnerCpu *>(Request);
9598// }
9699
97100void PytorchSerDes::cleanDataStructures () {
98- static_cast <TensorVec*>(this ->RequestVoid )->clear (); // Clear the input vector
101+ reinterpret_cast <TensorVec*>(this ->RequestVoid )->clear (); // Clear the input vector
99102}
100103
101104void *PytorchSerDes::deserializeUntyped (void *Data) {
@@ -104,7 +107,7 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
104107 }
105108
106109 // Assume Data is a pointer to a vector of tensors
107- std::vector<torch::Tensor> *serializedTensors = static_cast <TensorVec *>(Data);
110+ std::vector<torch::Tensor> *serializedTensors = reinterpret_cast <TensorVec *>(Data);
108111
109112 if (serializedTensors->empty ()) {
110113 return nullptr ;
@@ -119,9 +122,11 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
119122 return copyTensorToVect<int64_t >(Data);
120123 }
121124 else if (type_vect == torch::kFloat32 ) {
125+ llvm::errs () << " f32 here!\n " ;
122126 return copyTensorToVect<float >(Data);
123127 }
124128 else if (type_vect == torch::kFloat64 ) {
129+ llvm::errs () << " f64 here!\n " ;
125130 return copyTensorToVect<double >(Data);
126131 }
127132 else if (type_vect == torch::kBool ) {
@@ -138,23 +143,37 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
138143
139144void *PytorchSerDes::getSerializedData () {
140145 return this ->ResponseVoid ; // TODO - check
141- // TensorVec serializedData = *static_cast <TensorVec*>(this->ReponseVoid);
146+ // TensorVec serializedData = *reinterpret_cast <TensorVec*>(this->ReponseVoid);
142147
143148 // // Allocate memory for the output and copy the serialized data
144149 // auto *output = new TensorVec(serializedData);
145- // return static_cast <void *>(output);
150+ // return reinterpret_cast <void *>(output);
146151}
147152
148153 template <typename T>
149154 std::vector<T> *PytorchSerDes::copyTensorToVect (void *serializedTensors) {
150155 auto *ret = new std::vector<T>();
151- for (const auto &tensor : *static_cast <TensorVec*>(serializedTensors)) {
156+ for (const auto &tensor : *reinterpret_cast <TensorVec*>(serializedTensors)) {
152157 ret->insert (ret->end (), tensor.data_ptr <T>(), tensor.data_ptr <T>() + tensor.numel ());
153158 }
154159 return ret;
155160 }
156161
157- void *PytorchSerDes::getRequest () { return this ->RequestVoid ; }
162+ // void *PytorchSerDes::getRequest() { llvm::errs() << reinterpret_cast<TensorVec*>(this->RequestVoid)->size() << "[In getrequest, len of req]\n"; return this->RequestVoid; }
163+ void *PytorchSerDes::getRequest () {
164+ // return nullptr;
165+ auto *tensorVecPtr = reinterpret_cast <TensorVec*>(this ->RequestVoid );
166+ llvm::errs () << " Inside get request\n " ;
167+ if (!tensorVecPtr) {
168+ llvm::errs () << " Error: RequestVoid could not be cast to TensorVec*\n " ;
169+ return nullptr ;
170+ }
171+ else {
172+ llvm::errs () << reinterpret_cast <TensorVec*>(this ->RequestVoid )->size () << " [In getrequest, len of req]\n " ;
173+ }
174+
175+ return this ->RequestVoid ;
176+ }
158177void *PytorchSerDes::getResponse () { return this ->ResponseVoid ; }
159178
160179
0 commit comments