diff --git a/src/mednet/config/data/visceral/__init__.py b/src/mednet/libs/classification/config/data/visceral/__init__.py similarity index 100% rename from src/mednet/config/data/visceral/__init__.py rename to src/mednet/libs/classification/config/data/visceral/__init__.py diff --git a/src/mednet/config/data/visceral/datamodule.py b/src/mednet/libs/classification/config/data/visceral/datamodule.py similarity index 73% rename from src/mednet/config/data/visceral/datamodule.py rename to src/mednet/libs/classification/config/data/visceral/datamodule.py index 023962cf76d8b98b948739aaf5ceedf31ad9d881..1b5b96644814d24d11851d8d0e59c3ad25e762a3 100644 --- a/src/mednet/config/data/visceral/datamodule.py +++ b/src/mednet/libs/classification/config/data/visceral/datamodule.py @@ -9,21 +9,23 @@ Database reference: import os import pathlib +import typing import torchio as tio - -from ....data.datamodule import CachingDataModule -from ....data.split import make_split -from ....data.typing import RawDataLoader as _BaseRawDataLoader -from ....data.typing import Sample -from ....utils.rc import load_rc +from mednet.libs.classification.data.typing import ( + ClassificationRawDataLoader as _ClassificationRawDataLoader, +) +from mednet.libs.classification.data.typing import Sample +from mednet.libs.common.data.datamodule import CachingDataModule +from mednet.libs.common.data.split import make_split +from mednet.libs.common.utils.rc import load_rc CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) """Key to search for in the configuration file for the root directory of this database.""" -class RawDataLoader(_BaseRawDataLoader): +class ClassificationRawDataLoader(_ClassificationRawDataLoader): """A specialized raw-data-loader for the VISCERAL dataset.""" datadir: pathlib.Path @@ -38,7 +40,7 @@ class RawDataLoader(_BaseRawDataLoader): ), ) - def sample(self, sample: tuple[str, int]) -> Sample: + def sample(self, sample: tuple[str, int, typing.Any | None]) -> Sample: """Load a single volume sample from the disk. Parameters @@ -46,7 +48,7 @@ class RawDataLoader(_BaseRawDataLoader): sample A tuple containing the path suffix, within the dataset root folder, where to find the volume to be loaded and an integer, representing - the sample label. + the sample target. Returns ------- @@ -58,25 +60,25 @@ class RawDataLoader(_BaseRawDataLoader): image = tio.ScalarImage(self.datadir / sample[0]) image = preprocess(image) tensor = image.data - return tensor, dict(label=sample[1], name=sample[0]) + return tensor, dict(target=sample[1], name=sample[0]) - def label(self, sample: tuple[str, int]) -> int: - """Load a single image sample label from the disk. + def target(self, k: typing.Any) -> int | list[int]: + """Load a single image sample target from the disk. Parameters ---------- - sample + k A tuple containing the path suffix, within the dataset root folder, where to find the image to be loaded, and an integer, representing - the sample label. + the sample target. Returns ------- int - The integer label associated with the sample. + The integer target associated with the sample. """ - return sample[1] + return k[1] class DataModule(CachingDataModule): @@ -100,7 +102,7 @@ class DataModule(CachingDataModule): * Final specifications * 32-bit floats, cubes 16x16x16 pixels - * Labels: 0 (bladder), 1 (lung) + * targets: 0 (bladder), 1 (lung) Parameters ---------- @@ -109,9 +111,10 @@ class DataModule(CachingDataModule): """ def __init__(self, split_filename: str): + assert __package__ is not None super().__init__( make_split(__package__, split_filename), - raw_data_loader=RawDataLoader(), + raw_data_loader=ClassificationRawDataLoader(), database_name=__package__.split(".")[-1], split_name=pathlib.Path(split_filename).stem, ) diff --git a/src/mednet/config/data/visceral/default.json b/src/mednet/libs/classification/config/data/visceral/default.json similarity index 100% rename from src/mednet/config/data/visceral/default.json rename to src/mednet/libs/classification/config/data/visceral/default.json diff --git a/src/mednet/config/data/visceral/default.py b/src/mednet/libs/classification/config/data/visceral/default.py similarity index 100% rename from src/mednet/config/data/visceral/default.py rename to src/mednet/libs/classification/config/data/visceral/default.py diff --git a/src/mednet/libs/classification/data/typing.py b/src/mednet/libs/classification/data/typing.py index 10dc375be754fd0a7a0cb541d678ab3d341fd6de..3a29c2bf2bf48b822f3911e6573d65875b4d8863 100644 --- a/src/mednet/libs/classification/data/typing.py +++ b/src/mednet/libs/classification/data/typing.py @@ -23,8 +23,8 @@ class ClassificationRawDataLoader(RawDataLoader): raise NotImplementedError("You must implement the `sample()` method") - def label(self, k: typing.Any) -> int | list[int]: - """Load only sample label from media. + def target(self, k: typing.Any) -> int | list[int]: + """Load only sample target from media. If you do not override this implementation, then, by default, this method will call :py:meth:`sample` to load the whole sample diff --git a/src/mednet/libs/classification/models/classification_model.py b/src/mednet/libs/classification/models/classification_model.py index b125dec9d66f9881b7d43e4d0d6036197efc950a..2a8c24ee3655dce1d6823700a2f0cca4c4bb8471 100644 --- a/src/mednet/libs/classification/models/classification_model.py +++ b/src/mednet/libs/classification/models/classification_model.py @@ -49,11 +49,11 @@ class ClassificationModel(Model): def __init__( self, - loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss, loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, - scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None, scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], diff --git a/src/mednet/libs/classification/models/cnn3d.py b/src/mednet/libs/classification/models/cnn3d.py index b9abf5e98c3c7fef25e255e5a8b656dee7b272c2..9a77092df9113f7083a7db31b191384a71546d19 100644 --- a/src/mednet/libs/classification/models/cnn3d.py +++ b/src/mednet/libs/classification/models/cnn3d.py @@ -12,13 +12,12 @@ import torch.optim.optimizer import torch.utils.data from ...common.data.typing import TransformSequence -from ...common.models.model import Model -from .separate import separate +from ..models.classification_model import ClassificationModel logger = logging.getLogger(__name__) -class Conv3DNet(Model): +class Conv3DNet(ClassificationModel): """Implementation of 3D CNN. This network has a linear output. You should use losses with ``WithLogit`` @@ -39,6 +38,13 @@ class Conv3DNet(Model): The type of optimizer to use for training. optimizer_arguments Arguments to the optimizer after ``params``. + scheduler_type + The type of scheduler to use for training. + scheduler_arguments + Arguments to the scheduler after ``params``. + model_transforms + An optional sequence of torch modules containing transforms to be + applied on the input **before** it is fed into the network. augmentation_transforms An optional sequence of torch modules containing transforms to be applied on the input **before** it is fed into the network. @@ -48,27 +54,28 @@ class Conv3DNet(Model): def __init__( self, - loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss, loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, + scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None, + scheduler_arguments: dict[str, typing.Any] = {}, + model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], num_classes: int = 1, ): super().__init__( - loss_type, - loss_arguments, - optimizer_type, - optimizer_arguments, - augmentation_transforms, - num_classes, + loss_type=loss_type, + loss_arguments=loss_arguments, + optimizer_type=optimizer_type, + optimizer_arguments=optimizer_arguments, + scheduler_type=scheduler_type, + scheduler_arguments=scheduler_arguments, + model_transforms=model_transforms, + augmentation_transforms=augmentation_transforms, + num_classes=num_classes, ) - self.name = "cnn3D" - self.num_classes = num_classes - - self.model_transforms = [] - # First convolution block self.conv3d_1_1 = nn.Conv3d( in_channels=1, out_channels=4, kernel_size=3, stride=1, padding=1 @@ -172,35 +179,3 @@ class Conv3DNet(Model): return self.fc2(x) # x = F.log_softmax(x, dim=1) # 0 is batch size - - def training_step(self, batch, _): - images = batch[0] - labels = batch[1]["label"] - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # Forward pass on the network - outputs = self(self.augmentation_transforms(images)) - - return self._train_loss(outputs, labels.float()) - - def validation_step(self, batch, batch_idx, dataloader_idx=0): - images = batch[0] - labels = batch[1]["label"] - - # Increase label dimension if too low - # Allows single and multiclass usage - if labels.ndim == 1: - labels = torch.reshape(labels, (labels.shape[0], 1)) - - # data forwarding on the existing network - outputs = self(images) - return self._validation_loss(outputs, labels.float()) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - outputs = self(batch[0]) - probabilities = torch.sigmoid(outputs) - return separate((probabilities, batch[1])) diff --git a/src/mednet/libs/common/models/model.py b/src/mednet/libs/common/models/model.py index b59301e2b061785ae1c3d728588a767db213f06a..669ba6c18b35faf30876780c14fa30f6368846ef 100644 --- a/src/mednet/libs/common/models/model.py +++ b/src/mednet/libs/common/models/model.py @@ -58,7 +58,7 @@ class Model(pl.LightningModule): loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, - scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None, scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], @@ -73,10 +73,8 @@ class Model(pl.LightningModule): self._loss_type = loss_type - self._train_loss = None self._train_loss_arguments = loss_arguments - self._validation_loss = None self._validation_loss_arguments = loss_arguments self._optimizer_type = optimizer_type diff --git a/src/mednet/libs/segmentation/models/segmentation_model.py b/src/mednet/libs/segmentation/models/segmentation_model.py index 51599b3ef91ada9428cf9f796c5404dff2d58bda..f545e67f3cff41cff46890d13851bb9aa1885723 100644 --- a/src/mednet/libs/segmentation/models/segmentation_model.py +++ b/src/mednet/libs/segmentation/models/segmentation_model.py @@ -50,11 +50,11 @@ class SegmentationModel(Model): def __init__( self, - loss_type: torch.nn.Module = MultiWeightedBCELogitsLoss, + loss_type: type[torch.nn.Module] = MultiWeightedBCELogitsLoss, loss_arguments: dict[str, typing.Any] = {}, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_arguments: dict[str, typing.Any] = {}, - scheduler_type: type[torch.optim.lr_scheduler] = None, + scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None, scheduler_arguments: dict[str, typing.Any] = {}, model_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [], @@ -89,7 +89,7 @@ class SegmentationModel(Model): ) self.normalizer = make_z_normalizer(dataloader) - def training_step(self, batch, batch_idx): + def training_step(self, batch, _): images = self.augmentation_transforms(batch[0]["image"]) ground_truths = self.augmentation_transforms(batch[0]["target"]) masks = self.augmentation_transforms(batch[0]["mask"]) @@ -97,7 +97,7 @@ class SegmentationModel(Model): outputs = self(images) return self._train_loss(outputs, ground_truths, masks) - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[0]["image"] ground_truths = batch[0]["target"] masks = batch[0]["mask"]