diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py index 49ee76dbd002a9fbd13e99d052eef223c17a19ac..d1e1b0a3ae8d9e3e32a7ec19a49e21f01bb694d9 100644 --- a/src/ptbench/configs/models/pasa.py +++ b/src/ptbench/configs/models/pasa.py @@ -11,32 +11,16 @@ Screening and Visualization". Reference: [PASA-2019]_ """ -from torch import empty from torch.nn import BCEWithLogitsLoss from torch.optim import Adam -from ...models.pasa import PASA - -# optimizer -optimizer = Adam -optimizer_configs = {"lr": 8e-5} - -# criterion -criterion = BCEWithLogitsLoss(pos_weight=empty(1)) -criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) - from ...data.transforms import ElasticDeformation - -augmentation_transforms = [ElasticDeformation(p=0.8)] - -# from torchvision.transforms.v2 import ElasticTransform, InterpolationMode -# augmentation_transforms = [ElasticTransform(alpha=1000.0, sigma=30.0, interpolation=InterpolationMode.NEAREST)] - -# model -model = PASA( - criterion, - criterion_valid, - optimizer, - optimizer_configs, - augmentation_transforms=augmentation_transforms, +from ...models.pasa import Pasa + +model = Pasa( + train_loss=BCEWithLogitsLoss(), + validation_loss=BCEWithLogitsLoss(), + optimizer_type=Adam, + optimizer_arguments=dict(lr=8e-5), + augmentation_transforms=[ElasticDeformation(p=0.8)], ) diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index 55898b6759e4e471607bbe87cff0de3fb074724c..e8643b46ceef967e8c94f7425d9cecd9bd21a0b3 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -5,11 +5,13 @@ import logging import lightning.pytorch as pl -import torch import torch.nn as nn +import torch.utils.data import torchvision.models as models import torchvision.transforms +from ..data.typing import DataLoader + logger = logging.getLogger(__name__) @@ -30,7 +32,7 @@ class Alexnet(pl.LightningModule): ): super().__init__() - self.name = "AlexNet" + self.name = "alexnet" self.augmentation_transforms = torchvision.transforms.Compose( augmentation_transforms @@ -49,7 +51,7 @@ class Alexnet(pl.LightningModule): if not pretrained: weights = None else: - logger.info("Loading pretrained model weights") + logger.info(f"Loading pretrained {self.name} model weights") weights = models.AlexNet_Weights.DEFAULT self.model_ft = models.alexnet(weights=weights) @@ -81,47 +83,60 @@ class Alexnet(pl.LightningModule): from .normalizer import make_imagenet_normalizer logger.warning( - "ImageNet pre-trained densenet model - NOT " - "computing z-norm factors from training data. " - "Using preset factors from torchvision." + f"ImageNet pre-trained {self.name} model - NOT " + f"computing z-norm factors from train dataloader. " + f"Using preset factors from torchvision." ) self.normalizer = make_imagenet_normalizer() else: from .normalizer import make_z_normalizer logger.info( - "Uninitialised densenet model - " - "computing z-norm factors from training data." + f"Uninitialised {self.name} model - " + f"computing z-norm factors from train dataloader." ) self.normalizer = make_z_normalizer(dataloader) - def set_bce_loss_weights(self, datamodule): - """Reweights loss weights if BCEWithLogitsLoss is used. + def balance_losses_by_class( + self, train_dataloader: DataLoader, valid_dataloader: DataLoader + ): + """Reweights loss weights if possible. Parameters ---------- - datamodule: - A datamodule implementing train_dataloader() and val_dataloader() + train_dataloader + The data loader to use for training + + valid_dataloader + The data loader to use for validation + + + Raises + ------ + + RuntimeError + If train or validation losses are not of type + :py:class:`torch.nn.BCEWithLogitsLoss`. """ - from ..data.dataset import _get_positive_weights + from .loss_weights import get_label_weights if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss): logger.info("Reweighting BCEWithLogitsLoss training criterion.") - train_positive_weights = _get_positive_weights( - datamodule.train_dataloader() - ) - self.criterion = torch.nn.BCEWithLogitsLoss( - pos_weight=train_positive_weights + weights = get_label_weights(train_dataloader) + self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weights) + else: + raise RuntimeError( + "Training loss is not BCEWithLogitsLoss - dunno how to balance" ) if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss): logger.info("Reweighting BCEWithLogitsLoss validation criterion.") - validation_positive_weights = _get_positive_weights( - datamodule.val_dataloader()["validation"] - ) - self.criterion_valid = torch.nn.BCEWithLogitsLoss( - pos_weight=validation_positive_weights + weights = get_label_weights(valid_dataloader) + self.criterion_valid = torch.nn.BCEWithLogitsLoss(weights) + else: + raise RuntimeError( + "Validation loss is not BCEWithLogitsLoss - dunno how to balance" ) def training_step(self, batch, batch_idx): @@ -172,11 +187,6 @@ class Alexnet(pl.LightningModule): outputs = self(images) probabilities = torch.sigmoid(outputs) - # necessary check for HED architecture that uses several outputs - # for loss calculation instead of just the last concatfuse block - if isinstance(outputs, list): - outputs = outputs[-1] - return names[0], torch.flatten(probabilities), torch.flatten(labels) def configure_optimizers(self): diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index ae866e720ba4df0be292bd2ee3f23878a714818a..25d1d8ff344e15bc1c6729d7ebc7a324ce461e2a 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -6,17 +6,24 @@ import logging import lightning.pytorch as pl import torch -import torch.nn as nn +import torch.nn +import torch.utils.data import torchvision.models as models import torchvision.transforms +from ..data.typing import DataLoader + logger = logging.getLogger(__name__) class Densenet(pl.LightningModule): - """Densenet module. + """Densenet-121 module. + + Parameters + ---------- - Note: only usable with a normalized dataset + criterion + A dictionary containing the criteria for the """ def __init__( @@ -30,7 +37,7 @@ class Densenet(pl.LightningModule): ): super().__init__() - self.name = "Densenet" + self.name = "densenet-121" self.augmentation_transforms = torchvision.transforms.Compose( augmentation_transforms @@ -42,21 +49,20 @@ class Densenet(pl.LightningModule): self.optimizer = optimizer self.optimizer_configs = optimizer_configs - self.normalizer = None self.pretrained = pretrained # Load pretrained model if not pretrained: weights = None else: - logger.info("Loading pretrained model weights") + logger.info(f"Loading pretrained {self.name} model weights") weights = models.DenseNet121_Weights.DEFAULT self.model_ft = models.densenet121(weights=weights) # Adapt output features - self.model_ft.classifier = nn.Sequential( - nn.Linear(1024, 256), nn.Linear(256, 1) + self.model_ft.classifier = torch.nn.Sequential( + torch.nn.Linear(1024, 256), torch.nn.Linear(256, 1) ) def forward(self, x): @@ -82,47 +88,62 @@ class Densenet(pl.LightningModule): from .normalizer import make_imagenet_normalizer logger.warning( - "ImageNet pre-trained densenet model - NOT " - "computing z-norm factors from training data. " - "Using preset factors from torchvision." + f"ImageNet pre-trained {self.name} model - NOT " + f"computing z-norm factors from train dataloader. " + f"Using preset factors from torchvision." ) self.normalizer = make_imagenet_normalizer() else: from .normalizer import make_z_normalizer logger.info( - "Uninitialised densenet model - " - "computing z-norm factors from training data." + f"Uninitialised {self.name} model - " + f"computing z-norm factors from train dataloader." ) self.normalizer = make_z_normalizer(dataloader) - def set_bce_loss_weights(self, datamodule): - """Reweights loss weights if BCEWithLogitsLoss is used. + def balance_losses_by_class( + self, + train_dataloader: DataLoader, + valid_dataloader: dict[str, DataLoader], + ): + """Reweights loss weights if possible. Parameters ---------- - datamodule: - A datamodule implementing train_dataloader() and val_dataloader() + train_dataloader + The data loader to use for training + + valid_dataloader + The data loaders to use for each of the validation sets + + + Raises + ------ + + RuntimeError + If train or validation losses are not of type + :py:class:`torch.nn.BCEWithLogitsLoss`. """ - from ..data.dataset import _get_positive_weights + from .loss_weights import get_label_weights if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss): logger.info("Reweighting BCEWithLogitsLoss training criterion.") - train_positive_weights = _get_positive_weights( - datamodule.train_dataloader() - ) - self.criterion = torch.nn.BCEWithLogitsLoss( - pos_weight=train_positive_weights + weights = get_label_weights(train_dataloader) + self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weights) + else: + raise RuntimeError( + "Training loss is not BCEWithLogitsLoss - dunno how to balance" ) if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss): logger.info("Reweighting BCEWithLogitsLoss validation criterion.") - validation_positive_weights = _get_positive_weights( - datamodule.val_dataloader()["validation"] - ) - self.criterion_valid = torch.nn.BCEWithLogitsLoss( - pos_weight=validation_positive_weights + weights = get_label_weights(valid_dataloader) + self.criterion_valid = torch.nn.BCEWithLogitsLoss(weights) + else: + raise RuntimeError( + "Validation loss is not BCEWithLogitsLoss - dunno how to balance" ) def training_step(self, batch, batch_idx): @@ -173,11 +194,6 @@ class Densenet(pl.LightningModule): outputs = self(images) probabilities = torch.sigmoid(outputs) - # necessary check for HED architecture that uses several outputs - # for loss calculation instead of just the last concatfuse block - if isinstance(outputs, list): - outputs = outputs[-1] - return names[0], torch.flatten(probabilities), torch.flatten(labels) def configure_optimizers(self): diff --git a/src/ptbench/models/loss_weights.py b/src/ptbench/models/loss_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..6889b2539fb1071e228c4487540ad9272aad808c --- /dev/null +++ b/src/ptbench/models/loss_weights.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging + +import torch +import torch.utils.data + +logger = logging.getLogger(__name__) + + +def get_label_weights( + dataloader: torch.utils.data.DataLoader, +) -> torch.Tensor: + """Computes the weights of each class of a DataLoader. + + This function inputs a pytorch DataLoader and computes the ratio between + number of negative and positive samples (scalar). The weight can be used + to adjust minimisation criteria to in cases there is a huge data imbalance. + + If + + It returns a vector with weights (inverse counts) for each label. + + + Parameters + ---------- + + dataloader + A DataLoader from which to compute the positive weights. Entries must + be a dictionary which must contain a ``label`` key. + + + Returns + ------- + + positive_weights + the positive weight of each class in the dataset given as input + """ + + targets = torch.tensor( + [sample for batch in dataloader for sample in batch[1]["label"]] + ) + + # Binary labels + if len(list(targets.shape)) == 1: + class_sample_count = [ + float((targets == t).sum().item()) + for t in torch.unique(targets, sorted=True) + ] + + # Divide negatives by positives + positive_weights = torch.tensor( + [class_sample_count[0] / class_sample_count[1]] + ).reshape(-1) + + # Multiclass labels + else: + class_sample_count = torch.sum(targets, dim=0) + negative_class_sample_count = ( + torch.full((targets.size()[1],), float(targets.size()[0])) + - class_sample_count + ) + + positive_weights = negative_class_sample_count / ( + class_sample_count + negative_class_sample_count + ) + + return positive_weights diff --git a/src/ptbench/models/normalizer.py b/src/ptbench/models/normalizer.py index 2cc4b956f17ee1e690031f42e891ef726b548e2a..ce68f4b558b2812d17a8f54954e197a086233a52 100644 --- a/src/ptbench/models/normalizer.py +++ b/src/ptbench/models/normalizer.py @@ -8,6 +8,7 @@ import torch import torch.nn import torch.utils.data import torchvision.transforms +import tqdm def make_z_normalizer( @@ -42,7 +43,7 @@ def make_z_normalizer( num_images = 0 # Evaluates mean and standard deviation - for batch in dataloader: + for batch in tqdm.tqdm(dataloader, unit="batch"): data = batch[0] data = data.view(data.size(0), data.size(1), -1) diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index e257a4fc54190c778fdaa6c8d36522fa713fbd76..34e5c67fbb01684a51fe1e14ddc963ac0dadb6fb 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -3,94 +3,139 @@ # SPDX-License-Identifier: GPL-3.0-or-later import logging +import typing import lightning.pytorch as pl import torch -import torch.nn as nn +import torch.nn import torch.nn.functional as F +import torch.optim.optimizer import torch.utils.data import torchvision.transforms +from ..data.typing import TransformSequence + logger = logging.getLogger(__name__) -class PASA(pl.LightningModule): - """PASA module. +class Pasa(pl.LightningModule): + """Implementation of CNN by Pasa. + + Simple CNN for classification based on paper by [PASA-2019]_. + + + Parameters + ---------- + + train_loss + The loss to be used during the training. + + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + + validation_loss + The loss to be used for validation (may be different from the training + loss). If extra-validation sets are provided, the same loss will be + used throughout. - Based on paper by [PASA-2019]_. + .. warning:: + + The loss should be set to always return batch averages (as opposed + to the batch sum), as our logging system expects it so. + + optimizer_type + The type of optimizer to use for training + + optimizer_arguments + Arguments to the optimizer after ``params``. + + augmentation_transforms + An optional sequence of torch modules containing transforms to be + applied on the input **before** it is fed into the network. """ def __init__( self, - criterion, - criterion_valid, - optimizer, - optimizer_configs, - augmentation_transforms, + train_loss: torch.nn.Module, + validation_loss: torch.nn.Module | None, + optimizer_type: type[torch.optim.Optimizer], + optimizer_arguments: dict[str, typing.Any], + augmentation_transforms: TransformSequence = [], ): super().__init__() self.name = "pasa" - self.augmentation_transforms = torchvision.transforms.Compose( - augmentation_transforms + self._train_loss = train_loss + self._validation_loss = ( + validation_loss if validation_loss is not None else train_loss ) + self._optimizer_type = optimizer_type + self._optimizer_arguments = optimizer_arguments - self.criterion = criterion - self.criterion_valid = criterion_valid - - self.optimizer = optimizer - self.optimizer_configs = optimizer_configs - - self.normalizer = None + self._augmentation_transforms = torchvision.transforms.Compose( + augmentation_transforms + ) # First convolution block - self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1)) - self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1)) - self.fc3 = nn.Conv2d(1, 16, (1, 1), (4, 4)) + self.fc1 = torch.nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1)) + self.fc2 = torch.nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1)) + self.fc3 = torch.nn.Conv2d(1, 16, (1, 1), (4, 4)) - self.batchNorm2d_4 = nn.BatchNorm2d(4) - self.batchNorm2d_16 = nn.BatchNorm2d(16) - self.batchNorm2d_16_2 = nn.BatchNorm2d(16) + self.batchNorm2d_4 = torch.nn.BatchNorm2d(4) + self.batchNorm2d_16 = torch.nn.BatchNorm2d(16) + self.batchNorm2d_16_2 = torch.nn.BatchNorm2d(16) # Second convolution block - self.fc4 = nn.Conv2d(16, 24, (3, 3), (1, 1), (1, 1)) - self.fc5 = nn.Conv2d(24, 32, (3, 3), (1, 1), (1, 1)) - self.fc6 = nn.Conv2d(16, 32, (1, 1), (1, 1)) # Original stride (2, 2) + self.fc4 = torch.nn.Conv2d(16, 24, (3, 3), (1, 1), (1, 1)) + self.fc5 = torch.nn.Conv2d(24, 32, (3, 3), (1, 1), (1, 1)) + self.fc6 = torch.nn.Conv2d( + 16, 32, (1, 1), (1, 1) + ) # Original stride (2, 2) - self.batchNorm2d_24 = nn.BatchNorm2d(24) - self.batchNorm2d_32 = nn.BatchNorm2d(32) - self.batchNorm2d_32_2 = nn.BatchNorm2d(32) + self.batchNorm2d_24 = torch.nn.BatchNorm2d(24) + self.batchNorm2d_32 = torch.nn.BatchNorm2d(32) + self.batchNorm2d_32_2 = torch.nn.BatchNorm2d(32) # Third convolution block - self.fc7 = nn.Conv2d(32, 40, (3, 3), (1, 1), (1, 1)) - self.fc8 = nn.Conv2d(40, 48, (3, 3), (1, 1), (1, 1)) - self.fc9 = nn.Conv2d(32, 48, (1, 1), (1, 1)) # Original stride (2, 2) + self.fc7 = torch.nn.Conv2d(32, 40, (3, 3), (1, 1), (1, 1)) + self.fc8 = torch.nn.Conv2d(40, 48, (3, 3), (1, 1), (1, 1)) + self.fc9 = torch.nn.Conv2d( + 32, 48, (1, 1), (1, 1) + ) # Original stride (2, 2) - self.batchNorm2d_40 = nn.BatchNorm2d(40) - self.batchNorm2d_48 = nn.BatchNorm2d(48) - self.batchNorm2d_48_2 = nn.BatchNorm2d(48) + self.batchNorm2d_40 = torch.nn.BatchNorm2d(40) + self.batchNorm2d_48 = torch.nn.BatchNorm2d(48) + self.batchNorm2d_48_2 = torch.nn.BatchNorm2d(48) # Fourth convolution block - self.fc10 = nn.Conv2d(48, 56, (3, 3), (1, 1), (1, 1)) - self.fc11 = nn.Conv2d(56, 64, (3, 3), (1, 1), (1, 1)) - self.fc12 = nn.Conv2d(48, 64, (1, 1), (1, 1)) # Original stride (2, 2) + self.fc10 = torch.nn.Conv2d(48, 56, (3, 3), (1, 1), (1, 1)) + self.fc11 = torch.nn.Conv2d(56, 64, (3, 3), (1, 1), (1, 1)) + self.fc12 = torch.nn.Conv2d( + 48, 64, (1, 1), (1, 1) + ) # Original stride (2, 2) - self.batchNorm2d_56 = nn.BatchNorm2d(56) - self.batchNorm2d_64 = nn.BatchNorm2d(64) - self.batchNorm2d_64_2 = nn.BatchNorm2d(64) + self.batchNorm2d_56 = torch.nn.BatchNorm2d(56) + self.batchNorm2d_64 = torch.nn.BatchNorm2d(64) + self.batchNorm2d_64_2 = torch.nn.BatchNorm2d(64) # Fifth convolution block - self.fc13 = nn.Conv2d(64, 72, (3, 3), (1, 1), (1, 1)) - self.fc14 = nn.Conv2d(72, 80, (3, 3), (1, 1), (1, 1)) - self.fc15 = nn.Conv2d(64, 80, (1, 1), (1, 1)) # Original stride (2, 2) + self.fc13 = torch.nn.Conv2d(64, 72, (3, 3), (1, 1), (1, 1)) + self.fc14 = torch.nn.Conv2d(72, 80, (3, 3), (1, 1), (1, 1)) + self.fc15 = torch.nn.Conv2d( + 64, 80, (1, 1), (1, 1) + ) # Original stride (2, 2) - self.batchNorm2d_72 = nn.BatchNorm2d(72) - self.batchNorm2d_80 = nn.BatchNorm2d(80) - self.batchNorm2d_80_2 = nn.BatchNorm2d(80) + self.batchNorm2d_72 = torch.nn.BatchNorm2d(72) + self.batchNorm2d_80 = torch.nn.BatchNorm2d(80) + self.batchNorm2d_80_2 = torch.nn.BatchNorm2d(80) - self.pool2d = nn.MaxPool2d((3, 3), (2, 2)) # Pool after conv. block - self.dense = nn.Linear(80, 1) # Fully connected layer + self.pool2d = torch.nn.MaxPool2d( + (3, 3), (2, 2) + ) # Pool after conv. block + self.dense = torch.nn.Linear(80, 1) # Fully connected layer def forward(self, x): x = self.normalizer(x) # type: ignore @@ -141,51 +186,22 @@ class PASA(pl.LightningModule): return x def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: - """Initializes the normalizer for the current model. + """Initializes the input normalizer for the current model. Parameters ---------- - dataloader: :py:class:`torch.utils.data.DataLoader` + dataloader A torch Dataloader from which to compute the mean and std """ from .normalizer import make_z_normalizer logger.info( - "Uninitialised densenet model - " - "computing z-norm factors from training data." + f"Uninitialised {self.name} model - " + f"computing z-norm factors from train dataloader." ) self.normalizer = make_z_normalizer(dataloader) - def set_bce_loss_weights(self, datamodule): - """Reweights loss weights if BCEWithLogitsLoss is used. - - Parameters - ---------- - - datamodule: - A datamodule implementing train_dataloader() and val_dataloader() - """ - from ..data.dataset import _get_positive_weights - - if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss): - logger.info("Reweighting BCEWithLogitsLoss training criterion.") - train_positive_weights = _get_positive_weights( - datamodule.train_dataloader() - ) - self.criterion = torch.nn.BCEWithLogitsLoss( - pos_weight=train_positive_weights - ) - - if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss): - logger.info("Reweighting BCEWithLogitsLoss validation criterion.") - validation_positive_weights = _get_positive_weights( - datamodule.val_dataloader()["validation"] - ) - self.criterion_valid = torch.nn.BCEWithLogitsLoss( - pos_weight=validation_positive_weights - ) - def training_step(self, batch, _): images = batch[0] labels = batch[1]["label"] @@ -197,15 +213,13 @@ class PASA(pl.LightningModule): # Forward pass on the network augmented_images = [ - self.augmentation_transforms(img).to(self.device) for img in images + self._augmentation_transforms(img).to(self.device) for img in images ] # Combine list of augmented images back into a tensor augmented_images = torch.cat(augmented_images, 0).view(images.shape) outputs = self(augmented_images) - training_loss = self.criterion(outputs, labels.double()) - - return {"loss": training_loss} + return self._train_loss(outputs, labels.float()) def validation_step(self, batch, batch_idx, dataloader_idx=0): images = batch[0] @@ -219,12 +233,7 @@ class PASA(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - validation_loss = self.criterion_valid(outputs, labels.double()) - - if dataloader_idx == 0: - return {"validation_loss": validation_loss} - else: - return {f"extra_validation_loss_{dataloader_idx}": validation_loss} + return self._validation_loss(outputs, labels.float()) def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): images = batch[0] @@ -234,40 +243,13 @@ class PASA(pl.LightningModule): outputs = self(images) probabilities = torch.sigmoid(outputs) - # necessary check for HED architecture that uses several outputs - # for loss calculation instead of just the last concatfuse block - if isinstance(outputs, list): - outputs = outputs[-1] - - results = ( + return ( names[0], torch.flatten(probabilities), torch.flatten(labels), ) - return results - # { - # f"dataloader_{dataloader_idx}_predictions": ( - # names[0], - # torch.flatten(probabilities), - # torch.flatten(labels), - # ) - # } - - # def on_predict_epoch_end(self): - - # retval = defaultdict(list) - - # for dataloader_name, predictions in self.predictions_cache.items(): - # for prediction in predictions: - # retval[dataloader_name]["name"].append(prediction[0]) - # retval[dataloader_name]["prediction"].append(prediction[1]) - # retval[dataloader_name]["label"].append(prediction[2]) - - # Need to cache predictions in the predict step, then reorder by key - # Clear prediction dict - # raise NotImplementedError - def configure_optimizers(self): - optimizer = self.optimizer(self.parameters(), **self.optimizer_configs) - return optimizer + return self._optimizer_type( + self.parameters(), **self._optimizer_arguments + ) diff --git a/src/ptbench/utils/save_sh_command.py b/src/ptbench/utils/save_sh_command.py deleted file mode 100644 index e0a7d379c00caddb7ade2513668a1392e98b21f2..0000000000000000000000000000000000000000 --- a/src/ptbench/utils/save_sh_command.py +++ /dev/null @@ -1,74 +0,0 @@ -import glob -import logging -import os -import sys -import time - -import pkg_resources - -logger = logging.getLogger(__name__) - - -def save_sh_command(output_dir): - """Records command-line to reproduce this experiment. - - This function can record the current command-line used to call the script - being run. It creates an executable ``bash`` script setting up the current - working directory and activating a conda environment, if needed. It - records further information on the date and time the script was run and the - version of the package. - - - Parameters - ---------- - - output_folder : str - Path leading to the directory where the commands to reproduce the current - run will be recorded. A subdirectory will be created each time this function - is called to match lightning's versioning convention for loggers. - """ - - cmd_config_dir = os.path.join(output_dir, "cmd_line_configs") - cmd_config_versions = glob.glob(os.path.join(cmd_config_dir, "version_*")) - if len(cmd_config_versions) > 0: - latest_cmd_config_version = max( - [ - int(config.split("version_")[-1]) - for config in cmd_config_versions - ] - ) - current_cmd_config_version = str(latest_cmd_config_version + 1) - else: - current_cmd_config_version = "0" - - destfile = os.path.join( - cmd_config_dir, - f"version_{current_cmd_config_version}", - "cmd_line_config.txt", - ) - - if os.path.exists(destfile): - logger.info(f"Not overwriting existing file '{destfile}'") - return - - logger.info(f"Writing command-line for reproduction at '{destfile}'...") - os.makedirs(os.path.dirname(destfile), exist_ok=True) - - with open(destfile, "w") as f: - f.write("#!/usr/bin/env sh\n") - f.write(f"# date: {time.asctime()}\n") - version = pkg_resources.require("ptbench")[0].version - f.write(f"# version: {version} (deepdraw)\n") - f.write(f"# platform: {sys.platform}\n") - f.write("\n") - args = [] - for k in sys.argv: - if " " in k: - args.append(f'"{k}"') - else: - args.append(k) - if os.environ.get("CONDA_DEFAULT_ENV") is not None: - f.write(f"#conda activate {os.environ['CONDA_DEFAULT_ENV']}\n") - f.write(f"#cd {os.path.realpath(os.curdir)}\n") - f.write(" ".join(args) + "\n") - os.chmod(destfile, 0o755)