-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathcheckpoint.cpp
More file actions
137 lines (128 loc) · 5.08 KB
/
checkpoint.cpp
File metadata and controls
137 lines (128 loc) · 5.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#include "checkpoint.h"
#include "backends/backend_ops.h"
#include <sstream>
#include <fstream>
#include <iomanip>
#include <chrono>
std::string generateDateTimeSuffix() {
auto now = std::chrono::system_clock::now();
std::time_t currentTime = std::chrono::system_clock::to_time_t(now);
struct std::tm* localTime = std::localtime(¤tTime);
std::ostringstream oss;
oss << std::put_time(localTime, "_%Y%m%d_%H%M%S");
return oss.str();
}
void save_checkpoint(const std::string& prefix, int epoch, const std::vector<Parameter*>& parameters) {
std::ostringstream oss;
oss << prefix << "_" << epoch << ".bin";
std::string checkpoint_name = oss.str();
std::string path = "./checkpoints/" + checkpoint_name;
std::ofstream out(path, std::ios::out | std::ios::binary);
int num_params = parameters.size();
out.write((char*)&num_params, sizeof(num_params));
for (auto p : parameters) {
std::string serialized = p->serialize();
int size = serialized.size();
out.write((char*)&size, sizeof(size));
out.write(serialized.c_str(), serialized.size());
}
out.close();
std::cout << "checkpoint saved : " << path << std::endl;
}
void loadfrom_checkpoint(const std::string& filename, std::vector<Parameter*>& parameters) {
std::ifstream in(filename, std::ios::in | std::ios::binary);
// check file exists
if (!in) {
std::cerr << "file not found : " << filename << std::endl;
exit(1);
}
int num_params = 0;
in.read((char*)&num_params, sizeof(num_params));
assert(num_params == parameters.size());
for (int i = 0; i < num_params; i++) {
int size;
in.read((char*)&size, sizeof(size));
assert(size == parameters[i]->get_serialized_size());
char* buffer = static_cast<char*>(::malloc(size));
in.read(buffer, size);
parameters[i]->deserialize(buffer);
::free(buffer);
}
}
void diff_tensor_buffer(Tensor* tensor, char* buffer) {
int size = tensor->size();
int length = tensor->length();
char* tensor_buffer = static_cast<char*>(::malloc(size));
g_backend_ops->cp_from_device(tensor_buffer, tensor, size);
float* tensor_buffer_f = reinterpret_cast<float*>(tensor_buffer);
float* buffer_f = reinterpret_cast<float*>(buffer);
const float eps = 1e-4f;
for (int i = 0; i < length; ++i) {
if (std::abs(tensor_buffer_f[i] - buffer_f[i]) > eps) {
std::cerr << "diff tensor failed at index " << i
<< ", expected: " << tensor_buffer_f[i]
<< ", got: " << buffer_f[i] << std::endl;
std::cerr << "tensor meta : " << tensor->get_meta_info() << std::endl;
break;
}
}
::free(tensor_buffer);
}
void diff_para(Parameter* p, ParameterInfo* info) {
Tensor* w = p->get_w();
Tensor* m = p->get_m();
Tensor* v = p->get_v();
if (w->size() != info->weight_size) {
std::cerr << "weight size mismatch: expected " << info->weight_size
<< ", got " << w->size() << std::endl;
std::cerr << "parameter meta : " << p->get_w()->get_meta_info() << std::endl;
abort();
}
if (m->size() != info->m_size) {
std::cerr << "m size mismatch: expected " << info->m_size
<< ", got " << m->size() << std::endl;
std::cerr << "parameter meta : " << p->get_m()->get_meta_info() << std::endl;
abort();
}
if (v->size() != info->v_size) {
std::cerr << "v size mismatch: expected " << info->v_size
<< ", got " << v->size() << std::endl;
std::cerr << "parameter meta : " << p->get_v()->get_meta_info() << std::endl;
abort();
}
if (p->get_t() != info->t) {
std::cerr << "t mismatch: expected " << info->t
<< ", got " << p->get_t() << std::endl;
std::cerr << "parameter meta : " << p->get_w()->get_meta_info() << std::endl;
abort();
}
diff_tensor_buffer(w, info->weight_start);
diff_tensor_buffer(m, info->m_start);
diff_tensor_buffer(v, info->v_start);
std::cout << "diff parameter success : " << w->get_meta_info() << std::endl;
std::cout << "diff parameter success : " << m->get_meta_info() << std::endl;
std::cout << "diff parameter success : " << v->get_meta_info() << std::endl;
}
void difffrom_checkpoint(const std::string& filename, std::vector<Parameter*>& parameters) {
std::ifstream in(filename, std::ios::in | std::ios::binary);
// check file exists
if (!in) {
std::cerr << "file not found : " << filename << std::endl;
exit(1);
}
int num_params = 0;
in.read((char*)&num_params, sizeof(num_params));
assert(num_params == parameters.size());
bool succ = true;
for (int i = 0; i < num_params; i++) {
int size;
in.read((char*)&size, sizeof(size));
assert(size == parameters[i]->get_serialized_size());
char* buffer = static_cast<char*>(::malloc(size));
in.read(buffer, size);
ParameterInfo info;
parameters[i]->deserialize_info(buffer, &info);
diff_para(parameters[i], &info);
::free(buffer);
}
}