Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions experiment_params/train_config_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ optimization:
epochs: 5
batch_size: 16
input_frames: 5 # Number of frames to feed to the encoder while training
autoencoding: False # Whether to predict the whole input sequence
# Learning rates
encoder_lr: 1.5e-4
transformer_lr: 1.5e-4
Expand All @@ -50,9 +51,9 @@ geco:
alpha: 0.99 # decay of the moving average
tol: 3.3e-2 # per pixel error tolerance. keep in mind this gets squared
initial_lagrange_multiplier: 1.0 # this is 1/beta
lagrange_multiplier_param: 0.1 # adjust update on langrange multiplier
lagrange_multiplier_param: 0.1 # adjust update on lagrange multiplier
# To train in a beta-vae fashion use the following parameters:
# alpha: 0.0
# tol: 0.0
# initial_lagrange_multiplier: 1 / beta
# lagrange_multiplier_param = 1.0
# initial_lagrange_multiplier: 100
# lagrange_multiplier_param: 0.0
19 changes: 13 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,19 @@ def training_step(self, rollouts):
self.optimizer.zero_grad()

rollout_len = rollouts.shape[1]
input_frames = self.params['optimization']['input_frames']
assert(input_frames <= rollout_len) # optimization.use_steps must be smaller (or equal) to rollout.sequence_length
roll = rollouts[:, :input_frames]
autoencoding = self.params['optimization']['autoencoding']
hgn_output, target = None, None

if autoencoding: # We feed the whole sequence and try to fit it
hgn_output = self.hgn.forward(rollout_batch=rollouts, n_steps=rollout_len)
target = rollouts # Target is the full rollout
else:
input_frames = self.params['optimization']['input_frames']
assert(input_frames < rollout_len) # optimization.use_steps must be strictly smaller to rollout.sequence_length
roll = rollouts[:, :input_frames]
hgn_output = self.hgn.forward(rollout_batch=roll, n_steps=rollout_len - input_frames)
target = rollouts[:, input_frames-1:] # Fit first input_frames and try to predict the last + the next (rollout_len - input_frames)

hgn_output = self.hgn.forward(rollout_batch=roll, n_steps=rollout_len - input_frames)
target = rollouts[:, input_frames-1:] # Fit first input_frames and try to predict the last + the next (rollout_len - input_frames)
prediction = hgn_output.reconstructed_rollout

if self.params["networks"]["variational"]:
Expand Down Expand Up @@ -180,7 +187,7 @@ def training_step(self, rollouts):
# clamping the langrange multiplier to avoid inf values
self.langrange_multiplier = self.langrange_multiplier * torch.exp(
lagrange_mult_param * C.detach())
self.langrange_multiplier = torch.clamp(self.langrange_multiplier, 1e-10, 1e10)
self.langrange_multiplier = torch.clamp(self.langrange_multiplier, 1e-10, 1e4)

losses = {
'loss/train': train_loss.item(),
Expand Down