-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathcommon.py
More file actions
84 lines (65 loc) · 2.86 KB
/
common.py
File metadata and controls
84 lines (65 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from torch import nn
from torch.nn import *
from collections import OrderedDict
from typing import Any, Iterable, Iterator, Mapping, Optional, TYPE_CHECKING, overload, Tuple, TypeVar, Union
T = TypeVar('T', bound=Module)
class Conv2dWrapper(nn.Conv2d):
"""
Wrapper for pytorch Conv2d class which can take additional parameters(like temperature) and ignores them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
return super().forward(x)
class TempModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, temperature) -> torch.Tensor:
return x
class BaseModel(TempModule):
def __init__(self, ConvLayer):
super().__init__()
self.ConvLayer = ConvLayer
class TemperatureScheduler:
def __init__(self, initial_value, final_value=None, final_epoch=None):
self.initial_value = initial_value
self.final_value = final_value if final_value else initial_value
self.final_epoch = final_epoch if final_epoch else 1
self.step = 0 if self.final_epoch == 1 else (final_value - initial_value) / (final_epoch - 1)
def get(self, crt_epoch=None):
crt_epoch = crt_epoch if crt_epoch else self.final_epoch
return self.initial_value + (min(crt_epoch, self.final_epoch) - 1) * self.step
class CustomSequential(TempModule):
"""Sequential container that supports passing temperature to TempModule"""
def __init__(self, *args):
super().__init__()
self.layers = nn.ModuleList(args)
def forward(self, x, temperature):
for layer in self.layers:
if isinstance(layer, TempModule):
x = layer(x, temperature)
else:
x = layer(x)
return x
def __getitem__(self, idx):
return CustomSequential(*list(self.layers[idx]))
# if isinstance(idx, slice):
# return self.__class__(OrderedDict(list(self.layers.items())[idx]))
# else:
# return self._get_item_by_idx(self.layers.values(), idx)
# Implementation inspired from
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/train.py#L38 and
# https://github.com/pytorch/pytorch/issues/7455
class SmoothNLLLoss(nn.Module):
def __init__(self, smoothing=0.0, dim=-1):
super().__init__()
self.smoothing = smoothing
self.dim = dim
def forward(self, prediction, target):
with torch.no_grad():
smooth_target = torch.zeros_like(prediction)
n_class = prediction.size(self.dim)
smooth_target.fill_(self.smoothing / (n_class - 1))
smooth_target.scatter_(1, target.unsqueeze(1), 1 - self.smoothing)
return torch.mean(torch.sum(-smooth_target * prediction, dim=self.dim))