Skip to content
Snippets Groups Projects
Commit 57048d9b authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Merge branch 'loss-balancing' into 'main'

Replace sampler balancing by loss balancing

Closes #6

See merge request biosignal/software/mednet!38
parents aaa4d443 236447ec
No related branches found
No related tags found
1 merge request!38Replace sampler balancing by loss balancing
Pipeline #86803 failed
Showing
with 381 additions and 324 deletions
......@@ -45,6 +45,7 @@ CNN and other models implemented.
mednet.models.logistic_regression
mednet.models.loss_weights
mednet.models.mlp
mednet.models.model
mednet.models.normalizer
mednet.models.separate
mednet.models.transforms
......
......@@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.alexnet import Alexnet
model = Alexnet(
train_loss=BCEWithLogitsLoss(),
validation_loss=BCEWithLogitsLoss(),
loss_type=BCEWithLogitsLoss,
optimizer_type=SGD,
optimizer_arguments=dict(lr=0.01, momentum=0.1),
augmentation_transforms=[ElasticDeformation(p=0.8)],
......
......@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.alexnet import Alexnet
model = Alexnet(
train_loss=BCEWithLogitsLoss(),
validation_loss=BCEWithLogitsLoss(),
loss_type=BCEWithLogitsLoss,
optimizer_type=SGD,
optimizer_arguments=dict(lr=0.01, momentum=0.1),
augmentation_transforms=[ElasticDeformation(p=0.8)],
......
......@@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet
model = Densenet(
train_loss=BCEWithLogitsLoss(),
validation_loss=BCEWithLogitsLoss(),
loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)],
......
......@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet
model = Densenet(
train_loss=BCEWithLogitsLoss(),
validation_loss=BCEWithLogitsLoss(),
loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)],
......
......@@ -16,8 +16,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet
model = Densenet(
train_loss=BCEWithLogitsLoss(),
validation_loss=BCEWithLogitsLoss(),
loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)],
......
......@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.pasa import Pasa
model = Pasa(
train_loss=BCEWithLogitsLoss(),
validation_loss=BCEWithLogitsLoss(),
loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=8e-5),
augmentation_transforms=[ElasticDeformation(p=0.8)],
......
......@@ -481,10 +481,6 @@ class ConcatDataModule(lightning.LightningDataModule):
for CPU memory. Sufficient CPU memory must be available before you set
this attribute to ``True``. It is typically useful for relatively small
datasets.
balance_sampler_by_class
If set, then modifies the random sampler used during training and
validation to balance sample picking probability, making sample
across classes **and** datasets equitable.
batch_size
Number of samples in every **training** batch (this parameter affects
memory requirements for the network). If the number of samples in the
......@@ -529,7 +525,6 @@ class ConcatDataModule(lightning.LightningDataModule):
database_name: str = "",
split_name: str = "",
cache_samples: bool = False,
balance_sampler_by_class: bool = False,
batch_size: int = 1,
batch_chunk_count: int = 1,
drop_incomplete_batch: bool = False,
......@@ -552,7 +547,6 @@ class ConcatDataModule(lightning.LightningDataModule):
self.cache_samples = cache_samples
self._train_sampler = None
self.balance_sampler_by_class = balance_sampler_by_class
self._model_transforms: list[Transform] | None = None
......@@ -667,40 +661,6 @@ class ConcatDataModule(lightning.LightningDataModule):
)
self._datasets = {}
@property
def balance_sampler_by_class(self) -> bool:
"""Whether to balance samples across labels/datasets.
If set, then modifies the random sampler used during training
and validation to balance sample picking probability, making
sample across classes **and** datasets equitable.
.. warning::
This method does **NOT** balance the sampler per dataset, in case
multiple datasets compose the same training set. It only balances
samples acording to their ground-truth (labels). If you'd like to
have samples balanced per dataset, then implement your own data
module inheriting from this one.
Returns
-------
bool
True if self._train_sample is set, else False.
"""
return self._train_sampler is not None
@balance_sampler_by_class.setter
def balance_sampler_by_class(self, value: bool):
if value:
if "train" not in self._datasets:
self._setup_dataset("train")
self._train_sampler = _make_balanced_random_sampler(
self._datasets["train"],
)
else:
self._train_sampler = None
def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
"""Coherently set the batch-chunk-size after validation.
......@@ -798,7 +758,7 @@ class ConcatDataModule(lightning.LightningDataModule):
else:
self._datasets[name] = _ConcatDataset(datasets)
def _val_dataset_keys(self) -> list[str]:
def val_dataset_keys(self) -> list[str]:
"""Return list of validation dataset names.
Returns
......@@ -836,11 +796,11 @@ class ConcatDataModule(lightning.LightningDataModule):
"""
if stage == "fit":
for k in ["train"] + self._val_dataset_keys():
for k in ["train"] + self.val_dataset_keys():
self._setup_dataset(k)
elif stage == "validate":
for k in self._val_dataset_keys():
for k in self.val_dataset_keys():
self._setup_dataset(k)
elif stage == "test":
......@@ -929,7 +889,7 @@ class ConcatDataModule(lightning.LightningDataModule):
self._datasets[k],
**validation_loader_opts,
)
for k in self._val_dataset_keys()
for k in self.val_dataset_keys()
}
def test_dataloader(self) -> dict[str, DataLoader]:
......
......@@ -374,4 +374,5 @@ class LoggingCallback(lightning.pytorch.Callback):
on_step=False,
on_epoch=True,
batch_size=batch[0].shape[0],
add_dataloader_idx=False,
)
......@@ -72,6 +72,8 @@ def run(
output_folder.mkdir(parents=True, exist_ok=True)
model.configure_losses()
from .loggers import CustomTensorboardLogger
log_dir = "logs"
......
......@@ -5,7 +5,6 @@
import logging
import typing
import lightning.pytorch as pl
import torch
import torch.nn
import torch.optim.optimizer
......@@ -14,36 +13,29 @@ import torchvision.models as models
import torchvision.transforms
from ..data.typing import TransformSequence
from .model import Model
from .separate import separate
from .transforms import RGB, SquareCenterPad
from .typing import Checkpoint
logger = logging.getLogger(__name__)
class Alexnet(pl.LightningModule):
class Alexnet(Model):
"""Alexnet module.
Note: only usable with a normalized dataset
Parameters
----------
train_loss
The loss to be used during the training.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
......@@ -60,15 +52,22 @@ class Alexnet(pl.LightningModule):
def __init__(
self,
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
validation_loss: torch.nn.Module | None = None,
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
pretrained: bool = False,
num_classes: int = 1,
):
super().__init__()
super().__init__(
loss_type,
loss_arguments,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
num_classes,
)
self.name = "alexnet"
self.num_classes = num_classes
......@@ -79,17 +78,6 @@ class Alexnet(pl.LightningModule):
RGB(),
]
self._train_loss = train_loss
self._validation_loss = (
validation_loss if validation_loss is not None else train_loss
)
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
self._augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms,
)
self.pretrained = pretrained
# Load pretrained model
......@@ -109,36 +97,6 @@ class Alexnet(pl.LightningModule):
x = self.normalizer(x) # type: ignore
return self.model_ft(x)
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during checkpoint saving (called by lightning).
Called by Lightning when saving a checkpoint to give you a chance to
store anything else you might want to save. Use on_load_checkpoint() to
restore what additional data is saved here.
Parameters
----------
checkpoint
The checkpoint to save.
"""
checkpoint["normalizer"] = self.normalizer
def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during model loading (called by lightning).
If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters
----------
checkpoint
The loaded checkpoint.
"""
logger.info("Restoring normalizer from checkpoint.")
self.normalizer = checkpoint["normalizer"]
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the normalizer for the current model.
......@@ -201,16 +159,9 @@ class Alexnet(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
def configure_optimizers(self):
return self._optimizer_type(
self.parameters(),
**self._optimizer_arguments,
)
......@@ -5,7 +5,6 @@
import logging
import typing
import lightning.pytorch as pl
import torch
import torch.nn
import torch.optim.optimizer
......@@ -14,34 +13,27 @@ import torchvision.models as models
import torchvision.transforms
from ..data.typing import TransformSequence
from .model import Model
from .separate import separate
from .transforms import RGB, SquareCenterPad
from .typing import Checkpoint
logger = logging.getLogger(__name__)
class Densenet(pl.LightningModule):
class Densenet(Model):
"""Densenet-121 module.
Parameters
----------
train_loss
The loss to be used during the training.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
......@@ -60,8 +52,8 @@ class Densenet(pl.LightningModule):
def __init__(
self,
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
validation_loss: torch.nn.Module | None = None,
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
......@@ -69,7 +61,14 @@ class Densenet(pl.LightningModule):
dropout: float = 0.1,
num_classes: int = 1,
):
super().__init__()
super().__init__(
loss_type,
loss_arguments,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
num_classes,
)
self.name = "densenet-121"
self.num_classes = num_classes
......@@ -80,17 +79,6 @@ class Densenet(pl.LightningModule):
RGB(),
]
self._train_loss = train_loss
self._validation_loss = (
validation_loss if validation_loss is not None else train_loss
)
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
self._augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms,
)
self.pretrained = pretrained
# Load pretrained model
......@@ -112,36 +100,6 @@ class Densenet(pl.LightningModule):
x = self.normalizer(x) # type: ignore
return self.model_ft(x)
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during checkpoint saving (called by lightning).
Called by Lightning when saving a checkpoint to give you a chance to
store anything else you might want to save. Use on_load_checkpoint() to
restore what additional data is saved here.
Parameters
----------
checkpoint
The checkpoint to save.
"""
checkpoint["normalizer"] = self.normalizer
def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during model loading (called by lightning).
If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters
----------
checkpoint
The loaded checkpoint.
"""
logger.info("Restoring normalizer from checkpoint.")
self.normalizer = checkpoint["normalizer"]
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the normalizer for the current model.
......@@ -205,9 +163,3 @@ class Densenet(pl.LightningModule):
outputs = self(batch[0])
probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
def configure_optimizers(self):
return self._optimizer_type(
self.parameters(),
**self._optimizer_arguments,
)
......@@ -3,87 +3,180 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import typing
from collections import Counter
import torch
import torch.utils.data
from ..data.typing import DataLoader
logger = logging.getLogger(__name__)
def _get_label_weights(
dataloader: torch.utils.data.DataLoader,
) -> torch.Tensor:
"""Compute the weights of each class of a DataLoader.
def compute_binary_weights(targets):
"""Compute the positive weights when using binary targets.
This function inputs a pytorch DataLoader and computes the ratio between
number of negative and positive samples (scalar). The weight can be used
to adjust minimisation criteria to in cases there is a huge data imbalance.
Parameters
----------
targets
A tensor of integer values of length n.
It returns a vector with weights (inverse counts) for each label.
Returns
-------
The positive weights per class.
"""
class_sample_count = [
float((targets == t).sum().item())
for t in torch.unique(targets, sorted=True)
]
# Divide negatives by positives
return torch.tensor(
[class_sample_count[0] / class_sample_count[1]],
).reshape(-1)
def compute_multiclass_weights(targets):
"""Compute the positive weights when using exclusive, multiclass targets.
Parameters
----------
dataloader
A DataLoader from which to compute the positive weights. Entries must
be a dictionary which must contain a ``label`` key.
targets
A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
Returns
-------
torch.Tensor
The positive weight of each class in the dataset given as input.
The positive weights per class.
"""
targets = torch.tensor(
[sample for batch in dataloader for sample in batch[1]["label"]],
class_sample_count = torch.sum(targets, dim=1)
negative_class_sample_count = (
torch.full((targets.size()[0],), float(targets.size()[1]))
- class_sample_count
)
# Binary labels
if len(list(targets.shape)) == 1:
class_sample_count = [
float((targets == t).sum().item())
for t in torch.unique(targets, sorted=True)
]
return negative_class_sample_count / (
class_sample_count + negative_class_sample_count
)
# Divide negatives by positives
positive_weights = torch.tensor(
[class_sample_count[0] / class_sample_count[1]],
).reshape(-1)
# Multiclass labels
else:
class_sample_count = torch.sum(targets, dim=0)
negative_class_sample_count = (
torch.full((targets.size()[1],), float(targets.size()[0]))
- class_sample_count
)
def compute_non_exclusive_multiclass_weights(targets):
"""Compute the positive weights when using non-exclusive, multiclass targets.
positive_weights = negative_class_sample_count / (
class_sample_count + negative_class_sample_count
)
Parameters
----------
targets
A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
return positive_weights
Returns
-------
The positive weights per class.
"""
raise ValueError(
"Computing weights of multi-class, non-exclusive labels is not yet supported."
)
def is_multicalss_exclusive(targets: torch.Tensor) -> bool:
"""Given a [C x n] tensor of integer targets, checks whether samples can only belong to a single class.
Parameters
----------
targets
A [C x n] tensor of integer values, where `C` is the number of target classes and `n` the number of samples.
Returns
-------
True if all samples belong to a single class, False otherwise (a sample can belong to multiple classes).
"""
max_counts = []
transposed_targets = torch.transpose(targets, 0, 1)
for t in transposed_targets:
filtered_list = [i for i in t.tolist() if i != 2]
counts = Counter(filtered_list)
max_counts.append(max(counts.values()))
if set(max_counts) == {1}:
return True
return False
def make_balanced_bcewithlogitsloss(
dataloader: DataLoader,
) -> torch.nn.BCEWithLogitsLoss:
"""Return a balanced binary-cross-entropy loss.
def tensor_to_list(tensor) -> list[typing.Any]:
"""Convert a torch.Tensor to a list.
The loss is weighted using the ratio between positives and total examples
available.
This is necessary, as torch.tolist returns an int when then tensor contains a single value.
Parameters
----------
tensor
The tensor to convert to a list.
Returns
-------
The tensor converted to a list.
"""
tensor = tensor.tolist()
if isinstance(tensor, int):
return [tensor]
return tensor
def get_positive_weights(
dataloader: torch.utils.data.DataLoader,
) -> torch.Tensor:
"""Compute the weights of each class of a DataLoader.
This function inputs a pytorch DataLoader and computes the ratio between
number of negative and positive samples (scalar). The weight can be used
to adjust minimisation criteria to in cases there is a huge data imbalance.
It returns a vector with weights (inverse counts) for each label.
Parameters
----------
dataloader
The DataLoader to use to compute the BCE weights.
A DataLoader from which to compute the positive weights. Entries must
be a dictionary which must contain a ``label`` key.
Returns
-------
torch.nn.BCEWithLogitsLoss
An instance of the weighted loss.
The positive weight of each class in the dataset given as input.
"""
weights = _get_label_weights(dataloader)
return torch.nn.BCEWithLogitsLoss(pos_weight=weights)
from collections import defaultdict
targets = defaultdict(list)
for batch in dataloader:
for class_idx, class_targets in enumerate(batch[1]["label"]):
# Targets are either a single tensor (binary case) or a list of tensors (multilabel)
if isinstance(batch[1]["label"], list):
targets[class_idx].extend(tensor_to_list(class_targets))
else:
targets[0].extend(tensor_to_list(class_targets))
targets_list = []
for k in sorted(list(targets.keys())):
targets_list.append(targets[k])
targets_tensor = torch.tensor(targets_list)
if targets_tensor.shape[0] == 1:
logger.info("Computing positive weights assuming binary labels.")
positive_weights = compute_binary_weights(targets_tensor)
else:
if is_multicalss_exclusive(targets_tensor):
logger.info(
"Computing positive weights assuming multiclass, exclusive labels."
)
positive_weights = compute_multiclass_weights(targets_tensor)
else:
logger.info(
"Computing positive weights assuming multiclass, non-exclusive labels."
)
positive_weights = compute_non_exclusive_multiclass_weights(
targets_tensor
)
return positive_weights
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import typing
import lightning.pytorch as pl
import torch
import torch.nn
import torch.optim.optimizer
import torch.utils.data
import torchvision.transforms
from ..data.typing import TransformSequence
from .loss_weights import get_positive_weights
from .typing import Checkpoint
logger = logging.getLogger(__name__)
class Model(pl.LightningModule):
"""Base class for models.
Parameters
----------
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
Arguments to the optimizer after ``params``.
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
num_classes
Number of outputs (classes) for this model.
"""
def __init__(
self,
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
):
super().__init__()
self.name = "model"
self.num_classes = num_classes
self.model_transforms: TransformSequence = []
self._loss_type = loss_type
self._train_loss = None
self._train_loss_arguments = loss_arguments
self.validation_loss = None
self._validation_loss_arguments = loss_arguments
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
self._augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms,
)
def forward(self, x):
raise NotImplementedError
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during checkpoint saving (called by lightning).
Called by Lightning when saving a checkpoint to give you a chance to
store anything else you might want to save. Use on_load_checkpoint() to
restore what additional data is saved here.
Parameters
----------
checkpoint
The checkpoint to save.
"""
checkpoint["normalizer"] = self.normalizer
def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during model loading (called by lightning).
If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters
----------
checkpoint
The loaded checkpoint.
"""
logger.info("Restoring normalizer from checkpoint.")
self.normalizer = checkpoint["normalizer"]
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the input normalizer for the current model.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
"""
from .normalizer import make_z_normalizer
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, _):
raise NotImplementedError
def validation_step(self, batch, batch_idx, dataloader_idx=0):
raise NotImplementedError
def predict_step(self, batch, batch_idx, dataloader_idx=0):
raise NotImplementedError
def configure_losses(self):
self._train_loss = self._loss_type(**self._train_loss_arguments)
self._validation_loss = self._loss_type(
**self._validation_loss_arguments
)
def configure_optimizers(self):
return self._optimizer_type(
self.parameters(),
**self._optimizer_arguments,
)
def balance_losses(self, datamodule) -> None:
"""Balance the loss based on the distribution of targets in the datamodule, if the loss supports it (contains a 'pos_weight' attribute).
Parameters
----------
datamodule
Instance of a datamodule.
"""
try:
getattr(self._loss_type(), "pos_weight")
except AttributeError:
logger.warning(
f"Loss {self._loss_type} does not posess a 'pos_weight' attribute and will not be balanced."
)
else:
logger.info(f"Balancing training loss {self._loss_type}.")
train_weights = get_positive_weights(datamodule.train_dataloader())
self._train_loss_arguments["pos_weight"] = train_weights
logger.info(f"Balancing validation loss {self._loss_type}.")
validation_weights = get_positive_weights(
datamodule.val_dataloader()["validation"]
)
self._validation_loss_arguments["pos_weight"] = validation_weights
......@@ -5,7 +5,6 @@
import logging
import typing
import lightning.pytorch as pl
import torch
import torch.nn
import torch.nn.functional as F # noqa: N812
......@@ -14,14 +13,14 @@ import torch.utils.data
import torchvision.transforms
from ..data.typing import TransformSequence
from .model import Model
from .separate import separate
from .transforms import Grayscale, SquareCenterPad
from .typing import Checkpoint
logger = logging.getLogger(__name__)
class Pasa(pl.LightningModule):
class Pasa(Model):
"""Implementation of CNN by Pasa and others.
Simple CNN for classification based on paper by [PASA-2019]_.
......@@ -31,22 +30,15 @@ class Pasa(pl.LightningModule):
Parameters
----------
train_loss
The loss to be used during the training.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
......@@ -60,14 +52,21 @@ class Pasa(pl.LightningModule):
def __init__(
self,
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
validation_loss: torch.nn.Module | None = None,
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
):
super().__init__()
super().__init__(
loss_type,
loss_arguments,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
num_classes,
)
self.name = "pasa"
self.num_classes = num_classes
......@@ -82,17 +81,6 @@ class Pasa(pl.LightningModule):
),
]
self._train_loss = train_loss
self._validation_loss = (
validation_loss if validation_loss is not None else train_loss
)
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
self._augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms,
)
# First convolution block
self.fc1 = torch.nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
self.fc2 = torch.nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
......@@ -213,53 +201,6 @@ class Pasa(pl.LightningModule):
# x = F.log_softmax(x, dim=1) # 0 is batch size
def on_save_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during checkpoint saving (called by lightning).
Called by Lightning when saving a checkpoint to give you a chance to
store anything else you might want to save. Use on_load_checkpoint() to
restore what additional data is saved here.
Parameters
----------
checkpoint
The checkpoint to save.
"""
checkpoint["normalizer"] = self.normalizer
def on_load_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Perform actions during model loading (called by lightning).
If you saved something with on_save_checkpoint() this is your chance to
restore this.
Parameters
----------
checkpoint
The loaded checkpoint.
"""
logger.info("Restoring normalizer from checkpoint.")
self.normalizer = checkpoint["normalizer"]
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the input normalizer for the current model.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
"""
from .normalizer import make_z_normalizer
logger.info(
f"Uninitialised {self.name} model - "
f"computing z-norm factors from train dataloader.",
)
self.normalizer = make_z_normalizer(dataloader)
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
......@@ -285,16 +226,9 @@ class Pasa(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
probabilities = torch.sigmoid(outputs)
return separate((probabilities, batch[1]))
def configure_optimizers(self):
return self._optimizer_type(
self.parameters(),
**self._optimizer_arguments,
)
......@@ -296,10 +296,8 @@ def train(
# of class examples available in the training set. Also affects the
# validation loss if a validation set is available on the DataModule.
if balance_classes:
logger.info("Applying DataModule train sampler balancing...")
datamodule.balance_sampler_by_class = True
# logger.info("Applying train/valid loss balancing...")
# model.balance_losses_by_class(datamodule)
logger.info("Applying train/valid loss balancing...")
model.balance_losses(datamodule)
else:
logger.info(
"Skipping sample class/dataset ownership balancing on user request",
......
......@@ -241,8 +241,7 @@ def test_train_pasa_montgomery(temporary_basedir):
keywords = {
r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
r"^Applying DataModule train sampler balancing...$": 1,
r"^Balancing samples from dataset using metadata targets `label`$": 1,
r"^Applying train/valid loss balancing...$": 1,
r"^Training for at most 1 epochs.$": 1,
r"^Uninitialised pasa model - computing z-norm factors from train dataloader.$": 1,
r"^Writing run metadata at.*$": 1,
......@@ -323,8 +322,7 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
keywords = {
r"^Loading dataset:`train` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
r"^Loading dataset:`validation` without caching. Trade-off: CPU RAM usage: less | Disk I/O: more.$": 1,
r"^Applying DataModule train sampler balancing...$": 1,
r"^Balancing samples from dataset using metadata targets `label`$": 1,
r"^Applying train/valid loss balancing...$": 1,
r"^Training for at most 2 epochs.$": 1,
r"^Resuming from epoch 0 \(checkpoint file: .*$": 1,
r"^Writing run metadata at.*$": 1,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment