Skip to content
Snippets Groups Projects
Commit cfff0324 authored by Yvan Pannatier's avatar Yvan Pannatier
Browse files

[models] remove pasa3d code

parent 44df1610
No related branches found
No related tags found
1 merge request!513d cnn visceral
Pipeline #89210 passed
......@@ -247,7 +247,6 @@ densenet = "mednet.config.models.densenet"
densenet-pretrained = "mednet.config.models.densenet_pretrained"
# 3D models
pasa3d = "mednet.config.models.pasa3d"
cnn3d = "mednet.config.models.cnn3d"
# lists of data augmentations
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Simple CNN for 3D organ classification, to be trained from scratch.
3D Adaptation of the model architecture proposed by F. Pasa in the article
"Efficient Deep Network Architectures for Fast Chest X-Ray Tuberculosis
Screening and Visualization".
Reference: [PASA-2019]_
"""
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from mednet.models.pasa3d import Pasa3D
model = Pasa3D(
loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=8e-5),
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import typing
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torch.optim.optimizer
import torch.utils.data
from ..data.typing import TransformSequence
from .model import Model
from .separate import separate
logger = logging.getLogger(__name__)
class Pasa3D(Model):
"""Implementation of a 3D version of CNN by Pasa and others.
Simple CNN for classification adapted from paper by [PASA-2019]_.
This network has a linear output. You should use losses with ``WithLogit``
instead of cross-entropy versions when training.
Parameters
----------
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
num_classes
Number of outputs (classes) for this model.
"""
def __init__(
self,
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
):
super().__init__(
loss_type,
loss_arguments,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
num_classes,
)
self.name = "pasa3D"
self.num_classes = num_classes
self.model_transforms = []
# First convolution block
self.conv3d_1_1 = nn.Conv3d(in_channels=1, out_channels=4, kernel_size=3)
self.conv3d_1_2 = nn.Conv3d(in_channels=4, out_channels=16, kernel_size=3)
self.conv3d_1_3 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=5)
self.batch_norm_1_1 = nn.BatchNorm3d(4)
self.batch_norm_1_2 = nn.BatchNorm3d(16)
self.batch_norm_1_3 = nn.BatchNorm3d(16)
# Second convolution block
self.conv3d_2_1 = nn.Conv3d(in_channels=16, out_channels=24, kernel_size=3)
self.conv3d_2_2 = nn.Conv3d(in_channels=24, out_channels=32, kernel_size=3)
self.conv3d_2_3 = nn.Conv3d(in_channels=16, out_channels=32, kernel_size=5)
self.batch_norm_2_1 = nn.BatchNorm3d(24)
self.batch_norm_2_2 = nn.BatchNorm3d(32)
self.batch_norm_2_3 = nn.BatchNorm3d(32)
# Third convolution block
self.conv3d_3_1 = nn.Conv3d(
in_channels=32, out_channels=40, kernel_size=3, stride=1, padding=1
)
self.conv3d_3_2 = nn.Conv3d(
in_channels=40, out_channels=48, kernel_size=3, stride=1, padding=1
)
self.conv3d_3_3 = nn.Conv3d(
in_channels=32, out_channels=48, kernel_size=1, stride=1
)
self.batch_norm_3_1 = nn.BatchNorm3d(40)
self.batch_norm_3_2 = nn.BatchNorm3d(48)
self.batch_norm_3_3 = nn.BatchNorm3d(48)
# Fourth convolution block
self.conv3d_4_1 = nn.Conv3d(
in_channels=48, out_channels=56, kernel_size=3, stride=1, padding=1
)
self.conv3d_4_2 = nn.Conv3d(
in_channels=56, out_channels=64, kernel_size=3, stride=1, padding=1
)
self.conv3d_4_3 = nn.Conv3d(
in_channels=48, out_channels=64, kernel_size=1, stride=1
)
self.batch_norm_4_1 = nn.BatchNorm3d(56)
self.batch_norm_4_2 = nn.BatchNorm3d(64)
self.batch_norm_4_3 = nn.BatchNorm3d(64)
# Fifth convolution block
self.conv3d_5_1 = nn.Conv3d(
in_channels=64, out_channels=72, kernel_size=3, stride=1, padding=1
)
self.conv3d_5_2 = nn.Conv3d(
in_channels=72, out_channels=80, kernel_size=3, stride=1, padding=1
)
self.conv3d_5_3 = nn.Conv3d(
in_channels=64, out_channels=80, kernel_size=1, stride=1
)
self.batch_norm_5_1 = nn.BatchNorm3d(72)
self.batch_norm_5_2 = nn.BatchNorm3d(80)
self.batch_norm_5_3 = nn.BatchNorm3d(80)
self.pool = nn.MaxPool3d(kernel_size=3, stride=2)
self.global_avg_pool = nn.AdaptiveAvgPool3d(1)
self.fc1 = nn.Linear(80, self.num_classes)
def forward(self, x):
x = self.normalizer(x) # type: ignore
# First convolution block
_x = x
x = F.relu(self.batch_norm_1_1(self.conv3d_1_1(x)))
x = F.relu(self.batch_norm_1_2(self.conv3d_1_2(x)))
x = (x + F.relu(self.batch_norm_1_3(self.conv3d_1_3(_x)))) / 2
# Second convolution block
_x = x
x = F.relu(self.batch_norm_2_1(self.conv3d_2_1(x)))
x = F.relu(self.batch_norm_2_2(self.conv3d_2_2(x)))
x = (x + F.relu(self.batch_norm_2_3(self.conv3d_2_3(_x)))) / 2
# Third convolution block
_x = x
x = F.relu(self.batch_norm_3_1(self.conv3d_3_1(x)))
x = F.relu(self.batch_norm_3_2(self.conv3d_3_2(x)))
x = (x + F.relu(self.batch_norm_3_3(self.conv3d_3_3(_x)))) / 2
# Fourth convolution block
_x = x
x = F.relu(self.batch_norm_4_1(self.conv3d_4_1(x)))
x = F.relu(self.batch_norm_4_2(self.conv3d_4_2(x)))
x = (x + F.relu(self.batch_norm_4_3(self.conv3d_4_3(_x)))) / 2
# Fifth convolution block
_x = x
x = F.relu(self.batch_norm_5_1(self.conv3d_5_1(x)))
x = F.relu(self.batch_norm_5_2(self.conv3d_5_2(x)))
x = (x + F.relu(self.batch_norm_5_3(self.conv3d_5_3(_x)))) / 2
x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
return self.fc1(x)
# x = F.log_softmax(x, dim=1) # 0 is batch size
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(images)
return self._train_loss(outputs, labels.float())
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0]
labels = batch[1]["label"]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# data forwarding on the existing network
outputs = self(images)
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment