Skip to content
Snippets Groups Projects
Commit 43775598 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Improved normalizer, updated pasa model

parent 8d11b566
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
Pipeline #75389 failed
...@@ -95,7 +95,11 @@ class _DelayedLoadingDataset(torch.utils.data.Dataset): ...@@ -95,7 +95,11 @@ class _DelayedLoadingDataset(torch.utils.data.Dataset):
): ):
self.split = split self.split = split
self.raw_data_loader = raw_data_loader self.raw_data_loader = raw_data_loader
self.transform = torchvision.transforms.Compose(*transforms) # Cannot unpack empty list
if len(transforms) > 0:
self.transform = torchvision.transforms.Compose([*transforms])
else:
self.transform = torchvision.transforms.Compose([])
def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]: def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
tensor, metadata = self.raw_data_loader(self.split[key]) tensor, metadata = self.raw_data_loader(self.split[key])
......
...@@ -57,7 +57,10 @@ class Densenet(pl.LightningModule): ...@@ -57,7 +57,10 @@ class Densenet(pl.LightningModule):
if self.pretrained: if self.pretrained:
from .normalizer import TorchVisionNormalizer from .normalizer import TorchVisionNormalizer
self.normalizer = TorchVisionNormalizer(..., ...) self.normalizer = TorchVisionNormalizer(
torch.Tensor([0.485, 0.456, 0.406]),
torch.Tensor([0.229, 0.224, 0.225]),
)
else: else:
from .normalizer import get_znorm_normalizer from .normalizer import get_znorm_normalizer
......
...@@ -23,12 +23,17 @@ class TorchVisionNormalizer(torch.nn.Module): ...@@ -23,12 +23,17 @@ class TorchVisionNormalizer(torch.nn.Module):
def __init__(self, subtract: torch.Tensor, divide: torch.Tensor): def __init__(self, subtract: torch.Tensor, divide: torch.Tensor):
super().__init__() super().__init__()
assert len(subtract) == len(divide), "TODO" if len(subtract) != len(divide):
assert len(subtract) in (1, 3), "TODO" raise ValueError(
self.subtract = subtract "Lengths of 'subtract' and 'divide' tensors should be the same."
self.divided = divide )
subtract = torch.zeros(len(subtract.shape))[None, :, None, None] if len(subtract) not in (1, 3):
divide = torch.ones(len(divide.shape))[None, :, None, None] raise ValueError(
"Length of 'subtract' tensor should be either 1 or 3, depending on the number of color channels."
)
subtract = torch.as_tensor(subtract)[None, :, None, None]
divide = torch.as_tensor(divide)[None, :, None, None]
self.register_buffer("subtract", subtract) self.register_buffer("subtract", subtract)
self.register_buffer("divide", divide) self.register_buffer("divide", divide)
self.name = "torchvision-normalizer" self.name = "torchvision-normalizer"
...@@ -41,13 +46,35 @@ class TorchVisionNormalizer(torch.nn.Module): ...@@ -41,13 +46,35 @@ class TorchVisionNormalizer(torch.nn.Module):
def get_znorm_normalizer( def get_znorm_normalizer(
dataloader: torch.utils.data.DataLoader, dataloader: torch.utils.data.DataLoader,
) -> TorchVisionNormalizer: ) -> TorchVisionNormalizer:
# TODO: Fix this function to use unaugmented training set """Returns a normalizer with the mean and std computed from a dataloader's
# TODO: This function is only applicable IFF we are not fine-tuning (ie. unaugmented training set.
# model does not re-use weights from imagenet training!)
# TODO: Add type hints Parameters
# TODO: Add documentation ----------
dataloader:
A torch Dataloader from which to compute the mean and std
Returns
-------
An initialized TorchVisionNormalizer
"""
mean = 0.0
var = 0.0
num_images = 0
for batch in dataloader:
data = batch[0]
data = data.view(data.size(0), data.size(1), -1)
num_images += data.size(0)
mean += data.mean(2).sum(0)
var += data.var(2).sum(0)
# 1 extract mean/std from dataloader mean /= num_images
var /= num_images
std = torch.sqrt(var)
# 2 return TorchVisionNormalizer(mean, std) normalizer = TorchVisionNormalizer(mean, std)
pass return normalizer
...@@ -8,8 +8,6 @@ import torch.nn as nn ...@@ -8,8 +8,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.data import torch.utils.data
from .normalizer import TorchVisionNormalizer
class PASA(pl.LightningModule): class PASA(pl.LightningModule):
"""PASA module. """PASA module.
...@@ -30,7 +28,7 @@ class PASA(pl.LightningModule): ...@@ -30,7 +28,7 @@ class PASA(pl.LightningModule):
self.name = "pasa" self.name = "pasa"
self.normalizer = TorchVisionNormalizer(nb_channels=1) self.normalizer = None
# First convolution block # First convolution block
self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1)) self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
...@@ -81,6 +79,11 @@ class PASA(pl.LightningModule): ...@@ -81,6 +79,11 @@ class PASA(pl.LightningModule):
self.dense = nn.Linear(80, 1) # Fully connected layer self.dense = nn.Linear(80, 1) # Fully connected layer
def forward(self, x): def forward(self, x):
if self.normalizer is None:
raise TypeError(
"The normalizer has not been initialized. Make sure to call set_normalizer() after creation of the model."
)
x = self.normalizer(x) x = self.normalizer(x)
# First convolution block # First convolution block
...@@ -129,14 +132,21 @@ class PASA(pl.LightningModule): ...@@ -129,14 +132,21 @@ class PASA(pl.LightningModule):
return x return x
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""TODO: Write this function documentation""" """Initializes the normalizer for the current model.
Parameters
----------
dataloader:
A torch Dataloader from which to compute the mean and std
"""
from .normalizer import get_znorm_normalizer from .normalizer import get_znorm_normalizer
self.normalizer = get_znorm_normalizer(dataloader) self.normalizer = get_znorm_normalizer(dataloader)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch["data"] images = batch[0]
labels = batch["label"] labels = batch[1]["label"]
# Increase label dimension if too low # Increase label dimension if too low
# Allows single and multiclass usage # Allows single and multiclass usage
...@@ -153,8 +163,8 @@ class PASA(pl.LightningModule): ...@@ -153,8 +163,8 @@ class PASA(pl.LightningModule):
return {"loss": training_loss} return {"loss": training_loss}
def validation_step(self, batch, batch_idx, dataloader_idx=0): def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch["data"] images = batch[0]
labels = batch["label"] labels = batch[1]["label"]
# Increase label dimension if too low # Increase label dimension if too low
# Allows single and multiclass usage # Allows single and multiclass usage
...@@ -176,9 +186,9 @@ class PASA(pl.LightningModule): ...@@ -176,9 +186,9 @@ class PASA(pl.LightningModule):
return {f"extra_validation_loss_{dataloader_idx}": validation_loss} return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch["name"] images = batch[0]
images = batch["data"] labels = batch[1]["label"]
labels = batch["label"] names = batch[1]["names"]
outputs = self(images) outputs = self(images)
probabilities = torch.sigmoid(outputs) probabilities = torch.sigmoid(outputs)
......
...@@ -263,26 +263,21 @@ def train( ...@@ -263,26 +263,21 @@ def train(
import torch.cuda import torch.cuda
import torch.nn import torch.nn
from ..data.dataset import normalize_data, reweight_BCEWithLogitsLoss from ..data.dataset import reweight_BCEWithLogitsLoss
from ..engine.trainer import run from ..engine.trainer import run
seed_everything(seed) seed_everything(seed)
checkpoint_file = get_checkpoint(output_folder, resume_from) checkpoint_file = get_checkpoint(output_folder, resume_from)
datamodule.update_module_properties( datamodule.set_chunk_size(batch_size, batch_chunk_count)
batch_size=batch_size, datamodule.parallel = parallel
batch_chunk_count=batch_chunk_count,
drop_incomplete_batch=drop_incomplete_batch,
cache_samples=cache_samples,
parallel=parallel,
)
datamodule.prepare_data() datamodule.prepare_data()
datamodule.setup(stage="fit") datamodule.setup(stage="fit")
reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid) reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid)
normalize_data(normalization, model, datamodule) model.set_normalizer(datamodule.unaugmented_train_dataloader())
arguments = {} arguments = {}
arguments["max_epoch"] = epochs arguments["max_epoch"] = epochs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment