From c80bc9f0dcf21fd934c2b4403f4281e2019920a7 Mon Sep 17 00:00:00 2001
From: Yvan Pannatier <ypannatier@idiap.ch>
Date: Fri, 28 Jun 2024 14:56:57 +0200
Subject: [PATCH] [models] add skip connections to cnn3d

---
 src/mednet/models/cnn3d.py | 90 +++++++++++++++++++++++++++++++-------
 1 file changed, 73 insertions(+), 17 deletions(-)

diff --git a/src/mednet/models/cnn3d.py b/src/mednet/models/cnn3d.py
index b27c073c..8318714f 100644
--- a/src/mednet/models/cnn3d.py
+++ b/src/mednet/models/cnn3d.py
@@ -69,37 +69,93 @@ class Conv3DNet(Model):
 
         self.model_transforms = []
 
-        # First convolution block
-        self.conv3d_1 = nn.Conv3d(1, 32, kernel_size=3, padding=1)
-        self.batchnorm_1 = nn.BatchNorm3d(32)
+         # First convolution block
+        self.conv3d_1_1 = nn.Conv3d(in_channels=1, out_channels=4, kernel_size=3, stride=1, padding=1)
+        self.conv3d_1_2 = nn.Conv3d(in_channels=4, out_channels=16, kernel_size=3, stride=1, padding=1)
+        self.conv3d_1_3 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=1,stride=1)
+        self.batch_norm_1_1 = nn.BatchNorm3d(4)
+        self.batch_norm_1_2 = nn.BatchNorm3d(16)
+        self.batch_norm_1_3 = nn.BatchNorm3d(16)
+
         # Second convolution block
-        self.conv3d_2 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
-        self.batchnorm_2 = nn.BatchNorm3d(64)
+        self.conv3d_2_1 = nn.Conv3d(in_channels=16, out_channels=24, kernel_size=3, stride=1, padding=1)
+        self.conv3d_2_2 = nn.Conv3d(in_channels=24, out_channels=32, kernel_size=3, stride=1, padding=1)
+        self.conv3d_2_3 = nn.Conv3d(in_channels=16, out_channels=32, kernel_size=1, stride=1)
+        self.batch_norm_2_1 = nn.BatchNorm3d(24)
+        self.batch_norm_2_2 = nn.BatchNorm3d(32)
+        self.batch_norm_2_3 = nn.BatchNorm3d(32)
+
         # Third convolution block
-        self.conv3d_3 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
-        self.batchnorm_3 = nn.BatchNorm3d(128)
+        self.conv3d_3_1 = nn.Conv3d(
+            in_channels=32, out_channels=40, kernel_size=3, stride=1, padding=1
+        )
+        self.conv3d_3_2 = nn.Conv3d(
+            in_channels=40, out_channels=48, kernel_size=3, stride=1, padding=1
+        )
+        self.conv3d_3_3 = nn.Conv3d(
+            in_channels=32, out_channels=48, kernel_size=1, stride=1
+        )
+        self.batch_norm_3_1 = nn.BatchNorm3d(40)
+        self.batch_norm_3_2 = nn.BatchNorm3d(48)
+        self.batch_norm_3_3 = nn.BatchNorm3d(48)
+
         # Fourth convolution block
-        self.conv3d_4 = nn.Conv3d(128, 256, kernel_size=3, padding=1)
-        self.batchnorm_4 = nn.BatchNorm3d(256)
+        self.conv3d_4_1 = nn.Conv3d(
+            in_channels=48, out_channels=56, kernel_size=3, stride=1, padding=1
+        )
+        self.conv3d_4_2 = nn.Conv3d(
+            in_channels=56, out_channels=64, kernel_size=3, stride=1, padding=1
+        )
+        self.conv3d_4_3 = nn.Conv3d(
+            in_channels=48, out_channels=64, kernel_size=1, stride=1
+        )
+        self.batch_norm_4_1 = nn.BatchNorm3d(56)
+        self.batch_norm_4_2 = nn.BatchNorm3d(64)
+        self.batch_norm_4_3 = nn.BatchNorm3d(64)
 
         self.pool = nn.MaxPool3d(2)
         self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
         self.dropout = nn.Dropout(0.3)
-        self.fc1 = nn.Linear(256, 64)
-        self.fc2 = nn.Linear(64, num_classes)
+        self.fc1 = nn.Linear(64,32)
+        self.fc2 = nn.Linear(32, num_classes)
 
     def forward(self, x):
-        # x = self.normalizer(x)  # type: ignore
+        x = self.normalizer(x)  # type: ignore
 
-        x = F.relu(self.batchnorm_1(self.conv3d_1(x)))
+        # First convolution block
+        _x = x
+        x = F.relu(self.batch_norm_1_1(self.conv3d_1_1(x)))
+        x = F.relu(self.batch_norm_1_2(self.conv3d_1_2(x)))
+        x = (x + F.relu(self.batch_norm_1_3(self.conv3d_1_3(_x)))) / 2
         x = self.pool(x)
-        x = F.relu(self.batchnorm_2(self.conv3d_2(x)))
+
+        # Second convolution block
+
+        _x = x
+        x = F.relu(self.batch_norm_2_1(self.conv3d_2_1(x)))
+        x = F.relu(self.batch_norm_2_2(self.conv3d_2_2(x)))
+        x = (x + F.relu(self.batch_norm_2_3(self.conv3d_2_3(_x)))) / 2
         x = self.pool(x)
-        x = F.relu(self.batchnorm_3(self.conv3d_3(x)))
+
+        # Third convolution block 
+
+        _x = x
+        x = F.relu(self.batch_norm_3_1(self.conv3d_3_1(x)))
+        x = F.relu(self.batch_norm_3_2(self.conv3d_3_2(x)))
+        x = (x + F.relu(self.batch_norm_3_3(self.conv3d_3_3(_x)))) / 2
         x = self.pool(x)
-        x = F.relu(self.batchnorm_4(self.conv3d_4(x)))
+
+        # Fourth convolution block
+
+        _x = x
+        x = F.relu(self.batch_norm_4_1(self.conv3d_4_1(x)))
+        x = F.relu(self.batch_norm_4_2(self.conv3d_4_2(x)))
+        x = (x + F.relu(self.batch_norm_4_3(self.conv3d_4_3(_x)))) / 2
+
+
         x = self.global_pool(x)
-        x = x.view(x.size(0), -1)
+        x = x.view(x.size(0), x.size(1))
+
         x = F.relu(self.fc1(x))
         x = self.dropout(x)
         return self.fc2(x)
-- 
GitLab