Skip to content
Snippets Groups Projects
Commit 9726f356 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[3d-cnn/visceral] Adequate database and model to new-style datamodule/model

parent 740c3278
No related branches found
No related tags found
1 merge request!46Create common library
Pipeline #89236 failed
...@@ -9,21 +9,23 @@ Database reference: ...@@ -9,21 +9,23 @@ Database reference:
import os import os
import pathlib import pathlib
import typing
import torchio as tio import torchio as tio
from mednet.libs.classification.data.typing import (
from ....data.datamodule import CachingDataModule ClassificationRawDataLoader as _ClassificationRawDataLoader,
from ....data.split import make_split )
from ....data.typing import RawDataLoader as _BaseRawDataLoader from mednet.libs.classification.data.typing import Sample
from ....data.typing import Sample from mednet.libs.common.data.datamodule import CachingDataModule
from ....utils.rc import load_rc from mednet.libs.common.data.split import make_split
from mednet.libs.common.utils.rc import load_rc
CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2]) CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
"""Key to search for in the configuration file for the root directory of this """Key to search for in the configuration file for the root directory of this
database.""" database."""
class RawDataLoader(_BaseRawDataLoader): class ClassificationRawDataLoader(_ClassificationRawDataLoader):
"""A specialized raw-data-loader for the VISCERAL dataset.""" """A specialized raw-data-loader for the VISCERAL dataset."""
datadir: pathlib.Path datadir: pathlib.Path
...@@ -38,7 +40,7 @@ class RawDataLoader(_BaseRawDataLoader): ...@@ -38,7 +40,7 @@ class RawDataLoader(_BaseRawDataLoader):
), ),
) )
def sample(self, sample: tuple[str, int]) -> Sample: def sample(self, sample: tuple[str, int, typing.Any | None]) -> Sample:
"""Load a single volume sample from the disk. """Load a single volume sample from the disk.
Parameters Parameters
...@@ -46,7 +48,7 @@ class RawDataLoader(_BaseRawDataLoader): ...@@ -46,7 +48,7 @@ class RawDataLoader(_BaseRawDataLoader):
sample sample
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the volume to be loaded and an integer, representing where to find the volume to be loaded and an integer, representing
the sample label. the sample target.
Returns Returns
------- -------
...@@ -58,25 +60,25 @@ class RawDataLoader(_BaseRawDataLoader): ...@@ -58,25 +60,25 @@ class RawDataLoader(_BaseRawDataLoader):
image = tio.ScalarImage(self.datadir / sample[0]) image = tio.ScalarImage(self.datadir / sample[0])
image = preprocess(image) image = preprocess(image)
tensor = image.data tensor = image.data
return tensor, dict(label=sample[1], name=sample[0]) return tensor, dict(target=sample[1], name=sample[0])
def label(self, sample: tuple[str, int]) -> int: def target(self, k: typing.Any) -> int | list[int]:
"""Load a single image sample label from the disk. """Load a single image sample target from the disk.
Parameters Parameters
---------- ----------
sample k
A tuple containing the path suffix, within the dataset root folder, A tuple containing the path suffix, within the dataset root folder,
where to find the image to be loaded, and an integer, representing where to find the image to be loaded, and an integer, representing
the sample label. the sample target.
Returns Returns
------- -------
int int
The integer label associated with the sample. The integer target associated with the sample.
""" """
return sample[1] return k[1]
class DataModule(CachingDataModule): class DataModule(CachingDataModule):
...@@ -100,7 +102,7 @@ class DataModule(CachingDataModule): ...@@ -100,7 +102,7 @@ class DataModule(CachingDataModule):
* Final specifications * Final specifications
* 32-bit floats, cubes 16x16x16 pixels * 32-bit floats, cubes 16x16x16 pixels
* Labels: 0 (bladder), 1 (lung) * targets: 0 (bladder), 1 (lung)
Parameters Parameters
---------- ----------
...@@ -109,9 +111,10 @@ class DataModule(CachingDataModule): ...@@ -109,9 +111,10 @@ class DataModule(CachingDataModule):
""" """
def __init__(self, split_filename: str): def __init__(self, split_filename: str):
assert __package__ is not None
super().__init__( super().__init__(
make_split(__package__, split_filename), make_split(__package__, split_filename),
raw_data_loader=RawDataLoader(), raw_data_loader=ClassificationRawDataLoader(),
database_name=__package__.split(".")[-1], database_name=__package__.split(".")[-1],
split_name=pathlib.Path(split_filename).stem, split_name=pathlib.Path(split_filename).stem,
) )
...@@ -23,8 +23,8 @@ class ClassificationRawDataLoader(RawDataLoader): ...@@ -23,8 +23,8 @@ class ClassificationRawDataLoader(RawDataLoader):
raise NotImplementedError("You must implement the `sample()` method") raise NotImplementedError("You must implement the `sample()` method")
def label(self, k: typing.Any) -> int | list[int]: def target(self, k: typing.Any) -> int | list[int]:
"""Load only sample label from media. """Load only sample target from media.
If you do not override this implementation, then, by default, If you do not override this implementation, then, by default,
this method will call :py:meth:`sample` to load the whole sample this method will call :py:meth:`sample` to load the whole sample
......
...@@ -49,11 +49,11 @@ class ClassificationModel(Model): ...@@ -49,11 +49,11 @@ class ClassificationModel(Model):
def __init__( def __init__(
self, self,
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] = {}, scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [], model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
......
...@@ -12,13 +12,12 @@ import torch.optim.optimizer ...@@ -12,13 +12,12 @@ import torch.optim.optimizer
import torch.utils.data import torch.utils.data
from ...common.data.typing import TransformSequence from ...common.data.typing import TransformSequence
from ...common.models.model import Model from ..models.classification_model import ClassificationModel
from .separate import separate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Conv3DNet(Model): class Conv3DNet(ClassificationModel):
"""Implementation of 3D CNN. """Implementation of 3D CNN.
This network has a linear output. You should use losses with ``WithLogit`` This network has a linear output. You should use losses with ``WithLogit``
...@@ -39,6 +38,13 @@ class Conv3DNet(Model): ...@@ -39,6 +38,13 @@ class Conv3DNet(Model):
The type of optimizer to use for training. The type of optimizer to use for training.
optimizer_arguments optimizer_arguments
Arguments to the optimizer after ``params``. Arguments to the optimizer after ``params``.
scheduler_type
The type of scheduler to use for training.
scheduler_arguments
Arguments to the scheduler after ``params``.
model_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
augmentation_transforms augmentation_transforms
An optional sequence of torch modules containing transforms to be An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network. applied on the input **before** it is fed into the network.
...@@ -48,27 +54,28 @@ class Conv3DNet(Model): ...@@ -48,27 +54,28 @@ class Conv3DNet(Model):
def __init__( def __init__(
self, self,
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
num_classes: int = 1, num_classes: int = 1,
): ):
super().__init__( super().__init__(
loss_type, loss_type=loss_type,
loss_arguments, loss_arguments=loss_arguments,
optimizer_type, optimizer_type=optimizer_type,
optimizer_arguments, optimizer_arguments=optimizer_arguments,
augmentation_transforms, scheduler_type=scheduler_type,
num_classes, scheduler_arguments=scheduler_arguments,
model_transforms=model_transforms,
augmentation_transforms=augmentation_transforms,
num_classes=num_classes,
) )
self.name = "cnn3D"
self.num_classes = num_classes
self.model_transforms = []
# First convolution block # First convolution block
self.conv3d_1_1 = nn.Conv3d( self.conv3d_1_1 = nn.Conv3d(
in_channels=1, out_channels=4, kernel_size=3, stride=1, padding=1 in_channels=1, out_channels=4, kernel_size=3, stride=1, padding=1
...@@ -172,35 +179,3 @@ class Conv3DNet(Model): ...@@ -172,35 +179,3 @@ class Conv3DNet(Model):
return self.fc2(x) return self.fc2(x)
# x = F.log_softmax(x, dim=1) # 0 is batch size # x = F.log_softmax(x, dim=1) # 0 is batch size
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(self.augmentation_transforms(images))
return self._train_loss(outputs, labels.float())
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0]
labels = batch[1]["label"]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# data forwarding on the existing network
outputs = self(images)
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
...@@ -58,7 +58,7 @@ class Model(pl.LightningModule): ...@@ -58,7 +58,7 @@ class Model(pl.LightningModule):
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] = {}, scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [], model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -73,10 +73,8 @@ class Model(pl.LightningModule): ...@@ -73,10 +73,8 @@ class Model(pl.LightningModule):
self._loss_type = loss_type self._loss_type = loss_type
self._train_loss = None
self._train_loss_arguments = loss_arguments self._train_loss_arguments = loss_arguments
self._validation_loss = None
self._validation_loss_arguments = loss_arguments self._validation_loss_arguments = loss_arguments
self._optimizer_type = optimizer_type self._optimizer_type = optimizer_type
......
...@@ -50,11 +50,11 @@ class SegmentationModel(Model): ...@@ -50,11 +50,11 @@ class SegmentationModel(Model):
def __init__( def __init__(
self, self,
loss_type: torch.nn.Module = MultiWeightedBCELogitsLoss, loss_type: type[torch.nn.Module] = MultiWeightedBCELogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] = {}, scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [], model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -89,7 +89,7 @@ class SegmentationModel(Model): ...@@ -89,7 +89,7 @@ class SegmentationModel(Model):
) )
self.normalizer = make_z_normalizer(dataloader) self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, _):
images = self.augmentation_transforms(batch[0]["image"]) images = self.augmentation_transforms(batch[0]["image"])
ground_truths = self.augmentation_transforms(batch[0]["target"]) ground_truths = self.augmentation_transforms(batch[0]["target"])
masks = self.augmentation_transforms(batch[0]["mask"]) masks = self.augmentation_transforms(batch[0]["mask"])
...@@ -97,7 +97,7 @@ class SegmentationModel(Model): ...@@ -97,7 +97,7 @@ class SegmentationModel(Model):
outputs = self(images) outputs = self(images)
return self._train_loss(outputs, ground_truths, masks) return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0]["image"] images = batch[0]["image"]
ground_truths = batch[0]["target"] ground_truths = batch[0]["target"]
masks = batch[0]["mask"] masks = batch[0]["mask"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment