Skip to content
Snippets Groups Projects
Commit c464ae3e authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Updates to the data module and models

parent 9e721cb3
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
Pipeline #75501 failed
......@@ -3,6 +3,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import collections
import logging
import multiprocessing
import sys
import typing
......@@ -12,6 +13,8 @@ import torch
import torch.utils.data
import torchvision.transforms
logger = logging.getLogger(__name__)
def _setup_dataloader_multiproc_parameters(
parallel: int,
......@@ -80,7 +83,8 @@ class _DelayedLoadingDataset(torch.utils.data.Dataset):
interval ``[0, 1]``, and a dictionary with further metadata attributes.
transforms
A set of transforms that should be applied on-the-fly for this dataset.
A set of transforms that should be applied on-the-fly for this dataset,
to fit the output of the raw-data-loader to the model of interest.
"""
def __init__(
......@@ -113,7 +117,8 @@ class _CachedDataset(torch.utils.data.Dataset):
"""Basically, a list of preloaded samples.
This dataset will load all samples from the split during construction
instead of delaying that to the indexing.
instead of delaying that to the indexing. Beyong raw-data-loading,
``transforms`` given upon construction contribute to the cached samples.
Parameters
......@@ -130,7 +135,9 @@ class _CachedDataset(torch.utils.data.Dataset):
interval ``[0, 1]``, and a dictionary with further metadata attributes.
transforms
A set of transforms that should be applied on-the-fly for this dataset.
A set of transforms that should be applied to the cached samples for
this dataset, to fit the output of the raw-data-loader to the model of
interest.
"""
def __init__(
......@@ -143,8 +150,8 @@ class _CachedDataset(torch.utils.data.Dataset):
typing.Callable[[torch.Tensor], torch.Tensor]
] = [],
):
self.data = [raw_data_loader(k) for k in split]
self.transform = torchvision.transforms.Compose(*transforms)
self.data = [raw_data_loader(k) for k in split]
def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
tensor, metadata = self.data[key]
......@@ -286,21 +293,13 @@ class CachingDataModule(lightning.LightningDataModule):
particularly important in highly unbalanced datasets. The function
:py:func:`get_sample_weights` may help you in this aspect.
data_augmentations
A list of torchvision transforms (torch modules) that will be applied
on training set samples to create data augmentations during the
training of a model. Augmentation transform pipelines are applied
*after* the raw data is loaded, and before ``model_transforms``.
Augmentation transforms assume they receive a torch tensor representing
an image as input (see :py:class:`torchvision.transforms.ToTensor` for
details), in the range ``[0, 1]``.
model_transforms
A list of torchvision transforms (torch modules) that will be applied
after data augmentation transforms, and just before data is fed into
the model for all data loaders produced by this data module. This part
of the pipeline receives data as output by the raw-data-loader, or from
data augmentations, if any is specified.
A list of transforms (torch modules) that will be applied after
raw-data-loading, and just before data is fed into the model or
eventual data-augmentation transformations for all data loaders
produced by this data module. This part of the pipeline receives data
as output by the raw-data-loader, or model-related transforms (e.g.
resize adaptions), if any is specified.
batch_size
Number of samples in every **training** batch (this parameter affects
......@@ -348,9 +347,6 @@ class CachingDataModule(lightning.LightningDataModule):
],
cache_samples: bool = False,
train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
data_augmentations: list[
typing.Callable[[torch.Tensor], torch.Tensor]
] = [],
model_transforms: list[
typing.Callable[[torch.Tensor], torch.Tensor]
] = [],
......@@ -367,7 +363,6 @@ class CachingDataModule(lightning.LightningDataModule):
self.raw_data_loader = raw_data_loader
self.cache_samples = cache_samples
self.train_sampler = train_sampler
self.data_augmentations = data_augmentations
self.model_transforms = model_transforms
self.drop_incomplete_batch = drop_incomplete_batch
......@@ -437,10 +432,36 @@ class CachingDataModule(lightning.LightningDataModule):
self._batch_chunk_count = batch_chunk_count
self._chunk_size = self._batch_size // self._batch_chunk_count
def _setup_dataset(self, name: str) -> None:
"""Sets-up a single dataset from the input data split.
Parameters
----------
name
Name of the dataset to setup.
"""
if name in self._datasets:
logger.info(f"Dataset {name} is already setup. Not reloading it.")
return
if self.cache_samples:
self._datasets[name] = _CachedDataset(
self.database_split[name],
self.raw_data_loader,
self.model_transforms,
)
else:
self._datasets[name] = _DelayedLoadingDataset(
self.database_split[name],
self.raw_data_loader,
self.model_transforms,
)
def setup(self, stage: str) -> None:
"""Sets up datasets for different tasks on the pipeline.
This method should setup (load, pre-process, etc) all subsets required
This method should setup (load, pre-process, etc) all datasets required
for a particular ``stage`` (fit, validate, test, predict), and keep
them ready to be used on one of the `_dataloader()` functions that are
pertinent for such stage.
......@@ -463,62 +484,50 @@ class CachingDataModule(lightning.LightningDataModule):
* ``predict``: uses only the test dataset
"""
def _setup(name, transforms):
if self.cache_samples:
self._datasets[name] = _CachedDataset(
self.database_split[name], self.raw_data_loader, transforms
)
else:
self._datasets[name] = _DelayedLoadingDataset(
self.database_split[name], self.raw_data_loader, transforms
)
if stage == "fit":
_setup("train", self.data_augmentations + self.model_transforms)
_setup("validation", self.model_transforms)
self._setup_dataset("train")
self._setup_dataset("validation")
for k in self.database_split:
if k.startswith("monitor-"):
_setup(k, self.model_transforms)
self._setup_dataset(k)
elif stage == "validate":
_setup("validation", self.model_transforms)
self._setup_dataset("validation")
for k in self.database_split:
if k.startswith("monitor-"):
_setup(k, self.model_transforms)
self._setup_dataset(k)
elif stage == "test":
_setup("test", self.model_transforms)
self._setup_dataset("test")
elif stage == "predict":
_setup("test", self.model_transforms)
self._setup_dataset("test")
def teardown(self, stage: str) -> None:
"""Unset-up datasets for different tasks on the pipeline.
This method unsets (unload, remove from memory, etc) all datasets required
for a particular ``stage`` (fit, validate, test, predict).
def unaugmented_train_dataloader(self) -> torch.utils.data.DataLoader:
"""Returns a version of the train dataloader without augmentations.
If you have set ``cache_samples``, samples are loaded, this may
effectivley release all the associated memory.
Use this method to obtain a version of the train dataloader without
augmentations, to compute input normalisation factors (e.g. mean and
standard deviation or min-max parameterisations).
Parameters
----------
Returns
-------
stage
Name of the stage to which the teardown is applicable. Can be one of
``fit``, ``validate``, ``test`` or ``predict``. Each stage
typically uses the following data loaders:
dataloader
The unaugmented train dataloader
* ``fit``: uses both train and validation datasets
* ``validate``: uses only the validation dataset
* ``test``: uses only the test dataset
* ``predict``: uses only the test dataset
"""
dataset = _DelayedLoadingDataset(
self.database_split["train"],
self.raw_data_loader,
self.model_transforms,
)
return torch.utils.data.DataLoader(
dataset,
shuffle=False,
batch_size=self._chunk_size,
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
**self._dataloader_multiproc,
)
self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
def train_dataloader(self) -> torch.utils.data.DataLoader:
"""Returns the train data loader."""
......
......@@ -7,10 +7,10 @@ import logging
import os
import shutil
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.utilities.model_summary import ModelSummary
import lightning.pytorch
import lightning.pytorch.callbacks
import lightning.pytorch.loggers
import torch.nn
from ..utils.accelerator import AcceleratorProcessor
from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
......@@ -19,7 +19,7 @@ from .callbacks import LoggingCallback
logger = logging.getLogger(__name__)
def check_gpu(device):
def check_gpu(device: str) -> None:
"""Check the device type and the availability of GPU.
Parameters
......@@ -35,32 +35,37 @@ def check_gpu(device):
), f"Device set to '{device}', but nvidia-smi is not installed"
def save_model_summary(output_folder, model):
def save_model_summary(
output_folder: str, model: torch.nn.Module
) -> tuple[lightning.pytorch.callbacks.ModelSummary, int]:
"""Save a little summary of the model in a txt file.
Parameters
----------
output_folder : str
output_folder
output path
model : :py:class:`torch.nn.Module`
model
Network (e.g. driu, hed, unet)
Returns
-------
r : str
r
The model summary in a text format.
n : int
n
The number of parameters of the model.
"""
summary_path = os.path.join(output_folder, "model_summary.txt")
logger.info(f"Saving model summary at {summary_path}...")
with open(summary_path, "w") as f:
summary = ModelSummary(model, max_depth=-1)
summary = lightning.pytorch.callbacks.ModelSummary(model, max_depth=-1)
f.write(str(summary))
return summary, ModelSummary(model).total_parameters
return (
summary,
lightning.pytorch.callbacks.ModelSummary(model).total_parameters,
)
def static_information_to_csv(static_logfile_name, device, n):
......@@ -70,7 +75,8 @@ def static_information_to_csv(static_logfile_name, device, n):
----------
static_logfile_name : str
The static file name which is a join between the output folder and "constant.csv"
The static file name which is a join between the output folder and
"constant.csv"
"""
if os.path.exists(static_logfile_name):
backup = static_logfile_name + "~"
......@@ -188,7 +194,8 @@ def run(
not save intermediary checkpoints.
accelerator : str
A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)
A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The
device can also be specified (gpu:0)
arguments : dict
Start and end epochs:
......@@ -215,10 +222,12 @@ def run(
os.makedirs(output_folder, exist_ok=True)
# Save model summary
r, n = save_model_summary(output_folder, model)
_, n = save_model_summary(output_folder, model)
csv_logger = CSVLogger(output_folder, "logs_csv")
tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard")
csv_logger = lightning.pytorch.loggers.CSVLogger(output_folder, "logs_csv")
tensorboard_logger = lightning.pytorch.loggers.TensorBoardLogger(
output_folder, "logs_tensorboard"
)
resource_monitor = ResourceMonitor(
interval=monitoring_interval,
......@@ -227,7 +236,7 @@ def run(
logging_level=logging.ERROR,
)
checkpoint_callback = ModelCheckpoint(
checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint(
output_folder,
"model_lowest_valid_loss",
save_last=True,
......@@ -251,7 +260,7 @@ def run(
devices = accelerator_processor.device
with resource_monitor:
trainer = Trainer(
trainer = lightning.pytorch.Trainer(
accelerator=accelerator_processor.accelerator,
devices=devices,
max_epochs=max_epoch,
......
......@@ -2,11 +2,15 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torchvision.models as models
logger = logging.getLogger(__name__)
class Densenet(pl.LightningModule):
"""Densenet module.
......@@ -25,6 +29,8 @@ class Densenet(pl.LightningModule):
):
super().__init__()
# Saves all hyper parameters declared on __init__ into ``self.hparams`.
# You can access those by their name, like `self.hparams.optimizer`
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.name = "Densenet"
......@@ -55,10 +61,19 @@ class Densenet(pl.LightningModule):
if self.pretrained:
from .normalizer import make_imagenet_normalizer
logger.warning(
"ImageNet pre-trained densenet model - NOT"
"computing z-norm factors from training data. "
"Using preset factors from torchvision."
)
self.normalizer = make_imagenet_normalizer()
else:
from .normalizer import make_z_normalizer
logger.info(
"Uninitialised densenet model - "
"computing z-norm factors from training data."
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, batch_idx):
......
......@@ -2,12 +2,16 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
logger = logging.getLogger(__name__)
class PASA(pl.LightningModule):
"""PASA module.
......@@ -24,6 +28,8 @@ class PASA(pl.LightningModule):
):
super().__init__()
# Saves all hyper parameters declared on __init__ into ``self.hparams`.
# You can access those by their name, like `self.hparams.criterion`
self.save_hyperparameters()
self.name = "pasa"
......@@ -137,6 +143,10 @@ class PASA(pl.LightningModule):
"""
from .normalizer import make_z_normalizer
logger.info(
"Uninitialised densenet model - "
"computing z-norm factors from training data."
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, _):
......
......@@ -6,9 +6,6 @@ import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup
from lightning.pytorch import seed_everything
from ..utils.checkpointer import get_checkpoint
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
......@@ -65,21 +62,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
required=True,
cls=ResourceOption,
)
@click.option(
"--criterion",
help="A loss function to compute the CNN error for every sample "
"respecting the PyTorch API for loss functions (see torch.nn.modules.loss)",
required=True,
cls=ResourceOption,
)
@click.option(
"--criterion-valid",
help="A specific loss function for the validation set to compute the CNN"
"error for every sample respecting the PyTorch API for loss functions"
"(see torch.nn.modules.loss)",
required=False,
cls=ResourceOption,
)
@click.option(
"--batch-size",
"-b",
......@@ -159,7 +141,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@click.option(
"--accelerator",
"-a",
help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)',
help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). '
"The device can also be specified (gpu:0)",
show_default=True,
required=True,
default="cpu",
......@@ -167,7 +150,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
)
@click.option(
"--cache-samples",
help="If set to True, loads the sample into memory, otherwise loads them at runtime.",
help="If set to True, loads the sample into memory, "
"otherwise loads them at runtime.",
required=True,
show_default=True,
default=False,
......@@ -196,16 +180,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
default=-1,
cls=ResourceOption,
)
@click.option(
"--normalization",
"-n",
help="Z-Normalization of input images: 'imagenet' for ImageNet parameters,"
" 'current' for parameters of the current trainset, "
"'none' for no normalization.",
required=False,
default="none",
cls=ResourceOption,
)
@click.option(
"--monitoring-interval",
"-I",
......@@ -224,7 +198,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
)
@click.option(
"--resume-from",
help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a model checkpoint.",
help="Which checkpoint to resume training from. If set, can be one of "
"`best`, `last`, or a path to a model checkpoint.",
type=str,
required=False,
default=None,
......@@ -238,20 +213,16 @@ def train(
batch_size,
batch_chunk_count,
drop_incomplete_batch,
criterion,
criterion_valid,
datamodule,
checkpoint_period,
accelerator,
cache_samples,
seed,
parallel,
normalization,
monitoring_interval,
resume_from,
**_,
):
"""Trains an CNN to perform tuberculosis detection.
"""Trains an CNN to perform image classification.
Training is performed for a configurable number of epochs, and
generates at least a final_model.pth. It may also generate a number
......@@ -263,21 +234,33 @@ def train(
import torch.cuda
import torch.nn
from ..data.dataset import reweight_BCEWithLogitsLoss
from lightning.pytorch import seed_everything
from ..engine.trainer import run
from ..utils.checkpointer import get_checkpoint
seed_everything(seed)
checkpoint_file = get_checkpoint(output_folder, resume_from)
# reset datamodule with user configurable options
datamodule.set_chunk_size(batch_size, batch_chunk_count)
datamodule.drop_incomplete_batch = drop_incomplete_batch
datamodule.cache_samples = cache_samples
datamodule.parallel = parallel
datamodule.prepare_data()
datamodule.setup(stage="fit")
reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid)
model.set_normalizer(datamodule.unaugmented_train_dataloader())
# Sets the model normalizer with the unaugmented-train-subset.
# this call may be a NOOP, if the model was pre-trained and expects
# different weights for the normalisation layer.
model.set_normalizer(datamodule.train_dataloader())
# Rebalances the loss criterion based on the relative proportion of class
# examples available in the training set. Also affects the validation loss
# if a validation set is available on the data module.
model.set_bce_loss_weights(datamodule)
arguments = {}
arguments["max_epoch"] = epochs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment