Skip to content

Commit ab4383e

Browse files
committed
batching in serdes
1 parent 9e9269e commit ab4383e

4 files changed

Lines changed: 70 additions & 40 deletions

File tree

MLModelRunner/PTModelRunner/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
if(LLVM_MLBRIDGE)
2-
add_llvm_library(PTModelRunnerLib PTModelRunner.cpp)
2+
add_llvm_component_library(PTModelRunnerLib PTModelRunner.cpp)
33
else()
44
add_library(PTModelRunnerLib OBJECT PTModelRunner.cpp)
55
endif(LLVM_MLBRIDGE)

MLModelRunner/PTModelRunner/PTModelRunner.cpp

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,37 +16,45 @@
1616

1717
#include <memory>
1818
#include <vector>
19+
#include <string>
20+
1921

2022
using TensorVec = std::vector<torch::Tensor>;
2123

2224
namespace MLBridge
2325
{
2426

25-
PTModelRunner::PTModelRunner(const std::string &modelPath, llvm::LLVMContext &Ctx)
27+
PTModelRunner::PTModelRunner(const char* modelPath, llvm::LLVMContext &Ctx)
2628
: MLModelRunner(MLModelRunner::Kind::PTAOT, BaseSerDes::Kind::Pytorch, &Ctx)
2729
{
28-
this->SerDes = new PytorchSerDes();
29-
30+
// this->SerDes = new PytorchSerDes();
31+
llvm::errs() << "ModelPathName: " << std::string(modelPath) << "[END]\n";
3032
c10::InferenceMode mode;
31-
this->CompiledModel = new torch::inductor::AOTIModelContainerRunnerCpu(modelPath);
33+
this->CompiledModel = new torch::inductor::AOTIModelContainerRunnerCpu(std::string(modelPath));
3234
}
3335

3436

3537

3638
void *PTModelRunner::evaluateUntyped()
3739
{
40+
SerDes->getRequest();
3841

39-
if ((*static_cast<TensorVec*>(this->SerDes->getRequest())).empty())
42+
if (reinterpret_cast<TensorVec*>(this->SerDes->getRequest())->empty())
4043
{
4144
llvm::errs() << "Input vector is empty.\n";
4245
return nullptr;
4346
}
4447

4548
try
4649
{
47-
48-
std::vector<torch::Tensor> *outputTensors = static_cast<std::vector<torch::Tensor>*>(this->SerDes->getResponse());
50+
TensorVec* outputTensors = static_cast<TensorVec*>(this->SerDes->getResponse());
51+
// 2 torch::Tensor of size 1
52+
torch::Tensor state_ins = torch::ones(1);
53+
torch::Tensor seq_lens = torch::ones(1);
54+
static_cast<TensorVec*>(this->SerDes->getRequest())->push_back(state_ins);
55+
static_cast<TensorVec*>(this->SerDes->getRequest())->push_back(seq_lens);
4956
auto outputs = static_cast<torch::inductor::AOTIModelContainerRunnerCpu*>(this->CompiledModel)->run((*static_cast<TensorVec*>(this->SerDes->getRequest())));
57+
5058
for (auto i = outputs.begin(); i != outputs.end(); ++i)
5159
(*(outputTensors)).push_back(*i);
5260
void *rawData = this->SerDes->deserializeUntyped(outputTensors);
@@ -59,12 +67,15 @@ namespace MLBridge
5967
}
6068
}
6169

62-
template <typename U, typename T, typename... Types>
63-
void PTModelRunner::populateFeatures(const std::pair<U, T> &var1,
64-
const std::pair<U, Types> &...var2)
65-
{
66-
SerDes->setFeature(var1.first, var1.second);
67-
PTModelRunner::populateFeatures(var2...);
68-
}
70+
// template <typename U, typename T, typename... Types>
71+
// void PTModelRunner::populateFeatures(const std::pair<U, T> &var1,
72+
// const std::pair<U, Types> &...var2)
73+
// {
74+
// llvm::errs() << "Inside populate of ptmodelrunner\n";
75+
// SerDes->setFeature(var1.first, var1.second);
76+
// PTModelRunner::populateFeatures(var2...);
77+
// llvm::errs() << reinterpret_cast<TensorVec*>(this->SerDes->getRequest())->size() << "[In Runner after pop, len of req]\n";
78+
79+
// }
6980

7081
} // namespace MLBridge

SerDes/pytorchSerDes/pytorchSerDes.cpp

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,54 +24,57 @@ PytorchSerDes::PytorchSerDes() : BaseSerDes(BaseSerDes::Kind::Pytorch) {
2424

2525
void PytorchSerDes::setFeature(const std::string &Name, const int Value) {
2626
auto tensor = torch::tensor({Value}, torch::kInt32);
27-
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
27+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
2828
}
2929

3030
void PytorchSerDes::setFeature(const std::string &Name, const long Value) {
3131
auto tensor = torch::tensor({Value}, torch::kInt64);
32-
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
32+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
3333
}
3434

3535
void PytorchSerDes::setFeature(const std::string &Name, const float Value) {
3636
auto tensor = torch::tensor({Value}, torch::kFloat32);
37-
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
37+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
3838
}
3939

4040
void PytorchSerDes::setFeature(const std::string &Name, const double Value) {
4141
auto tensor = torch::tensor({Value}, torch::kFloat64);
42-
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
42+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
4343
}
4444

4545
void PytorchSerDes::setFeature(const std::string &Name, const std::string Value) {
4646
std::vector<int8_t> encoded_str(Value.begin(), Value.end());
4747
auto tensor = torch::tensor(encoded_str, torch::kInt8);
48-
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
48+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
4949
}
5050

5151
void PytorchSerDes::setFeature(const std::string &Name, const bool Value) {
5252
auto tensor = torch::tensor({Value}, torch::kBool);
53-
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
53+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
5454
}
5555

5656
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<int> &Value) {
5757
auto tensor = torch::tensor(Value, torch::kInt32);
58-
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
58+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
5959
}
6060

6161
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<long> &Value) {
6262
auto tensor = torch::tensor(Value, torch::kInt64);
63-
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
63+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
6464
}
6565

6666
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<float> &Value) {
6767
auto tensor = torch::tensor(Value, torch::kFloat32);
68+
llvm::errs() << Value.size() << "[Vec Size]\n";
6869
tensor = tensor.reshape({1, Value.size()});
69-
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
70+
llvm::errs() << tensor.sizes()[1] << "[Tensor Size]\n";
71+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
72+
llvm::errs() << reinterpret_cast<TensorVec*>(this->RequestVoid)->size() << "[In serdes, len of req, (TensorVec)]\n";
7073
}
7174

7275
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<double> &Value) {
7376
auto tensor = torch::tensor(Value, torch::kFloat64);
74-
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
77+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
7578
}
7679

7780
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<std::string> &Value) {
@@ -81,21 +84,21 @@ void PytorchSerDes::setFeature(const std::string &Name, const std::vector<std::s
8184
flat_vec.push_back('\0'); // Null-terminate each string
8285
}
8386
auto tensor = torch::tensor(flat_vec, torch::kInt8);
84-
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
87+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
8588
}
8689

8790
void PytorchSerDes::setFeature(const std::string &Name, const std::vector<bool> &Value) {
8891
std::vector<uint8_t> bool_vec(Value.begin(), Value.end());
8992
auto tensor = torch::tensor(bool_vec, torch::kUInt8);
90-
static_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
93+
reinterpret_cast<TensorVec*>(this->RequestVoid)->push_back(tensor.clone());
9194
}
9295

9396
// void PytorchSerDes::setRequest(void *Request) {
94-
// CompiledModel = static_cast<torch::inductor::AOTIModelContainerRunnerCpu *>(Request);
97+
// CompiledModel = reinterpret_cast<torch::inductor::AOTIModelContainerRunnerCpu *>(Request);
9598
// }
9699

97100
void PytorchSerDes::cleanDataStructures() {
98-
static_cast<TensorVec*>(this->RequestVoid)->clear(); // Clear the input vector
101+
reinterpret_cast<TensorVec*>(this->RequestVoid)->clear(); // Clear the input vector
99102
}
100103

101104
void *PytorchSerDes::deserializeUntyped(void *Data) {
@@ -104,7 +107,7 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
104107
}
105108

