Source code for pypots.nn.modules.inception.layers
""" """
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause
import torch
import torch.nn as nn
[docs]
class InceptionBlockV1(nn.Module):
def __init__(
self,
in_channels,
out_channels,
num_kernels=6,
stride=1,
init_weight=True,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.num_kernels = num_kernels
self.stride = stride
kernels = []
for i in range(self.num_kernels):
kernels.append(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=2 * i + 1,
padding=i,
stride=stride,
)
)
self.kernels = nn.ModuleList(kernels)
if init_weight:
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
[docs]
def forward(self, x):
res_list = []
for i in range(self.num_kernels):
res_list.append(self.kernels[i](x))
res = torch.stack(res_list, dim=-1).mean(-1)
return res
[docs]
class InceptionTransBlockV1(nn.Module):
def __init__(
self,
in_channels,
out_channels,
stride=1,
num_kernels=6,
init_weight=True,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.num_kernels = num_kernels
self.stride = stride
kernels = []
for i in range(self.num_kernels):
kernels.append(
nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=2 * i + 1,
padding=i,
stride=stride,
)
)
self.kernels = nn.ModuleList(kernels)
if init_weight:
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
[docs]
def forward(self, x, output_size):
res_list = []
for i in range(self.num_kernels):
res_list.append(self.kernels[i](x, output_size=output_size))
res = torch.stack(res_list, dim=-1).mean(-1)
return res