forked from tobegit3hub/tensorflow_template_application
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsparse_predict_client.cc
More file actions
173 lines (142 loc) · 6.27 KB
/
sparse_predict_client.cc
File metadata and controls
173 lines (142 loc) · 6.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
/* A c++ version of sparse_predict_client
* Build it like inception_client.cc
=======================================================*/
#include <iostream>
#include <fstream>
#include <grpc++/create_channel.h>
#include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/util/command_line_flags.h"
using grpc::Channel;
using grpc::ClientContext;
using grpc::ClientReader;
using grpc::ClientReaderWriter;
using grpc::ClientWriter;
using grpc::Status;
using tensorflow::serving::PredictRequest;
using tensorflow::serving::PredictResponse;
using tensorflow::serving::PredictionService;
typedef google::protobuf::Map< std::string, tensorflow::TensorProto > OutMap;
class ServingClient {
public:
ServingClient(std::shared_ptr<Channel> channel)
: stub_(PredictionService::NewStub(channel)) {
}
std::string callPredict(std::string model_name) {
PredictRequest predictRequest;
PredictResponse response;
ClientContext context;
predictRequest.mutable_model_spec()->set_name(model_name);
google::protobuf::Map< std::string, tensorflow::TensorProto >& inputs =
*predictRequest.mutable_inputs();
// Example libSVM data:
// 0 5:1 6:1 17:1 21:1 35:1 40:1 53:1 63:1 71:1 73:1 74:1 76:1 80:1 83:1
// 1 5:1 7:1 17:1 22:1 36:1 40:1 51:1 63:1 67:1 73:1 74:1 76:1 81:1 83:1
// Generate keys proto
tensorflow::TensorProto keys_tensor_proto;
keys_tensor_proto.set_dtype(tensorflow::DataType::DT_INT32);
keys_tensor_proto.add_int_val(1);
keys_tensor_proto.add_int_val(2);
keys_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
inputs["keys"] = keys_tensor_proto;
// Generate indexs TensorProto
tensorflow::TensorProto indexs_tensor_proto;
indexs_tensor_proto.set_dtype(tensorflow::DataType::DT_INT64);
long indexs[28][2] = { {0, 0}, {0, 1}, {0, 2}, {0, 3}, {0, 4}, {0, 5},
{0, 6}, {0, 7}, {0, 8}, {0, 9}, {0, 10}, {0, 11},
{0, 12}, {0, 13}, {1, 0}, {1, 1}, {1, 2}, {1, 3},
{1, 4}, {1, 5}, {1, 6}, {1, 7}, {1, 8}, {1, 9},
{1, 10}, {1, 11}, {1, 12}, {1, 13} };
for (int i = 0; i < 28; i++) {
for (int j = 0; j < 2; j++) {
indexs_tensor_proto.add_int64_val(indexs[i][j]);
}
}
indexs_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(28);
indexs_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
inputs["indexs"] = indexs_tensor_proto;
std::cout << "Generate indexs tensorproto ok." << std::endl;
// Generate ids TensorProto
tensorflow::TensorProto ids_tensor_proto;
ids_tensor_proto.set_dtype(tensorflow::DataType::DT_INT64);
int ids[28] = {5, 6, 17, 21, 35, 40, 53, 63, 71, 73, 74, 76, 80, 83, 5,
7, 17, 22, 36, 40, 51, 63, 67, 73, 74, 76, 81, 83};
for (int i = 0; i < 28; i++) {
ids_tensor_proto.add_int64_val(ids[i]);
}
ids_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(28);
inputs["ids"] = ids_tensor_proto;
std::cout << "Generate ids tensorproto ok." << std::endl;
// Generate values TensorProto
tensorflow::TensorProto values_tensor_proto;
values_tensor_proto.set_dtype(tensorflow::DataType::DT_FLOAT);
float values[] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
for (int i = 0; i < 28; i++) {
values_tensor_proto.add_float_val(values[i]);
}
values_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(28);
inputs["values"] = values_tensor_proto;
std::cout << "Generate values tensorproto ok." << std::endl;
// Generate shape TensorProto
tensorflow::TensorProto shape_tensor_proto;
shape_tensor_proto.set_dtype(tensorflow::DataType::DT_INT64);
shape_tensor_proto.add_int64_val(2); // ins num
shape_tensor_proto.add_int64_val(124); // feature num
shape_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
inputs["shape"] = shape_tensor_proto;
std::cout << "Generate shape tensorproto ok." << std::endl;
Status status = stub_->Predict(&context, predictRequest, &response);
std::cout << "check status.." << std::endl;
if (status.ok()) {
std::cout << "call predict ok" << std::endl;
std::cout << "outputs size is "<< response.outputs_size() << std::endl;
OutMap& map_outputs = *response.mutable_outputs();
OutMap::iterator iter;
int output_index = 0;
for(iter = map_outputs.begin();iter != map_outputs.end(); ++iter){
tensorflow::TensorProto& result_tensor_proto= iter->second;
tensorflow::Tensor tensor;
bool converted = tensor.FromProto(result_tensor_proto);
if (converted) {
std::cout << "the " <<iter->first <<" result tensor[" << output_index << "] is:" <<
std::endl << tensor.SummarizeValue(13) << std::endl;
}else {
std::cout << "the " <<iter->first <<" result tensor[" << output_index <<
"] convert failed." << std::endl;
}
++output_index;
}
return "Done.";
} else {
std::cout << "gRPC call return code: "
<<status.error_code() << ": " << status.error_message()
<< std::endl;
return "gRPC failed.";
}
}
private:
std::unique_ptr<PredictionService::Stub> stub_;
};
int main(int argc, char** argv) {
std::string server_port = "localhost:9000";
std::string model_name = "sparse";
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("server_port", &server_port,
"the IP and port of the server"),
tensorflow::Flag("model_name", &model_name, "name of model")
};
std::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
std::cout << usage;
return -1;
}
ServingClient guide(
grpc::CreateChannel( server_port,
grpc::InsecureChannelCredentials()));
std::cout << "Calling sparse predictor..." << std::endl;
std::cout << guide.callPredict(model_name) << std::endl;
return 0;
}