diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index c5a68e0f93..c3c0149f0d 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -1,7 +1,7 @@ [base] package = ocean env_name = puffer_drive -policy_name = Drive +policy_name = MultiDiscreteDriveMLP rnn_name = Recurrent [vec] @@ -103,8 +103,8 @@ reward_bound_offroad_max = -0.1 reward_bound_comfort_min = -0.1 reward_bound_comfort_max = 0.0 -reward_bound_lane_align_min = 0.0020 -reward_bound_lane_align_max = 0.0025 +reward_bound_lane_align_min = 0.00020 +reward_bound_lane_align_max = 0.00025 reward_bound_lane_center_min = -0.00075 reward_bound_lane_center_max = -0.00065 diff --git a/pufferlib/models.py b/pufferlib/models.py deleted file mode 100644 index 0893a9db47..0000000000 --- a/pufferlib/models.py +++ /dev/null @@ -1,339 +0,0 @@ -import numpy as np - -import torch -import torch.nn as nn - -import pufferlib.emulation -import pufferlib.pytorch -import pufferlib.spaces - - -class Default(nn.Module): - """Default PyTorch policy. Flattens obs and applies a linear layer. - - PufferLib is not a framework. It does not enforce a base class. - You can use any PyTorch policy that returns actions and values. - We structure our forward methods as encode_observations and decode_actions - to make it easier to wrap policies with LSTMs. You can do that and use - our LSTM wrapper or implement your own. To port an existing policy - for use with our LSTM wrapper, simply put everything from forward() before - the recurrent cell into encode_observations and put everything after - into decode_actions. - """ - - def __init__(self, env, hidden_size=128): - super().__init__() - self.hidden_size = hidden_size - self.is_multidiscrete = isinstance(env.single_action_space, pufferlib.spaces.MultiDiscrete) - self.is_continuous = isinstance(env.single_action_space, pufferlib.spaces.Box) - try: - self.is_dict_obs = isinstance(env.env.observation_space, pufferlib.spaces.Dict) - except: - self.is_dict_obs = isinstance(env.observation_space, pufferlib.spaces.Dict) - - if self.is_dict_obs: - self.dtype = pufferlib.pytorch.nativize_dtype(env.emulated) - input_size = int(sum(np.prod(v.shape) for v in env.env.observation_space.values())) - self.encoder = nn.Linear(input_size, self.hidden_size) - else: - num_obs = np.prod(env.single_observation_space.shape) - self.encoder = torch.nn.Sequential( - pufferlib.pytorch.layer_init(nn.Linear(num_obs, hidden_size)), - nn.GELU(), - ) - - if self.is_multidiscrete: - self.action_nvec = tuple(env.single_action_space.nvec) - num_atns = sum(self.action_nvec) - self.decoder = pufferlib.pytorch.layer_init(nn.Linear(hidden_size, num_atns), std=0.01) - elif not self.is_continuous: - num_atns = env.single_action_space.n - self.decoder = pufferlib.pytorch.layer_init(nn.Linear(hidden_size, num_atns), std=0.01) - else: - self.decoder_mean = pufferlib.pytorch.layer_init( - nn.Linear(hidden_size, env.single_action_space.shape[0]), std=0.01 - ) - self.decoder_logstd = nn.Parameter(torch.zeros(1, env.single_action_space.shape[0])) - - self.value = pufferlib.pytorch.layer_init(nn.Linear(hidden_size, 1), std=1) - - def forward_eval(self, observations, state=None): - hidden = self.encode_observations(observations, state=state) - logits, values = self.decode_actions(hidden) - return logits, values - - def forward(self, observations, state=None): - return self.forward_eval(observations, state) - - def encode_observations(self, observations, state=None): - """Encodes a batch of observations into hidden states. Assumes - no time dimension (handled by LSTM wrappers).""" - batch_size = observations.shape[0] - if self.is_dict_obs: - observations = pufferlib.pytorch.nativize_tensor(observations, self.dtype) - observations = torch.cat([v.view(batch_size, -1) for v in observations.values()], dim=1) - else: - observations = observations.view(batch_size, -1) - return self.encoder(observations.float()) - - def decode_actions(self, hidden): - """Decodes a batch of hidden states into (multi)discrete actions. - Assumes no time dimension (handled by LSTM wrappers).""" - if self.is_multidiscrete: - logits = self.decoder(hidden).split(self.action_nvec, dim=1) - elif self.is_continuous: - mean = self.decoder_mean(hidden) - logstd = self.decoder_logstd.expand_as(mean) - std = torch.exp(logstd) - logits = torch.distributions.Normal(mean, std) - else: - logits = self.decoder(hidden) - - values = self.value(hidden) - return logits, values - - -class LSTMWrapper(nn.Module): - def __init__(self, env, policy, input_size=128, hidden_size=128): - """Wraps your policy with an LSTM without letting you shoot yourself in the - foot with bad transpose and shape operations. This saves much pain. - Requires that your policy define encode_observations and decode_actions. - See the Default policy for an example.""" - super().__init__() - self.obs_shape = env.single_observation_space.shape - - self.policy = policy - self.input_size = input_size - self.hidden_size = hidden_size - self.is_continuous = self.policy.is_continuous - - for name, param in self.named_parameters(): - if "layer_norm" in name: - continue - if "bias" in name: - nn.init.constant_(param, 0) - elif "weight" in name and param.ndim >= 2: - nn.init.orthogonal_(param, 1.0) - - self.lstm = nn.LSTM(input_size, hidden_size) - - self.cell = torch.nn.LSTMCell(input_size, hidden_size) - self.cell.weight_ih = self.lstm.weight_ih_l0 - self.cell.weight_hh = self.lstm.weight_hh_l0 - self.cell.bias_ih = self.lstm.bias_ih_l0 - self.cell.bias_hh = self.lstm.bias_hh_l0 - - # self.pre_layernorm = nn.LayerNorm(hidden_size) - # self.post_layernorm = nn.LayerNorm(hidden_size) - - def forward_eval(self, observations, state): - """Forward function for inference. 3x faster than using LSTM directly""" - hidden = self.policy.encode_observations(observations, state=state) - h = state["lstm_h"] - c = state["lstm_c"] - - # TODO: Don't break compile - if h is not None: - assert h.shape[0] == c.shape[0] == observations.shape[0], "LSTM state must be (h, c)" - lstm_state = (h, c) - else: - lstm_state = None - - # hidden = self.pre_layernorm(hidden) - hidden, c = self.cell(hidden, lstm_state) - # hidden = self.post_layernorm(hidden) - state["hidden"] = hidden - state["lstm_h"] = hidden - state["lstm_c"] = c - logits, values = self.policy.decode_actions(hidden) - return logits, values - - def forward(self, observations, state): - """Forward function for training. Uses LSTM for fast time-batching""" - x = observations - lstm_h = state["lstm_h"] - lstm_c = state["lstm_c"] - - x_shape, space_shape = x.shape, self.obs_shape - x_n, space_n = len(x_shape), len(space_shape) - if x_shape[-space_n:] != space_shape: - raise ValueError("Invalid input tensor shape", x.shape) - - if x_n == space_n + 1: - B, TT = x_shape[0], 1 - elif x_n == space_n + 2: - B, TT = x_shape[:2] - else: - raise ValueError("Invalid input tensor shape", x.shape) - - if lstm_h is not None: - assert lstm_h.shape[1] == lstm_c.shape[1] == B, "LSTM state must be (h, c)" - lstm_state = (lstm_h, lstm_c) - else: - lstm_state = None - - x = x.reshape(B * TT, *space_shape) - hidden = self.policy.encode_observations(x, state) - assert hidden.shape == (B * TT, self.input_size) - - hidden = hidden.reshape(B, TT, self.input_size) - - hidden = hidden.transpose(0, 1) - # hidden = self.pre_layernorm(hidden) - hidden, (lstm_h, lstm_c) = self.lstm.forward(hidden, lstm_state) - hidden = hidden.float() - - # hidden = self.post_layernorm(hidden) - hidden = hidden.transpose(0, 1) - - flat_hidden = hidden.reshape(B * TT, self.hidden_size) - logits, values = self.policy.decode_actions(flat_hidden) - values = values.reshape(B, TT) - # state.batch_logits = logits.reshape(B, TT, -1) - state["hidden"] = hidden - state["lstm_h"] = lstm_h.detach() - state["lstm_c"] = lstm_c.detach() - return logits, values - - -class Convolutional(nn.Module): - def __init__( - self, - env, - *args, - framestack, - flat_size, - input_size=512, - hidden_size=512, - output_size=512, - channels_last=False, - downsample=1, - **kwargs, - ): - """The CleanRL default NatureCNN policy used for Atari. - It's just a stack of three convolutions followed by a linear layer - - Takes framestack as a mandatory keyword argument. Suggested default is 1 frame - with LSTM or 4 frames without.""" - super().__init__() - self.channels_last = channels_last - self.downsample = downsample - - # TODO: Remove these from required params - self.hidden_size = hidden_size - self.is_continuous = False - - self.network = nn.Sequential( - pufferlib.pytorch.layer_init(nn.Conv2d(framestack, 32, 8, stride=4)), - nn.ReLU(), - pufferlib.pytorch.layer_init(nn.Conv2d(32, 64, 4, stride=2)), - nn.ReLU(), - pufferlib.pytorch.layer_init(nn.Conv2d(64, 64, 3, stride=1)), - nn.ReLU(), - nn.Flatten(), - pufferlib.pytorch.layer_init(nn.Linear(flat_size, hidden_size)), - nn.ReLU(), - ) - self.actor = pufferlib.pytorch.layer_init(nn.Linear(hidden_size, env.single_action_space.n), std=0.01) - self.value_fn = pufferlib.pytorch.layer_init(nn.Linear(output_size, 1), std=1) - - def forward(self, observations, state=None): - hidden = self.encode_observations(observations) - actions, value = self.decode_actions(hidden) - return actions, value - - def forward_train(self, observations, state=None): - return self.forward(observations, state) - - def encode_observations(self, observations, state=None): - if self.channels_last: - observations = observations.permute(0, 3, 1, 2) - if self.downsample > 1: - observations = observations[:, :, :: self.downsample, :: self.downsample] - return self.network(observations.float() / 255.0) - - def decode_actions(self, flat_hidden): - action = self.actor(flat_hidden) - value = self.value_fn(flat_hidden) - return action, value - - -class ProcgenResnet(nn.Module): - """Procgen baseline from the AICrowd NeurIPS 2020 competition - Based on the ResNet architecture that was used in the Impala paper.""" - - def __init__(self, env, cnn_width=16, mlp_width=256): - super().__init__() - h, w, c = env.single_observation_space.shape - shape = (c, h, w) - conv_seqs = [] - for out_channels in [cnn_width, 2 * cnn_width, 2 * cnn_width]: - conv_seq = ConvSequence(shape, out_channels) - shape = conv_seq.get_output_shape() - conv_seqs.append(conv_seq) - conv_seqs += [ - nn.Flatten(), - nn.ReLU(), - nn.Linear(in_features=shape[0] * shape[1] * shape[2], out_features=mlp_width), - nn.ReLU(), - ] - self.network = nn.Sequential(*conv_seqs) - self.actor = pufferlib.pytorch.layer_init(nn.Linear(mlp_width, env.single_action_space.n), std=0.01) - self.value = pufferlib.pytorch.layer_init(nn.Linear(mlp_width, 1), std=1) - - def forward(self, observations, state=None): - hidden = self.encode_observations(observations) - actions, value = self.decode_actions(hidden) - return actions, value - - def forward_train(self, observations, state=None): - return self.forward(observations, state) - - def encode_observations(self, x): - hidden = self.network(x.permute((0, 3, 1, 2)) / 255.0) - return hidden - - def decode_actions(self, hidden): - """linear decoder function""" - action = self.actor(hidden) - value = self.value(hidden) - return action, value - - -class ResidualBlock(nn.Module): - def __init__(self, channels): - super().__init__() - self.conv0 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1) - self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1) - - def forward(self, x): - inputs = x - x = nn.functional.relu(x) - x = self.conv0(x) - x = nn.functional.relu(x) - x = self.conv1(x) - return x + inputs - - -class ConvSequence(nn.Module): - def __init__(self, input_shape, out_channels): - super().__init__() - self._input_shape = input_shape - self._out_channels = out_channels - self.conv = nn.Conv2d( - in_channels=self._input_shape[0], out_channels=self._out_channels, kernel_size=3, padding=1 - ) - self.res_block0 = ResidualBlock(self._out_channels) - self.res_block1 = ResidualBlock(self._out_channels) - - def forward(self, x): - x = self.conv(x) - x = nn.functional.max_pool2d(x, kernel_size=3, stride=2, padding=1) - x = self.res_block0(x) - x = self.res_block1(x) - assert x.shape[1:] == self.get_output_shape() - return x - - def get_output_shape(self): - _c, h, w = self._input_shape - return (self._out_channels, (h + 1) // 2, (w + 1) // 2) diff --git a/pufferlib/ocean/__init__.py b/pufferlib/ocean/__init__.py index 55822f8a57..ce7a197275 100644 --- a/pufferlib/ocean/__init__.py +++ b/pufferlib/ocean/__init__.py @@ -1,13 +1 @@ from .environment import * - -try: - import torch -except ImportError: - pass -else: - from .torch import Policy - - try: - from .torch import Recurrent - except: - Recurrent = None diff --git a/pufferlib/ocean/drive/drivenet.h b/pufferlib/ocean/drive/drivenet.h index 162c7d2306..d89e2e364e 100644 --- a/pufferlib/ocean/drive/drivenet.h +++ b/pufferlib/ocean/drive/drivenet.h @@ -40,7 +40,6 @@ struct DriveNet { GELU *gelu; Linear *shared_embedding; ReLU *relu; - LSTM *lstm; Linear *actor; Linear *value_fn; Multidiscrete *multidiscrete; @@ -104,9 +103,6 @@ DriveNet *init_drivenet(Weights *weights, int num_agents, int dynamics_model, in net->relu = make_relu(num_agents, hidden_size); net->actor = make_linear(weights, num_agents, hidden_size, action_size); net->value_fn = make_linear(weights, num_agents, hidden_size, 1); - net->lstm = make_lstm(weights, num_agents, hidden_size, NN_HIDDEN_SIZE); - memset(net->lstm->state_h, 0, num_agents * NN_HIDDEN_SIZE * sizeof(float)); - memset(net->lstm->state_c, 0, num_agents * NN_HIDDEN_SIZE * sizeof(float)); net->multidiscrete = make_multidiscrete(num_agents, logit_sizes, action_dim); return net; } @@ -140,7 +136,6 @@ void free_drivenet(DriveNet *net) { free(net->multidiscrete); free(net->actor); free(net->value_fn); - free(net->lstm); free(net); } @@ -265,9 +260,8 @@ void forward(DriveNet *net, float *observations, int *actions) { gelu(net->gelu, net->cat2->output); linear(net->shared_embedding, net->gelu->output); relu(net->relu, net->shared_embedding->output); - lstm(net->lstm, net->relu->output); - linear(net->actor, net->lstm->state_h); - linear(net->value_fn, net->lstm->state_h); + linear(net->actor, net->relu->output); + linear(net->value_fn, net->relu->output); // Get action by taking argmax of actor output softmax_multidiscrete(net->multidiscrete, net->actor->output, actions); diff --git a/pufferlib/ocean/policies.py b/pufferlib/ocean/policies.py new file mode 100644 index 0000000000..8151fb0997 --- /dev/null +++ b/pufferlib/ocean/policies.py @@ -0,0 +1,262 @@ +from torch import nn +import torch +import torch.nn.functional as F + +import pufferlib + +from pufferlib.policy import Policy +from pufferlib.samplers import DiscreteSampler, MultiDiscreteSampler, Sampler + + +class MultiDiscreteDriveMLP(Policy): + def __init__(self, env, input_size=128, hidden_size=128, **kwargs): + super().__init__(sampler=MultiDiscreteSampler()) + self.hidden_size = hidden_size + self.observation_size = env.single_observation_space.shape[0] + self.max_partner_objects = env.max_partner_objects + self.partner_features = env.partner_features + self.max_road_objects = env.max_road_objects + self.road_features = env.road_features + self.road_features_after_onehot = env.road_features + 6 # 6 is the number of one-hot encoded categories + # Determine ego dimension from environment's feature layout + self.ego_dim = env.ego_features + + self.ego_encoder = nn.Sequential( + pufferlib.pytorch.layer_init(nn.Linear(self.ego_dim, input_size)), + nn.LayerNorm(input_size), + # nn.ReLU(), + pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), + ) + + self.road_encoder = nn.Sequential( + pufferlib.pytorch.layer_init(nn.Linear(self.road_features_after_onehot, input_size)), + nn.LayerNorm(input_size), + # nn.ReLU(), + pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), + ) + + self.partner_encoder = nn.Sequential( + pufferlib.pytorch.layer_init(nn.Linear(self.partner_features, input_size)), + nn.LayerNorm(input_size), + # nn.ReLU(), + pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), + ) + + self.shared_embedding = nn.Sequential( + nn.GELU(), + pufferlib.pytorch.layer_init(nn.Linear(3 * input_size, hidden_size)), + ) + + self.atn_dim = env.single_action_space.nvec.tolist() + + self.actor = pufferlib.pytorch.layer_init(nn.Linear(hidden_size, sum(self.atn_dim)), std=0.01) + self.value_fn = pufferlib.pytorch.layer_init(nn.Linear(hidden_size, 1), std=1) + + def forward_eval(self, observations, state=None): + assert observations.dim() == 2, "Expected input shape [batch_size, obs_dim]" + hidden = self.encode_observations(observations) + logits, value = self.decode_actions(hidden) + actions, logprobs = self.sampler.sample_actions(logits) + return actions, logprobs, value + + def forward_train(self, observations, actions, mask=None, state=None): + assert observations.dim() == 3, "Expected input shape [batch_size, bptt, obs_dim]" + flat_obs = observations.view(-1, observations.size(-1)) + flat_actions = actions.view(-1, actions.size(-1)) + if mask is not None: + assert mask.dim() == 2, "Expected mask shape [batch_size, bptt]" + flat_mask = mask.view(-1) + flat_obs = flat_obs[flat_mask] + flat_actions = flat_actions[flat_mask] + logits, newvalue = self.decode_actions(self.encode_observations(flat_obs)) + newlogprob, entropy = self.sampler.compute_logprobs(logits, flat_actions) + return newvalue, newlogprob, entropy + + def encode_observations(self, observations, state=None): + ego_dim = self.ego_dim + partner_dim = self.max_partner_objects * self.partner_features + road_dim = self.max_road_objects * self.road_features + ego_obs = observations[:, :ego_dim] + partner_obs = observations[:, ego_dim : ego_dim + partner_dim] + road_obs = observations[:, ego_dim + partner_dim : ego_dim + partner_dim + road_dim] + + partner_objects = partner_obs.view(-1, self.max_partner_objects, self.partner_features) + + road_objects = road_obs.view(-1, self.max_road_objects, self.road_features) + road_continuous = road_objects[:, :, : self.road_features - 1] + road_categorical = road_objects[:, :, self.road_features - 1] + road_onehot = F.one_hot(road_categorical.long(), num_classes=7) # Shape: [batch, ROAD_MAX_OBJECTS, 7] + road_objects = torch.cat([road_continuous, road_onehot], dim=2) + ego_features = self.ego_encoder(ego_obs) + partner_features, _ = self.partner_encoder(partner_objects).max(dim=1) + road_features, _ = self.road_encoder(road_objects).max(dim=1) + + concat_features = torch.cat([ego_features, road_features, partner_features], dim=1) + + # Pass through shared embedding + embedding = F.relu(self.shared_embedding(concat_features)) + return embedding + + def decode_actions(self, flat_hidden): + logits = self.actor(flat_hidden) + logits = torch.split(logits, self.atn_dim, dim=1) + value = self.value_fn(flat_hidden) + + return logits, value + + +class MultiDiscreteDriveLSTM(Policy): + def __init__(self, env, input_size=128, hidden_size=128, **kwargs): + super().__init__(sampler=MultiDiscreteSampler()) + self.hidden_size = hidden_size + self.observation_size = env.single_observation_space.shape[0] + self.max_partner_objects = env.max_partner_objects + self.partner_features = env.partner_features + self.max_road_objects = env.max_road_objects + self.road_features = env.road_features + self.road_features_after_onehot = env.road_features + 6 # 6 is the number of one-hot encoded categories + # Determine ego dimension from environment's feature layout + self.ego_dim = env.ego_features + + self.ego_encoder = nn.Sequential( + pufferlib.pytorch.layer_init(nn.Linear(self.ego_dim, input_size)), + nn.LayerNorm(input_size), + # nn.ReLU(), + pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), + ) + + self.road_encoder = nn.Sequential( + pufferlib.pytorch.layer_init(nn.Linear(self.road_features_after_onehot, input_size)), + nn.LayerNorm(input_size), + # nn.ReLU(), + pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), + ) + + self.partner_encoder = nn.Sequential( + pufferlib.pytorch.layer_init(nn.Linear(self.partner_features, input_size)), + nn.LayerNorm(input_size), + # nn.ReLU(), + pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), + ) + + self.shared_embedding = nn.Sequential( + nn.GELU(), + pufferlib.pytorch.layer_init(nn.Linear(3 * input_size, hidden_size)), + ) + + self.atn_dim = env.single_action_space.nvec.tolist() + + self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=False) + self.actor = pufferlib.pytorch.layer_init(nn.Linear(hidden_size, sum(self.atn_dim)), std=0.01) + self.value_fn = pufferlib.pytorch.layer_init(nn.Linear(hidden_size, 1), std=1) + + # Per-agent hidden state for rollout + # Not registered as parameters so they are excluded from the weight binary. + self._eval_buffer_h = None # [1, num_agents, hidden_size] + self._eval_buffer_c = None # [1, num_agents, hidden_size] + + # Not registered as parameters so they are excluded from the weight binary. + self._train_buffer_h = None # [1, num_agents, hidden_size] + self._train_buffer_c = None # [1, num_agents, hidden_size] + + def _init_buffers(self, batch_size: int, device: torch.device): + if self._eval_buffer_h is None: + self._eval_buffer_h = torch.zeros(1, batch_size, self.hidden_size, device=device) + self._eval_buffer_c = torch.zeros(1, batch_size, self.hidden_size, device=device) + self._train_buffer_h = torch.zeros(1, batch_size, self.hidden_size, device=device) + self._train_buffer_c = torch.zeros(1, batch_size, self.hidden_size, device=device) + + def forward_eval(self, observations, state=None, truncations=None): + assert observations.dim() == 2, "Expected input shape [batch_size, obs_dim]" + assert "env_id" in state, "Expected state to contain 'env_id' for indexing recurrent buffer" + batch_size, device = observations.shape[0], observations.device + self._init_buffers(batch_size, device) + + embedding = self.encode_observations(observations) # [batch_size, hidden_size] + lstm_out, (h_new, c_new) = self.lstm( + embedding.unsqueeze(0), # [1, batch_size, hidden_size] + (self.buffer_h.detach(), self.buffer_c.detach()), + ) + self.buffer_h, self.buffer_c = h_new.detach(), c_new.detach() + + logits, value = self.decode_actions(lstm_out.squeeze(0)) + actions, logprobs = self.sampler.sample_actions(logits) + return actions, logprobs, value + + def forward_train(self, observations, actions, mask=None, truncations=None): + """ + inputs: + observations - (batch_size, bptt, obs_dim) + actions - (batch_size, bptt, 1) + mask - (batch_size, bptt) boolean tensor indicating valid samples (optional) + truncations - (batch_size, bptt) boolean tensor indicating truncation points (optional) + + when mask = false it indicates the sample is invalid, we drop the sample and interrupt the LSTM recurring states + when truncation = true, we interrupt the LSTM recurring states but keep the sample + + """ + assert observations.dim() == 3, "Expected input shape [batch_size, bptt, obs_dim]" + batch_size, bptt, _ = observations.shape + device = observations.device + + h = torch.zeros(1, batch_size, self.hidden_size, device=device) + c = torch.zeros(1, batch_size, self.hidden_size, device=device) + + hiddens = [] + for t in range(bptt): + if t > 0: + reset = torch.zeros(batch_size, dtype=torch.bool, device=device) + if mask is not None: + reset = reset | ~mask[:, t - 1].bool() + if truncations is not None: + reset = reset | truncations[:, t - 1].bool() + h[:, reset, :] = 0.0 + c[:, reset, :] = 0.0 + + embedding = self.encode_observations(observations[:, t, :]) + out, (h, c) = self.lstm(embedding.unsqueeze(0), (h, c)) + hiddens.append(out.squeeze(0)) + + hidden = torch.stack(hiddens, dim=1).reshape(batch_size * bptt, self.hidden_size) + flat_actions = actions.view(-1, actions.size(-1)) + + if mask is not None: + flat_mask = mask.reshape(-1) + hidden = hidden[flat_mask] + flat_actions = flat_actions[flat_mask] + + logits, newvalue = self.decode_actions(hidden) + newlogprob, entropy = self.sampler.compute_logprobs(logits, flat_actions) + return newvalue, newlogprob, entropy + + def encode_observations(self, observations, state=None): + ego_dim = self.ego_dim + partner_dim = self.max_partner_objects * self.partner_features + road_dim = self.max_road_objects * self.road_features + ego_obs = observations[:, :ego_dim] + partner_obs = observations[:, ego_dim : ego_dim + partner_dim] + road_obs = observations[:, ego_dim + partner_dim : ego_dim + partner_dim + road_dim] + + partner_objects = partner_obs.view(-1, self.max_partner_objects, self.partner_features) + + road_objects = road_obs.view(-1, self.max_road_objects, self.road_features) + road_continuous = road_objects[:, :, : self.road_features - 1] + road_categorical = road_objects[:, :, self.road_features - 1] + road_onehot = F.one_hot(road_categorical.long(), num_classes=7) # Shape: [batch, ROAD_MAX_OBJECTS, 7] + road_objects = torch.cat([road_continuous, road_onehot], dim=2) + ego_features = self.ego_encoder(ego_obs) + partner_features, _ = self.partner_encoder(partner_objects).max(dim=1) + road_features, _ = self.road_encoder(road_objects).max(dim=1) + + concat_features = torch.cat([ego_features, road_features, partner_features], dim=1) + + # Pass through shared embedding + embedding = F.relu(self.shared_embedding(concat_features)) + return embedding + + def decode_actions(self, flat_hidden): + logits = self.actor(flat_hidden) + logits = torch.split(logits, self.atn_dim, dim=1) + value = self.value_fn(flat_hidden) + + return logits, value diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 925da37984..b7d0f3e982 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -3,13 +3,6 @@ import torch.nn.functional as F import pufferlib -import pufferlib.models - -from pufferlib.models import Default as Policy # noqa: F401 -from pufferlib.models import Convolutional as Conv # noqa: F401 - - -Recurrent = pufferlib.models.LSTMWrapper class Drive(nn.Module): diff --git a/pufferlib/policy.py b/pufferlib/policy.py new file mode 100644 index 0000000000..44d9b65ed2 --- /dev/null +++ b/pufferlib/policy.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from abc import ABC, abstractmethod +from typing import Tuple +from pufferlib.samplers import Sampler +from os import path + + +class Policy(nn.Module, ABC): + def __init__(self, sampler: Sampler): + super().__init__() + self.sampler = sampler + + @abstractmethod + def forward_eval(self, obs: torch.Tensor, state: dict) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + given the observation and state, return the sampled action, its log probability, and value estimate. + input: obs - (batch_size, obs_dim) + state: dict of metadata for the policy (e.g. env id) + + returns action - (batch_size, 1) sampled action + logprob - (batch_size, 1) log probability of the sampled action + value - (batch_size, 1) value estimate + """ + actions = None + logprobs = None + values = None + return actions, logprobs, values + + @abstractmethod + def forward_train( + self, obs: torch.Tensor, actions: torch.Tensor, mask: torch.Tensor = None, truncations: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + given stored observations and actions, recompute values, log probabilities, and entropy. + input: obs - (batch_size, bptt, obs_dim) + actions - (batch_size, bptt, 1) stored actions from the rollout buffer + mask - (batch_size, bptt) boolean tensor indicating valid samples (optional) + state - dict of metadata (optional) + + returns newvalue - (num_valid_samples, 1) + newlogprob - (num_valid_samples, 1) + entropy - (num_valid_samples, 1) + """ + newvalue = None + newlogprob = None + entropy = None + return newvalue, newlogprob, entropy diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 424baf4f65..ccaca6d230 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -136,13 +136,6 @@ def __init__(self, config, vecenv, policy, logger=None): self.render_queue = multiprocessing.Queue() self.render_processes = [] - # LSTM - if config["use_rnn"]: - n = vecenv.agents_per_batch - h = policy.hidden_size - self.lstm_h = {i * n: torch.zeros(n, h, device=device) for i in range(total_agents // n)} - self.lstm_c = {i * n: torch.zeros(n, h, device=device) for i in range(total_agents // n)} - # Minibatching & gradient accumulation minibatch_size = config["minibatch_size"] max_minibatch_size = config["max_minibatch_size"] @@ -258,11 +251,6 @@ def evaluate(self): config = self.config device = config["device"] - if config["use_rnn"]: - for k in self.lstm_h: - self.lstm_h[k] = torch.zeros(self.lstm_h[k].shape, device=device) - self.lstm_c[k] = torch.zeros(self.lstm_c[k].shape, device=device) - self.full_rows = 0 while self.full_rows < self.segments: profile("env", epoch) @@ -290,21 +278,13 @@ def evaluate(self): mask=mask, ) - if config["use_rnn"]: - state["lstm_h"] = self.lstm_h[env_id.start] - state["lstm_c"] = self.lstm_c[env_id.start] + action, logprob, value = self.policy.forward_eval(o_device, state) - logits, value = self.policy.forward_eval(o_device, state) - action, logprob, _ = pufferlib.pytorch.sample_logits(logits) if config.get("clamp_reward", True): r = torch.clamp(r, -1, 1) profile("eval_copy", epoch) with torch.no_grad(): - if config["use_rnn"]: - self.lstm_h[env_id.start] = state["lstm_h"] - self.lstm_c[env_id.start] = state["lstm_c"] - # Fast path for fully vectorized envs l = self.ep_lengths[env_id.start].item() batch_rows = slice(self.ep_indices[env_id.start].item(), 1 + self.ep_indices[env_id.stop - 1].item()) @@ -338,8 +318,6 @@ def evaluate(self): self.full_rows += num_full action = action.cpu().numpy() - if isinstance(logits, torch.distributions.Normal): - action = np.clip(action, self.vecenv.action_space.low, self.vecenv.action_space.high) profile("eval_misc", epoch) for i in info: @@ -413,17 +391,8 @@ def train(self): mb_advantages = advantages[idx] profile("train_forward", epoch) - if not config["use_rnn"]: - mb_obs = mb_obs.reshape(-1, *self.vecenv.single_observation_space.shape) - - state = dict( - action=mb_actions, - lstm_h=None, - lstm_c=None, - ) - logits, newvalue = self.policy(mb_obs, state) - actions, newlogprob, entropy = pufferlib.pytorch.sample_logits(logits, action=mb_actions) + newvalue, newlogprob, entropy = self.policy.forward_train(mb_obs, mb_actions) profile("train_misc", epoch) newlogprob = newlogprob.reshape(mb_logprobs.shape) @@ -1248,11 +1217,6 @@ def train(env_name, args=None, vecenv=None, policy=None, logger=None): torch.distributed.init_process_group(backend="nccl", world_size=world_size) policy = policy.to(local_rank) model = torch.nn.parallel.DistributedDataParallel(policy, device_ids=[local_rank], output_device=local_rank) - if hasattr(policy, "lstm"): - # model.lstm = policy.lstm - model.hidden_size = policy.hidden_size - - model.forward_eval = policy.forward_eval policy = model.to(local_rank) # Only rank 0 should create the logger to avoid duplicate runs @@ -1412,11 +1376,6 @@ def eval(env_name, args=None, vecenv=None, policy=None): device = args["train"]["device"] state = {} - if args["train"]["use_rnn"]: - state = dict( - lstm_h=torch.zeros(num_agents, policy.hidden_size, device=device), - lstm_c=torch.zeros(num_agents, policy.hidden_size, device=device), - ) frames = [] while True: @@ -1681,18 +1640,16 @@ def load_env(env_name, args): def load_policy(args, vecenv, env_name=""): package = args["package"] - module_name = "pufferlib.ocean" if package == "ocean" else f"pufferlib.environments.{package}" - env_module = importlib.import_module(module_name) + if package == "ocean": + policies_module = importlib.import_module("pufferlib.ocean.policies") + else: + env_module = importlib.import_module(f"pufferlib.environments.{package}") + policies_module = env_module.torch device = args["train"]["device"] - policy_cls = getattr(env_module.torch, args["policy_name"]) + policy_cls = getattr(policies_module, args["policy_name"]) policy = policy_cls(vecenv.driver_env, **args["policy"]) - rnn_name = args["rnn_name"] - if rnn_name is not None: - rnn_cls = getattr(env_module.torch, args["rnn_name"]) - policy = rnn_cls(vecenv.driver_env, policy, **args["rnn"]) - policy = policy.to(device) load_id = args["load_id"] @@ -1801,7 +1758,6 @@ def puffer_type(value): prev[subkey] = value - args["train"]["use_rnn"] = args["rnn_name"] is not None return args diff --git a/pufferlib/samplers.py b/pufferlib/samplers.py new file mode 100644 index 0000000000..bf46f6e3f3 --- /dev/null +++ b/pufferlib/samplers.py @@ -0,0 +1,78 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from abc import ABC, abstractmethod +from typing import Dict, Tuple, List + + +class Sampler(nn.Module, ABC): + def __init__(self): + super().__init__() + + @abstractmethod + def sample_actions(self, logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + given the logits it samples actions and returns the sampled actions and their log probabilities + input: logits - (batch_size, action_dim) + + returns action - (batch_size, action_dim) sampled action + logprobs - (batch_size,) log probability per sample + """ + actions = None + logprobs = None + return actions, logprobs + + @abstractmethod + def compute_logprobs(self, logits: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: + """ + given the logits and actions, compute the log probabilities of the actions under the given logits + input: logits - (batch_size, action_dim) + actions - (batch_size, action_dim) + + returns logprobs - (batch_size,) log probability per sample + entropy - (batch_size,) entropy per sample + """ + logprobs = None + return logprobs + + +class DiscreteSampler(Sampler): + def __init__(self): + super().__init__() + + def sample_actions(self, logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + action_dist = torch.distributions.Categorical(logits=logits) + actions = action_dist.sample().unsqueeze(-1) # [batch_size, 1] + logprobs = action_dist.log_prob(actions.squeeze(-1)) # [batch_size] + return actions, logprobs + + def compute_logprobs(self, logits: torch.Tensor, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + action_dist = torch.distributions.Categorical(logits=logits) + logprobs = action_dist.log_prob(actions.squeeze(-1)) # [batch_size] + entropy = action_dist.entropy() # [batch_size] + return logprobs, entropy + + +class MultiDiscreteSampler(Sampler): + def __init__(self): + super().__init__() + + def _pad_logits(self, logits: List[torch.Tensor]) -> torch.Tensor: + # logits: list of [batch, action_dim] → [batch, num_heads, max_action_dim] padded + return torch.nn.utils.rnn.pad_sequence( + [l.transpose(0, 1) for l in logits], batch_first=False, padding_value=-torch.inf + ).permute(1, 2, 0) # [num_heads, batch, max_action_dim] + + def sample_actions(self, logits: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + padded = self._pad_logits(logits) # [num_heads, batch, max_action_dim] + action_dist = torch.distributions.Categorical(logits=padded) + actions = action_dist.sample() # [num_heads, batch] + logprobs = action_dist.log_prob(actions).sum(0) # [batch] + return actions.T, logprobs # [batch, num_heads], [batch] + + def compute_logprobs(self, logits: List[torch.Tensor], actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + padded = self._pad_logits(logits) # [num_heads, batch, max_action_dim] + action_dist = torch.distributions.Categorical(logits=padded) + logprobs = action_dist.log_prob(actions.T).sum(0) # [batch] + entropy = action_dist.entropy().sum(0) # [batch] + return logprobs, entropy diff --git a/scripts/export_model_bin.py b/scripts/export_model_bin.py index 5abfbffeea..0bd7b08efe 100644 --- a/scripts/export_model_bin.py +++ b/scripts/export_model_bin.py @@ -6,9 +6,6 @@ import pufferlib.utils import pufferlib.vector -import pufferlib.models - -from pufferlib.ocean.torch import Drive def load_config(env_name, config_dir=None): @@ -86,12 +83,10 @@ def export_weights(): vecenv = pufferlib.vector.make(make_env, env_kwargs=env_kwargs, backend=pufferlib.vector.Serial, num_envs=1) # Initialize Policy - print("Initializing Policy...") - policy = Drive(vecenv.driver_env, **config["policy"]) - - if config["base"]["rnn_name"]: - print("Wrapping with LSTM...") - policy = pufferlib.models.LSTMWrapper(vecenv.driver_env, policy, **config["rnn"]) + policy_name = config["base"]["policy_name"] + policy_cls = getattr(__import__("pufferlib.ocean.policies", fromlist=[policy_name]), policy_name) + print(f"Initializing {policy_name}...") + policy = policy_cls(vecenv.driver_env, **config["policy"]) # Load Checkpoint print(f"Loading checkpoint from {args.checkpoint}...") diff --git a/tests/test_puffernet.py b/tests/test_puffernet.py index d91b4354a2..217c2f30d3 100644 --- a/tests/test_puffernet.py +++ b/tests/test_puffernet.py @@ -272,31 +272,8 @@ def test_puffernet_argmax_multidiscrete(batch_size=16, logit_sizes=[5, 7, 2]): assert_near(output_puffer, output_torch.numpy()) -def test_drive(batch_size=1, input_size=512, hidden_size=512): - from pufferlib.ocean.torch import Drive, Recurrent - from pufferlib.ocean import env_creator - - env = env_creator("puffer_drive")(num_maps=1) - input_torch = torch.arange(4 * env.num_obs) % 7 - input_torch = input_torch.view(4, -1).float() - - model = Drive(env, hidden_size=hidden_size) - model = Recurrent(env, policy=model, input_size=input_size, hidden_size=hidden_size) - - # state_dict = torch.load("nmmo3_642b.pt") - # state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} - # model.load_state_dict(state_dict) - - state = { - "lstm_h": torch.zeros(batch_size, hidden_size), - "lstm_c": torch.zeros(batch_size, hidden_size), - } - output = model.forward_eval(input_torch, state) - pass - - if __name__ == "__main__": - test_drive() + pass # exit() # test_puffernet_relu() # test_puffernet_sigmoid()