Skip to content
Snippets Groups Projects
Commit da13b5b0 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[normalizer] Move normalizer out of common package and inside libs

parent 2bc38491
No related branches found
No related tags found
1 merge request!46Create common library
Showing
with 123 additions and 25 deletions
......@@ -192,6 +192,23 @@ class Pasa(Model):
# x = F.log_softmax(x, dim=1) # 0 is batch size
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the input normalizer for the current model.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
"""
from .normalizer import make_z_normalizer
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["target"]
......
......@@ -130,21 +130,7 @@ class Model(pl.LightningModule):
self.normalizer = checkpoint["normalizer"]
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the input normalizer for the current model.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
"""
from .normalizer import make_z_normalizer
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
raise NotImplementedError
def training_step(self, batch, _):
raise NotImplementedError
......
......@@ -57,11 +57,11 @@ class SegmentationRawDataLoader(_SegmentationRawDataLoader):
PIL.Image.open(self.datadir / sample[2]).convert(mode="1", dither=None)
)
tensor = tv_tensors.Image(crop_image_to_mask(image, mask))
target = tv_tensors.Image(crop_image_to_mask(target, mask))
image = tv_tensors.Image(crop_image_to_mask(image, mask))
target = tv_tensors.Mask(crop_image_to_mask(target, mask))
mask = tv_tensors.Mask(crop_image_to_mask(mask, mask))
return tensor, dict(target=target, mask=mask, name=sample[0]) # type: ignore[arg-type]
return dict(image=image, target=target, mask=mask), dict(name=sample[0]) # type: ignore[arg-type]
class DataModule(CachingDataModule):
......
......@@ -149,7 +149,7 @@ class DRIU(Model):
Will not be used if the model is pretrained.
"""
if self.pretrained:
from mednet.libs.common.models.normalizer import make_imagenet_normalizer
from .normalizer import make_imagenet_normalizer
logger.warning(
f"ImageNet pre-trained {self.name} model - NOT "
......
......@@ -152,7 +152,7 @@ class DRIUBN(Model):
Will not be used if the model is pretrained.
"""
if self.pretrained:
from mednet.libs.common.models.normalizer import make_imagenet_normalizer
from .normalizer import make_imagenet_normalizer
logger.warning(
f"ImageNet pre-trained {self.name} model - NOT "
......
......@@ -134,7 +134,7 @@ class DRIUOD(Model):
Will not be used if the model is pretrained.
"""
if self.pretrained:
from mednet.libs.common.models.normalizer import make_imagenet_normalizer
from .normalizer import make_imagenet_normalizer
logger.warning(
f"ImageNet pre-trained {self.name} model - NOT "
......
......@@ -138,7 +138,7 @@ class DRIUPix(Model):
Will not be used if the model is pretrained.
"""
if self.pretrained:
from mednet.libs.common.models.normalizer import make_imagenet_normalizer
from .normalizer import make_imagenet_normalizer
logger.warning(
f"ImageNet pre-trained {self.name} model - NOT "
......
......@@ -153,7 +153,7 @@ class HED(Model):
Will not be used if the model is pretrained.
"""
if self.pretrained:
from mednet.libs.common.models.normalizer import make_imagenet_normalizer
from .normalizer import make_imagenet_normalizer
logger.warning(
f"ImageNet pre-trained {self.name} model - NOT "
......
......@@ -15,6 +15,7 @@ guide segmentation.
Reference: [GALDRAN-2020]_
"""
import logging
import typing
import torch
......@@ -23,6 +24,8 @@ from mednet.libs.common.data.typing import TransformSequence
from mednet.libs.common.models.model import Model
from mednet.libs.segmentation.models.losses import MultiWeightedBCELogitsLoss
logger = logging.getLogger("mednet")
def _conv1x1(in_planes, out_planes, stride=1):
return torch.nn.Conv2d(
......@@ -338,6 +341,23 @@ class LittleWNet(Model):
shortcut=True,
)
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the input normalizer for the current model.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
"""
from .normalizer import make_z_normalizer
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def forward(self, x):
xn = self.normalizer(x)
x1 = self.unet1(xn)
......
......@@ -201,7 +201,7 @@ class M2UNET(Model):
Will not be used if the model is pretrained.
"""
if self.pretrained:
from mednet.libs.common.models.normalizer import make_imagenet_normalizer
from .normalizer import make_imagenet_normalizer
logger.warning(
f"ImageNet pre-trained {self.name} model - NOT "
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Functions to compute normalisation factors based on dataloaders."""
import logging
import torch
import torch.nn
import torch.utils.data
import torchvision.transforms
import tqdm
logger = logging.getLogger("mednet")
def make_z_normalizer(
dataloader: torch.utils.data.DataLoader,
) -> torchvision.transforms.Normalize:
"""Compute mean and standard deviation from a dataloader.
This function will input a dataloader, and compute the mean and standard
deviation by image channel. It will work for both monochromatic, and color
inputs with 2, 3 or more color planes.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
Returns
-------
An initialized normalizer.
"""
# Peek the number of channels of batches in the data loader
batch = next(iter(dataloader))
channels = batch[0]["image"].shape[1]
# Initialises accumulators
mean = torch.zeros(channels, dtype=batch[0]["image"].dtype)
var = torch.zeros(channels, dtype=batch[0]["image"].dtype)
num_images = 0
# Evaluates mean and standard deviation
for batch in tqdm.tqdm(dataloader, unit="batch"):
data = batch[0]
data = data.view(data.size(0), data.size(1), -1)
num_images += data.size(0)
mean += data.mean(2).sum(0)
var += data.var(2).sum(0)
mean /= num_images
var /= num_images
std = torch.sqrt(var)
return torchvision.transforms.Normalize(mean, std)
def make_imagenet_normalizer() -> torchvision.transforms.Normalize:
"""Return the stock ImageNet normalisation weights from torchvision.
The weights are wrapped in a torch module. This normalizer only works for
**RGB (color) images**.
Returns
-------
An initialized normalizer.
"""
return torchvision.transforms.Normalize(
(0.485, 0.456, 0.406),
(0.229, 0.224, 0.225),
)
......@@ -142,7 +142,7 @@ class Unet(Model):
Will not be used if the model is pretrained.
"""
if self.pretrained:
from mednet.libs.common.models.normalizer import make_imagenet_normalizer
from .normalizer import make_imagenet_normalizer
logger.warning(
f"ImageNet pre-trained {self.name} model - NOT "
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment