diff --git a/dataloaders/language_model/lm_dataloader.cpp b/dataloaders/language_model/lm_dataloader.cpp index 44cd6ef..5b2ab49 100644 --- a/dataloaders/language_model/lm_dataloader.cpp +++ b/dataloaders/language_model/lm_dataloader.cpp @@ -34,7 +34,7 @@ void LMDataLoader::get_token_ids( int token_ids_size = std::min((int)token_ids.size(), max_token_ids_size); std::cout << "token_ids_size : " << token_ids_size << std::endl; - for (size_t i = 0; i < token_ids_size; ++i) { + for (size_t i = 0; i < token_ids_size && i + num_steps < token_ids_size; ++i) { std::vector src_step_tokens; std::vector tgt_step_tokens; for (size_t j = 0; j < num_steps && (i + j) < token_ids_size - 1; ++j) { diff --git a/lm.cpp b/lm.cpp index e057163..2ed7985 100644 --- a/lm.cpp +++ b/lm.cpp @@ -41,7 +41,8 @@ std::vector trim_or_padding(const std::vector& src, uint max_len, ui std::vector res = src; if (src.size() > max_len) { res.resize(max_len); - } else { + } + else { res.resize(max_len, pad_id); } return res; @@ -98,7 +99,7 @@ int main(int argc, char* argv[]) { int opt; int epochs = 10; - int batch_size = 16; + int batch_size = 4; int gpu = 1; int max_words_cnt = 256; float lr = 0.001f; @@ -246,7 +247,8 @@ int main(int argc, char* argv[]) { auto origin_size = src_token_ids.size(); if (src_token_ids.size() < num_steps) { src_token_ids.resize(num_steps, loader.get_pad_id()); - } else if (src_token_ids.size() > num_steps) { + } + else if (src_token_ids.size() > num_steps) { src_token_ids.erase(src_token_ids.begin(), src_token_ids.end() - num_steps); } auto cur_step = origin_size - 1; @@ -284,7 +286,8 @@ int main(int argc, char* argv[]) { if (cur_step >= num_steps - 1) { src_token_ids.push_back(max_index); src_token_ids.erase(src_token_ids.begin(), src_token_ids.end() - num_steps); - } else { + } + else { src_token_ids[++cur_step] = max_index; } } @@ -292,7 +295,8 @@ int main(int argc, char* argv[]) { std::cout << "-----------------" << std::endl; ::free(res_buffer); } - } else { + } + else { init_dec_valid_lens_for_training(dec_valid_lens); signal(SIGINT, signal_callback_handler); int epoch = 0; diff --git a/makefile b/makefile index 3dc68b0..aa15357 100644 --- a/makefile +++ b/makefile @@ -72,6 +72,8 @@ endif ifeq ($(RELEASE),1) NVCC_CFLAGS += -DNDEBUG NVCC_CFLAGS := $(filter-out -G,$(NVCC_CFLAGS)) +else + NVCC_CFLAGS := $(filter-out -O3,$(NVCC_CFLAGS)) endif ifeq ($(MACOS), 1)