106109
// Assume Data is a pointer to a vector of tensors
107-
std::vector<torch::Tensor> *serializedTensors = static_cast<TensorVec *>(Data);
110+
std::vector<torch::Tensor> *serializedTensors = reinterpret_cast<TensorVec *>(Data);
108111

109112
if (serializedTensors->empty()) {
110113
return nullptr;
@@ -119,9 +122,11 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
119122
return copyTensorToVect<int64_t>(Data);
120123
}
121124
else if (type_vect == torch::kFloat32) {
125+
llvm::errs() << "f32 here!\n";
122126
return copyTensorToVect<float>(Data);
123127
}
124128
else if (type_vect == torch::kFloat64) {
129+
llvm::errs() << "f64 here!\n";
125130
return copyTensorToVect<double>(Data);
126131
}
127132
else if (type_vect == torch::kBool) {
@@ -138,23 +143,37 @@ void *PytorchSerDes::deserializeUntyped(void *Data) {
138143

139144
void *PytorchSerDes::getSerializedData() {
140145
return this->ResponseVoid; // TODO - check
141-
// TensorVec serializedData = *static_cast<TensorVec*>(this->ReponseVoid);
146+
// TensorVec serializedData = *reinterpret_cast<TensorVec*>(this->ReponseVoid);
142147

143148
// // Allocate memory for the output and copy the serialized data
144149
// auto *output = new TensorVec(serializedData);
145-
// return static_cast<void *>(output);
150+
// return reinterpret_cast<void *>(output);
146151
}
147152

148153
template <typename T>
149154
std::vector<T> *PytorchSerDes::copyTensorToVect(void *serializedTensors) {
150155
auto *ret = new std::vector<T>();
151-
for (const auto &tensor : *static_cast<TensorVec*>(serializedTensors)) {
156+
for (const auto &tensor : *reinterpret_cast<TensorVec*>(serializedTensors)) {
152157
ret->insert(ret->end(), tensor.data_ptr<T>(), tensor.data_ptr<T>() + tensor.numel());
153158
}
154159
return ret;
155160
}
156161

157-
void *PytorchSerDes::getRequest() { return this->RequestVoid; }
162+
// void *PytorchSerDes::getRequest() { llvm::errs() << reinterpret_cast<TensorVec*>(this->RequestVoid)->size() << "[In getrequest, len of req]\n"; return this->RequestVoid; }
163+
void *PytorchSerDes::getRequest() {
164+
// return nullptr;
165+
auto *tensorVecPtr = reinterpret_cast<TensorVec*>(this->RequestVoid);
166+
llvm::errs() << "Inside get request\n";
167+
if (!tensorVecPtr) {
168+
llvm::errs() << "Error: RequestVoid could not be cast to TensorVec*\n";
169+
return nullptr;
170+
}
171+
else {
172+
llvm::errs() << reinterpret_cast<TensorVec*>(this->RequestVoid)->size() << "[In getrequest, len of req]\n";
173+
}
174+
175+
return this->RequestVoid;
176+
}
158177
void *PytorchSerDes::getResponse() { return this->ResponseVoid; }
159178

160179

include/MLModelRunner/PTModelRunner.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace MLBridge
2121
{
2222
public:
2323
// New constructor that takes the model path as an input
24-
PTModelRunner(const std::string &modelPath, llvm::LLVMContext &Ctx);
24+
PTModelRunner(const char* modelPath, llvm::LLVMContext &Ctx);
2525
// {
2626
// this->SerDes = new PytorchSerDes();
2727

@@ -41,14 +41,14 @@ namespace MLBridge
4141
return R->getKind() == MLModelRunner::Kind::PTAOT;
4242
}
4343

44-
template <typename U, typename T, typename... Types>
45-
void populateFeatures(const std::pair<U, T> &var1,
46-
const std::pair<U, Types> &...var2);
44+
// template <typename U, typename T, typename... Types>
45+
// void populateFeatures(const std::pair<U, T> &var1,
46+
// const std::pair<U, Types> &...var2);
4747

48-
void populateFeatures() {}
48+
// void populateFeatures() {}
4949

5050
void *evaluateUntyped() override;
51-
PytorchSerDes *SerDes;
51+
// PytorchSerDes *SerDes;
5252
// Compiled model container added to the PTModelRunner
5353
private:
5454
void *CompiledModel;

0 commit comments

Comments
 (0)