diff --git a/src/mednet/config/models/cnn3d.py b/src/mednet/config/models/cnn3d.py new file mode 100644 index 0000000000000000000000000000000000000000..790570ab93b5e4c67c1d219f861788759ca4e3c6 --- /dev/null +++ b/src/mednet/config/models/cnn3d.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Simple CNN for 3D organ classification, to be trained from scratch. +""" + +from torch.nn import BCEWithLogitsLoss +from torch.optim import Adam + +from mednet.models.cnn3d import Conv3DNet + +model = Conv3DNet( + loss_type=BCEWithLogitsLoss, + optimizer_type=Adam, + optimizer_arguments=dict(lr=8e-5), +) diff --git a/src/mednet/models/cnn3d.py b/src/mednet/models/cnn3d.py new file mode 100644 index 0000000000000000000000000000000000000000..def0a91689b60d05b96392336f786859e9bb44a0 --- /dev/null +++ b/src/mednet/models/cnn3d.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging +import typing + +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +import torch.optim.optimizer +import torch.utils.data + +from ..data.typing import TransformSequence +from .model import Model +from .separate import separate + +logger = logging.getLogger(__name__) + + +class Conv3DNet(Model): + """Implementation of 3D CNN. + + This network has a linear output. You should use losses with ``WithLogit`` + instead of cross-entropy versions when training. + + Parameters + ---------- + loss_type + The loss to be used for training and evaluation. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + loss_arguments + Arguments to the loss. + optimizer_type + The type of optimizer to use for training. + optimizer_arguments + Arguments to the optimizer after ``params``. + augmentation_transforms + An optional sequence of torch modules containing transforms to be + applied on the input **before** it is fed into the network. + num_classes + Number of outputs (classes) for this model. + """ + + def __init__( + self, + loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, + loss_arguments: dict[str, typing.Any] = {}, + optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer_arguments: dict[str, typing.Any] = {}, + augmentation_transforms: TransformSequence = [], + num_classes: int = 1, + ): + super().__init__( + loss_type, + loss_arguments, + optimizer_type, + optimizer_arguments, + augmentation_transforms, + num_classes, + ) + + self.name = "cnn3D" + self.num_classes = num_classes + + self.model_transforms = [] + + # First convolution block + self.conv3d_1 = nn.Conv3d(1, 32, kernel_size=3, padding=1) + self.batchnorm_1 = nn.BatchNorm3d(32) + # Second convolution block + self.conv3d_2 = nn.Conv3d(32, 64, kernel_size=3, padding=1) + self.batchnorm_2 = nn.BatchNorm3d(64) + # Third convolution block + self.conv3d_3 = nn.Conv3d(64, 128, kernel_size=3, padding=1) + self.batchnorm_3 = nn.BatchNorm3d(128) + # Fourth convolution block + self.conv3d_4 = nn.Conv3d(128, 256, kernel_size=3, padding=1) + self.batchnorm_4 = nn.BatchNorm3d(256) + + self.pool = nn.MaxPool3d(2) + self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) + self.dropout = nn.Dropout(0.3) + self.fc1 = nn.Linear(256, 64) + self.fc2 = nn.Linear(64, num_classes) + + + def forward(self, x): + #x = self.normalizer(x) # type: ignore + + x = F.relu(self.batchnorm_1(self.conv3d_1(x))) + x = self.pool(x) + x = F.relu(self.batchnorm_2(self.conv3d_2(x))) + x = self.pool(x) + x = F.relu(self.batchnorm_3(self.conv3d_3(x))) + x = self.pool(x) + x = F.relu(self.batchnorm_4(self.conv3d_4(x))) + x = self.global_pool(x) + x = x.view(x.size(0), -1) + x = F.relu(self.fc1(x)) + x = self.dropout(x) + x = self.fc2(x) + return x + + + # x = F.log_softmax(x, dim=1) # 0 is batch size + + def training_step(self, batch, _): + images = batch[0] + labels = batch[1]["label"] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # Forward pass on the network + outputs = self(self.augmentation_transforms(images)) + + return self._train_loss(outputs, labels.float()) + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + images = batch[0] + labels = batch[1]["label"] + + # Increase label dimension if too low + # Allows single and multiclass usage + if labels.ndim == 1: + labels = torch.reshape(labels, (labels.shape[0], 1)) + + # data forwarding on the existing network + outputs = self(images) + return self._validation_loss(outputs, labels.float()) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + outputs = self(batch[0]) + probabilities = torch.sigmoid(outputs) + return separate((probabilities, batch[1]))