From 082ea689ecb5e6e8c7815b9210b8526205ccd789 Mon Sep 17 00:00:00 2001 From: freelw <“freelw81@qq.com“> Date: Mon, 16 Jun 2025 20:23:51 +0800 Subject: [PATCH 1/3] opt makefile --- makefile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/makefile b/makefile index 3dc68b0..aa15357 100644 --- a/makefile +++ b/makefile @@ -72,6 +72,8 @@ endif ifeq ($(RELEASE),1) NVCC_CFLAGS += -DNDEBUG NVCC_CFLAGS := $(filter-out -G,$(NVCC_CFLAGS)) +else + NVCC_CFLAGS := $(filter-out -O3,$(NVCC_CFLAGS)) endif ifeq ($(MACOS), 1) From 7039c947e7872fcdd3fb91f002c4d983d25488cd Mon Sep 17 00:00:00 2001 From: freelw <“freelw81@qq.com“> Date: Mon, 16 Jun 2025 20:26:12 +0800 Subject: [PATCH 2/3] update --- lm.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/lm.cpp b/lm.cpp index e057163..2ed7985 100644 --- a/lm.cpp +++ b/lm.cpp @@ -41,7 +41,8 @@ std::vector trim_or_padding(const std::vector& src, uint max_len, ui std::vector res = src; if (src.size() > max_len) { res.resize(max_len); - } else { + } + else { res.resize(max_len, pad_id); } return res; @@ -98,7 +99,7 @@ int main(int argc, char* argv[]) { int opt; int epochs = 10; - int batch_size = 16; + int batch_size = 4; int gpu = 1; int max_words_cnt = 256; float lr = 0.001f; @@ -246,7 +247,8 @@ int main(int argc, char* argv[]) { auto origin_size = src_token_ids.size(); if (src_token_ids.size() < num_steps) { src_token_ids.resize(num_steps, loader.get_pad_id()); - } else if (src_token_ids.size() > num_steps) { + } + else if (src_token_ids.size() > num_steps) { src_token_ids.erase(src_token_ids.begin(), src_token_ids.end() - num_steps); } auto cur_step = origin_size - 1; @@ -284,7 +286,8 @@ int main(int argc, char* argv[]) { if (cur_step >= num_steps - 1) { src_token_ids.push_back(max_index); src_token_ids.erase(src_token_ids.begin(), src_token_ids.end() - num_steps); - } else { + } + else { src_token_ids[++cur_step] = max_index; } } @@ -292,7 +295,8 @@ int main(int argc, char* argv[]) { std::cout << "-----------------" << std::endl; ::free(res_buffer); } - } else { + } + else { init_dec_valid_lens_for_training(dec_valid_lens); signal(SIGINT, signal_callback_handler); int epoch = 0; From 3ac86a057cbff2f1c63f67d852d8fd359646d60c Mon Sep 17 00:00:00 2001 From: freelw <“freelw81@qq.com“> Date: Mon, 16 Jun 2025 20:39:48 +0800 Subject: [PATCH 3/3] fix buffer overflow --- dataloaders/language_model/lm_dataloader.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataloaders/language_model/lm_dataloader.cpp b/dataloaders/language_model/lm_dataloader.cpp index 44cd6ef..5b2ab49 100644 --- a/dataloaders/language_model/lm_dataloader.cpp +++ b/dataloaders/language_model/lm_dataloader.cpp @@ -34,7 +34,7 @@ void LMDataLoader::get_token_ids( int token_ids_size = std::min((int)token_ids.size(), max_token_ids_size); std::cout << "token_ids_size : " << token_ids_size << std::endl; - for (size_t i = 0; i < token_ids_size; ++i) { + for (size_t i = 0; i < token_ids_size && i + num_steps < token_ids_size; ++i) { std::vector src_step_tokens; std::vector tgt_step_tokens; for (size_t j = 0; j < num_steps && (i + j) < token_ids_size - 1; ++j) {