Skip to content
Snippets Groups Projects
Commit 11b65a40 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Functional model training step for pasa

Implemented BCEWithLogitsLoss reweighting function
Removed save_hyperparameters from model
Apply augmentation transforms on singular images
Fixed model summary
parent 9eb1e023
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -13,22 +13,25 @@ Reference: [PASA-2019]_ ...@@ -13,22 +13,25 @@ Reference: [PASA-2019]_
from torch import empty from torch import empty
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...data.transforms import ElasticDeformation
from ...models.pasa import PASA from ...models.pasa import PASA
# config
optimizer_configs = {"lr": 8e-5}
# optimizer # optimizer
optimizer = "Adam" optimizer = Adam
optimizer_configs = {"lr": 8e-5}
# criterion # criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [ElasticDeformation(p=0.8)] augmentation_transforms = [ElasticDeformation(p=0.8)]
# from torchvision.transforms.v2 import ElasticTransform, InterpolationMode
# augmentation_transforms = [ElasticTransform(alpha=1000.0, sigma=30.0, interpolation=InterpolationMode.NEAREST)]
# model # model
model = PASA( model = PASA(
criterion, criterion,
......
...@@ -13,6 +13,8 @@ import torch ...@@ -13,6 +13,8 @@ import torch
import torch.utils.data import torch.utils.data
import torchvision.transforms import torchvision.transforms
from tqdm import tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -150,8 +152,13 @@ class _CachedDataset(torch.utils.data.Dataset): ...@@ -150,8 +152,13 @@ class _CachedDataset(torch.utils.data.Dataset):
typing.Callable[[torch.Tensor], torch.Tensor] typing.Callable[[torch.Tensor], torch.Tensor]
] = [], ] = [],
): ):
self.transform = torchvision.transforms.Compose(*transforms) # Cannot unpack empty list
self.data = [raw_data_loader(k) for k in split] if len(transforms) > 0:
self.transform = torchvision.transforms.Compose([*transforms])
else:
self.transform = torchvision.transforms.Compose([])
self.data = [raw_data_loader(k) for k in tqdm(split)]
def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]: def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
tensor, metadata = self.data[key] tensor, metadata = self.data[key]
...@@ -446,6 +453,7 @@ class CachingDataModule(lightning.LightningDataModule): ...@@ -446,6 +453,7 @@ class CachingDataModule(lightning.LightningDataModule):
logger.info(f"Dataset {name} is already setup. Not reloading it.") logger.info(f"Dataset {name} is already setup. Not reloading it.")
return return
if self.cache_samples: if self.cache_samples:
logger.info(f"Caching {name} dataset")
self._datasets[name] = _CachedDataset( self._datasets[name] = _CachedDataset(
self.database_split[name], self.database_split[name],
self.raw_data_loader, self.raw_data_loader,
......
...@@ -10,11 +10,11 @@ import torch.utils.data ...@@ -10,11 +10,11 @@ import torch.utils.data
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _get_positive_weights(dataset): def _get_positive_weights(dataloader):
"""Compute the positive weights of each class of the dataset to balance the """Compute the positive weights of each class of the dataset to balance the
BCEWithLogitsLoss criterion. BCEWithLogitsLoss criterion.
This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` This function takes as input a :py:class:`torch.utils.data.DataLoader`
and computes the positive weights of each class to use them to have and computes the positive weights of each class to use them to have
a balanced loss. a balanced loss.
...@@ -22,9 +22,8 @@ def _get_positive_weights(dataset): ...@@ -22,9 +22,8 @@ def _get_positive_weights(dataset):
Parameters Parameters
---------- ----------
dataset : torch.utils.data.dataset.Dataset dataloader : :py:class:`torch.utils.data.DataLoader`
An instance of torch.utils.data.dataset.Dataset A DataLoader from which to compute the positive weights. Must contain a 'label' key in the metadata returned by __getitem__().
ConcatDataset are supported
Returns Returns
...@@ -35,14 +34,8 @@ def _get_positive_weights(dataset): ...@@ -35,14 +34,8 @@ def _get_positive_weights(dataset):
""" """
targets = [] targets = []
if isinstance(dataset, torch.utils.data.ConcatDataset): for batch in dataloader:
for ds in dataset.datasets: targets.extend(batch[1]["label"])
for s in ds._samples:
targets.append(s["label"])
else:
for s in dataset._samples:
targets.append(s["label"])
targets = torch.tensor(targets) targets = torch.tensor(targets)
...@@ -71,33 +64,3 @@ def _get_positive_weights(dataset): ...@@ -71,33 +64,3 @@ def _get_positive_weights(dataset):
) )
return positive_weights return positive_weights
def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
from torch.nn import BCEWithLogitsLoss
datamodule.prepare_data()
datamodule.setup(stage="fit")
train_dataset = datamodule.train_dataset
validation_dataset = datamodule.validation_dataset
# Redefine a weighted criterion if possible
if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
positive_weights = _get_positive_weights(train_dataset)
model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
else:
logger.warning("Weighted criterion not supported")
if validation_dataset is not None:
# Redefine a weighted valid criterion if possible
if (
isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss)
or criterion_valid is None
):
positive_weights = _get_positive_weights(validation_dataset)
model.hparams.criterion_valid = BCEWithLogitsLoss(
pos_weight=positive_weights
)
else:
logger.warning("Weighted valid criterion not supported")
...@@ -94,9 +94,7 @@ class LoggingCallback(Callback): ...@@ -94,9 +94,7 @@ class LoggingCallback(Callback):
self.log("total_time", current_time) self.log("total_time", current_time)
self.log("eta", eta_seconds) self.log("eta", eta_seconds)
self.log("loss", numpy.average(self.training_loss)) self.log("loss", numpy.average(self.training_loss))
self.log( self.log("learning_rate", pl_module.optimizer_configs["lr"])
"learning_rate", pl_module.hparams["optimizer_configs"]["lr"]
)
self.log("validation_loss", numpy.sum(self.validation_loss)) self.log("validation_loss", numpy.sum(self.validation_loss))
if len(self.extra_validation_loss) > 0: if len(self.extra_validation_loss) > 0:
......
...@@ -51,20 +51,24 @@ def save_model_summary( ...@@ -51,20 +51,24 @@ def save_model_summary(
Returns Returns
------- -------
r summary:
The model summary in a text format. The model summary in a text format.
n total_parameters:
The number of parameters of the model. The number of parameters of the model.
""" """
summary_path = os.path.join(output_folder, "model_summary.txt") summary_path = os.path.join(output_folder, "model_summary.txt")
logger.info(f"Saving model summary at {summary_path}...") logger.info(f"Saving model summary at {summary_path}...")
with open(summary_path, "w") as f: with open(summary_path, "w") as f:
summary = lightning.pytorch.callbacks.ModelSummary(model, max_depth=-1) summary = lightning.pytorch.utilities.model_summary.ModelSummary(
model, max_depth=-1
)
f.write(str(summary)) f.write(str(summary))
return ( return (
summary, summary,
lightning.pytorch.callbacks.ModelSummary(model).total_parameters, lightning.pytorch.utilities.model_summary.ModelSummary(
model
).total_parameters,
) )
......
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.data import torch.utils.data
import torchvision.transforms
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -31,7 +32,15 @@ class PASA(pl.LightningModule): ...@@ -31,7 +32,15 @@ class PASA(pl.LightningModule):
self.name = "pasa" self.name = "pasa"
self.augmentation_transforms = augmentation_transforms self.augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms
)
self.criterion = criterion
self.criterion_valid = criterion_valid
self.optimizer = optimizer
self.optimizer_configs = optimizer_configs
self.normalizer = None self.normalizer = None
...@@ -137,7 +146,7 @@ class PASA(pl.LightningModule): ...@@ -137,7 +146,7 @@ class PASA(pl.LightningModule):
Parameters Parameters
---------- ----------
dataloader: dataloader: :py:class:`torch.utils.data.DataLoader`
A torch Dataloader from which to compute the mean and std A torch Dataloader from which to compute the mean and std
""" """
from .normalizer import make_z_normalizer from .normalizer import make_z_normalizer
...@@ -148,6 +157,35 @@ class PASA(pl.LightningModule): ...@@ -148,6 +157,35 @@ class PASA(pl.LightningModule):
) )
self.normalizer = make_z_normalizer(dataloader) self.normalizer = make_z_normalizer(dataloader)
def set_bce_loss_weights(self, datamodule):
"""Reweights loss weights if BCEWithLogitsLoss is used.
Parameters
----------
datamodule:
A datamodule implementing train_dataloader() and val_dataloader()
"""
from ..data.dataset import _get_positive_weights
if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training criterion.")
train_positive_weights = _get_positive_weights(
datamodule.train_dataloader()
)
self.criterion = torch.nn.BCEWithLogitsLoss(
pos_weight=train_positive_weights
)
if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
validation_positive_weights = _get_positive_weights(
datamodule.val_dataloader()["validation"]
)
self.criterion_valid = torch.nn.BCEWithLogitsLoss(
pos_weight=validation_positive_weights
)
def training_step(self, batch, _): def training_step(self, batch, _):
images = batch[0] images = batch[0]
labels = batch[1]["label"] labels = batch[1]["label"]
...@@ -158,12 +196,11 @@ class PASA(pl.LightningModule): ...@@ -158,12 +196,11 @@ class PASA(pl.LightningModule):
labels = torch.reshape(labels, (labels.shape[0], 1)) labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network # Forward pass on the network
augmented_images = self.augmentation_transforms(images) augmented_images = [self.augmentation_transforms(img) for img in images]
augmented_images = torch.unsqueeze(torch.cat(augmented_images, 0), 1)
outputs = self(augmented_images) outputs = self(augmented_images)
# Manually move criterion to selected device, since not part of the model. training_loss = self.criterion(outputs, labels.double())
self.hparams.criterion = self.hparams.criterion.to(self.device)
training_loss = self.hparams.criterion(outputs, labels.double())
return {"loss": training_loss} return {"loss": training_loss}
...@@ -179,11 +216,7 @@ class PASA(pl.LightningModule): ...@@ -179,11 +216,7 @@ class PASA(pl.LightningModule):
# data forwarding on the existing network # data forwarding on the existing network
outputs = self(images) outputs = self(images)
# Manually move criterion to selected device, since not part of the model. validation_loss = self.criterion_valid(outputs, labels.double())
self.hparams.criterion_valid = self.hparams.criterion_valid.to(
self.device
)
validation_loss = self.hparams.criterion_valid(outputs, labels.double())
if dataloader_idx == 0: if dataloader_idx == 0:
return {"validation_loss": validation_loss} return {"validation_loss": validation_loss}
...@@ -233,9 +266,5 @@ class PASA(pl.LightningModule): ...@@ -233,9 +266,5 @@ class PASA(pl.LightningModule):
# raise NotImplementedError # raise NotImplementedError
def configure_optimizers(self): def configure_optimizers(self):
# Dynamically instantiates the optimizer given the configs optimizer = self.optimizer(self.parameters(), **self.optimizer_configs)
optimizer = getattr(torch.optim, self.hparams.optimizer)(
self.parameters(), **self.hparams.optimizer_configs
)
return optimizer return optimizer
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment