Skip to content
Snippets Groups Projects
Commit 44df1610 authored by Yvan Pannatier's avatar Yvan Pannatier
Browse files

[models.cnn3d] fix qa.

parent c80bc9f0
No related branches found
No related tags found
1 merge request!513d cnn visceral
Pipeline #89125 passed
......@@ -69,18 +69,30 @@ class Conv3DNet(Model):
self.model_transforms = []
# First convolution block
self.conv3d_1_1 = nn.Conv3d(in_channels=1, out_channels=4, kernel_size=3, stride=1, padding=1)
self.conv3d_1_2 = nn.Conv3d(in_channels=4, out_channels=16, kernel_size=3, stride=1, padding=1)
self.conv3d_1_3 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=1,stride=1)
# First convolution block
self.conv3d_1_1 = nn.Conv3d(
in_channels=1, out_channels=4, kernel_size=3, stride=1, padding=1
)
self.conv3d_1_2 = nn.Conv3d(
in_channels=4, out_channels=16, kernel_size=3, stride=1, padding=1
)
self.conv3d_1_3 = nn.Conv3d(
in_channels=1, out_channels=16, kernel_size=1, stride=1
)
self.batch_norm_1_1 = nn.BatchNorm3d(4)
self.batch_norm_1_2 = nn.BatchNorm3d(16)
self.batch_norm_1_3 = nn.BatchNorm3d(16)
# Second convolution block
self.conv3d_2_1 = nn.Conv3d(in_channels=16, out_channels=24, kernel_size=3, stride=1, padding=1)
self.conv3d_2_2 = nn.Conv3d(in_channels=24, out_channels=32, kernel_size=3, stride=1, padding=1)
self.conv3d_2_3 = nn.Conv3d(in_channels=16, out_channels=32, kernel_size=1, stride=1)
self.conv3d_2_1 = nn.Conv3d(
in_channels=16, out_channels=24, kernel_size=3, stride=1, padding=1
)
self.conv3d_2_2 = nn.Conv3d(
in_channels=24, out_channels=32, kernel_size=3, stride=1, padding=1
)
self.conv3d_2_3 = nn.Conv3d(
in_channels=16, out_channels=32, kernel_size=1, stride=1
)
self.batch_norm_2_1 = nn.BatchNorm3d(24)
self.batch_norm_2_2 = nn.BatchNorm3d(32)
self.batch_norm_2_3 = nn.BatchNorm3d(32)
......@@ -116,7 +128,7 @@ class Conv3DNet(Model):
self.pool = nn.MaxPool3d(2)
self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.dropout = nn.Dropout(0.3)
self.fc1 = nn.Linear(64,32)
self.fc1 = nn.Linear(64, 32)
self.fc2 = nn.Linear(32, num_classes)
def forward(self, x):
......@@ -137,7 +149,7 @@ class Conv3DNet(Model):
x = (x + F.relu(self.batch_norm_2_3(self.conv3d_2_3(_x)))) / 2
x = self.pool(x)
# Third convolution block
# Third convolution block
_x = x
x = F.relu(self.batch_norm_3_1(self.conv3d_3_1(x)))
......@@ -152,7 +164,6 @@ class Conv3DNet(Model):
x = F.relu(self.batch_norm_4_2(self.conv3d_4_2(x)))
x = (x + F.relu(self.batch_norm_4_3(self.conv3d_4_3(_x)))) / 2
x = self.global_pool(x)
x = x.view(x.size(0), x.size(1))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment