Skip to content
Snippets Groups Projects
Commit 5ccbe6d6 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[models] Create specialized model per library

parent 145e8f74
No related branches found
No related tags found
1 merge request!46Create common library
Showing
with 198 additions and 119 deletions
...@@ -11,12 +11,13 @@ import torch.optim.optimizer ...@@ -11,12 +11,13 @@ import torch.optim.optimizer
import torch.utils.data import torch.utils.data
import torchvision.models as models import torchvision.models as models
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from .classification_model import ClassificationModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Alexnet(Model): class Alexnet(ClassificationModel):
"""Alexnet module. """Alexnet module.
Note: only usable with a normalized dataset Note: only usable with a normalized dataset
...@@ -115,13 +116,7 @@ class Alexnet(Model): ...@@ -115,13 +116,7 @@ class Alexnet(Model):
) )
self.normalizer = make_imagenet_normalizer() self.normalizer = make_imagenet_normalizer()
else: else:
from .normalizer import make_z_normalizer super().set_normalizer(dataloader)
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, _): def training_step(self, batch, _):
images = batch[0] images = batch[0]
......
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import typing
import torch
import torch.nn
import torch.optim.optimizer
import torch.utils.data
from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
logger = logging.getLogger("mednet")
class ClassificationModel(Model):
"""Base model for classification task.
Parameters
----------
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
model_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
num_classes
Number of outputs (classes) for this model.
"""
def __init__(
self,
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
):
super().__init__(
loss_type,
loss_arguments,
optimizer_type,
optimizer_arguments,
model_transforms,
augmentation_transforms,
num_classes,
)
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the input normalizer for the current model.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
"""
from .normalizer import make_z_normalizer
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
...@@ -11,12 +11,13 @@ import torch.optim.optimizer ...@@ -11,12 +11,13 @@ import torch.optim.optimizer
import torch.utils.data import torch.utils.data
import torchvision.models as models import torchvision.models as models
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from .classification_model import ClassificationModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Densenet(Model): class Densenet(ClassificationModel):
"""Densenet-121 module. """Densenet-121 module.
Parameters Parameters
...@@ -118,13 +119,7 @@ class Densenet(Model): ...@@ -118,13 +119,7 @@ class Densenet(Model):
) )
self.normalizer = make_imagenet_normalizer() self.normalizer = make_imagenet_normalizer()
else: else:
from .normalizer import make_z_normalizer super().set_normalizer(dataloader)
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, _): def training_step(self, batch, _):
images = batch[0] images = batch[0]
......
...@@ -11,12 +11,13 @@ import torch.nn.functional as F # noqa: N812 ...@@ -11,12 +11,13 @@ import torch.nn.functional as F # noqa: N812
import torch.optim.optimizer import torch.optim.optimizer
import torch.utils.data import torch.utils.data
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from .classification_model import ClassificationModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Pasa(Model): class Pasa(ClassificationModel):
"""Implementation of CNN by Pasa and others. """Implementation of CNN by Pasa and others.
Simple CNN for classification based on paper by [PASA-2019]_. Simple CNN for classification based on paper by [PASA-2019]_.
...@@ -192,23 +193,6 @@ class Pasa(Model): ...@@ -192,23 +193,6 @@ class Pasa(Model):
# x = F.log_softmax(x, dim=1) # 0 is batch size # x = F.log_softmax(x, dim=1) # 0 is batch size
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the input normalizer for the current model.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
"""
from .normalizer import make_z_normalizer
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, _): def training_step(self, batch, _):
images = batch[0] images = batch[0]
labels = batch[1]["target"] labels = batch[1]["target"]
......
...@@ -9,11 +9,11 @@ import typing ...@@ -9,11 +9,11 @@ import typing
import torch import torch
import torch.nn import torch.nn
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from .backbones.vgg import vgg16_for_segmentation from .backbones.vgg import vgg16_for_segmentation
from .losses import SoftJaccardBCELogitsLoss from .losses import SoftJaccardBCELogitsLoss
from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
from .segmentation_model import SegmentationModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -70,7 +70,7 @@ class DRIUHead(torch.nn.Module): ...@@ -70,7 +70,7 @@ class DRIUHead(torch.nn.Module):
return self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8) return self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8)
class DRIU(Model): class DRIU(SegmentationModel):
"""Implementation of the DRIU model. """Implementation of the DRIU model.
Parameters Parameters
...@@ -158,13 +158,7 @@ class DRIU(Model): ...@@ -158,13 +158,7 @@ class DRIU(Model):
) )
self.normalizer = make_imagenet_normalizer() self.normalizer = make_imagenet_normalizer()
else: else:
from .normalizer import make_z_normalizer super().set_normalizer(dataloader)
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0]["image"] images = batch[0]["image"]
......
...@@ -9,11 +9,11 @@ import typing ...@@ -9,11 +9,11 @@ import typing
import torch import torch
import torch.nn import torch.nn
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from .backbones.vgg import vgg16_for_segmentation from .backbones.vgg import vgg16_for_segmentation
from .losses import SoftJaccardBCELogitsLoss from .losses import SoftJaccardBCELogitsLoss
from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
from .segmentation_model import SegmentationModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -73,7 +73,7 @@ class DRIUBNHead(torch.nn.Module): ...@@ -73,7 +73,7 @@ class DRIUBNHead(torch.nn.Module):
return self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8) return self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8)
class DRIUBN(Model): class DRIUBN(SegmentationModel):
"""Implementation of the DRIU-BN model. """Implementation of the DRIU-BN model.
Parameters Parameters
...@@ -161,13 +161,7 @@ class DRIUBN(Model): ...@@ -161,13 +161,7 @@ class DRIUBN(Model):
) )
self.normalizer = make_imagenet_normalizer() self.normalizer = make_imagenet_normalizer()
else: else:
from .normalizer import make_z_normalizer super().set_normalizer(dataloader)
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0]["image"] images = batch[0]["image"]
......
...@@ -9,12 +9,12 @@ import typing ...@@ -9,12 +9,12 @@ import typing
import torch import torch
import torch.nn import torch.nn
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from .backbones.vgg import vgg16_for_segmentation from .backbones.vgg import vgg16_for_segmentation
from .driu import ConcatFuseBlock from .driu import ConcatFuseBlock
from .losses import SoftJaccardBCELogitsLoss from .losses import SoftJaccardBCELogitsLoss
from .make_layers import UpsampleCropBlock from .make_layers import UpsampleCropBlock
from .segmentation_model import SegmentationModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -55,7 +55,7 @@ class DRIUODHead(torch.nn.Module): ...@@ -55,7 +55,7 @@ class DRIUODHead(torch.nn.Module):
return self.concatfuse(upsample2, upsample4, upsample8, upsample16) return self.concatfuse(upsample2, upsample4, upsample8, upsample16)
class DRIUOD(Model): class DRIUOD(SegmentationModel):
"""Implementation of the DRIU-OD model. """Implementation of the DRIU-OD model.
Parameters Parameters
...@@ -143,13 +143,7 @@ class DRIUOD(Model): ...@@ -143,13 +143,7 @@ class DRIUOD(Model):
) )
self.normalizer = make_imagenet_normalizer() self.normalizer = make_imagenet_normalizer()
else: else:
from .normalizer import make_z_normalizer super().set_normalizer(dataloader)
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0]["image"] images = batch[0]["image"]
......
...@@ -9,12 +9,12 @@ import typing ...@@ -9,12 +9,12 @@ import typing
import torch import torch
import torch.nn import torch.nn
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from .backbones.vgg import vgg16_for_segmentation from .backbones.vgg import vgg16_for_segmentation
from .driu import ConcatFuseBlock from .driu import ConcatFuseBlock
from .losses import SoftJaccardBCELogitsLoss from .losses import SoftJaccardBCELogitsLoss
from .make_layers import UpsampleCropBlock from .make_layers import UpsampleCropBlock
from .segmentation_model import SegmentationModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -59,7 +59,7 @@ class DRIUPIXHead(torch.nn.Module): ...@@ -59,7 +59,7 @@ class DRIUPIXHead(torch.nn.Module):
return self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8) return self.concatfuse(conv1_2_16, upsample2, upsample4, upsample8)
class DRIUPix(Model): class DRIUPix(SegmentationModel):
"""Implementation of the DRIU-BN model. """Implementation of the DRIU-BN model.
Parameters Parameters
...@@ -147,13 +147,7 @@ class DRIUPix(Model): ...@@ -147,13 +147,7 @@ class DRIUPix(Model):
) )
self.normalizer = make_imagenet_normalizer() self.normalizer = make_imagenet_normalizer()
else: else:
from .normalizer import make_z_normalizer super().set_normalizer(dataloader)
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0]["image"] images = batch[0]["image"]
......
...@@ -8,11 +8,11 @@ import typing ...@@ -8,11 +8,11 @@ import typing
import torch import torch
import torch.nn import torch.nn
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from .backbones.vgg import vgg16_for_segmentation from .backbones.vgg import vgg16_for_segmentation
from .losses import MultiSoftJaccardBCELogitsLoss from .losses import MultiSoftJaccardBCELogitsLoss
from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform from .make_layers import UpsampleCropBlock, conv_with_kaiming_uniform
from .segmentation_model import SegmentationModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -73,7 +73,7 @@ class HEDHead(torch.nn.Module): ...@@ -73,7 +73,7 @@ class HEDHead(torch.nn.Module):
return (upsample2, upsample4, upsample8, upsample16, concatfuse) return (upsample2, upsample4, upsample8, upsample16, concatfuse)
class HED(Model): class HED(SegmentationModel):
"""Implementation of the HED model. """Implementation of the HED model.
Parameters Parameters
...@@ -162,13 +162,7 @@ class HED(Model): ...@@ -162,13 +162,7 @@ class HED(Model):
) )
self.normalizer = make_imagenet_normalizer() self.normalizer = make_imagenet_normalizer()
else: else:
from .normalizer import make_z_normalizer super().set_normalizer(dataloader)
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0]["image"] images = batch[0]["image"]
......
...@@ -21,9 +21,10 @@ import typing ...@@ -21,9 +21,10 @@ import typing
import torch import torch
import torch.nn import torch.nn
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
from .segmentation_model import SegmentationModel
logger = logging.getLogger("mednet") logger = logging.getLogger("mednet")
...@@ -275,7 +276,7 @@ class LittleUNet(torch.nn.Module): ...@@ -275,7 +276,7 @@ class LittleUNet(torch.nn.Module):
return self.final(x) return self.final(x)
class LittleWNet(Model): class LittleWNet(SegmentationModel):
"""Little W-Net model, concatenating two Little U-Net models. """Little W-Net model, concatenating two Little U-Net models.
Parameters Parameters
...@@ -341,23 +342,6 @@ class LittleWNet(Model): ...@@ -341,23 +342,6 @@ class LittleWNet(Model):
shortcut=True, shortcut=True,
) )
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the input normalizer for the current model.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
"""
from .normalizer import make_z_normalizer
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def forward(self, x): def forward(self, x):
xn = self.normalizer(x) xn = self.normalizer(x)
x1 = self.unet1(xn) x1 = self.unet1(xn)
......
...@@ -8,11 +8,11 @@ import typing ...@@ -8,11 +8,11 @@ import typing
import torch import torch
import torch.nn import torch.nn
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss from mednet.libs.segmentation.models.losses import SoftJaccardBCELogitsLoss
from torchvision.models.mobilenetv2 import InvertedResidual from torchvision.models.mobilenetv2 import InvertedResidual
from .backbones.mobilenetv2 import mobilenet_v2_for_segmentation from .backbones.mobilenetv2 import mobilenet_v2_for_segmentation
from .segmentation_model import SegmentationModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -121,7 +121,7 @@ class M2UNetHead(torch.nn.Module): ...@@ -121,7 +121,7 @@ class M2UNetHead(torch.nn.Module):
return self.decode1(decode2, x[1]) # 30, 3 return self.decode1(decode2, x[1]) # 30, 3
class M2UNET(Model): class M2UNET(SegmentationModel):
"""Implementation of the M2UNET model. """Implementation of the M2UNET model.
Parameters Parameters
...@@ -210,13 +210,7 @@ class M2UNET(Model): ...@@ -210,13 +210,7 @@ class M2UNET(Model):
) )
self.normalizer = make_imagenet_normalizer() self.normalizer = make_imagenet_normalizer()
else: else:
from .normalizer import make_z_normalizer super().set_normalizer(dataloader)
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0]["image"] images = batch[0]["image"]
......
# SPDX-FileCopyrightText: Copyright © 2024 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import typing
import torch
import torch.nn
import torch.optim.optimizer
import torch.utils.data
from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
logger = logging.getLogger("mednet")
class SegmentationModel(Model):
"""Base model for segmentation task.
Parameters
----------
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
model_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
num_classes
Number of outputs (classes) for this model.
"""
def __init__(
self,
loss_type: torch.nn.Module = MultiWeightedBCELogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
model_transforms: TransformSequence = [],
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
):
super().__init__(
loss_type,
loss_arguments,
optimizer_type,
optimizer_arguments,
model_transforms,
augmentation_transforms,
num_classes,
)
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the input normalizer for the current model.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
"""
from .normalizer import make_z_normalizer
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
...@@ -8,11 +8,11 @@ import typing ...@@ -8,11 +8,11 @@ import typing
import torch.nn import torch.nn
from mednet.libs.common.data.typing import TransformSequence from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from .backbones.vgg import vgg16_for_segmentation from .backbones.vgg import vgg16_for_segmentation
from .losses import SoftJaccardBCELogitsLoss from .losses import SoftJaccardBCELogitsLoss
from .make_layers import UnetBlock, conv_with_kaiming_uniform from .make_layers import UnetBlock, conv_with_kaiming_uniform
from .segmentation_model import SegmentationModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -62,7 +62,7 @@ class UNetHead(torch.nn.Module): ...@@ -62,7 +62,7 @@ class UNetHead(torch.nn.Module):
return self.final(decode1) return self.final(decode1)
class Unet(Model): class Unet(SegmentationModel):
"""Implementation of the Unet model. """Implementation of the Unet model.
Parameters Parameters
...@@ -151,13 +151,7 @@ class Unet(Model): ...@@ -151,13 +151,7 @@ class Unet(Model):
) )
self.normalizer = make_imagenet_normalizer() self.normalizer = make_imagenet_normalizer()
else: else:
from .normalizer import make_z_normalizer super().set_normalizer(dataloader)
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0]["image"] images = batch[0]["image"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment