conv3_x_resnet_50 = nn.Sequential(
ResidualBlock(256, [128, 128, 512], [1, 3, 1], dashed_shortcut=True),
ResidualBlock(512, [128, 128, 512], [1, 3, 1], dashed_shortcut=False),
ResidualBlock(512, [128, 128, 512], [1, 3, 1], dashed_shortcut=False),
ResidualBlock(512, [128, 128, 512], [1, 3, 1], dashed_shortcut=False),
)
in_feature_map = torch.randn(1, 256, 56, 56)
out = conv3_x_resnet_50(in_feature_map)
print(out.shape)
# torch.Size([1, 512, 28, 28])