From 2766338c13862f48e358fdc97e4dfce7f656d4fb Mon Sep 17 00:00:00 2001 From: Akhilesh Khambekar <77014592+akhileshkh@users.noreply.github.com> Date: Sun, 7 Dec 2025 22:15:35 -0500 Subject: [PATCH 1/4] Create FCN1D.py Adds a Fully Convolutional Network for Semantic Segmentation of 1D data --- pyhealth/models/FCN1D.py | 147 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 pyhealth/models/FCN1D.py diff --git a/pyhealth/models/FCN1D.py b/pyhealth/models/FCN1D.py new file mode 100644 index 000000000..20abec327 --- /dev/null +++ b/pyhealth/models/FCN1D.py @@ -0,0 +1,147 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class FCN1D(nn.Module): + + def __init__(self, n_class=4, in_channels = 1, kernel_size=3, inplanes = 64): + super(FCN1D, self).__init__() + + """ + Implement a Fully Convolutional Network for semantic segmentation + as defined in the paper https://arxiv.org/pdf/1411.4038 + + Args: + n_classes: The number of possible classes for the output prediction + + in_channels: The number of channels in the input. For instance, + a 12-lead ecg would have 12 + + kernel_size: The size of the kernel to be used in all convolutional + layers except for the final one + + inplanes: The number of output channels for the first + convolutional block. Subsequent blocks will use + inplanes *2, inplanes*4, inplanes*8 respectively + """ + + conv_padding = int(kernel_size / 2) + l2 = 2* inplanes + l3 = 4* inplanes + self.features_123 = nn.Sequential( + # conv1 + + nn.Conv1d(in_channels, inplanes, kernel_size, padding=100), + nn.ReLU(inplace=True), + nn.Conv1d(inplanes, inplanes, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.MaxPool1d(2, stride=2, ceil_mode=True), # 1/2 + + # conv2 + + nn.Conv1d(inplanes, l2, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.Conv1d(l2, l2, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.MaxPool1d(2, stride=2, ceil_mode=True), # 1/4 + + # conv3 + + nn.Conv1d(l2, l3, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.Conv1d(l3, l3, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.Conv1d(l3, l3, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.MaxPool1d(2, stride=2, ceil_mode=True), # 1/8 + ) + l4 = 8*inplanes + self.features_4 = nn.Sequential( + # conv4 + + nn.Conv1d(l3, l4, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.Conv1d(l4, l4, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.Conv1d(l4, l4, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.MaxPool1d(2, stride=2, ceil_mode=True), # 1/16 + ) + self.features_5 = nn.Sequential( + # conv5 features + nn.Conv1d(l4, l4, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.Conv1d(l4, l4, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.Conv1d(l4, l4, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.MaxPool1d(2, stride=2, ceil_mode=True), # 1/32 + ) + self.classifier = nn.Sequential( + # fc6 + nn.Conv1d(l4, 4096, 7), + nn.ReLU(inplace=True), + nn.Dropout1d(), + + # fc7 + nn.Conv1d(4096, 4096, 1), + nn.ReLU(inplace=True), + nn.Dropout1d(), + + # score_fr + nn.Conv1d(4096, n_class, 1), + ) + self.score_feat3 = nn.Conv1d(l3, n_class, 1) + self.score_feat4 = nn.Conv1d(l4, n_class, 1) + self.upscore = nn.ConvTranspose1d(n_class, n_class, 16, stride=8, + bias=False) + self.upscore_4 = nn.ConvTranspose1d(n_class, n_class, 4, stride=2, + bias=False) + self.upscore_5 = nn.ConvTranspose1d(n_class, n_class, 4, stride=2, + bias=False) + + def forward(self, x): + """ + Input shape (Batch, channels, Sequence Length) + """ + feat3 = self.features_123(x) #1/8 + feat4 = self.features_4(feat3) #1/16 + feat5 = self.features_5(feat4) #1/32 + + score5 = self.classifier(feat5) + upscore5 = self.upscore_5(score5) + score4 = self.score_feat4(feat4) + score4 = score4[:, :, 5:5+upscore5.size()[2]].contiguous() #, 5:5+upscore5.size()[3] + score4 += upscore5 + + score3 = self.score_feat3(feat3) + upscore4 = self.upscore_4(score4) + score3 = score3[:, :, 9:9+upscore4.size()[2]].contiguous() #, 9:9+upscore4.size()[3] + score3 += upscore4 + h = self.upscore(score3) + h = h[:, :, 28:28+x.size()[2]].contiguous() #, 28:28+x.size()[3] + + return h + + +if __name__ == "__main__": + + """ simple test """ + n_class = 4 + n_samples = 16 + n_channels=3 + seq_len = 5000 + + model = FCN1D(n_class=n_class, in_channels = n_channels, kernel_size=3, inplanes = 64) + + test_x = torch.randn((n_samples, n_channels, seq_len)) + + test_y = model(test_x) + + print(test_y.shape) + + assert test_y.shape[0] == n_samples + + assert test_y.shape[1] == n_class + + assert test_y.shape[-1] == seq_len From 886be46425979f885454c9de1e09154169221952 Mon Sep 17 00:00:00 2001 From: Akhilesh Khambekar <77014592+akhileshkh@users.noreply.github.com> Date: Sun, 7 Dec 2025 22:16:17 -0500 Subject: [PATCH 2/4] Delete pyhealth/models/FCN1D.py --- pyhealth/models/FCN1D.py | 147 --------------------------------------- 1 file changed, 147 deletions(-) delete mode 100644 pyhealth/models/FCN1D.py diff --git a/pyhealth/models/FCN1D.py b/pyhealth/models/FCN1D.py deleted file mode 100644 index 20abec327..000000000 --- a/pyhealth/models/FCN1D.py +++ /dev/null @@ -1,147 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -class FCN1D(nn.Module): - - def __init__(self, n_class=4, in_channels = 1, kernel_size=3, inplanes = 64): - super(FCN1D, self).__init__() - - """ - Implement a Fully Convolutional Network for semantic segmentation - as defined in the paper https://arxiv.org/pdf/1411.4038 - - Args: - n_classes: The number of possible classes for the output prediction - - in_channels: The number of channels in the input. For instance, - a 12-lead ecg would have 12 - - kernel_size: The size of the kernel to be used in all convolutional - layers except for the final one - - inplanes: The number of output channels for the first - convolutional block. Subsequent blocks will use - inplanes *2, inplanes*4, inplanes*8 respectively - """ - - conv_padding = int(kernel_size / 2) - l2 = 2* inplanes - l3 = 4* inplanes - self.features_123 = nn.Sequential( - # conv1 - - nn.Conv1d(in_channels, inplanes, kernel_size, padding=100), - nn.ReLU(inplace=True), - nn.Conv1d(inplanes, inplanes, kernel_size, padding=conv_padding), - nn.ReLU(inplace=True), - nn.MaxPool1d(2, stride=2, ceil_mode=True), # 1/2 - - # conv2 - - nn.Conv1d(inplanes, l2, kernel_size, padding=conv_padding), - nn.ReLU(inplace=True), - nn.Conv1d(l2, l2, kernel_size, padding=conv_padding), - nn.ReLU(inplace=True), - nn.MaxPool1d(2, stride=2, ceil_mode=True), # 1/4 - - # conv3 - - nn.Conv1d(l2, l3, kernel_size, padding=conv_padding), - nn.ReLU(inplace=True), - nn.Conv1d(l3, l3, kernel_size, padding=conv_padding), - nn.ReLU(inplace=True), - nn.Conv1d(l3, l3, kernel_size, padding=conv_padding), - nn.ReLU(inplace=True), - nn.MaxPool1d(2, stride=2, ceil_mode=True), # 1/8 - ) - l4 = 8*inplanes - self.features_4 = nn.Sequential( - # conv4 - - nn.Conv1d(l3, l4, kernel_size, padding=conv_padding), - nn.ReLU(inplace=True), - nn.Conv1d(l4, l4, kernel_size, padding=conv_padding), - nn.ReLU(inplace=True), - nn.Conv1d(l4, l4, kernel_size, padding=conv_padding), - nn.ReLU(inplace=True), - nn.MaxPool1d(2, stride=2, ceil_mode=True), # 1/16 - ) - self.features_5 = nn.Sequential( - # conv5 features - nn.Conv1d(l4, l4, kernel_size, padding=conv_padding), - nn.ReLU(inplace=True), - nn.Conv1d(l4, l4, kernel_size, padding=conv_padding), - nn.ReLU(inplace=True), - nn.Conv1d(l4, l4, kernel_size, padding=conv_padding), - nn.ReLU(inplace=True), - nn.MaxPool1d(2, stride=2, ceil_mode=True), # 1/32 - ) - self.classifier = nn.Sequential( - # fc6 - nn.Conv1d(l4, 4096, 7), - nn.ReLU(inplace=True), - nn.Dropout1d(), - - # fc7 - nn.Conv1d(4096, 4096, 1), - nn.ReLU(inplace=True), - nn.Dropout1d(), - - # score_fr - nn.Conv1d(4096, n_class, 1), - ) - self.score_feat3 = nn.Conv1d(l3, n_class, 1) - self.score_feat4 = nn.Conv1d(l4, n_class, 1) - self.upscore = nn.ConvTranspose1d(n_class, n_class, 16, stride=8, - bias=False) - self.upscore_4 = nn.ConvTranspose1d(n_class, n_class, 4, stride=2, - bias=False) - self.upscore_5 = nn.ConvTranspose1d(n_class, n_class, 4, stride=2, - bias=False) - - def forward(self, x): - """ - Input shape (Batch, channels, Sequence Length) - """ - feat3 = self.features_123(x) #1/8 - feat4 = self.features_4(feat3) #1/16 - feat5 = self.features_5(feat4) #1/32 - - score5 = self.classifier(feat5) - upscore5 = self.upscore_5(score5) - score4 = self.score_feat4(feat4) - score4 = score4[:, :, 5:5+upscore5.size()[2]].contiguous() #, 5:5+upscore5.size()[3] - score4 += upscore5 - - score3 = self.score_feat3(feat3) - upscore4 = self.upscore_4(score4) - score3 = score3[:, :, 9:9+upscore4.size()[2]].contiguous() #, 9:9+upscore4.size()[3] - score3 += upscore4 - h = self.upscore(score3) - h = h[:, :, 28:28+x.size()[2]].contiguous() #, 28:28+x.size()[3] - - return h - - -if __name__ == "__main__": - - """ simple test """ - n_class = 4 - n_samples = 16 - n_channels=3 - seq_len = 5000 - - model = FCN1D(n_class=n_class, in_channels = n_channels, kernel_size=3, inplanes = 64) - - test_x = torch.randn((n_samples, n_channels, seq_len)) - - test_y = model(test_x) - - print(test_y.shape) - - assert test_y.shape[0] == n_samples - - assert test_y.shape[1] == n_class - - assert test_y.shape[-1] == seq_len From ed92cb66101043c81706beb1c107a5ebdbc4f60f Mon Sep 17 00:00:00 2001 From: Akhilesh Khambekar <77014592+akhileshkh@users.noreply.github.com> Date: Sun, 7 Dec 2025 22:17:11 -0500 Subject: [PATCH 3/4] Create fcn1d.py Add the Fully Convolutional Network(FCN) for semantic segmentation of 1D sequential data. --- pyhealth/models/fcn1d.py | 147 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 pyhealth/models/fcn1d.py diff --git a/pyhealth/models/fcn1d.py b/pyhealth/models/fcn1d.py new file mode 100644 index 000000000..20abec327 --- /dev/null +++ b/pyhealth/models/fcn1d.py @@ -0,0 +1,147 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class FCN1D(nn.Module): + + def __init__(self, n_class=4, in_channels = 1, kernel_size=3, inplanes = 64): + super(FCN1D, self).__init__() + + """ + Implement a Fully Convolutional Network for semantic segmentation + as defined in the paper https://arxiv.org/pdf/1411.4038 + + Args: + n_classes: The number of possible classes for the output prediction + + in_channels: The number of channels in the input. For instance, + a 12-lead ecg would have 12 + + kernel_size: The size of the kernel to be used in all convolutional + layers except for the final one + + inplanes: The number of output channels for the first + convolutional block. Subsequent blocks will use + inplanes *2, inplanes*4, inplanes*8 respectively + """ + + conv_padding = int(kernel_size / 2) + l2 = 2* inplanes + l3 = 4* inplanes + self.features_123 = nn.Sequential( + # conv1 + + nn.Conv1d(in_channels, inplanes, kernel_size, padding=100), + nn.ReLU(inplace=True), + nn.Conv1d(inplanes, inplanes, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.MaxPool1d(2, stride=2, ceil_mode=True), # 1/2 + + # conv2 + + nn.Conv1d(inplanes, l2, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.Conv1d(l2, l2, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.MaxPool1d(2, stride=2, ceil_mode=True), # 1/4 + + # conv3 + + nn.Conv1d(l2, l3, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.Conv1d(l3, l3, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.Conv1d(l3, l3, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.MaxPool1d(2, stride=2, ceil_mode=True), # 1/8 + ) + l4 = 8*inplanes + self.features_4 = nn.Sequential( + # conv4 + + nn.Conv1d(l3, l4, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.Conv1d(l4, l4, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.Conv1d(l4, l4, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.MaxPool1d(2, stride=2, ceil_mode=True), # 1/16 + ) + self.features_5 = nn.Sequential( + # conv5 features + nn.Conv1d(l4, l4, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.Conv1d(l4, l4, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.Conv1d(l4, l4, kernel_size, padding=conv_padding), + nn.ReLU(inplace=True), + nn.MaxPool1d(2, stride=2, ceil_mode=True), # 1/32 + ) + self.classifier = nn.Sequential( + # fc6 + nn.Conv1d(l4, 4096, 7), + nn.ReLU(inplace=True), + nn.Dropout1d(), + + # fc7 + nn.Conv1d(4096, 4096, 1), + nn.ReLU(inplace=True), + nn.Dropout1d(), + + # score_fr + nn.Conv1d(4096, n_class, 1), + ) + self.score_feat3 = nn.Conv1d(l3, n_class, 1) + self.score_feat4 = nn.Conv1d(l4, n_class, 1) + self.upscore = nn.ConvTranspose1d(n_class, n_class, 16, stride=8, + bias=False) + self.upscore_4 = nn.ConvTranspose1d(n_class, n_class, 4, stride=2, + bias=False) + self.upscore_5 = nn.ConvTranspose1d(n_class, n_class, 4, stride=2, + bias=False) + + def forward(self, x): + """ + Input shape (Batch, channels, Sequence Length) + """ + feat3 = self.features_123(x) #1/8 + feat4 = self.features_4(feat3) #1/16 + feat5 = self.features_5(feat4) #1/32 + + score5 = self.classifier(feat5) + upscore5 = self.upscore_5(score5) + score4 = self.score_feat4(feat4) + score4 = score4[:, :, 5:5+upscore5.size()[2]].contiguous() #, 5:5+upscore5.size()[3] + score4 += upscore5 + + score3 = self.score_feat3(feat3) + upscore4 = self.upscore_4(score4) + score3 = score3[:, :, 9:9+upscore4.size()[2]].contiguous() #, 9:9+upscore4.size()[3] + score3 += upscore4 + h = self.upscore(score3) + h = h[:, :, 28:28+x.size()[2]].contiguous() #, 28:28+x.size()[3] + + return h + + +if __name__ == "__main__": + + """ simple test """ + n_class = 4 + n_samples = 16 + n_channels=3 + seq_len = 5000 + + model = FCN1D(n_class=n_class, in_channels = n_channels, kernel_size=3, inplanes = 64) + + test_x = torch.randn((n_samples, n_channels, seq_len)) + + test_y = model(test_x) + + print(test_y.shape) + + assert test_y.shape[0] == n_samples + + assert test_y.shape[1] == n_class + + assert test_y.shape[-1] == seq_len From 350144c634bd6ec480f6c43bf545957b60ccdc80 Mon Sep 17 00:00:00 2001 From: Akhilesh Khambekar <77014592+akhileshkh@users.noreply.github.com> Date: Sun, 7 Dec 2025 22:18:51 -0500 Subject: [PATCH 4/4] Create unet1d.py Adding a U-Net model for semantic segmentation of 1D Signal Data. --- pyhealth/models/unet1d.py | 417 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 417 insertions(+) create mode 100644 pyhealth/models/unet1d.py diff --git a/pyhealth/models/unet1d.py b/pyhealth/models/unet1d.py new file mode 100644 index 000000000..d1aa56e8d --- /dev/null +++ b/pyhealth/models/unet1d.py @@ -0,0 +1,417 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init + + +def weights_init_normal(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('Linear') != -1: + init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + + +def weights_init_xavier(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + init.xavier_normal_(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.xavier_normal_(m.weight.data, gain=1) + elif classname.find('BatchNorm') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + + +def weights_init_kaiming(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif classname.find('Linear') != -1: + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif classname.find('BatchNorm') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + + +def weights_init_orthogonal(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + init.orthogonal_(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.orthogonal_(m.weight.data, gain=1) + elif classname.find('BatchNorm') != -1: + init.normal_(m.weight.data, 1.0, 0.02) + init.constant_(m.bias.data, 0.0) + + +def init_weights(net, init_type='normal'): + if init_type == 'normal': + net.apply(weights_init_normal) + elif init_type == 'xavier': + net.apply(weights_init_xavier) + elif init_type == 'kaiming': + net.apply(weights_init_kaiming) + elif init_type == 'orthogonal': + net.apply(weights_init_orthogonal) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + + +class unetConv2(nn.Module): + """ + Encoder convolution block with configurable parameters. + + Args: + in_size: Input channels + out_size: Output channels + is_batchnorm: Whether to use BatchNorm + n: Number of Conv-BN-ReLU layers + ks: Kernel size + stride: Convolution stride + padding: Convolution padding (auto-calculated if None) + """ + def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=None): + super(unetConv2, self).__init__() + self.n = n + self.ks = ks + self.stride = stride + # Auto-calculate padding to maintain spatial dimensions: padding = (ks - 1) // 2 + self.padding = padding if padding is not None else (ks - 1) // 2 + s = stride + p = self.padding + + if is_batchnorm: + for i in range(1, n + 1): + conv = nn.Sequential( + nn.Conv1d(in_size, out_size, ks, s, p), + nn.BatchNorm1d(out_size), + nn.ReLU(inplace=True), + ) + setattr(self, 'conv%d' % i, conv) + in_size = out_size + else: + for i in range(1, n + 1): + conv = nn.Sequential( + nn.Conv1d(in_size, out_size, ks, s, p), + nn.ReLU(inplace=True), + ) + setattr(self, 'conv%d' % i, conv) + in_size = out_size + + # Initialize weights + for m in self.children(): + init_weights(m, init_type='kaiming') + + def forward(self, inputs): + x = inputs + for i in range(1, self.n + 1): + conv = getattr(self, 'conv%d' % i) + x = conv(x) + return x + +class UNet1D(nn.Module): + """ + UNet 3+ with configurable parameters for 1D Signal data + + Args: + in_channels: Number of input channels + n_classes: Number of output classe + inplanes: Base channel count. Channels at depth i = inplanes * 2^i + + kernel_size: Convolution kernel size + num_encoder_layers: Number of Conv-BN-ReLU blocks per encoder stage. + + encoder_batchnorm: Use batch normalization (0=False, 1=True) + interpolate_mode: Interpolation mode for upsampling ('linear' or 'nearest') + + """ + + def __init__(self, seq_len, in_channels=1, n_classes=4, inplanes=64, kernel_size=3, + num_encoder_layers=2, encoder_batchnorm=True, interpolate_mode='linear'): + super(UNet1D, self).__init__() + + assert seq_len % 32 == 0, "Sequence len must be a multiple of 32, pad if necessary" + + self.in_channels = in_channels + self.inplanes = inplanes + self.kernel_size = kernel_size + self.num_encoder_layers = num_encoder_layers + self.encoder_batchnorm = encoder_batchnorm + self.interpolate_mode = interpolate_mode + + # Calculate filters from inplanes: filters[i] = inplanes * 2^i + # For inplanes=64: [64, 128, 256, 512, 1024] + # For inplanes=32: [32, 64, 128, 256, 512] + # For inplanes=128: [128, 256, 512, 1024, 2048] + filters = [inplanes * (2 ** i) for i in range(5)] + self.filters = filters + + ## -------------Encoder-------------- + self.conv1 = unetConv2(self.in_channels, filters[0], encoder_batchnorm, + n=num_encoder_layers, ks=kernel_size) + self.maxpool1 = nn.MaxPool1d(kernel_size=2) + + self.conv2 = unetConv2(filters[0], filters[1], encoder_batchnorm, + n=num_encoder_layers, ks=kernel_size) + self.maxpool2 = nn.MaxPool1d(kernel_size=2) + + self.conv3 = unetConv2(filters[1], filters[2], encoder_batchnorm, + n=num_encoder_layers, ks=kernel_size) + self.maxpool3 = nn.MaxPool1d(kernel_size=2) + + self.conv4 = unetConv2(filters[2], filters[3], encoder_batchnorm, + n=num_encoder_layers, ks=kernel_size) + self.maxpool4 = nn.MaxPool1d(kernel_size=2) + + self.conv5 = unetConv2(filters[3], filters[4], encoder_batchnorm, + n=num_encoder_layers, ks=kernel_size) + + ## -------------Decoder-------------- + self.CatChannels = filters[0] + self.CatBlocks = 5 + self.UpChannels = self.CatChannels * self.CatBlocks + + '''stage 4d''' + # h1->320*320, hd4->40*40, Pooling 8 times + self.h1_PT_hd4 = nn.MaxPool1d(8, 8, ceil_mode=True) + self.h1_PT_hd4_conv = nn.Conv1d(filters[0], self.CatChannels, 3, padding=1) + self.h1_PT_hd4_bn = nn.BatchNorm1d(self.CatChannels) + self.h1_PT_hd4_relu = nn.ReLU(inplace=True) + + # h2->160*160, hd4->40*40, Pooling 4 times + self.h2_PT_hd4 = nn.MaxPool1d(4, 4, ceil_mode=True) + self.h2_PT_hd4_conv = nn.Conv1d(filters[1], self.CatChannels, 3, padding=1) + self.h2_PT_hd4_bn = nn.BatchNorm1d(self.CatChannels) + self.h2_PT_hd4_relu = nn.ReLU(inplace=True) + + # h3->80*80, hd4->40*40, Pooling 2 times + self.h3_PT_hd4 = nn.MaxPool1d(2, 2, ceil_mode=True) + self.h3_PT_hd4_conv = nn.Conv1d(filters[2], self.CatChannels, 3, padding=1) + self.h3_PT_hd4_bn = nn.BatchNorm1d(self.CatChannels) + self.h3_PT_hd4_relu = nn.ReLU(inplace=True) + + # h4->40*40, hd4->40*40, Concatenation + self.h4_Cat_hd4_conv = nn.Conv1d(filters[3], self.CatChannels, 3, padding=1) + self.h4_Cat_hd4_bn = nn.BatchNorm1d(self.CatChannels) + self.h4_Cat_hd4_relu = nn.ReLU(inplace=True) + + # hd5->20*20, hd4->40*40, Upsample 2 times + self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode=self.interpolate_mode) + self.hd5_UT_hd4_conv = nn.Conv1d(filters[4], self.CatChannels, 3, padding=1) + self.hd5_UT_hd4_bn = nn.BatchNorm1d(self.CatChannels) + self.hd5_UT_hd4_relu = nn.ReLU(inplace=True) + + # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4) + self.conv4d_1 = nn.Conv1d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 + self.bn4d_1 = nn.BatchNorm1d(self.UpChannels) + self.relu4d_1 = nn.ReLU(inplace=True) + + '''stage 3d''' + # h1->320*320, hd3->80*80, Pooling 4 times + self.h1_PT_hd3 = nn.MaxPool1d(4, 4, ceil_mode=True) + self.h1_PT_hd3_conv = nn.Conv1d(filters[0], self.CatChannels, 3, padding=1) + self.h1_PT_hd3_bn = nn.BatchNorm1d(self.CatChannels) + self.h1_PT_hd3_relu = nn.ReLU(inplace=True) + + # h2->160*160, hd3->80*80, Pooling 2 times + self.h2_PT_hd3 = nn.MaxPool1d(2, 2, ceil_mode=True) + self.h2_PT_hd3_conv = nn.Conv1d(filters[1], self.CatChannels, 3, padding=1) + self.h2_PT_hd3_bn = nn.BatchNorm1d(self.CatChannels) + self.h2_PT_hd3_relu = nn.ReLU(inplace=True) + + # h3->80*80, hd3->80*80, Concatenation + self.h3_Cat_hd3_conv = nn.Conv1d(filters[2], self.CatChannels, 3, padding=1) + self.h3_Cat_hd3_bn = nn.BatchNorm1d(self.CatChannels) + self.h3_Cat_hd3_relu = nn.ReLU(inplace=True) + + # hd4->40*40, hd4->80*80, Upsample 2 times + self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode=self.interpolate_mode) + self.hd4_UT_hd3_conv = nn.Conv1d(self.UpChannels, self.CatChannels, 3, padding=1) + self.hd4_UT_hd3_bn = nn.BatchNorm1d(self.CatChannels) + self.hd4_UT_hd3_relu = nn.ReLU(inplace=True) + + # hd5->20*20, hd4->80*80, Upsample 4 times + self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode=self.interpolate_mode) + self.hd5_UT_hd3_conv = nn.Conv1d(filters[4], self.CatChannels, 3, padding=1) + self.hd5_UT_hd3_bn = nn.BatchNorm1d(self.CatChannels) + self.hd5_UT_hd3_relu = nn.ReLU(inplace=True) + + # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3) + self.conv3d_1 = nn.Conv1d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 + self.bn3d_1 = nn.BatchNorm1d(self.UpChannels) + self.relu3d_1 = nn.ReLU(inplace=True) + + '''stage 2d ''' + # h1->320*320, hd2->160*160, Pooling 2 times + self.h1_PT_hd2 = nn.MaxPool1d(2, 2, ceil_mode=True) + self.h1_PT_hd2_conv = nn.Conv1d(filters[0], self.CatChannels, 3, padding=1) + self.h1_PT_hd2_bn = nn.BatchNorm1d(self.CatChannels) + self.h1_PT_hd2_relu = nn.ReLU(inplace=True) + + # h2->160*160, hd2->160*160, Concatenation + self.h2_Cat_hd2_conv = nn.Conv1d(filters[1], self.CatChannels, 3, padding=1) + self.h2_Cat_hd2_bn = nn.BatchNorm1d(self.CatChannels) + self.h2_Cat_hd2_relu = nn.ReLU(inplace=True) + + # hd3->80*80, hd2->160*160, Upsample 2 times + self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode=self.interpolate_mode) + self.hd3_UT_hd2_conv = nn.Conv1d(self.UpChannels, self.CatChannels, 3, padding=1) + self.hd3_UT_hd2_bn = nn.BatchNorm1d(self.CatChannels) + self.hd3_UT_hd2_relu = nn.ReLU(inplace=True) + + # hd4->40*40, hd2->160*160, Upsample 4 times + self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode=self.interpolate_mode) + self.hd4_UT_hd2_conv = nn.Conv1d(self.UpChannels, self.CatChannels, 3, padding=1) + self.hd4_UT_hd2_bn = nn.BatchNorm1d(self.CatChannels) + self.hd4_UT_hd2_relu = nn.ReLU(inplace=True) + + # hd5->20*20, hd2->160*160, Upsample 8 times + self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode=self.interpolate_mode) + self.hd5_UT_hd2_conv = nn.Conv1d(filters[4], self.CatChannels, 3, padding=1) + self.hd5_UT_hd2_bn = nn.BatchNorm1d(self.CatChannels) + self.hd5_UT_hd2_relu = nn.ReLU(inplace=True) + + # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2) + self.conv2d_1 = nn.Conv1d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 + self.bn2d_1 = nn.BatchNorm1d(self.UpChannels) + self.relu2d_1 = nn.ReLU(inplace=True) + + '''stage 1d''' + # h1->320*320, hd1->320*320, Concatenation + self.h1_Cat_hd1_conv = nn.Conv1d(filters[0], self.CatChannels, 3, padding=1) + self.h1_Cat_hd1_bn = nn.BatchNorm1d(self.CatChannels) + self.h1_Cat_hd1_relu = nn.ReLU(inplace=True) + + # hd2->160*160, hd1->320*320, Upsample 2 times + self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode=self.interpolate_mode) + self.hd2_UT_hd1_conv = nn.Conv1d(self.UpChannels, self.CatChannels, 3, padding=1) + self.hd2_UT_hd1_bn = nn.BatchNorm1d(self.CatChannels) + self.hd2_UT_hd1_relu = nn.ReLU(inplace=True) + + # hd3->80*80, hd1->320*320, Upsample 4 times + self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode=self.interpolate_mode) + self.hd3_UT_hd1_conv = nn.Conv1d(self.UpChannels, self.CatChannels, 3, padding=1) + self.hd3_UT_hd1_bn = nn.BatchNorm1d(self.CatChannels) + self.hd3_UT_hd1_relu = nn.ReLU(inplace=True) + + # hd4->40*40, hd1->320*320, Upsample 8 times + self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode=self.interpolate_mode) + self.hd4_UT_hd1_conv = nn.Conv1d(self.UpChannels, self.CatChannels, 3, padding=1) + self.hd4_UT_hd1_bn = nn.BatchNorm1d(self.CatChannels) + self.hd4_UT_hd1_relu = nn.ReLU(inplace=True) + + # hd5->20*20, hd1->320*320, Upsample 16 times + self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode=self.interpolate_mode) + self.hd5_UT_hd1_conv = nn.Conv1d(filters[4], self.CatChannels, 3, padding=1) + self.hd5_UT_hd1_bn = nn.BatchNorm1d(self.CatChannels) + self.hd5_UT_hd1_relu = nn.ReLU(inplace=True) + + # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1) + self.conv1d_1 = nn.Conv1d(self.UpChannels, self.UpChannels, 3, padding=1) # 16 + self.bn1d_1 = nn.BatchNorm1d(self.UpChannels) + self.relu1d_1 = nn.ReLU(inplace=True) + + # output + self.outconv1 = nn.Conv1d(self.UpChannels, n_classes, 3, padding=1) + + # initialise weights + for m in self.modules(): + if isinstance(m, nn.Conv1d): + init_weights(m, init_type='kaiming') + elif isinstance(m, nn.BatchNorm1d): + init_weights(m, init_type='kaiming') + + def forward(self, inputs): + """ + Inputs should be of shape (Batch, Channels, Sequence) + """ + assert inputs.shape[-1] % 32 == 0, "Sequence dimension must be a multiple of 32, pad if necessary" + + ## -------------Encoder------------- + h1 = self.conv1(inputs) # h1->320*320*64 + + + h2 = self.maxpool1(h1) + h2 = self.conv2(h2) # h2->160*160*128 + + + h3 = self.maxpool2(h2) + h3 = self.conv3(h3) # h3->80*80*256 + + + h4 = self.maxpool3(h3) + h4 = self.conv4(h4) # h4->40*40*512 + + + h5 = self.maxpool4(h4) + hd5 = self.conv5(h5) # h5->20*20*1024 + + + + ## -------------Decoder------------- + h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1)))) + + h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2)))) + + h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3)))) + + h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4))) + + hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5)))) + + hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1( + torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels + + h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1)))) + h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2)))) + h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3))) + hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4)))) + hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5)))) + hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1( + torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->80*80*UpChannels + + h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1)))) + h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2))) + hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3)))) + hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4)))) + hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5)))) + hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1( + torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1)))) # hd2->160*160*UpChannels + + h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1))) + hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2)))) + hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3)))) + hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4)))) + hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5)))) + hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1( + torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1)))) # hd1->320*320*UpChannels + + d1 = self.outconv1(hd1) + return d1 + +if __name__ == "__main__": + + """ simple test """ + n_class = 4 + n_samples = 16 + n_channels=3 + seq_len = 5024 + + model = UNet1D(seq_len=seq_len, in_channels=n_channels, n_classes=n_class, inplanes=64, kernel_size=3, + num_encoder_layers=2, encoder_batchnorm=True, interpolate_mode='linear') + + test_x = torch.randn((n_samples, n_channels, seq_len)) + + test_y = model(test_x) + + print(test_y.shape) + + assert test_y.shape[0] == n_samples + + assert test_y.shape[1] == n_class + + assert test_y.shape[-1] == seq_len