From 8648d6390a80f1ff0bb44e641fc62d7e953dc653 Mon Sep 17 00:00:00 2001 From: Ugenteraan Date: Mon, 28 Dec 2020 16:14:12 +0800 Subject: [PATCH] detached dynamic routing and model mode change --- deepcaps.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/deepcaps.py b/deepcaps.py index ee9db0e..65e00f7 100644 --- a/deepcaps.py +++ b/deepcaps.py @@ -713,17 +713,22 @@ def forward(self, x): x = x.unsqueeze(2).unsqueeze(dim=4) u_hat = torch.matmul(self.W, x).squeeze() # u_hat -> [batch_size, 32, 10, 32] + u_hat_detached = u_hat.detach() # b_ij = torch.zeros((batch_size, self.num_routes, self.num_capsules, 1)) b_ij = x.new(x.shape[0], self.num_routes, self.num_capsules, 1).zero_() for itr in range(self.routing_iters): c_ij = func.softmax(b_ij, dim=2) - s_j = (c_ij * u_hat).sum(dim=1, keepdim=True) + self.bias - v_j = squash(s_j, dim=-1) + + if itr == self.routing_iters -1 : + s_j = (c_ij * u_hat).sum(dim=1, keepdim=True) + self.bias + v_j = squash(s_j, dim=-1) - if itr < self.routing_iters-1: - a_ij = (u_hat * v_j).sum(dim=-1, keepdim=True) + else: + s_j = (c_ij * u_hat_detached).sum(dim=1, keepdim=True) + v_j = squash(s_j, dim=-1) + a_ij = (u_hat_detached * v_j).sum(dim=-1, keepdim=True) b_ij = b_ij + a_ij v_j = v_j.squeeze() #.unsqueeze(-1) @@ -965,6 +970,7 @@ def accuracy(indices, labels): print("def test") def test(model, test_loader, loss, batch_size, lamda=0.5, m_plus=0.9, m_minus=0.1): + model.eval() test_loss = 0.0 correct = 0.0 for batch_idx, (data, label) in enumerate(test_loader): @@ -989,11 +995,13 @@ def test(model, test_loader, loss, batch_size, lamda=0.5, m_plus=0.9, m_minus=0. def train(train_loader, model, num_epochs, lr=0.001, batch_size=64, lamda=0.5, m_plus=0.9, m_minus=0.1): + optimizer = torch.optim.Adam(model.parameters(), lr) lambda1 = lambda epoch: 0.5**(epoch // 10) lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) #lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.96) for epoch in range(num_epochs): + model.train() for batch_idx, (data, label_) in enumerate(train_loader): data, label = data.cuda(), label_.cuda() labels = one_hot(label)