From ffad08889edf2c1fc25584103636b856f3bb85de Mon Sep 17 00:00:00 2001 From: Yvan Pannatier <yvan.pannatier@idiap.ch> Date: Wed, 19 Jun 2024 14:37:35 +0200 Subject: [PATCH] Fix qa. --- src/mednet/config/data/visceral/datamodule.py | 17 ++++--- src/mednet/config/data/visceral/default.json | 2 +- src/mednet/config/data/visceral/default.py | 4 +- src/mednet/config/models/cnn3d.py | 3 +- src/mednet/models/cnn3d.py | 9 ++-- src/mednet/models/pasa3d.py | 50 ++++++++++++------- tests/test_visceral.py | 4 +- 7 files changed, 50 insertions(+), 39 deletions(-) diff --git a/src/mednet/config/data/visceral/datamodule.py b/src/mednet/config/data/visceral/datamodule.py index c3142015..023962cf 100644 --- a/src/mednet/config/data/visceral/datamodule.py +++ b/src/mednet/config/data/visceral/datamodule.py @@ -1,18 +1,20 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -""" VISCERAL dataset for 3D organ classification (only lungs and bladders). +"""VISCERAL dataset for 3D organ classification (only lungs and bladders). Loaded samples are not full scans but 16x16x16 volumes of organs. -Database reference: +Database reference: """ + import os import pathlib + import torchio as tio from ....data.datamodule import CachingDataModule -from ....data.typing import RawDataLoader as _BaseRawDataLoader from ....data.split import make_split +from ....data.typing import RawDataLoader as _BaseRawDataLoader from ....data.typing import Sample from ....utils.rc import load_rc @@ -20,6 +22,7 @@ CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) """Key to search for in the configuration file for the root directory of this database.""" + class RawDataLoader(_BaseRawDataLoader): """A specialized raw-data-loader for the VISCERAL dataset.""" @@ -51,12 +54,11 @@ class RawDataLoader(_BaseRawDataLoader): """ clamp = tio.Clamp(out_min=-1000, out_max=2000) rescale = tio.RescaleIntensity(percentiles=(0.5, 99.5)) - preprocess = tio.Compose([clamp,rescale,]) + preprocess = tio.Compose([clamp, rescale]) image = tio.ScalarImage(self.datadir / sample[0]) image = preprocess(image) tensor = image.data - return tensor, dict(label=sample[1], name=sample[0]) - + return tensor, dict(label=sample[1], name=sample[0]) def label(self, sample: tuple[str, int]) -> int: """Load a single image sample label from the disk. @@ -77,8 +79,6 @@ class RawDataLoader(_BaseRawDataLoader): return sample[1] - - class DataModule(CachingDataModule): """VISCERAL DataModule for 3D organ binary classification. @@ -107,6 +107,7 @@ class DataModule(CachingDataModule): split_filename Name of the .json file containing the split to load. """ + def __init__(self, split_filename: str): super().__init__( make_split(__package__, split_filename), diff --git a/src/mednet/config/data/visceral/default.json b/src/mednet/config/data/visceral/default.json index f0d1c302..6bd8d5bd 100644 --- a/src/mednet/config/data/visceral/default.json +++ b/src/mednet/config/data/visceral/default.json @@ -350,4 +350,4 @@ ["16/10000194_1_237_117.nii.gz",0], ["16/10000059_1_1302_117.nii.gz",1] ] -} \ No newline at end of file +} diff --git a/src/mednet/config/data/visceral/default.py b/src/mednet/config/data/visceral/default.py index 385d00c0..46942bb7 100644 --- a/src/mednet/config/data/visceral/default.py +++ b/src/mednet/config/data/visceral/default.py @@ -1,9 +1,9 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -""" VISCERAL dataset for 3D organ classification. +"""VISCERAL dataset for 3D organ classification. -Database reference: +Database reference: See :py:class:`mednet.config.data.visceral.datamodule.DataModule` for technical details. diff --git a/src/mednet/config/models/cnn3d.py b/src/mednet/config/models/cnn3d.py index 790570ab..305b5452 100644 --- a/src/mednet/config/models/cnn3d.py +++ b/src/mednet/config/models/cnn3d.py @@ -1,8 +1,7 @@ # SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Simple CNN for 3D organ classification, to be trained from scratch. -""" +"""Simple CNN for 3D organ classification, to be trained from scratch.""" from torch.nn import BCEWithLogitsLoss from torch.optim import Adam diff --git a/src/mednet/models/cnn3d.py b/src/mednet/models/cnn3d.py index def0a916..b27c073c 100644 --- a/src/mednet/models/cnn3d.py +++ b/src/mednet/models/cnn3d.py @@ -88,10 +88,9 @@ class Conv3DNet(Model): self.fc1 = nn.Linear(256, 64) self.fc2 = nn.Linear(64, num_classes) - def forward(self, x): - #x = self.normalizer(x) # type: ignore - + # x = self.normalizer(x) # type: ignore + x = F.relu(self.batchnorm_1(self.conv3d_1(x))) x = self.pool(x) x = F.relu(self.batchnorm_2(self.conv3d_2(x))) @@ -103,9 +102,7 @@ class Conv3DNet(Model): x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.dropout(x) - x = self.fc2(x) - return x - + return self.fc2(x) # x = F.log_softmax(x, dim=1) # 0 is batch size diff --git a/src/mednet/models/pasa3d.py b/src/mednet/models/pasa3d.py index 7fcb0e08..6b73d297 100644 --- a/src/mednet/models/pasa3d.py +++ b/src/mednet/models/pasa3d.py @@ -69,8 +69,7 @@ class Pasa3D(Model): self.name = "pasa3D" self.num_classes = num_classes - self.model_transforms = [ - ] + self.model_transforms = [] # First convolution block self.conv3d_1_1 = nn.Conv3d(in_channels=1, out_channels=4, kernel_size=3) @@ -80,7 +79,6 @@ class Pasa3D(Model): self.batch_norm_1_2 = nn.BatchNorm3d(16) self.batch_norm_1_3 = nn.BatchNorm3d(16) - # Second convolution block self.conv3d_2_1 = nn.Conv3d(in_channels=16, out_channels=24, kernel_size=3) self.conv3d_2_2 = nn.Conv3d(in_channels=24, out_channels=32, kernel_size=3) @@ -89,37 +87,52 @@ class Pasa3D(Model): self.batch_norm_2_2 = nn.BatchNorm3d(32) self.batch_norm_2_3 = nn.BatchNorm3d(32) - # Third convolution block - self.conv3d_3_1 = nn.Conv3d(in_channels=32, out_channels=40, kernel_size=3, stride=1, padding=1) - self.conv3d_3_2 = nn.Conv3d(in_channels=40, out_channels=48, kernel_size=3, stride=1, padding=1) - self.conv3d_3_3 = nn.Conv3d(in_channels=32, out_channels=48, kernel_size=1, stride=1) + self.conv3d_3_1 = nn.Conv3d( + in_channels=32, out_channels=40, kernel_size=3, stride=1, padding=1 + ) + self.conv3d_3_2 = nn.Conv3d( + in_channels=40, out_channels=48, kernel_size=3, stride=1, padding=1 + ) + self.conv3d_3_3 = nn.Conv3d( + in_channels=32, out_channels=48, kernel_size=1, stride=1 + ) self.batch_norm_3_1 = nn.BatchNorm3d(40) self.batch_norm_3_2 = nn.BatchNorm3d(48) self.batch_norm_3_3 = nn.BatchNorm3d(48) # Fourth convolution block - self.conv3d_4_1 = nn.Conv3d(in_channels=48, out_channels=56, kernel_size=3, stride=1, padding=1) - self.conv3d_4_2 = nn.Conv3d(in_channels=56, out_channels=64, kernel_size=3, stride=1, padding=1) - self.conv3d_4_3 = nn.Conv3d(in_channels=48, out_channels=64, kernel_size=1, stride=1) + self.conv3d_4_1 = nn.Conv3d( + in_channels=48, out_channels=56, kernel_size=3, stride=1, padding=1 + ) + self.conv3d_4_2 = nn.Conv3d( + in_channels=56, out_channels=64, kernel_size=3, stride=1, padding=1 + ) + self.conv3d_4_3 = nn.Conv3d( + in_channels=48, out_channels=64, kernel_size=1, stride=1 + ) self.batch_norm_4_1 = nn.BatchNorm3d(56) self.batch_norm_4_2 = nn.BatchNorm3d(64) self.batch_norm_4_3 = nn.BatchNorm3d(64) # Fifth convolution block - self.conv3d_5_1 = nn.Conv3d(in_channels=64, out_channels=72, kernel_size=3, stride=1, padding=1) - self.conv3d_5_2 = nn.Conv3d(in_channels=72, out_channels=80, kernel_size=3, stride=1, padding=1) - self.conv3d_5_3 = nn.Conv3d(in_channels=64, out_channels=80, kernel_size=1, stride=1) + self.conv3d_5_1 = nn.Conv3d( + in_channels=64, out_channels=72, kernel_size=3, stride=1, padding=1 + ) + self.conv3d_5_2 = nn.Conv3d( + in_channels=72, out_channels=80, kernel_size=3, stride=1, padding=1 + ) + self.conv3d_5_3 = nn.Conv3d( + in_channels=64, out_channels=80, kernel_size=1, stride=1 + ) self.batch_norm_5_1 = nn.BatchNorm3d(72) self.batch_norm_5_2 = nn.BatchNorm3d(80) self.batch_norm_5_3 = nn.BatchNorm3d(80) - self.pool = nn.MaxPool3d(kernel_size=3, stride=2) self.global_avg_pool = nn.AdaptiveAvgPool3d(1) - self.fc1 = nn.Linear(80, self.num_classes,) - + self.fc1 = nn.Linear(80, self.num_classes) def forward(self, x): x = self.normalizer(x) # type: ignore @@ -129,7 +142,7 @@ class Pasa3D(Model): x = F.relu(self.batch_norm_1_1(self.conv3d_1_1(x))) x = F.relu(self.batch_norm_1_2(self.conv3d_1_2(x))) x = (x + F.relu(self.batch_norm_1_3(self.conv3d_1_3(_x)))) / 2 - + # Second convolution block _x = x x = F.relu(self.batch_norm_2_1(self.conv3d_2_1(x))) @@ -156,8 +169,7 @@ class Pasa3D(Model): x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2) - x = self.fc1(x) - return x + return self.fc1(x) # x = F.log_softmax(x, dim=1) # 0 is batch size diff --git a/tests/test_visceral.py b/tests/test_visceral.py index fdd1b8ff..9743577e 100644 --- a/tests/test_visceral.py +++ b/tests/test_visceral.py @@ -6,11 +6,13 @@ import pytest from click.testing import CliRunner + def id_function(val): if isinstance(val, dict): return str(val) return repr(val) + @pytest.mark.parametrize( "split,lenghts", [ @@ -32,13 +34,13 @@ def test_protocol_consistency( possible_labels=(0, 1), ) + @pytest.mark.skip_if_rc_var_not_set("datadir.visceral") def test_database_check(): from mednet.scripts.database import check runner = CliRunner() result = runner.invoke(check, ["visceral"]) - assert(1 == 0), f"test" assert ( result.exit_code == 0 ), f"Exit code {result.exit_code} != 0 -- Output:\n{result.output}" -- GitLab