From ffad08889edf2c1fc25584103636b856f3bb85de Mon Sep 17 00:00:00 2001
From: Yvan Pannatier <yvan.pannatier@idiap.ch>
Date: Wed, 19 Jun 2024 14:37:35 +0200
Subject: [PATCH] Fix qa.

---
 src/mednet/config/data/visceral/datamodule.py | 17 ++++---
 src/mednet/config/data/visceral/default.json  |  2 +-
 src/mednet/config/data/visceral/default.py    |  4 +-
 src/mednet/config/models/cnn3d.py             |  3 +-
 src/mednet/models/cnn3d.py                    |  9 ++--
 src/mednet/models/pasa3d.py                   | 50 ++++++++++++-------
 tests/test_visceral.py                        |  4 +-
 7 files changed, 50 insertions(+), 39 deletions(-)

diff --git a/src/mednet/config/data/visceral/datamodule.py b/src/mednet/config/data/visceral/datamodule.py
index c3142015..023962cf 100644
--- a/src/mednet/config/data/visceral/datamodule.py
+++ b/src/mednet/config/data/visceral/datamodule.py
@@ -1,18 +1,20 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
-""" VISCERAL dataset for 3D organ classification (only lungs and bladders).
+"""VISCERAL dataset for 3D organ classification (only lungs and bladders).
 
 Loaded samples are not full scans but 16x16x16 volumes of organs.
-Database reference: 
+Database reference:
 """
+
 import os
 import pathlib
+
 import torchio as tio
 
 from ....data.datamodule import CachingDataModule
-from ....data.typing import RawDataLoader as _BaseRawDataLoader
 from ....data.split import make_split
+from ....data.typing import RawDataLoader as _BaseRawDataLoader
 from ....data.typing import Sample
 from ....utils.rc import load_rc
 
@@ -20,6 +22,7 @@ CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
 """Key to search for in the configuration file for the root directory of this
 database."""
 
+
 class RawDataLoader(_BaseRawDataLoader):
     """A specialized raw-data-loader for the VISCERAL dataset."""
 
@@ -51,12 +54,11 @@ class RawDataLoader(_BaseRawDataLoader):
         """
         clamp = tio.Clamp(out_min=-1000, out_max=2000)
         rescale = tio.RescaleIntensity(percentiles=(0.5, 99.5))
-        preprocess = tio.Compose([clamp,rescale,])
+        preprocess = tio.Compose([clamp, rescale])
         image = tio.ScalarImage(self.datadir / sample[0])
         image = preprocess(image)
         tensor = image.data
-        return tensor, dict(label=sample[1], name=sample[0]) 
-
+        return tensor, dict(label=sample[1], name=sample[0])
 
     def label(self, sample: tuple[str, int]) -> int:
         """Load a single image sample label from the disk.
@@ -77,8 +79,6 @@ class RawDataLoader(_BaseRawDataLoader):
         return sample[1]
 
 
-
-
 class DataModule(CachingDataModule):
     """VISCERAL DataModule for 3D organ binary classification.
 
