-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrainer.py
More file actions
67 lines (48 loc) · 2.41 KB
/
trainer.py
File metadata and controls
67 lines (48 loc) · 2.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import os
import logging
def train(model, train_loader, valid_loader, criterion, optimizer, start_epoch, num_epochs, device, best_model_weights, best_val_loss):
model.to(device)
logging.info(f"Starting training at epoch {start_epoch} for {num_epochs} epochs.")
for epoch in range(start_epoch, num_epochs):
model.train()
running_loss = 0.0
for data in train_loader:
src_inputs, tgt_inputs = data["src"].to(device), data["tgt"].to(device)
labels = data["label"].to(device)
optimizer.zero_grad()
outputs = model(src_inputs, tgt_inputs)
loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
loss.backward()
optimizer.step()
running_loss += loss.item() * src_inputs.size(0)
avg_train_loss = running_loss / len(train_loader.dataset)
model.eval()
val_loss = 0.0
with torch.no_grad():
for data in valid_loader:
src_inputs, tgt_inputs = data["src"].to(device), data["tgt"].to(device)
labels = data["label"].to(device)
outputs = model(src_inputs, tgt_inputs)
loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
val_loss += loss.item() * src_inputs.size(0)
avg_val_loss = val_loss / len(valid_loader.dataset)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
best_model_weights = model.state_dict().copy()
logging.info("Epoch [{}/{}], Train Loss: {:.4f}, Val Loss: {:.4f}".format(epoch + 1, num_epochs, avg_train_loss, avg_val_loss))
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_train_loss,
'best_model_weights': best_model_weights,
'best_val_loss': best_val_loss
}, "checkpoint.pth")
logging.info("Saved checkpoint at {}".format(epoch + 1))
if best_model_weights:
model.load_state_dict(best_model_weights)
torch.save(model.state_dict(), f'best_model_{best_val_loss:.2f}.pth')
logging.info(f"Best model weights saved with Val Loss: {best_val_loss:.2f}")
os.remove("checkpoint.pth")
logging.info("Training complete.")