Skip to content
Snippets Groups Projects
Commit bf987617 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Functional model training step for pasa

Implemented BCEWithLogitsLoss reweighting function
Removed save_hyperparameters from model
Apply augmentation transforms on singular images
Fixed model summary
parent bd46b6bd
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
Pipeline #75543 failed
......@@ -13,22 +13,25 @@ Reference: [PASA-2019]_
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...data.transforms import ElasticDeformation
from ...models.pasa import PASA
# config
optimizer_configs = {"lr": 8e-5}
# optimizer
optimizer = "Adam"
optimizer = Adam
optimizer_configs = {"lr": 8e-5}
# criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [ElasticDeformation(p=0.8)]
# from torchvision.transforms.v2 import ElasticTransform, InterpolationMode
# augmentation_transforms = [ElasticTransform(alpha=1000.0, sigma=30.0, interpolation=InterpolationMode.NEAREST)]
# model
model = PASA(
criterion,
......
......@@ -13,6 +13,8 @@ import torch
import torch.utils.data
import torchvision.transforms
from tqdm import tqdm
logger = logging.getLogger(__name__)
......@@ -150,8 +152,13 @@ class _CachedDataset(torch.utils.data.Dataset):
typing.Callable[[torch.Tensor], torch.Tensor]
] = [],
):
self.transform = torchvision.transforms.Compose(*transforms)
self.data = [raw_data_loader(k) for k in split]
# Cannot unpack empty list
if len(transforms) > 0:
self.transform = torchvision.transforms.Compose([*transforms])
else:
self.transform = torchvision.transforms.Compose([])
self.data = [raw_data_loader(k) for k in tqdm(split)]
def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
tensor, metadata = self.data[key]
......@@ -446,6 +453,7 @@ class CachingDataModule(lightning.LightningDataModule):
logger.info(f"Dataset {name} is already setup. Not reloading it.")
return
if self.cache_samples:
logger.info(f"Caching {name} dataset")
self._datasets[name] = _CachedDataset(
self.database_split[name],
self.raw_data_loader,
......
......@@ -10,11 +10,11 @@ import torch.utils.data
logger = logging.getLogger(__name__)
def _get_positive_weights(dataset):
def _get_positive_weights(dataloader):
"""Compute the positive weights of each class of the dataset to balance the
BCEWithLogitsLoss criterion.
This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
This function takes as input a :py:class:`torch.utils.data.DataLoader`
and computes the positive weights of each class to use them to have
a balanced loss.
......@@ -22,9 +22,8 @@ def _get_positive_weights(dataset):
Parameters
----------
dataset : torch.utils.data.dataset.Dataset
An instance of torch.utils.data.dataset.Dataset
ConcatDataset are supported
dataloader : :py:class:`torch.utils.data.DataLoader`
A DataLoader from which to compute the positive weights. Must contain a 'label' key in the metadata returned by __getitem__().
Returns
......@@ -35,14 +34,8 @@ def _get_positive_weights(dataset):
"""
targets = []
if isinstance(dataset, torch.utils.data.ConcatDataset):
for ds in dataset.datasets:
for s in ds._samples:
targets.append(s["label"])
else:
for s in dataset._samples:
targets.append(s["label"])
for batch in dataloader:
targets.extend(batch[1]["label"])
targets = torch.tensor(targets)
......@@ -71,33 +64,3 @@ def _get_positive_weights(dataset):
)
return positive_weights
def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
from torch.nn import BCEWithLogitsLoss
datamodule.prepare_data()
datamodule.setup(stage="fit")
train_dataset = datamodule.train_dataset
validation_dataset = datamodule.validation_dataset
# Redefine a weighted criterion if possible
if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
positive_weights = _get_positive_weights(train_dataset)
model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
else:
logger.warning("Weighted criterion not supported")
if validation_dataset is not None:
# Redefine a weighted valid criterion if possible
if (
isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss)
or criterion_valid is None
):
positive_weights = _get_positive_weights(validation_dataset)
model.hparams.criterion_valid = BCEWithLogitsLoss(
pos_weight=positive_weights
)
else:
logger.warning("Weighted valid criterion not supported")
......@@ -94,9 +94,7 @@ class LoggingCallback(Callback):
self.log("total_time", current_time)
self.log("eta", eta_seconds)
self.log("loss", numpy.average(self.training_loss))
self.log(
"learning_rate", pl_module.hparams["optimizer_configs"]["lr"]
)
self.log("learning_rate", pl_module.optimizer_configs["lr"])
self.log("validation_loss", numpy.sum(self.validation_loss))
if len(self.extra_validation_loss) > 0:
......
......@@ -51,20 +51,24 @@ def save_model_summary(
Returns
-------
r
summary:
The model summary in a text format.
n
total_parameters:
The number of parameters of the model.
"""
summary_path = os.path.join(output_folder, "model_summary.txt")
logger.info(f"Saving model summary at {summary_path}...")
with open(summary_path, "w") as f:
summary = lightning.pytorch.callbacks.ModelSummary(model, max_depth=-1)
summary = lightning.pytorch.utilities.model_summary.ModelSummary(
model, max_depth=-1
)
f.write(str(summary))
return (
summary,
lightning.pytorch.callbacks.ModelSummary(model).total_parameters,
lightning.pytorch.utilities.model_summary.ModelSummary(
model
).total_parameters,
)
......
......@@ -9,6 +9,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.transforms
logger = logging.getLogger(__name__)
......@@ -31,7 +32,15 @@ class PASA(pl.LightningModule):
self.name = "pasa"
self.augmentation_transforms = augmentation_transforms
self.augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms
)
self.criterion = criterion
self.criterion_valid = criterion_valid
self.optimizer = optimizer
self.optimizer_configs = optimizer_configs
self.normalizer = None
......@@ -137,7 +146,7 @@ class PASA(pl.LightningModule):
Parameters
----------
dataloader:
dataloader: :py:class:`torch.utils.data.DataLoader`
A torch Dataloader from which to compute the mean and std
"""
from .normalizer import make_z_normalizer
......@@ -148,6 +157,35 @@ class PASA(pl.LightningModule):
)
self.normalizer = make_z_normalizer(dataloader)
def set_bce_loss_weights(self, datamodule):
"""Reweights loss weights if BCEWithLogitsLoss is used.
Parameters
----------
datamodule:
A datamodule implementing train_dataloader() and val_dataloader()
"""
from ..data.dataset import _get_positive_weights
if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training criterion.")
train_positive_weights = _get_positive_weights(
datamodule.train_dataloader()
)
self.criterion = torch.nn.BCEWithLogitsLoss(
pos_weight=train_positive_weights
)
if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
validation_positive_weights = _get_positive_weights(
datamodule.val_dataloader()["validation"]
)
self.criterion_valid = torch.nn.BCEWithLogitsLoss(
pos_weight=validation_positive_weights
)
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
......@@ -158,12 +196,11 @@ class PASA(pl.LightningModule):
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
augmented_images = self.augmentation_transforms(images)
augmented_images = [self.augmentation_transforms(img) for img in images]
augmented_images = torch.unsqueeze(torch.cat(augmented_images, 0), 1)
outputs = self(augmented_images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion = self.hparams.criterion.to(self.device)
training_loss = self.hparams.criterion(outputs, labels.double())
training_loss = self.criterion(outputs, labels.double())
return {"loss": training_loss}
......@@ -179,11 +216,7 @@ class PASA(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion_valid = self.hparams.criterion_valid.to(
self.device
)
validation_loss = self.hparams.criterion_valid(outputs, labels.double())
validation_loss = self.criterion_valid(outputs, labels.double())
if dataloader_idx == 0:
return {"validation_loss": validation_loss}
......@@ -233,9 +266,5 @@ class PASA(pl.LightningModule):
# raise NotImplementedError
def configure_optimizers(self):
# Dynamically instantiates the optimizer given the configs
optimizer = getattr(torch.optim, self.hparams.optimizer)(
self.parameters(), **self.hparams.optimizer_configs
)
optimizer = self.optimizer(self.parameters(), **self.optimizer_configs)
return optimizer
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment