Skip to content
Snippets Groups Projects
Commit 5249d724 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[*/models] Unify modelling by applying DRY; Fix a nasty loss-balancing bug;...

[*/models] Unify modelling by applying DRY; Fix a nasty loss-balancing bug; Make loss-configuration private to base model class; Automate loss-configuration; Fix typing across model submodules; Move `name` into parent type
parent f09d3fe8
No related branches found
No related tags found
1 merge request!46Create common library
Showing
with 240 additions and 343 deletions
...@@ -56,11 +56,11 @@ class Alexnet(ClassificationModel): ...@@ -56,11 +56,11 @@ class Alexnet(ClassificationModel):
def __init__( def __init__(
self, self,
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] = {}, scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [], model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -68,20 +68,18 @@ class Alexnet(ClassificationModel): ...@@ -68,20 +68,18 @@ class Alexnet(ClassificationModel):
num_classes: int = 1, num_classes: int = 1,
): ):
super().__init__( super().__init__(
loss_type, name="alexnet",
loss_arguments, loss_type=loss_type,
optimizer_type, loss_arguments=loss_arguments,
optimizer_arguments, optimizer_type=optimizer_type,
scheduler_type, optimizer_arguments=optimizer_arguments,
scheduler_arguments, scheduler_type=scheduler_type,
model_transforms, scheduler_arguments=scheduler_arguments,
augmentation_transforms, model_transforms=model_transforms,
num_classes, augmentation_transforms=augmentation_transforms,
num_classes=num_classes,
) )
self.name = "alexnet"
self.num_classes = num_classes
self.pretrained = pretrained self.pretrained = pretrained
# Load pretrained model # Load pretrained model
......
...@@ -20,6 +20,8 @@ class ClassificationModel(Model): ...@@ -20,6 +20,8 @@ class ClassificationModel(Model):
Parameters Parameters
---------- ----------
name
Common name to give to models of this type.
loss_type loss_type
The loss to be used for training and evaluation. The loss to be used for training and evaluation.
...@@ -49,6 +51,7 @@ class ClassificationModel(Model): ...@@ -49,6 +51,7 @@ class ClassificationModel(Model):
def __init__( def __init__(
self, self,
name: str,
loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss, loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
...@@ -60,6 +63,7 @@ class ClassificationModel(Model): ...@@ -60,6 +63,7 @@ class ClassificationModel(Model):
num_classes: int = 1, num_classes: int = 1,
): ):
super().__init__( super().__init__(
name,
loss_type, loss_type,
loss_arguments, loss_arguments,
optimizer_type, optimizer_type,
......
...@@ -65,6 +65,7 @@ class Conv3DNet(ClassificationModel): ...@@ -65,6 +65,7 @@ class Conv3DNet(ClassificationModel):
num_classes: int = 1, num_classes: int = 1,
): ):
super().__init__( super().__init__(
name="cnn3d",
loss_type=loss_type, loss_type=loss_type,
loss_arguments=loss_arguments, loss_arguments=loss_arguments,
optimizer_type=optimizer_type, optimizer_type=optimizer_type,
......
...@@ -56,11 +56,11 @@ class Densenet(ClassificationModel): ...@@ -56,11 +56,11 @@ class Densenet(ClassificationModel):
def __init__( def __init__(
self, self,
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] = {}, scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [], model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -69,20 +69,18 @@ class Densenet(ClassificationModel): ...@@ -69,20 +69,18 @@ class Densenet(ClassificationModel):
num_classes: int = 1, num_classes: int = 1,
): ):
super().__init__( super().__init__(
loss_type, name="densenet-121",
loss_arguments, loss_type=loss_type,
optimizer_type, loss_arguments=loss_arguments,
optimizer_arguments, optimizer_type=optimizer_type,
scheduler_type, optimizer_arguments=optimizer_arguments,
scheduler_arguments, scheduler_type=scheduler_type,
model_transforms, scheduler_arguments=scheduler_arguments,
augmentation_transforms, model_transforms=model_transforms,
num_classes, augmentation_transforms=augmentation_transforms,
num_classes=num_classes,
) )
self.name = "densenet-121"
self.num_classes = num_classes
self.pretrained = pretrained self.pretrained = pretrained
# Load pretrained model # Load pretrained model
......
...@@ -4,109 +4,59 @@ ...@@ -4,109 +4,59 @@
import typing import typing
import lightning.pytorch as pl
import torch import torch
import torch.nn as nn import torch.nn
from .classification_model import ClassificationModel
class LogisticRegression(pl.LightningModule):
class LogisticRegression(ClassificationModel):
"""Logistic regression classifier with a single output. """Logistic regression classifier with a single output.
Parameters Parameters
---------- ----------
train_loss loss_type
The loss to be used during the training. The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
.. warning:: .. warning::
The loss should be set to always return batch averages (as opposed The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so. to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type optimizer_type
The type of optimizer to use for training. The type of optimizer to use for training.
optimizer_arguments optimizer_arguments
Arguments to the optimizer after ``params``. Arguments to the optimizer after ``params``.
num_classes
Number of outputs (classes) for this model.
input_size input_size
The number of inputs this classifer shall process. The number of inputs this classifer shall process.
""" """
def __init__( def __init__(
self, self,
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
validation_loss: torch.nn.Module | None = None, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {"lr": 1e-2}, optimizer_arguments: dict[str, typing.Any] = {"lr": 1e-2},
num_classes: int = 1,
input_size: int = 14, input_size: int = 14,
): ):
super().__init__() super().__init__(
name="logistic-regression",
self._train_loss = train_loss.to(self.device) loss_type=loss_type,
self._validation_loss = ( loss_arguments=loss_arguments,
validation_loss if validation_loss is not None else train_loss optimizer_type=optimizer_type,
).to(self.device) optimizer_arguments=optimizer_arguments,
self._optimizer_type = optimizer_type scheduler_type=None,
self._optimizer_arguments = optimizer_arguments scheduler_arguments={},
model_transforms=[],
self.name = "logistic-regression" augmentation_transforms=[],
num_classes=num_classes,
)
self.linear = nn.Linear(input_size, 1) self.linear = torch.nn.Linear(input_size, self.num_classes)
def forward(self, x): def forward(self, x):
return self.linear(self.normalizer(x)) return self.linear(self.normalizer(x))
def training_step(self, batch, batch_idx):
_input = batch[1]
labels = batch[2]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
output = self(_input)
# Manually move criterion to selected device, since not part of the model.
self._train_loss = self._train_loss.to(self.device)
training_loss = self._train_loss(output, labels.float())
return {"loss": training_loss}
def validation_step(self, batch, batch_idx, dataloader_idx=0):
_input = batch[1]
labels = batch[2]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# data forwarding on the existing network
output = self(_input)
# Manually move criterion to selected device, since not part of the model.
self._validation_loss = self._validation_loss.to(self.device)
validation_loss = self._validation_loss(output, labels.float())
if dataloader_idx == 0:
return {"validation_loss": validation_loss}
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
return torch.sigmoid(outputs)
def configure_optimizers(self):
return self._optimizer_type(
self.parameters(),
**self._optimizer_arguments,
)
...@@ -4,35 +4,32 @@ ...@@ -4,35 +4,32 @@
import typing import typing
import lightning.pytorch as pl
import torch import torch
import torch.nn
from .classification_model import ClassificationModel
class MultiLayerPerceptron(pl.LightningModule):
class MultiLayerPerceptron(ClassificationModel):
"""MLP with a variable number of inputs and hidden neurons (single layer). """MLP with a variable number of inputs and hidden neurons (single layer).
Parameters Parameters
---------- ----------
train_loss loss_type
The loss to be used during the training. The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
.. warning:: .. warning::
The loss should be set to always return batch averages (as opposed The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so. to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type optimizer_type
The type of optimizer to use for training. The type of optimizer to use for training.
optimizer_arguments optimizer_arguments
Arguments to the optimizer after ``params``. Arguments to the optimizer after ``params``.
num_classes
Number of outputs (classes) for this model.
input_size input_size
The number of inputs this classifer shall process. The number of inputs this classifer shall process.
hidden_size hidden_size
...@@ -41,76 +38,30 @@ class MultiLayerPerceptron(pl.LightningModule): ...@@ -41,76 +38,30 @@ class MultiLayerPerceptron(pl.LightningModule):
def __init__( def __init__(
self, self,
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
validation_loss: torch.nn.Module | None = None, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {"lr": 1e-2}, optimizer_arguments: dict[str, typing.Any] = {"lr": 1e-2},
num_classes: int = 1,
input_size: int = 14, input_size: int = 14,
hidden_size: int = 10, hidden_size: int = 10,
): ):
super().__init__() super().__init__(
name="mlp",
self._train_loss = train_loss.to(self.device) loss_type=loss_type,
self._validation_loss = ( loss_arguments=loss_arguments,
validation_loss if validation_loss is not None else train_loss optimizer_type=optimizer_type,
).to(self.device) optimizer_arguments=optimizer_arguments,
self._optimizer_type = optimizer_type scheduler_type=None,
self._optimizer_arguments = optimizer_arguments scheduler_arguments={},
model_transforms=[],
self.name = "mlp" augmentation_transforms=[],
num_classes=num_classes,
)
self.fc1 = torch.nn.Linear(input_size, hidden_size) self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU() self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, 1) self.fc2 = torch.nn.Linear(hidden_size, self.num_classes)
def forward(self, x): def forward(self, x):
return self.fc2(self.relu(self.fc1(self.normalizer(x)))) return self.fc2(self.relu(self.fc1(self.normalizer(x))))
def training_step(self, batch, batch_idx):
_input = batch[1]
labels = batch[2]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
output = self(_input)
# Manually move criterion to selected device, since not part of the model.
self._train_loss = self._train_loss.to(self.device)
training_loss = self._train_loss(output, labels.float())
return {"loss": training_loss}
def validation_step(self, batch, batch_idx, dataloader_idx=0):
_input = batch[1]
labels = batch[2]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# data forwarding on the existing network
output = self(_input)
# Manually move criterion to selected device, since not part of the model.
self._validation_loss = self._validation_loss.to(self.device)
validation_loss = self._validation_loss(output, labels.float())
if dataloader_idx == 0:
return {"validation_loss": validation_loss}
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
return torch.sigmoid(outputs)
def configure_optimizers(self):
return self._optimizer_type(
self.parameters(),
**self._optimizer_arguments,
)
...@@ -56,31 +56,29 @@ class Pasa(ClassificationModel): ...@@ -56,31 +56,29 @@ class Pasa(ClassificationModel):
def __init__( def __init__(
self, self,
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] = {}, scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [], model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
num_classes: int = 1, num_classes: int = 1,
): ):
super().__init__( super().__init__(
loss_type, name="pasa",
loss_arguments, loss_type=loss_type,
optimizer_type, loss_arguments=loss_arguments,
optimizer_arguments, optimizer_type=optimizer_type,
scheduler_type, optimizer_arguments=optimizer_arguments,
scheduler_arguments, scheduler_type=scheduler_type,
model_transforms, scheduler_arguments=scheduler_arguments,
augmentation_transforms, model_transforms=model_transforms,
num_classes, augmentation_transforms=augmentation_transforms,
num_classes=num_classes,
) )
self.name = "pasa"
self.num_classes = num_classes
# First convolution block # First convolution block
self.fc1 = torch.nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1)) self.fc1 = torch.nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
self.fc2 = torch.nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1)) self.fc2 = torch.nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
......
...@@ -73,8 +73,6 @@ def run( ...@@ -73,8 +73,6 @@ def run(
output_folder.mkdir(parents=True, exist_ok=True) output_folder.mkdir(parents=True, exist_ok=True)
model.configure_losses()
from .loggers import CustomTensorboardLogger from .loggers import CustomTensorboardLogger
log_dir = "logs" log_dir = "logs"
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import copy
import logging import logging
import typing import typing
...@@ -25,6 +26,8 @@ class Model(pl.LightningModule): ...@@ -25,6 +26,8 @@ class Model(pl.LightningModule):
Parameters Parameters
---------- ----------
name
Common name to give to models of this type.
loss_type loss_type
The loss to be used for training and evaluation. The loss to be used for training and evaluation.
...@@ -54,6 +57,7 @@ class Model(pl.LightningModule): ...@@ -54,6 +57,7 @@ class Model(pl.LightningModule):
def __init__( def __init__(
self, self,
name: str,
loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss, loss_type: type[torch.nn.Module] = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
...@@ -66,25 +70,21 @@ class Model(pl.LightningModule): ...@@ -66,25 +70,21 @@ class Model(pl.LightningModule):
): ):
super().__init__() super().__init__()
self.name = "model" self.name = name
self.num_classes = num_classes self.num_classes = num_classes
self.model_transforms = model_transforms self.model_transforms = model_transforms
self._loss_type = loss_type self._loss_type = loss_type
self._train_loss_arguments = copy.deepcopy(loss_arguments)
self._train_loss_arguments = loss_arguments self._validation_loss_arguments = copy.deepcopy(loss_arguments)
self._validation_loss_arguments = loss_arguments
self._optimizer_type = optimizer_type self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments self._optimizer_arguments = optimizer_arguments
self._scheduler_type = scheduler_type self._scheduler_type = scheduler_type
self._scheduler_arguments = scheduler_arguments self._scheduler_arguments = scheduler_arguments
self.augmentation_transforms = augmentation_transforms self.augmentation_transforms = augmentation_transforms
# initializes losses from input arguments
self._configure_losses()
@property @property
def augmentation_transforms(self) -> torchvision.transforms.Compose: def augmentation_transforms(self) -> torchvision.transforms.Compose:
return self._augmentation_transforms return self._augmentation_transforms
...@@ -149,8 +149,14 @@ class Model(pl.LightningModule): ...@@ -149,8 +149,14 @@ class Model(pl.LightningModule):
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
raise NotImplementedError raise NotImplementedError
def configure_losses(self): def _configure_losses(self):
"""Create loss objects for train and validation."""
logger.info(f"Configuring train loss ({self._train_loss_arguments})...")
self._train_loss = self._loss_type(**self._train_loss_arguments) self._train_loss = self._loss_type(**self._train_loss_arguments)
logger.info(
f"Configuring validation loss ({self._validation_loss_arguments})..."
)
self._validation_loss = self._loss_type(**self._validation_loss_arguments) self._validation_loss = self._loss_type(**self._validation_loss_arguments)
def configure_optimizers(self): def configure_optimizers(self):
...@@ -195,10 +201,8 @@ class Model(pl.LightningModule): ...@@ -195,10 +201,8 @@ class Model(pl.LightningModule):
] ]
) )
if self._train_loss is not None: self._train_loss.to(*args, **kwargs)
self._train_loss.to(*args, **kwargs) self._validation_loss.to(*args, **kwargs)
if self._validation_loss is not None:
self._validation_loss.to(*args, **kwargs)
return self return self
...@@ -213,23 +217,36 @@ class Model(pl.LightningModule): ...@@ -213,23 +217,36 @@ class Model(pl.LightningModule):
try: try:
getattr(self._loss_type(), "pos_weight") getattr(self._loss_type(), "pos_weight")
except AttributeError: except AttributeError:
logger.warning( logger.warning(
f"Loss {self._loss_type} does not posess a 'pos_weight' attribute and will not be balanced." f"Loss {self._loss_type} does not posess a 'pos_weight' attribute and will not be balanced."
) )
else: else:
logger.info(f"Balancing training loss {self._loss_type}.")
train_weights = get_positive_weights(datamodule.train_dataloader()) train_weights = get_positive_weights(datamodule.train_dataloader())
self._train_loss_arguments["pos_weight"] = train_weights self._train_loss_arguments["pos_weight"] = train_weights
logger.info(
f"Balanced training loss {self._loss_type}: "
f"`pos_weight={train_weights.item():.3f}`."
)
logger.info(f"Balancing validation loss {self._loss_type}.")
if "validation" in datamodule.val_dataloader().keys(): if "validation" in datamodule.val_dataloader().keys():
validation_weights = get_positive_weights( validation_weights = get_positive_weights(
datamodule.val_dataloader()["validation"] datamodule.val_dataloader()["validation"]
) )
else: else:
logger.warning( logger.warning(
"Datamodule does not contain a validation dataloader. The training dataloader will be used instead." "Datamodule does not contain a validation dataloader. "
"The training dataloader will be used instead."
) )
validation_weights = get_positive_weights(datamodule.train_dataloader()) validation_weights = get_positive_weights(datamodule.train_dataloader())
self._validation_loss_arguments["pos_weight"] = validation_weights self._validation_loss_arguments["pos_weight"] = validation_weights
logger.info(
f"Balanced validation loss {self._loss_type}: "
f"`pos_weight={validation_weights.item():.3f}`."
)
# re-instantiates losses for train and validation
self._configure_losses()
...@@ -8,6 +8,7 @@ import typing ...@@ -8,6 +8,7 @@ import typing
import torch import torch
import torch.nn import torch.nn
import torch.utils.data
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from .backbones.vgg import vgg16_for_segmentation from .backbones.vgg import vgg16_for_segmentation
...@@ -43,7 +44,7 @@ class DRIUHead(torch.nn.Module): ...@@ -43,7 +44,7 @@ class DRIUHead(torch.nn.Module):
Number of channels for each feature map that is returned from backbone. Number of channels for each feature map that is returned from backbone.
""" """
def __init__(self, in_channels_list=None): def __init__(self, in_channels_list):
super().__init__() super().__init__()
( (
in_conv_1_2_16, in_conv_1_2_16,
...@@ -106,11 +107,11 @@ class DRIU(SegmentationModel): ...@@ -106,11 +107,11 @@ class DRIU(SegmentationModel):
def __init__( def __init__(
self, self,
loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss, loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] = {}, scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [], model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -118,17 +119,17 @@ class DRIU(SegmentationModel): ...@@ -118,17 +119,17 @@ class DRIU(SegmentationModel):
pretrained: bool = False, pretrained: bool = False,
): ):
super().__init__( super().__init__(
loss_type, name="driu",
loss_arguments, loss_type=loss_type,
optimizer_type, loss_arguments=loss_arguments,
optimizer_arguments, optimizer_type=optimizer_type,
scheduler_type, optimizer_arguments=optimizer_arguments,
scheduler_arguments, scheduler_type=scheduler_type,
model_transforms, scheduler_arguments=scheduler_arguments,
augmentation_transforms, model_transforms=model_transforms,
num_classes, augmentation_transforms=augmentation_transforms,
num_classes=num_classes,
) )
self.name = "driu"
self.pretrained = pretrained self.pretrained = pretrained
......
...@@ -8,6 +8,7 @@ import typing ...@@ -8,6 +8,7 @@ import typing
import torch import torch
import torch.nn import torch.nn
import torch.utils.data
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from .backbones.vgg import vgg16_for_segmentation from .backbones.vgg import vgg16_for_segmentation
...@@ -46,7 +47,7 @@ class DRIUBNHead(torch.nn.Module): ...@@ -46,7 +47,7 @@ class DRIUBNHead(torch.nn.Module):
Number of channels for each feature map that is returned from backbone. Number of channels for each feature map that is returned from backbone.
""" """
def __init__(self, in_channels_list=None): def __init__(self, in_channels_list):
super().__init__() super().__init__()
( (
in_conv_1_2_16, in_conv_1_2_16,
...@@ -109,11 +110,11 @@ class DRIUBN(SegmentationModel): ...@@ -109,11 +110,11 @@ class DRIUBN(SegmentationModel):
def __init__( def __init__(
self, self,
loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss, loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] = {}, scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [], model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -121,17 +122,17 @@ class DRIUBN(SegmentationModel): ...@@ -121,17 +122,17 @@ class DRIUBN(SegmentationModel):
pretrained: bool = False, pretrained: bool = False,
): ):
super().__init__( super().__init__(
loss_type, name="driu-bn",
loss_arguments, loss_type=loss_type,
optimizer_type, loss_arguments=loss_arguments,
optimizer_arguments, optimizer_type=optimizer_type,
scheduler_type, optimizer_arguments=optimizer_arguments,
scheduler_arguments, scheduler_type=scheduler_type,
model_transforms, scheduler_arguments=scheduler_arguments,
augmentation_transforms, model_transforms=model_transforms,
num_classes, augmentation_transforms=augmentation_transforms,
num_classes=num_classes,
) )
self.name = "driu-bn"
self.pretrained = pretrained self.pretrained = pretrained
......
...@@ -28,7 +28,7 @@ class DRIUODHead(torch.nn.Module): ...@@ -28,7 +28,7 @@ class DRIUODHead(torch.nn.Module):
Number of channels for each feature map that is returned from backbone. Number of channels for each feature map that is returned from backbone.
""" """
def __init__(self, in_channels_list=None): def __init__(self, in_channels_list):
super().__init__() super().__init__()
( (
in_upsample2, in_upsample2,
...@@ -91,11 +91,11 @@ class DRIUOD(SegmentationModel): ...@@ -91,11 +91,11 @@ class DRIUOD(SegmentationModel):
def __init__( def __init__(
self, self,
loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss, loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] = {}, scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [], model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -103,17 +103,17 @@ class DRIUOD(SegmentationModel): ...@@ -103,17 +103,17 @@ class DRIUOD(SegmentationModel):
pretrained: bool = False, pretrained: bool = False,
): ):
super().__init__( super().__init__(
loss_type, name="driu-od",
loss_arguments, loss_type=loss_type,
optimizer_type, loss_arguments=loss_arguments,
optimizer_arguments, optimizer_type=optimizer_type,
scheduler_type, optimizer_arguments=optimizer_arguments,
scheduler_arguments, scheduler_type=scheduler_type,
model_transforms, scheduler_arguments=scheduler_arguments,
augmentation_transforms, model_transforms=model_transforms,
num_classes, augmentation_transforms=augmentation_transforms,
num_classes=num_classes,
) )
self.name = "driu-od"
self.pretrained = pretrained self.pretrained = pretrained
......
...@@ -8,6 +8,7 @@ import typing ...@@ -8,6 +8,7 @@ import typing
import torch import torch
import torch.nn import torch.nn
import torch.utils.data
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from .backbones.vgg import vgg16_for_segmentation from .backbones.vgg import vgg16_for_segmentation
...@@ -28,7 +29,7 @@ class DRIUPIXHead(torch.nn.Module): ...@@ -28,7 +29,7 @@ class DRIUPIXHead(torch.nn.Module):
Number of channels for each feature map that is returned from backbone. Number of channels for each feature map that is returned from backbone.
""" """
def __init__(self, in_channels_list=None): def __init__(self, in_channels_list):
super().__init__() super().__init__()
( (
in_conv_1_2_16, in_conv_1_2_16,
...@@ -95,11 +96,11 @@ class DRIUPix(SegmentationModel): ...@@ -95,11 +96,11 @@ class DRIUPix(SegmentationModel):
def __init__( def __init__(
self, self,
loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss, loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] = {}, scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [], model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -107,17 +108,17 @@ class DRIUPix(SegmentationModel): ...@@ -107,17 +108,17 @@ class DRIUPix(SegmentationModel):
pretrained: bool = False, pretrained: bool = False,
): ):
super().__init__( super().__init__(
loss_type, name="driu-pix",
loss_arguments, loss_type=loss_type,
optimizer_type, loss_arguments=loss_arguments,
optimizer_arguments, optimizer_type=optimizer_type,
scheduler_type, optimizer_arguments=optimizer_arguments,
scheduler_arguments, scheduler_type=scheduler_type,
model_transforms, scheduler_arguments=scheduler_arguments,
augmentation_transforms, model_transforms=model_transforms,
num_classes, augmentation_transforms=augmentation_transforms,
num_classes=num_classes,
) )
self.name = "driu-pix"
self.pretrained = pretrained self.pretrained = pretrained
......
...@@ -7,6 +7,7 @@ import typing ...@@ -7,6 +7,7 @@ import typing
import torch import torch
import torch.nn import torch.nn
import torch.utils.data
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from .backbones.vgg import vgg16_for_segmentation from .backbones.vgg import vgg16_for_segmentation
...@@ -40,7 +41,7 @@ class HEDHead(torch.nn.Module): ...@@ -40,7 +41,7 @@ class HEDHead(torch.nn.Module):
Number of channels for each feature map that is returned from backbone. Number of channels for each feature map that is returned from backbone.
""" """
def __init__(self, in_channels_list=None): def __init__(self, in_channels_list):
super().__init__() super().__init__()
( (
in_conv_1_2_16, in_conv_1_2_16,
...@@ -109,11 +110,11 @@ class HED(SegmentationModel): ...@@ -109,11 +110,11 @@ class HED(SegmentationModel):
def __init__( def __init__(
self, self,
loss_type: torch.nn.Module = MultiSoftJaccardBCELogitsLoss, loss_type: type[torch.nn.Module] = MultiSoftJaccardBCELogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] = {}, scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [], model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -121,19 +122,18 @@ class HED(SegmentationModel): ...@@ -121,19 +122,18 @@ class HED(SegmentationModel):
pretrained: bool = False, pretrained: bool = False,
): ):
super().__init__( super().__init__(
loss_type, name="hed",
loss_arguments, loss_type=loss_type,
optimizer_type, loss_arguments=loss_arguments,
optimizer_arguments, optimizer_type=optimizer_type,
scheduler_type, optimizer_arguments=optimizer_arguments,
scheduler_arguments, scheduler_type=scheduler_type,
model_transforms, scheduler_arguments=scheduler_arguments,
augmentation_transforms, model_transforms=model_transforms,
num_classes, augmentation_transforms=augmentation_transforms,
num_classes=num_classes,
) )
self.name = "hed"
self.pretrained = pretrained self.pretrained = pretrained
self.backbone = vgg16_for_segmentation( self.backbone = vgg16_for_segmentation(
......
...@@ -310,31 +310,29 @@ class LittleWNet(SegmentationModel): ...@@ -310,31 +310,29 @@ class LittleWNet(SegmentationModel):
def __init__( def __init__(
self, self,
loss_type: torch.nn.Module = MultiWeightedBCELogitsLoss, loss_type: type[torch.nn.Module] = MultiWeightedBCELogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] = {}, scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [], model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
num_classes: int = 1, num_classes: int = 1,
): ):
super().__init__( super().__init__(
loss_type, name="lwnet",
loss_arguments, loss_type=loss_type,
optimizer_type, loss_arguments=loss_arguments,
optimizer_arguments, optimizer_type=optimizer_type,
scheduler_type, optimizer_arguments=optimizer_arguments,
scheduler_arguments, scheduler_type=scheduler_type,
model_transforms, scheduler_arguments=scheduler_arguments,
augmentation_transforms, model_transforms=model_transforms,
num_classes, augmentation_transforms=augmentation_transforms,
num_classes=num_classes,
) )
self.name = "lwnet"
self.num_classes = num_classes
self.unet1 = LittleUNet( self.unet1 = LittleUNet(
in_c=3, in_c=3,
n_classes=self.num_classes, n_classes=self.num_classes,
......
...@@ -7,6 +7,7 @@ import typing ...@@ -7,6 +7,7 @@ import typing
import torch import torch
import torch.nn import torch.nn
import torch.utils.data
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
from torchvision.models.mobilenetv2 import InvertedResidual from torchvision.models.mobilenetv2 import InvertedResidual
...@@ -157,11 +158,11 @@ class M2UNET(SegmentationModel): ...@@ -157,11 +158,11 @@ class M2UNET(SegmentationModel):
def __init__( def __init__(
self, self,
loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss, loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] = {}, scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [], model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -169,19 +170,18 @@ class M2UNET(SegmentationModel): ...@@ -169,19 +170,18 @@ class M2UNET(SegmentationModel):
pretrained: bool = False, pretrained: bool = False,
): ):
super().__init__( super().__init__(
loss_type, name="m2unet",
loss_arguments, loss_type=loss_type,
optimizer_type, loss_arguments=loss_arguments,
optimizer_arguments, optimizer_type=optimizer_type,
scheduler_type, optimizer_arguments=optimizer_arguments,
scheduler_arguments, scheduler_type=scheduler_type,
model_transforms, scheduler_arguments=scheduler_arguments,
augmentation_transforms, model_transforms=model_transforms,
num_classes, augmentation_transforms=augmentation_transforms,
num_classes=num_classes,
) )
self.name = "m2unet"
self.pretrained = pretrained self.pretrained = pretrained
self.backbone = mobilenet_v2_for_segmentation( self.backbone = mobilenet_v2_for_segmentation(
......
...@@ -21,6 +21,8 @@ class SegmentationModel(Model): ...@@ -21,6 +21,8 @@ class SegmentationModel(Model):
Parameters Parameters
---------- ----------
name
Common name to give to models of this type.
loss_type loss_type
The loss to be used for training and evaluation. The loss to be used for training and evaluation.
...@@ -50,6 +52,7 @@ class SegmentationModel(Model): ...@@ -50,6 +52,7 @@ class SegmentationModel(Model):
def __init__( def __init__(
self, self,
name: str,
loss_type: type[torch.nn.Module] = MultiWeightedBCELogitsLoss, loss_type: type[torch.nn.Module] = MultiWeightedBCELogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
...@@ -61,6 +64,7 @@ class SegmentationModel(Model): ...@@ -61,6 +64,7 @@ class SegmentationModel(Model):
num_classes: int = 1, num_classes: int = 1,
): ):
super().__init__( super().__init__(
name,
loss_type, loss_type,
loss_arguments, loss_arguments,
optimizer_type, optimizer_type,
......
...@@ -7,6 +7,7 @@ import logging ...@@ -7,6 +7,7 @@ import logging
import typing import typing
import torch.nn import torch.nn
import torch.utils.data
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from .backbones.vgg import vgg16_for_segmentation from .backbones.vgg import vgg16_for_segmentation
...@@ -28,7 +29,7 @@ class UNetHead(torch.nn.Module): ...@@ -28,7 +29,7 @@ class UNetHead(torch.nn.Module):
If True, upsample using PixelShuffleICNR. If True, upsample using PixelShuffleICNR.
""" """
def __init__(self, in_channels_list: list[int] = None, pixel_shuffle=False): def __init__(self, in_channels_list: list[int], pixel_shuffle=False):
super().__init__() super().__init__()
# number of channels # number of channels
c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list c_decode1, c_decode2, c_decode3, c_decode4, c_decode5 = in_channels_list
...@@ -98,11 +99,11 @@ class Unet(SegmentationModel): ...@@ -98,11 +99,11 @@ class Unet(SegmentationModel):
def __init__( def __init__(
self, self,
loss_type: torch.nn.Module = SoftJaccardBCELogitsLoss, loss_type: type[torch.nn.Module] = SoftJaccardBCELogitsLoss,
loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
scheduler_type: type[torch.optim.lr_scheduler] = None, scheduler_type: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_arguments: dict[str, typing.Any] = {}, scheduler_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [], model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -110,19 +111,18 @@ class Unet(SegmentationModel): ...@@ -110,19 +111,18 @@ class Unet(SegmentationModel):
pretrained: bool = False, pretrained: bool = False,
): ):
super().__init__( super().__init__(
loss_type, name="unet",
loss_arguments, loss_type=loss_type,
optimizer_type, loss_arguments=loss_arguments,
optimizer_arguments, optimizer_type=optimizer_type,
scheduler_type, optimizer_arguments=optimizer_arguments,
scheduler_arguments, scheduler_type=scheduler_type,
model_transforms, scheduler_arguments=scheduler_arguments,
augmentation_transforms, model_transforms=model_transforms,
num_classes, augmentation_transforms=augmentation_transforms,
num_classes=num_classes,
) )
self.name = "unet"
self.pretrained = pretrained self.pretrained = pretrained
self.backbone = vgg16_for_segmentation( self.backbone = vgg16_for_segmentation(
...@@ -160,26 +160,3 @@ class Unet(SegmentationModel): ...@@ -160,26 +160,3 @@ class Unet(SegmentationModel):
self.normalizer = make_imagenet_normalizer() self.normalizer = make_imagenet_normalizer()
else: else:
super().set_normalizer(dataloader) super().set_normalizer(dataloader)
def training_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(self._augmentation_transforms(images))
return self._train_loss(outputs, ground_truths, masks)
def validation_step(self, batch, batch_idx):
images = batch[0]["image"]
ground_truths = batch[0]["target"]
masks = batch[0]["mask"]
outputs = self(images)
return self._validation_loss(outputs, ground_truths, masks)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
output = self(batch[0]["image"])[1]
return torch.sigmoid(output)
def configure_optimizers(self):
return self._optimizer_type(self.parameters(), **self._optimizer_arguments)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment