diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py index 47324199ff6daf3475d96e5999063621493b985b..49ee76dbd002a9fbd13e99d052eef223c17a19ac 100644 --- a/src/ptbench/configs/models/pasa.py +++ b/src/ptbench/configs/models/pasa.py @@ -13,22 +13,25 @@ Reference: [PASA-2019]_ from torch import empty from torch.nn import BCEWithLogitsLoss +from torch.optim import Adam -from ...data.transforms import ElasticDeformation from ...models.pasa import PASA -# config -optimizer_configs = {"lr": 8e-5} - # optimizer -optimizer = "Adam" +optimizer = Adam +optimizer_configs = {"lr": 8e-5} # criterion criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) +from ...data.transforms import ElasticDeformation + augmentation_transforms = [ElasticDeformation(p=0.8)] +# from torchvision.transforms.v2 import ElasticTransform, InterpolationMode +# augmentation_transforms = [ElasticTransform(alpha=1000.0, sigma=30.0, interpolation=InterpolationMode.NEAREST)] + # model model = PASA( criterion, diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 8c297823fb649d6f750c2799416504b335b24f60..fc9883d2f89c008fe7fb6c52ed62dea8664df17b 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -13,6 +13,8 @@ import torch import torch.utils.data import torchvision.transforms +from tqdm import tqdm + logger = logging.getLogger(__name__) @@ -150,8 +152,13 @@ class _CachedDataset(torch.utils.data.Dataset): typing.Callable[[torch.Tensor], torch.Tensor] ] = [], ): - self.transform = torchvision.transforms.Compose(*transforms) - self.data = [raw_data_loader(k) for k in split] + # Cannot unpack empty list + if len(transforms) > 0: + self.transform = torchvision.transforms.Compose([*transforms]) + else: + self.transform = torchvision.transforms.Compose([]) + + self.data = [raw_data_loader(k) for k in tqdm(split)] def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]: tensor, metadata = self.data[key] @@ -446,6 +453,7 @@ class CachingDataModule(lightning.LightningDataModule): logger.info(f"Dataset {name} is already setup. Not reloading it.") return if self.cache_samples: + logger.info(f"Caching {name} dataset") self._datasets[name] = _CachedDataset( self.database_split[name], self.raw_data_loader, diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py index af8a736454608961535caee8cddb432a42199410..3529ee7f0a14c0fc03a01f9d02348cf09a43cb20 100644 --- a/src/ptbench/data/dataset.py +++ b/src/ptbench/data/dataset.py @@ -10,11 +10,11 @@ import torch.utils.data logger = logging.getLogger(__name__) -def _get_positive_weights(dataset): +def _get_positive_weights(dataloader): """Compute the positive weights of each class of the dataset to balance the BCEWithLogitsLoss criterion. - This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` + This function takes as input a :py:class:`torch.utils.data.DataLoader` and computes the positive weights of each class to use them to have a balanced loss. @@ -22,9 +22,8 @@ def _get_positive_weights(dataset): Parameters ---------- - dataset : torch.utils.data.dataset.Dataset - An instance of torch.utils.data.dataset.Dataset - ConcatDataset are supported + dataloader : :py:class:`torch.utils.data.DataLoader` + A DataLoader from which to compute the positive weights. Must contain a 'label' key in the metadata returned by __getitem__(). Returns @@ -35,14 +34,8 @@ def _get_positive_weights(dataset): """ targets = [] - if isinstance(dataset, torch.utils.data.ConcatDataset): - for ds in dataset.datasets: - for s in ds._samples: - targets.append(s["label"]) - - else: - for s in dataset._samples: - targets.append(s["label"]) + for batch in dataloader: + targets.extend(batch[1]["label"]) targets = torch.tensor(targets) @@ -71,33 +64,3 @@ def _get_positive_weights(dataset): ) return positive_weights - - -def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid): - from torch.nn import BCEWithLogitsLoss - - datamodule.prepare_data() - datamodule.setup(stage="fit") - - train_dataset = datamodule.train_dataset - validation_dataset = datamodule.validation_dataset - - # Redefine a weighted criterion if possible - if isinstance(criterion, torch.nn.BCEWithLogitsLoss): - positive_weights = _get_positive_weights(train_dataset) - model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights) - else: - logger.warning("Weighted criterion not supported") - - if validation_dataset is not None: - # Redefine a weighted valid criterion if possible - if ( - isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss) - or criterion_valid is None - ): - positive_weights = _get_positive_weights(validation_dataset) - model.hparams.criterion_valid = BCEWithLogitsLoss( - pos_weight=positive_weights - ) - else: - logger.warning("Weighted valid criterion not supported") diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index d0ac43f98e21b8ce6803797d6a1fde38c6302660..580cc26fd24efc6ebb98fbcbb02d51348c6c11ec 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -94,9 +94,7 @@ class LoggingCallback(Callback): self.log("total_time", current_time) self.log("eta", eta_seconds) self.log("loss", numpy.average(self.training_loss)) - self.log( - "learning_rate", pl_module.hparams["optimizer_configs"]["lr"] - ) + self.log("learning_rate", pl_module.optimizer_configs["lr"]) self.log("validation_loss", numpy.sum(self.validation_loss)) if len(self.extra_validation_loss) > 0: diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index 6b156f86822209445a709c0af99b8b1eb14c55dc..ecf29153bc56c84185a905edc97867b0730a0a57 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -51,20 +51,24 @@ def save_model_summary( Returns ------- - r + summary: The model summary in a text format. - n + total_parameters: The number of parameters of the model. """ summary_path = os.path.join(output_folder, "model_summary.txt") logger.info(f"Saving model summary at {summary_path}...") with open(summary_path, "w") as f: - summary = lightning.pytorch.callbacks.ModelSummary(model, max_depth=-1) + summary = lightning.pytorch.utilities.model_summary.ModelSummary( + model, max_depth=-1 + ) f.write(str(summary)) return ( summary, - lightning.pytorch.callbacks.ModelSummary(model).total_parameters, + lightning.pytorch.utilities.model_summary.ModelSummary( + model + ).total_parameters, ) diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 76327670d047900f10b8a94d90a7bc95a44fe2a0..e93b4e61f3bd5c9a0259e7d21728362f8e26999a 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.data +import torchvision.transforms logger = logging.getLogger(__name__) @@ -31,7 +32,15 @@ class PASA(pl.LightningModule): self.name = "pasa" - self.augmentation_transforms = augmentation_transforms + self.augmentation_transforms = torchvision.transforms.Compose( + augmentation_transforms + ) + + self.criterion = criterion + self.criterion_valid = criterion_valid + + self.optimizer = optimizer + self.optimizer_configs = optimizer_configs self.normalizer = None @@ -137,7 +146,7 @@ class PASA(pl.LightningModule): Parameters ---------- - dataloader: + dataloader: :py:class:`torch.utils.data.DataLoader` A torch Dataloader from which to compute the mean and std """ from .normalizer import make_z_normalizer @@ -148,6 +157,35 @@ class PASA(pl.LightningModule): ) self.normalizer = make_z_normalizer(dataloader) + def set_bce_loss_weights(self, datamodule): + """Reweights loss weights if BCEWithLogitsLoss is used. + + Parameters + ---------- + + datamodule: + A datamodule implementing train_dataloader() and val_dataloader() + """ + from ..data.dataset import _get_positive_weights + + if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss training criterion.") + train_positive_weights = _get_positive_weights( + datamodule.train_dataloader() + ) + self.criterion = torch.nn.BCEWithLogitsLoss( + pos_weight=train_positive_weights + ) + + if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss validation criterion.") + validation_positive_weights = _get_positive_weights( + datamodule.val_dataloader()["validation"] + ) + self.criterion_valid = torch.nn.BCEWithLogitsLoss( + pos_weight=validation_positive_weights + ) + def training_step(self, batch, _): images = batch[0] labels = batch[1]["label"] @@ -158,12 +196,11 @@ class PASA(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - augmented_images = self.augmentation_transforms(images) + augmented_images = [self.augmentation_transforms(img) for img in images] + augmented_images = torch.unsqueeze(torch.cat(augmented_images, 0), 1) outputs = self(augmented_images) - # Manually move criterion to selected device, since not part of the model. - self.hparams.criterion = self.hparams.criterion.to(self.device) - training_loss = self.hparams.criterion(outputs, labels.double()) + training_loss = self.criterion(outputs, labels.double()) return {"loss": training_loss} @@ -179,11 +216,7 @@ class PASA(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - # Manually move criterion to selected device, since not part of the model. - self.hparams.criterion_valid = self.hparams.criterion_valid.to( - self.device - ) - validation_loss = self.hparams.criterion_valid(outputs, labels.double()) + validation_loss = self.criterion_valid(outputs, labels.double()) if dataloader_idx == 0: return {"validation_loss": validation_loss} @@ -233,9 +266,5 @@ class PASA(pl.LightningModule): # raise NotImplementedError def configure_optimizers(self): - # Dynamically instantiates the optimizer given the configs - optimizer = getattr(torch.optim, self.hparams.optimizer)( - self.parameters(), **self.hparams.optimizer_configs - ) - + optimizer = self.optimizer(self.parameters(), **self.optimizer_configs) return optimizer