@@ -107,6 +107,7 @@ class DataModule(CachingDataModule):
     split_filename
         Name of the .json file containing the split to load.
     """
+
     def __init__(self, split_filename: str):
         super().__init__(
             make_split(__package__, split_filename),
diff --git a/src/mednet/config/data/visceral/default.json b/src/mednet/config/data/visceral/default.json
index f0d1c302..6bd8d5bd 100644
--- a/src/mednet/config/data/visceral/default.json
+++ b/src/mednet/config/data/visceral/default.json
@@ -350,4 +350,4 @@
         ["16/10000194_1_237_117.nii.gz",0],
         ["16/10000059_1_1302_117.nii.gz",1]
     ]
-}
\ No newline at end of file
+}
diff --git a/src/mednet/config/data/visceral/default.py b/src/mednet/config/data/visceral/default.py
index 385d00c0..46942bb7 100644
--- a/src/mednet/config/data/visceral/default.py
+++ b/src/mednet/config/data/visceral/default.py
@@ -1,9 +1,9 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
-""" VISCERAL dataset for 3D organ classification.
+"""VISCERAL dataset for 3D organ classification.
 
-Database reference: 
+Database reference:
 
 See :py:class:`mednet.config.data.visceral.datamodule.DataModule` for
 technical details.
diff --git a/src/mednet/config/models/cnn3d.py b/src/mednet/config/models/cnn3d.py
index 790570ab..305b5452 100644
--- a/src/mednet/config/models/cnn3d.py
+++ b/src/mednet/config/models/cnn3d.py
@@ -1,8 +1,7 @@
 # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
-"""Simple CNN for 3D organ classification, to be trained from scratch.
-"""
+"""Simple CNN for 3D organ classification, to be trained from scratch."""
 
 from torch.nn import BCEWithLogitsLoss
 from torch.optim import Adam
diff --git a/src/mednet/models/cnn3d.py b/src/mednet/models/cnn3d.py
index def0a916..b27c073c 100644
--- a/src/mednet/models/cnn3d.py
+++ b/src/mednet/models/cnn3d.py
@@ -88,10 +88,9 @@ class Conv3DNet(Model):
         self.fc1 = nn.Linear(256, 64)
         self.fc2 = nn.Linear(64, num_classes)
 
-
     def forward(self, x):
-        #x = self.normalizer(x)  # type: ignore
-    
+        # x = self.normalizer(x)  # type: ignore
+
         x = F.relu(self.batchnorm_1(self.conv3d_1(x)))
         x = self.pool(x)
         x = F.relu(self.batchnorm_2(self.conv3d_2(x)))
@@ -103,9 +102,7 @@ class Conv3DNet(Model):
         x = x.view(x.size(0), -1)
         x = F.relu(self.fc1(x))
         x = self.dropout(x)
-        x = self.fc2(x)
-        return x
-
+        return self.fc2(x)
 
         # x = F.log_softmax(x, dim=1) # 0 is batch size
 
diff --git a/src/mednet/models/pasa3d.py b/src/mednet/models/pasa3d.py
index 7fcb0e08..6b73d297 100644
--- a/src/mednet/models/pasa3d.py
+++ b/src/mednet/models/pasa3d.py
@@ -69,8 +69,7 @@ class Pasa3D(Model):
         self.name = "pasa3D"
         self.num_classes = num_classes
 
-        self.model_transforms = [
-        ]
+        self.model_transforms = []
 
         # First convolution block
         self.conv3d_1_1 = nn.Conv3d(in_channels=1, out_channels=4, kernel_size=3)
@@ -80,7 +79,6 @@ class Pasa3D(Model):
         self.batch_norm_1_2 = nn.BatchNorm3d(16)
         self.batch_norm_1_3 = nn.BatchNorm3d(16)
 
-
         # Second convolution block
         self.conv3d_2_1 = nn.Conv3d(in_channels=16, out_channels=24, kernel_size=3)
         self.conv3d_2_2 = nn.Conv3d(in_channels=24, out_channels=32, kernel_size=3)
@@ -89,37 +87,52 @@ class Pasa3D(Model):
         self.batch_norm_2_2 = nn.BatchNorm3d(32)
         self.batch_norm_2_3 = nn.BatchNorm3d(32)
 
-
         # Third convolution block
-        self.conv3d_3_1 = nn.Conv3d(in_channels=32, out_channels=40, kernel_size=3, stride=1, padding=1)
-        self.conv3d_3_2 = nn.Conv3d(in_channels=40, out_channels=48, kernel_size=3, stride=1, padding=1)
-        self.conv3d_3_3 = nn.Conv3d(in_channels=32, out_channels=48, kernel_size=1, stride=1)
+        self.conv3d_3_1 = nn.Conv3d(
+            in_channels=32, out_channels=40, kernel_size=3, stride=1, padding=1
+        )
+        self.conv3d_3_2 = nn.Conv3d(
+            in_channels=40, out_channels=48, kernel_size=3, stride=1, padding=1
+        )
+        self.conv3d_3_3 = nn.Conv3d(
+            in_channels=32, out_channels=48, kernel_size=1, stride=1
+        )
         self.batch_norm_3_1 = nn.BatchNorm3d(40)
         self.batch_norm_3_2 = nn.BatchNorm3d(48)
         self.batch_norm_3_3 = nn.BatchNorm3d(48)
 
         # Fourth convolution block
-        self.conv3d_4_1 = nn.Conv3d(in_channels=48, out_channels=56, kernel_size=3, stride=1, padding=1)
-        self.conv3d_4_2 = nn.Conv3d(in_channels=56, out_channels=64, kernel_size=3, stride=1, padding=1)
-        self.conv3d_4_3 = nn.Conv3d(in_channels=48, out_channels=64, kernel_size=1, stride=1)
+        self.conv3d_4_1 = nn.Conv3d(
+            in_channels=48, out_channels=56, kernel_size=3, stride=1, padding=1
+        )
+        self.conv3d_4_2 = nn.Conv3d(
+            in_channels=56, out_channels=64, kernel_size=3, stride=1, padding=1
+        )
+        self.conv3d_4_3 = nn.Conv3d(
+            in_channels=48, out_channels=64, kernel_size=1, stride=1
+        )
         self.batch_norm_4_1 = nn.BatchNorm3d(56)
         self.batch_norm_4_2 = nn.BatchNorm3d(64)
         self.batch_norm_4_3 = nn.BatchNorm3d(64)
 
         # Fifth convolution block
-        self.conv3d_5_1 = nn.Conv3d(in_channels=64, out_channels=72, kernel_size=3, stride=1, padding=1)
-        self.conv3d_5_2 = nn.Conv3d(in_channels=72, out_channels=80, kernel_size=3, stride=1, padding=1)
-        self.conv3d_5_3 = nn.Conv3d(in_channels=64, out_channels=80, kernel_size=1, stride=1)
+        self.conv3d_5_1 = nn.Conv3d(
+            in_channels=64, out_channels=72, kernel_size=3, stride=1, padding=1
+        )
+        self.conv3d_5_2 = nn.Conv3d(
+            in_channels=72, out_channels=80, kernel_size=3, stride=1, padding=1
+        )
+        self.conv3d_5_3 = nn.Conv3d(
+            in_channels=64, out_channels=80, kernel_size=1, stride=1
+        )
         self.batch_norm_5_1 = nn.BatchNorm3d(72)
         self.batch_norm_5_2 = nn.BatchNorm3d(80)
         self.batch_norm_5_3 = nn.BatchNorm3d(80)
 
-
         self.pool = nn.MaxPool3d(kernel_size=3, stride=2)
         self.global_avg_pool = nn.AdaptiveAvgPool3d(1)
 
-        self.fc1 = nn.Linear(80, self.num_classes,)
-
+        self.fc1 = nn.Linear(80, self.num_classes)
 
     def forward(self, x):
         x = self.normalizer(x)  # type: ignore
@@ -129,7 +142,7 @@ class Pasa3D(Model):
         x = F.relu(self.batch_norm_1_1(self.conv3d_1_1(x)))
         x = F.relu(self.batch_norm_1_2(self.conv3d_1_2(x)))
         x = (x + F.relu(self.batch_norm_1_3(self.conv3d_1_3(_x)))) / 2
-        
+
         # Second convolution block
         _x = x
         x = F.relu(self.batch_norm_2_1(self.conv3d_2_1(x)))
@@ -156,8 +169,7 @@ class Pasa3D(Model):
 
         x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
 
-        x = self.fc1(x)
-        return x
+        return self.fc1(x)
 
         # x = F.log_softmax(x, dim=1) # 0 is batch size
 
diff --git a/tests/test_visceral.py b/tests/test_visceral.py
index fdd1b8ff..9743577e 100644
--- a/tests/test_visceral.py
+++ b/tests/test_visceral.py
@@ -6,11 +6,13 @@
 import pytest
 from click.testing import CliRunner
 
+
 def id_function(val):
     if isinstance(val, dict):
         return str(val)
     return repr(val)
 
+
 @pytest.mark.parametrize(
     "split,lenghts",
     [
@@ -32,13 +34,13 @@ def test_protocol_consistency(
         possible_labels=(0, 1),
     )
 
+
 @pytest.mark.skip_if_rc_var_not_set("datadir.visceral")
 def test_database_check():
     from mednet.scripts.database import check
 
     runner = CliRunner()
     result = runner.invoke(check, ["visceral"])
-    assert(1 == 0), f"test"
     assert (
         result.exit_code == 0
     ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}"
-- 
GitLab