diff --git a/src/mednet/models/cnn3d.py b/src/mednet/models/cnn3d.py index b27c073c03ffad2b6e2570685ff6e1c8e9dfb634..8318714f44f2f3533ab841b3f68589df8b7ccc5e 100644 --- a/src/mednet/models/cnn3d.py +++ b/src/mednet/models/cnn3d.py @@ -69,37 +69,93 @@ class Conv3DNet(Model): self.model_transforms = [] - # First convolution block - self.conv3d_1 = nn.Conv3d(1, 32, kernel_size=3, padding=1) - self.batchnorm_1 = nn.BatchNorm3d(32) + # First convolution block + self.conv3d_1_1 = nn.Conv3d(in_channels=1, out_channels=4, kernel_size=3, stride=1, padding=1) + self.conv3d_1_2 = nn.Conv3d(in_channels=4, out_channels=16, kernel_size=3, stride=1, padding=1) + self.conv3d_1_3 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=1,stride=1) + self.batch_norm_1_1 = nn.BatchNorm3d(4) + self.batch_norm_1_2 = nn.BatchNorm3d(16) + self.batch_norm_1_3 = nn.BatchNorm3d(16) + # Second convolution block - self.conv3d_2 = nn.Conv3d(32, 64, kernel_size=3, padding=1) - self.batchnorm_2 = nn.BatchNorm3d(64) + self.conv3d_2_1 = nn.Conv3d(in_channels=16, out_channels=24, kernel_size=3, stride=1, padding=1) + self.conv3d_2_2 = nn.Conv3d(in_channels=24, out_channels=32, kernel_size=3, stride=1, padding=1) + self.conv3d_2_3 = nn.Conv3d(in_channels=16, out_channels=32, kernel_size=1, stride=1) + self.batch_norm_2_1 = nn.BatchNorm3d(24) + self.batch_norm_2_2 = nn.BatchNorm3d(32) + self.batch_norm_2_3 = nn.BatchNorm3d(32) + # Third convolution block - self.conv3d_3 = nn.Conv3d(64, 128, kernel_size=3, padding=1) - self.batchnorm_3 = nn.BatchNorm3d(128) + self.conv3d_3_1 = nn.Conv3d( + in_channels=32, out_channels=40, kernel_size=3, stride=1, padding=1 + ) + self.conv3d_3_2 = nn.Conv3d( + in_channels=40, out_channels=48, kernel_size=3, stride=1, padding=1 + ) + self.conv3d_3_3 = nn.Conv3d( + in_channels=32, out_channels=48, kernel_size=1, stride=1 + ) + self.batch_norm_3_1 = nn.BatchNorm3d(40) + self.batch_norm_3_2 = nn.BatchNorm3d(48) + self.batch_norm_3_3 = nn.BatchNorm3d(48) + # Fourth convolution block - self.conv3d_4 = nn.Conv3d(128, 256, kernel_size=3, padding=1) - self.batchnorm_4 = nn.BatchNorm3d(256) + self.conv3d_4_1 = nn.Conv3d( + in_channels=48, out_channels=56, kernel_size=3, stride=1, padding=1 + ) + self.conv3d_4_2 = nn.Conv3d( + in_channels=56, out_channels=64, kernel_size=3, stride=1, padding=1 + ) + self.conv3d_4_3 = nn.Conv3d( + in_channels=48, out_channels=64, kernel_size=1, stride=1 + ) + self.batch_norm_4_1 = nn.BatchNorm3d(56) + self.batch_norm_4_2 = nn.BatchNorm3d(64) + self.batch_norm_4_3 = nn.BatchNorm3d(64) self.pool = nn.MaxPool3d(2) self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) self.dropout = nn.Dropout(0.3) - self.fc1 = nn.Linear(256, 64) - self.fc2 = nn.Linear(64, num_classes) + self.fc1 = nn.Linear(64,32) + self.fc2 = nn.Linear(32, num_classes) def forward(self, x): - # x = self.normalizer(x) # type: ignore + x = self.normalizer(x) # type: ignore - x = F.relu(self.batchnorm_1(self.conv3d_1(x))) + # First convolution block + _x = x + x = F.relu(self.batch_norm_1_1(self.conv3d_1_1(x))) + x = F.relu(self.batch_norm_1_2(self.conv3d_1_2(x))) + x = (x + F.relu(self.batch_norm_1_3(self.conv3d_1_3(_x)))) / 2 x = self.pool(x) - x = F.relu(self.batchnorm_2(self.conv3d_2(x))) + + # Second convolution block + + _x = x + x = F.relu(self.batch_norm_2_1(self.conv3d_2_1(x))) + x = F.relu(self.batch_norm_2_2(self.conv3d_2_2(x))) + x = (x + F.relu(self.batch_norm_2_3(self.conv3d_2_3(_x)))) / 2 x = self.pool(x) - x = F.relu(self.batchnorm_3(self.conv3d_3(x))) + + # Third convolution block + + _x = x + x = F.relu(self.batch_norm_3_1(self.conv3d_3_1(x))) + x = F.relu(self.batch_norm_3_2(self.conv3d_3_2(x))) + x = (x + F.relu(self.batch_norm_3_3(self.conv3d_3_3(_x)))) / 2 x = self.pool(x) - x = F.relu(self.batchnorm_4(self.conv3d_4(x))) + + # Fourth convolution block + + _x = x + x = F.relu(self.batch_norm_4_1(self.conv3d_4_1(x))) + x = F.relu(self.batch_norm_4_2(self.conv3d_4_2(x))) + x = (x + F.relu(self.batch_norm_4_3(self.conv3d_4_3(_x)))) / 2 + + x = self.global_pool(x) - x = x.view(x.size(0), -1) + x = x.view(x.size(0), x.size(1)) + x = F.relu(self.fc1(x)) x = self.dropout(x) return self.fc2(x)