Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/gpu/metal/metal_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1467,6 +1467,10 @@ void MetalOps::commit() {

void MetalOps::wait() {
commandBuffer->waitUntilCompleted();
auto error = commandBuffer->error();
if (error) {
std::cerr << "Error: " << error->localizedDescription()->utf8String() << std::endl;
}
commandBuffer->release();
}

Expand Down
84 changes: 84 additions & 0 deletions checkpoint.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "checkpoint.h"
#include "backends/backend_ops.h"

#include <sstream>
#include <fstream>
Expand Down Expand Up @@ -51,4 +52,87 @@ void loadfrom_checkpoint(const std::string& filename, std::vector<Parameter*>& p
parameters[i]->deserialize(buffer);
::free(buffer);
}
}

void diff_tensor_buffer(Tensor* tensor, char* buffer) {
int size = tensor->size();
int length = tensor->length();
char* tensor_buffer = static_cast<char*>(::malloc(size));
g_backend_ops->cp_from_device(tensor_buffer, tensor, size);
float* tensor_buffer_f = reinterpret_cast<float*>(tensor_buffer);
float* buffer_f = reinterpret_cast<float*>(buffer);
const float eps = 1e-4f;

for (int i = 0; i < length; ++i) {
if (std::abs(tensor_buffer_f[i] - buffer_f[i]) > eps) {
std::cerr << "diff tensor failed at index " << i
<< ", expected: " << tensor_buffer_f[i]
<< ", got: " << buffer_f[i] << std::endl;
std::cerr << "tensor meta : " << tensor->get_meta_info() << std::endl;
break;
// abort();
}
}
::free(tensor_buffer);
}

void diff_para(Parameter* p, ParameterInfo* info) {
Tensor* w = p->get_w();
Tensor* m = p->get_m();
Tensor* v = p->get_v();
if (w->size() != info->weight_size) {
std::cerr << "weight size mismatch: expected " << info->weight_size
<< ", got " << w->size() << std::endl;
std::cerr << "parameter meta : " << p->get_w()->get_meta_info() << std::endl;
abort();
}
if (m->size() != info->m_size) {
std::cerr << "m size mismatch: expected " << info->m_size
<< ", got " << m->size() << std::endl;
std::cerr << "parameter meta : " << p->get_m()->get_meta_info() << std::endl;
abort();
}
if (v->size() != info->v_size) {
std::cerr << "v size mismatch: expected " << info->v_size
<< ", got " << v->size() << std::endl;
std::cerr << "parameter meta : " << p->get_v()->get_meta_info() << std::endl;
abort();
}
if (p->get_t() != info->t) {
std::cerr << "t mismatch: expected " << info->t
<< ", got " << p->get_t() << std::endl;
std::cerr << "parameter meta : " << p->get_w()->get_meta_info() << std::endl;
abort();
}

diff_tensor_buffer(w, info->weight_start);
diff_tensor_buffer(m, info->m_start);
diff_tensor_buffer(v, info->v_start);
std::cout << "diff parameter success : " << w->get_meta_info() << std::endl;
std::cout << "diff parameter success : " << m->get_meta_info() << std::endl;
std::cout << "diff parameter success : " << v->get_meta_info() << std::endl;
}

void difffrom_checkpoint(const std::string& filename, std::vector<Parameter*>& parameters) {
std::ifstream in(filename, std::ios::in | std::ios::binary);
// check file exists
if (!in) {
std::cerr << "file not found : " << filename << std::endl;
exit(1);
}
int num_params = 0;
in.read((char*)&num_params, sizeof(num_params));
assert(num_params == parameters.size());
bool succ = true;
for (int i = 0; i < num_params; i++) {
int size;
in.read((char*)&size, sizeof(size));
assert(size == parameters[i]->get_serialized_size());
char* buffer = static_cast<char*>(::malloc(size));
in.read(buffer, size);
ParameterInfo info;
parameters[i]->deserialize_info(buffer, &info);
diff_para(parameters[i], &info);
::free(buffer);
}
}
1 change: 1 addition & 0 deletions checkpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
std::string generateDateTimeSuffix();
void save_checkpoint(const std::string& prefix, int epoch, const std::vector<Parameter*>& parameters);
void loadfrom_checkpoint(const std::string& filename, std::vector<Parameter*>& parameters);
void difffrom_checkpoint(const std::string& filename, std::vector<Parameter*>& parameters);

#endif
36 changes: 25 additions & 11 deletions lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ std::vector<uint> trim_or_padding(const std::vector<uint>& src, uint max_len, ui
std::vector<uint> res = src;
if (src.size() > max_len) {
res.resize(max_len);
}
else {
} else {
res.resize(max_len, pad_id);
}
return res;
Expand Down Expand Up @@ -104,10 +103,12 @@ int main(int argc, char* argv[]) {
int max_words_cnt = 256;
float lr = 0.001f;
int lm_predict_cnt = LM_PREDICT_CNT;
float dropout = 0.2f;
std::string checkpoint;
std::string checkpoint_diff_tgt;
std::string corpus = TIMEMACHINE_RESOURCE_NAME;

while ((opt = getopt(argc, argv, "f:c:e:l:b:g:m:p:")) != -1) {
while ((opt = getopt(argc, argv, "f:c:e:d:l:b:g:m:p:k:h")) != -1) {
switch (opt) {
case 'f':
corpus = optarg;
Expand All @@ -118,6 +119,9 @@ int main(int argc, char* argv[]) {
case 'e':
epochs = atoi(optarg);
break;
case 'd':
dropout = atof(optarg);
break;
case 'l':
lr = atof(optarg);
break;
Expand All @@ -133,24 +137,29 @@ int main(int argc, char* argv[]) {
case 'p':
lm_predict_cnt = atoi(optarg);
break;
case 'k':
checkpoint_diff_tgt = optarg;
break;
case 'h':
default:
std::cerr << "Usage: " << argv[0]
<< " -f <corpus> -c <checpoint> -e <epochs>" << std::endl;
<< " -f <corpus> -c <checpoint> -e <epochs> -d <dropout> -l <lr> -b <batch_size> -g <gpu> -m <max_words_cnt> -p <lm_predict_cnt>" << std::endl;
return 1;
}
}

std::cout << "corpus : " << corpus << std::endl;
std::cout << "epochs : " << epochs << std::endl;
std::cout << "batch_size : " << batch_size << std::endl;
std::cout << "dropout : " << dropout << std::endl;
std::cout << "gpu : " << gpu << std::endl;
std::cout << "learning rate : " << lr << std::endl;
std::cout << "checkpoint : " << checkpoint << std::endl;
std::cout << "max_words_cnt : " << max_words_cnt << std::endl;

int num_hiddens = 256;
int num_blks = 2;
float dropout = 0.2f;

int ffn_num_hiddens = 64;
int num_heads = 4;
int num_steps = LM_NUM_STEPS;
Expand Down Expand Up @@ -231,6 +240,14 @@ int main(int argc, char* argv[]) {
std::cout << "loaded from checkpoint" << std::endl;
}

if (!checkpoint_diff_tgt.empty()) {
std::cout << "diff mode start." << std::endl;
std::cout << "checkpoint 0 is : " << checkpoint << std::endl;
std::cout << "checkpoint 1 is : " << checkpoint_diff_tgt << std::endl;
difffrom_checkpoint(checkpoint_diff_tgt, parameters);
return 0;
}

if (predicting) {
assert(batch_size == 1);
std::cout << "serving mode" << std::endl;
Expand All @@ -247,8 +264,7 @@ int main(int argc, char* argv[]) {
auto origin_size = src_token_ids.size();
if (src_token_ids.size() < num_steps) {
src_token_ids.resize(num_steps, loader.get_pad_id());
}
else if (src_token_ids.size() > num_steps) {
} else if (src_token_ids.size() > num_steps) {
src_token_ids.erase(src_token_ids.begin(), src_token_ids.end() - num_steps);
}
auto cur_step = origin_size - 1;
Expand Down Expand Up @@ -286,17 +302,15 @@ int main(int argc, char* argv[]) {
if (cur_step >= num_steps - 1) {
src_token_ids.push_back(max_index);
src_token_ids.erase(src_token_ids.begin(), src_token_ids.end() - num_steps);
}
else {
} else {
src_token_ids[++cur_step] = max_index;
}
}
std::cout << std::endl;
std::cout << "-----------------" << std::endl;
::free(res_buffer);
}
}
else {
} else {
init_dec_valid_lens_for_training(dec_valid_lens);
signal(SIGINT, signal_callback_handler);
int epoch = 0;
Expand Down
73 changes: 52 additions & 21 deletions optimizers/parameter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,18 @@ bool Parameter::is_require_grad() {

std::string Parameter::serialize() {
int weight_size = get_w()->size();
int grad_size = get_grad()->size();
// int grad_size = get_grad()->size();
int m_size = m->size();
int v_size = v->size();

int tot_size = 0;
tot_size += sizeof(weight_size);
tot_size += sizeof(grad_size);
// tot_size += sizeof(grad_size);
tot_size += sizeof(m_size);
tot_size += sizeof(v_size);
tot_size += sizeof(t);
tot_size += weight_size;
tot_size += grad_size;
// tot_size += grad_size;
tot_size += m_size;
tot_size += v_size;

Expand All @@ -54,8 +54,8 @@ std::string Parameter::serialize() {

::memcpy(buffer + offset, &weight_size, sizeof(weight_size));
offset += sizeof(weight_size);
::memcpy(buffer + offset, &grad_size, sizeof(grad_size));
offset += sizeof(grad_size);
// ::memcpy(buffer + offset, &grad_size, sizeof(grad_size));
// offset += sizeof(grad_size);
::memcpy(buffer + offset, &m_size, sizeof(m_size));
offset += sizeof(m_size);
::memcpy(buffer + offset, &v_size, sizeof(v_size));
Expand All @@ -69,12 +69,12 @@ std::string Parameter::serialize() {
weight_size
);
offset += weight_size;
g_backend_ops->cp_from_device(
buffer + offset,
get_grad(),
grad_size
);
offset += grad_size;
// g_backend_ops->cp_from_device(
// buffer + offset,
// get_grad(),
// grad_size
// );
// offset += grad_size;
g_backend_ops->cp_from_device(
buffer + offset,
m,
Expand All @@ -94,13 +94,13 @@ std::string Parameter::serialize() {
}

void Parameter::deserialize(char* buffer) {
int weight_size, grad_size, m_size, v_size;
int weight_size, m_size, v_size;
int offset = 0;

::memcpy(&weight_size, buffer + offset, sizeof(weight_size));
offset += sizeof(weight_size);
::memcpy(&grad_size, buffer + offset, sizeof(grad_size));
offset += sizeof(grad_size);
// ::memcpy(&grad_size, buffer + offset, sizeof(grad_size));
// offset += sizeof(grad_size);
::memcpy(&m_size, buffer + offset, sizeof(m_size));
offset += sizeof(m_size);
::memcpy(&v_size, buffer + offset, sizeof(v_size));
Expand All @@ -109,7 +109,7 @@ void Parameter::deserialize(char* buffer) {
offset += sizeof(t);

assert(weight_size == get_w()->size());
assert(grad_size == get_grad()->size());
// assert(grad_size == get_grad()->size());
assert(m_size == m->size());
assert(v_size == v->size());

Expand All @@ -119,12 +119,12 @@ void Parameter::deserialize(char* buffer) {
weight_size
);
offset += weight_size;
g_backend_ops->cp_to_device(
get_grad(),
buffer + offset,
grad_size
);
offset += grad_size;
// g_backend_ops->cp_to_device(
// get_grad(),
// buffer + offset,
// grad_size
// );
// offset += grad_size;
g_backend_ops->cp_to_device(
m,
buffer + offset,
Expand All @@ -140,6 +140,37 @@ void Parameter::deserialize(char* buffer) {
assert(offset == get_serialized_size());
}

void Parameter::deserialize_info(char* buffer, ParameterInfo* info) {
int weight_size, m_size, v_size, _t;
int offset = 0;

::memcpy(&weight_size, buffer + offset, sizeof(weight_size));
offset += sizeof(weight_size);
::memcpy(&m_size, buffer + offset, sizeof(m_size));
offset += sizeof(m_size);
::memcpy(&v_size, buffer + offset, sizeof(v_size));
offset += sizeof(v_size);
::memcpy(&_t, buffer + offset, sizeof(_t));
offset += sizeof(_t);

assert(weight_size == get_w()->size());
assert(m_size == m->size());
assert(v_size == v->size());

info->weight_size = weight_size;
info->m_size = m_size;
info->v_size = v_size;
info->t = _t;

info->weight_start = buffer + offset;
offset += weight_size;
info->m_start = buffer + offset;
offset += m_size;
info->v_start = buffer + offset;
offset += v_size;
assert(offset == get_serialized_size());
}

int Parameter::get_serialized_size() {
return sizeof(int) * 4 + sizeof(t) +
get_w()->size() + get_grad()->size() + m->size() + v->size();
Expand Down
11 changes: 11 additions & 0 deletions optimizers/parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@

#include "graph/node.h"

struct ParameterInfo {
int weight_size;
int m_size;
int v_size;
int t;
char* weight_start;
char* m_start;
char* v_start;
};

class Parameter {
public:
Parameter(graph::Node* _node);
Expand All @@ -19,6 +29,7 @@ class Parameter {
}
std::string serialize();
void deserialize(char* buffer);
void deserialize_info(char* buffer, ParameterInfo* info);
int get_serialized_size();

private:
Expand Down