Skip to content
Snippets Groups Projects
Commit 4b512c5f authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Added DataModule in shenzhen, started support for extra-valid datasets

parent 2bdffabc
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
* See :py:mod:`ptbench.data.shenzhen` for dataset details * See :py:mod:`ptbench.data.shenzhen` for dataset details
""" """
from ....data.shenzhen.datamodule import ShenzhenDataModule
from . import _maker from . import _maker
dataset = _maker("default") dataset = _maker("default")
datamodule = ShenzhenDataModule
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import lightning as pl
import torch
from clapper.logging import setup
from torch.utils.data import DataLoader, WeightedRandomSampler
from ptbench.configs.datasets import get_samples_weights
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class ShenzhenDataModule(pl.LightningDataModule):
def __init__(
self,
dataset,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
super().__init__()
self.dataset = dataset
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.drop_incomplete_batch = drop_incomplete_batch
self.pin_memory = (
torch.cuda.is_available()
) # should only be true if GPU available and using it
self.multiproc_kwargs = multiproc_kwargs
def setup(self, stage: str):
if stage == "fit":
if "__train__" in self.dataset:
logger.info("Found (dedicated) '__train__' set for training")
self.train_dataset = self.dataset["__train__"]
else:
self.train_dataset = self.dataset["train"]
if "__valid__" in self.dataset:
logger.info("Found (dedicated) '__valid__' set for validation")
self.validation_dataset = self.dataset["__valid__"]
if "__extra_valid__" in self.dataset:
if not isinstance(self.dataset["__extra_valid__"], list):
raise RuntimeError(
f"If present, dataset['__extra_valid__'] must be a list, "
f"but you passed a {type(self.dataset['__extra_valid__'])}, "
f"which is invalid."
)
logger.info(
f"Found {len(self.dataset['__extra_valid__'])} extra validation "
f"set(s) to be tracked during training"
)
logger.info(
"Extra validation sets are NOT used for model checkpointing!"
)
self.extra_validation_datasets = self.dataset["__extra_valid__"]
else:
self.extra_validation_datasets = None
if stage == "predict":
self.predict_dataset = []
for split_key in self.dataset.keys():
if split_key.startswith("_"):
logger.info(
f"Skipping dataset '{split_key}' (not to be evaluated)"
)
continue
else:
self.predict_dataset.append(self.dataset[split_key])
def train_dataloader(self):
train_samples_weights = get_samples_weights(self.train_dataset)
train_sampler = WeightedRandomSampler(
train_samples_weights, len(train_samples_weights), replacement=True
)
return DataLoader(
self.train_dataset,
batch_size=self.train_batch_size,
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
sampler=train_sampler,
**self.multiproc_kwargs,
)
def val_dataloader(self):
loaders_dict = {}
val_loader = DataLoader(
dataset=self.validation_dataset,
batch_size=self.train_batch_size,
shuffle=False,
drop_last=False,
pin_memory=self.pin_memory,
**self.multiproc_kwargs,
)
loaders_dict["validation_loader"] = val_loader
if self.extra_validation_datasets is not None:
for set_idx, extra_set in enumerate(self.extra_validation_datasets):
extra_val_loader = DataLoader(
dataset=extra_set,
batch_size=self.train_batch_size,
shuffle=False,
drop_last=False,
pin_memory=self.pin_memory,
**self.multiproc_kwargs,
)
loaders_dict[
f"extra_validation_loader{set_idx}"
] = extra_val_loader
return loaders_dict
def predict_dataloader(self):
loaders_dict = {}
for set_idx, pred_set in enumerate(self.predict_dataset):
loaders_dict[set_idx] = DataLoader(
dataset=pred_set,
batch_size=self.predict_batch_size,
shuffle=False,
pin_memory=self.pin_memory,
)
return loaders_dict
import csv import csv
import time import time
from collections import defaultdict
import numpy import numpy
from lightning.pytorch import Callback from lightning.pytorch import Callback
...@@ -17,6 +19,7 @@ class LoggingCallback(Callback): ...@@ -17,6 +19,7 @@ class LoggingCallback(Callback):
super().__init__() super().__init__()
self.training_loss = [] self.training_loss = []
self.validation_loss = [] self.validation_loss = []
self.extra_validation_loss = defaultdict(list)
self.start_training_time = 0 self.start_training_time = 0
self.start_epoch_time = 0 self.start_epoch_time = 0
...@@ -37,6 +40,13 @@ class LoggingCallback(Callback): ...@@ -37,6 +40,13 @@ class LoggingCallback(Callback):
): ):
self.validation_loss.append(outputs["validation_loss"].item()) self.validation_loss.append(outputs["validation_loss"].item())
if len(outputs) > 1:
extra_validation_keys = outputs.keys().remove("validation_loss")
for extra_validation_loss_key in extra_validation_keys:
self.extra_validation_loss[extra_validation_loss_key].append(
outputs[extra_validation_loss_key]
)
def on_validation_epoch_end(self, trainer, pl_module): def on_validation_epoch_end(self, trainer, pl_module):
self.resource_monitor.trigger_summary() self.resource_monitor.trigger_summary()
...@@ -52,6 +62,15 @@ class LoggingCallback(Callback): ...@@ -52,6 +62,15 @@ class LoggingCallback(Callback):
self.log("learning_rate", pl_module.hparams["optimizer_configs"]["lr"]) self.log("learning_rate", pl_module.hparams["optimizer_configs"]["lr"])
self.log("validation_loss", numpy.average(self.validation_loss)) self.log("validation_loss", numpy.average(self.validation_loss))
if len(self.extra_validation_loss) > 0:
for (
extra_valid_loss_key,
extra_valid_loss_values,
) in self.extra_validation_loss.items:
self.log(
extra_valid_loss_key, numpy.average(extra_valid_loss_values)
)
queue_retries = 0 queue_retries = 0
# In case the resource monitor takes longer to fetch data from the queue, we wait # In case the resource monitor takes longer to fetch data from the queue, we wait
# Give up after self.resource_monitor.interval * self.max_queue_retries if cannot retrieve metrics from queue # Give up after self.resource_monitor.interval * self.max_queue_retries if cannot retrieve metrics from queue
......
...@@ -149,9 +149,7 @@ def create_logfile_fields(valid_loader, extra_valid_loaders, device): ...@@ -149,9 +149,7 @@ def create_logfile_fields(valid_loader, extra_valid_loaders, device):
def run( def run(
model, model,
data_loader, datamodule,
valid_loader,
extra_valid_loaders,
checkpoint_period, checkpoint_period,
accelerator, accelerator,
arguments, arguments,
...@@ -263,4 +261,4 @@ def run( ...@@ -263,4 +261,4 @@ def run(
callbacks=[LoggingCallback(resource_monitor), checkpoint_callback], callbacks=[LoggingCallback(resource_monitor), checkpoint_callback],
) )
_ = trainer.fit(model, data_loader, valid_loader, ckpt_path=checkpoint) _ = trainer.fit(model, datamodule, ckpt_path=checkpoint)
...@@ -141,7 +141,7 @@ class PASA(pl.LightningModule): ...@@ -141,7 +141,7 @@ class PASA(pl.LightningModule):
return {"loss": training_loss} return {"loss": training_loss}
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[1] images = batch[1]
labels = batch[2] labels = batch[2]
...@@ -159,9 +159,12 @@ class PASA(pl.LightningModule): ...@@ -159,9 +159,12 @@ class PASA(pl.LightningModule):
) )
validation_loss = self.hparams.criterion_valid(outputs, labels.double()) validation_loss = self.hparams.criterion_valid(outputs, labels.double())
return {"validation_loss": validation_loss} if dataloader_idx == 0:
return {"validation_loss": validation_loss}
else:
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, grad_cams=False): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch[0] names = batch[0]
images = batch[1] images = batch[1]
......
...@@ -13,25 +13,6 @@ from ..utils.checkpointer import get_checkpoint ...@@ -13,25 +13,6 @@ from ..utils.checkpointer import get_checkpoint
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def set_reproducible_cuda():
"""Turns-off all CUDA optimizations that would affect reproducibility.
For full reproducibility, also ensure not to use multiple (parallel) data
lowers. That is setup ``num_workers=0``.
Reference: `PyTorch page for reproducibility
<https://pytorch.org/docs/stable/notes/randomness.html>`_.
"""
import torch.backends.cudnn
# ensure to use only optimization algos for cuda that are known to have
# a deterministic effect (not random)
torch.backends.cudnn.deterministic = True
# turns off any optimization tricks
torch.backends.cudnn.benchmark = False
@click.command( @click.command(
entry_point_group="ptbench.config", entry_point_group="ptbench.config",
cls=ConfigCommand, cls=ConfigCommand,
...@@ -62,6 +43,12 @@ def set_reproducible_cuda(): ...@@ -62,6 +43,12 @@ def set_reproducible_cuda():
required=True, required=True,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option(
"--datamodule",
help="A torch.nn.Module instance implementing the network to be trained",
required=True,
cls=ResourceOption,
)
@click.option( @click.option(
"--dataset", "--dataset",
"-d", "-d",
...@@ -235,7 +222,7 @@ def set_reproducible_cuda(): ...@@ -235,7 +222,7 @@ def set_reproducible_cuda():
) )
@click.option( @click.option(
"--resume-from", "--resume-from",
help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a model checkpoint.", help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a model checkpoint.",
type=str, type=str,
required=False, required=False,
default=None, default=None,
...@@ -251,6 +238,7 @@ def train( ...@@ -251,6 +238,7 @@ def train(
drop_incomplete_batch, drop_incomplete_batch,
criterion, criterion,
criterion_valid, criterion_valid,
datamodule,
dataset, dataset,
checkpoint_period, checkpoint_period,
accelerator, accelerator,
...@@ -277,45 +265,13 @@ def train( ...@@ -277,45 +265,13 @@ def train(
import torch.nn import torch.nn
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.utils.data import DataLoader, WeightedRandomSampler from torch.utils.data import DataLoader
from ..configs.datasets import get_positive_weights, get_samples_weights from ..configs.datasets import get_positive_weights
from ..engine.trainer import run from ..engine.trainer import run
seed_everything(seed) seed_everything(seed)
use_dataset = dataset
validation_dataset = None
extra_validation_datasets = []
if isinstance(dataset, dict):
if "__train__" in dataset:
logger.info("Found (dedicated) '__train__' set for training")
use_dataset = dataset["__train__"]
else:
use_dataset = dataset["train"]
if "__valid__" in dataset:
logger.info("Found (dedicated) '__valid__' set for validation")
logger.info("Will checkpoint lowest loss model on validation set")
validation_dataset = dataset["__valid__"]
if "__extra_valid__" in dataset:
if not isinstance(dataset["__extra_valid__"], list):
raise RuntimeError(
f"If present, dataset['__extra_valid__'] must be a list, "
f"but you passed a {type(dataset['__extra_valid__'])}, "
f"which is invalid."
)
logger.info(
f"Found {len(dataset['__extra_valid__'])} extra validation "
f"set(s) to be tracked during training"
)
logger.info(
"Extra validation sets are NOT used for model checkpointing!"
)
extra_validation_datasets = dataset["__extra_valid__"]
# PyTorch dataloader # PyTorch dataloader
multiproc_kwargs = dict() multiproc_kwargs = dict()
if parallel < 0: if parallel < 0:
...@@ -340,31 +296,25 @@ def train( ...@@ -340,31 +296,25 @@ def train(
else: else:
batch_chunk_size = batch_size // batch_chunk_count batch_chunk_size = batch_size // batch_chunk_count
# Create weighted random sampler datamodule = datamodule(
train_samples_weights = get_samples_weights(use_dataset) dataset,
train_sampler = WeightedRandomSampler( train_batch_size=batch_chunk_size,
train_samples_weights, len(train_samples_weights), replacement=True multiproc_kwargs=multiproc_kwargs,
) )
# Manually calling these as we need to access some values to reweight the criterion
datamodule.prepare_data()
datamodule.setup(stage="fit")
train_dataset = datamodule.train_dataset
validation_dataset = datamodule.validation_dataset
# Redefine a weighted criterion if possible # Redefine a weighted criterion if possible
if isinstance(criterion, torch.nn.BCEWithLogitsLoss): if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
positive_weights = get_positive_weights(use_dataset) positive_weights = get_positive_weights(train_dataset)
model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights) model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
else: else:
logger.warning("Weighted criterion not supported") logger.warning("Weighted criterion not supported")
# PyTorch dataloader
data_loader = DataLoader(
dataset=use_dataset,
batch_size=batch_chunk_size,
drop_last=drop_incomplete_batch,
pin_memory=torch.cuda.is_available(),
sampler=train_sampler,
**multiproc_kwargs,
)
valid_loader = None
if validation_dataset is not None: if validation_dataset is not None:
# Redefine a weighted valid criterion if possible # Redefine a weighted valid criterion if possible
if ( if (
...@@ -378,27 +328,6 @@ def train( ...@@ -378,27 +328,6 @@ def train(
else: else:
logger.warning("Weighted valid criterion not supported") logger.warning("Weighted valid criterion not supported")
valid_loader = DataLoader(
dataset=validation_dataset,
batch_size=batch_chunk_size,
shuffle=False,
drop_last=False,
pin_memory=torch.cuda.is_available(),
**multiproc_kwargs,
)
extra_valid_loaders = [
DataLoader(
dataset=k,
batch_size=batch_chunk_size,
shuffle=False,
drop_last=False,
pin_memory=torch.cuda.is_available(),
**multiproc_kwargs,
)
for k in extra_validation_datasets
]
# Create z-normalization model layer if needed # Create z-normalization model layer if needed
if normalization == "imagenet": if normalization == "imagenet":
model.normalizer.set_mean_std( model.normalizer.set_mean_std(
...@@ -407,7 +336,9 @@ def train( ...@@ -407,7 +336,9 @@ def train(
logger.info("Z-normalization with ImageNet mean and std") logger.info("Z-normalization with ImageNet mean and std")
elif normalization == "current": elif normalization == "current":
# Compute mean/std of current train subset # Compute mean/std of current train subset
temp_dl = DataLoader(dataset=use_dataset, batch_size=len(use_dataset)) temp_dl = DataLoader(
dataset=train_dataset, batch_size=len(train_dataset)
)
data = next(iter(temp_dl)) data = next(iter(temp_dl))
mean = data[1].mean(dim=[0, 2, 3]) mean = data[1].mean(dim=[0, 2, 3])
...@@ -446,9 +377,7 @@ def train( ...@@ -446,9 +377,7 @@ def train(
run( run(
model=model, model=model,
data_loader=data_loader, datamodule=datamodule,
valid_loader=valid_loader,
extra_valid_loaders=extra_valid_loaders,
checkpoint_period=checkpoint_period, checkpoint_period=checkpoint_period,
accelerator=accelerator, accelerator=accelerator,
arguments=arguments, arguments=arguments,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment