-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathpy_train.h
More file actions
99 lines (84 loc) · 4.11 KB
/
py_train.h
File metadata and controls
99 lines (84 loc) · 4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
// Copyright (c) 2025, IST Austria, developed by Erik Schultheis
// SPDX-License-Identifier: Apache-2.0
//
#ifndef LLMQ_SRC_BINDING_PY_TRAIN_H
#define LLMQ_SRC_BINDING_PY_TRAIN_H
#include <string>
#include <utility>
#include <thread>
#include <functional>
#include "../training/transformer_config.h"
#include "models/llama_model.h"
class DataLoader;
class IGPUUtilTracker;
struct GPUUtilInfo;
struct sSegmentMemory;
class CommunicatorThreadsPack;
//! \brief A multi-GPU trainer wrapper to be used for python bindings
//! \details When wrapping llm.q for python, the main source of difficulty is handling
//! multi-GPU support. The cpp version supports both multi-process and multi-thread, with
//! multi-thread being the more interesting (due to cudaMemcpy) option.
//! However, mapping multi-threading to python is problematic due to GIL (maybe that will be better once
//! free-threaded python is widely used); hence, this wrapper is used to hide all worker threads
//! from the python interface.
//!
//! Internally, we start up one thread per GPU, and keep track of its training state (`sThreadContext`).
//! Each interface function wraps the desired model call into a std::function that gets sent to the thread
//! context. Each thread runs an infinite loop, and picks up the work it has been sent. Interface functions
//! only return once the work is done. If the work function does not synchronize with the GPU, "done" in this
//! case means that the CPU execution has finished, but the GPU might still be busy. This allows overlap of
//! python execution with GPU execution.
//!
//! As a consequence of this implementation strategy, data loading in python will be slightly different than in the
//! cpp implementation. For cpp, each thread has its own DataLoader, providing `B*T` tokens each step. For python,
//! we have only one interface-visible thread, which gets `nGPU*B*T` tokens per step, and splits them into `B*T`-sized
//! chunks for each GPU.
class MultiGPUPyTrainer
{
public:
MultiGPUPyTrainer(int ngpus, TransformerConfig config, LLamaOptions options, int batch_size, int seq_len, int grad_accum, bool memcpy_all_gather, bool memcpy_send_recv);
~MultiGPUPyTrainer();
void import_weights(std::string path);
void export_model(std::string path);
void init_weights();
void load_checkpoint(std::string directory, int step);
void save_checkpoint(std::string directory, int step);
void step(const std::int32_t* inputs, const std::int32_t* targets, float z_loss);
std::pair<float, float> validate(const std::int32_t* inputs, const std::int32_t* targets);
std::tuple<float, float, float, float, float> update(float lr, float beta1, float beta2, int step, float weight_decay, float grad_clip);
void stop();
std::vector<GPUUtilInfo> get_gpu_info();
int world_size() const;
int batch_size() const { return B; }
int seq_length() const { return T; }
const TransformerConfig& config() const { return mConfig; }
const LLamaOptions& options() const { return mOptions; }
std::vector<std::pair<std::string, sSegmentMemory>> get_allocations(int gpu_id);
std::vector<std::pair<std::string, long>> get_stack_info(int gpu_id);
std::vector<std::pair<std::string, Tensor>> get_gradients(int gpu_id);
private:
TransformerConfig mConfig;
LLamaOptions mOptions;
int B;
int T;
int mTrainMicroStep = 0;
int mEvalStep = 0;
int mGradAccumulation = 1;
std::unique_ptr<CommunicatorThreadsPack> mThreads;
struct sThreadContext {
NCCLCommunicator* Communicator;
std::unique_ptr<LLamaModel> Model;
std::unique_ptr<IGPUUtilTracker> GPUUtil;
std::function<void(sThreadContext& ctx)> Work;
};
std::vector<sThreadContext> mContexts;
std::mutex mGlobalMutex;
std::atomic<bool> mIsRunning = false;
std::atomic<bool> mHasCrashed = false;
std::atomic<int> mIsReady = 0;
std::atomic<int> mWorkDone = 0;
std::function<void(sThreadContext& ctx)> fetch_work(sThreadContext& ctx);
void run_work(std::function<void(sThreadContext& ctx)> work, int idx=-1);
void main_loop(NCCLCommunicator& comm);
};
#endif //LLMQ_SRC_BINDING_PY_TRAIN_H