Skip to content
Snippets Groups Projects
Commit c80bc9f0 authored by Yvan Pannatier's avatar Yvan Pannatier
Browse files

[models] add skip connections to cnn3d

parent 4eb31eb5
No related branches found
Tags v2.0.2
1 merge request!513d cnn visceral
......@@ -69,37 +69,93 @@ class Conv3DNet(Model):
self.model_transforms = []
# First convolution block
self.conv3d_1 = nn.Conv3d(1, 32, kernel_size=3, padding=1)
self.batchnorm_1 = nn.BatchNorm3d(32)
# First convolution block
self.conv3d_1_1 = nn.Conv3d(in_channels=1, out_channels=4, kernel_size=3, stride=1, padding=1)
self.conv3d_1_2 = nn.Conv3d(in_channels=4, out_channels=16, kernel_size=3, stride=1, padding=1)
self.conv3d_1_3 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=1,stride=1)
self.batch_norm_1_1 = nn.BatchNorm3d(4)
self.batch_norm_1_2 = nn.BatchNorm3d(16)
self.batch_norm_1_3 = nn.BatchNorm3d(16)
# Second convolution block
self.conv3d_2 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
self.batchnorm_2 = nn.BatchNorm3d(64)
self.conv3d_2_1 = nn.Conv3d(in_channels=16, out_channels=24, kernel_size=3, stride=1, padding=1)
self.conv3d_2_2 = nn.Conv3d(in_channels=24, out_channels=32, kernel_size=3, stride=1, padding=1)
self.conv3d_2_3 = nn.Conv3d(in_channels=16, out_channels=32, kernel_size=1, stride=1)
self.batch_norm_2_1 = nn.BatchNorm3d(24)
self.batch_norm_2_2 = nn.BatchNorm3d(32)
self.batch_norm_2_3 = nn.BatchNorm3d(32)
# Third convolution block
self.conv3d_3 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
self.batchnorm_3 = nn.BatchNorm3d(128)
self.conv3d_3_1 = nn.Conv3d(
in_channels=32, out_channels=40, kernel_size=3, stride=1, padding=1
)
self.conv3d_3_2 = nn.Conv3d(
in_channels=40, out_channels=48, kernel_size=3, stride=1, padding=1
)
self.conv3d_3_3 = nn.Conv3d(
in_channels=32, out_channels=48, kernel_size=1, stride=1
)
self.batch_norm_3_1 = nn.BatchNorm3d(40)
self.batch_norm_3_2 = nn.BatchNorm3d(48)
self.batch_norm_3_3 = nn.BatchNorm3d(48)
# Fourth convolution block
self.conv3d_4 = nn.Conv3d(128, 256, kernel_size=3, padding=1)
self.batchnorm_4 = nn.BatchNorm3d(256)
self.conv3d_4_1 = nn.Conv3d(
in_channels=48, out_channels=56, kernel_size=3, stride=1, padding=1
)
self.conv3d_4_2 = nn.Conv3d(
in_channels=56, out_channels=64, kernel_size=3, stride=1, padding=1
)
self.conv3d_4_3 = nn.Conv3d(
in_channels=48, out_channels=64, kernel_size=1, stride=1
)
self.batch_norm_4_1 = nn.BatchNorm3d(56)
self.batch_norm_4_2 = nn.BatchNorm3d(64)
self.batch_norm_4_3 = nn.BatchNorm3d(64)
self.pool = nn.MaxPool3d(2)
self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.dropout = nn.Dropout(0.3)
self.fc1 = nn.Linear(256, 64)
self.fc2 = nn.Linear(64, num_classes)
self.fc1 = nn.Linear(64,32)
self.fc2 = nn.Linear(32, num_classes)
def forward(self, x):
# x = self.normalizer(x) # type: ignore
x = self.normalizer(x) # type: ignore
x = F.relu(self.batchnorm_1(self.conv3d_1(x)))
# First convolution block
_x = x
x = F.relu(self.batch_norm_1_1(self.conv3d_1_1(x)))
x = F.relu(self.batch_norm_1_2(self.conv3d_1_2(x)))
x = (x + F.relu(self.batch_norm_1_3(self.conv3d_1_3(_x)))) / 2
x = self.pool(x)
x = F.relu(self.batchnorm_2(self.conv3d_2(x)))
# Second convolution block
_x = x
x = F.relu(self.batch_norm_2_1(self.conv3d_2_1(x)))
x = F.relu(self.batch_norm_2_2(self.conv3d_2_2(x)))
x = (x + F.relu(self.batch_norm_2_3(self.conv3d_2_3(_x)))) / 2
x = self.pool(x)
x = F.relu(self.batchnorm_3(self.conv3d_3(x)))
# Third convolution block
_x = x
x = F.relu(self.batch_norm_3_1(self.conv3d_3_1(x)))
x = F.relu(self.batch_norm_3_2(self.conv3d_3_2(x)))
x = (x + F.relu(self.batch_norm_3_3(self.conv3d_3_3(_x)))) / 2
x = self.pool(x)
x = F.relu(self.batchnorm_4(self.conv3d_4(x)))
# Fourth convolution block
_x = x
x = F.relu(self.batch_norm_4_1(self.conv3d_4_1(x)))
x = F.relu(self.batch_norm_4_2(self.conv3d_4_2(x)))
x = (x + F.relu(self.batch_norm_4_3(self.conv3d_4_3(_x)))) / 2
x = self.global_pool(x)
x = x.view(x.size(0), -1)
x = x.view(x.size(0), x.size(1))
x = F.relu(self.fc1(x))
x = self.dropout(x)
return self.fc2(x)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment