class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels,
kernel_size=3, stride=1, padding=1):
super(ResidualBlock, self).__init__()
self.layer_one = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels)
)
self.relu = nn.ReLU()
self.layer_two = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
skip_output = x
x = self.layer_one(x)
x = self.relu(x)
x = self.layer_two(x)
x += skip_output
return self.relu(x)
residualBlock = ResidualBlock(64, 64, 3)
out = residualBlock(in_feature_map)
print(out.shape)
# torch.Size([1, 64, 56, 56])