Skip to content
Snippets Groups Projects
Commit c25e5008 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[ptbench.models.pasa] Define new API for modules

parent a0f264f0
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
Pipeline #75719 failed
This commit is part of merge request !6. Comments created here will be created in the context of that merge request.
...@@ -11,32 +11,16 @@ Screening and Visualization". ...@@ -11,32 +11,16 @@ Screening and Visualization".
Reference: [PASA-2019]_ Reference: [PASA-2019]_
""" """
from torch import empty
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam from torch.optim import Adam
from ...models.pasa import PASA
# optimizer
optimizer = Adam
optimizer_configs = {"lr": 8e-5}
# criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation from ...data.transforms import ElasticDeformation
from ...models.pasa import Pasa
augmentation_transforms = [ElasticDeformation(p=0.8)]
model = Pasa(
# from torchvision.transforms.v2 import ElasticTransform, InterpolationMode train_loss=BCEWithLogitsLoss(),
# augmentation_transforms = [ElasticTransform(alpha=1000.0, sigma=30.0, interpolation=InterpolationMode.NEAREST)] validation_loss=BCEWithLogitsLoss(),
optimizer_type=Adam,
# model optimizer_arguments=dict(lr=8e-5),
model = PASA( augmentation_transforms=[ElasticDeformation(p=0.8)],
criterion,
criterion_valid,
optimizer,
optimizer_configs,
augmentation_transforms=augmentation_transforms,
) )
...@@ -5,11 +5,13 @@ ...@@ -5,11 +5,13 @@
import logging import logging
import lightning.pytorch as pl import lightning.pytorch as pl
import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.data
import torchvision.models as models import torchvision.models as models
import torchvision.transforms import torchvision.transforms
from ..data.typing import DataLoader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -30,7 +32,7 @@ class Alexnet(pl.LightningModule): ...@@ -30,7 +32,7 @@ class Alexnet(pl.LightningModule):
): ):
super().__init__() super().__init__()
self.name = "AlexNet" self.name = "alexnet"
self.augmentation_transforms = torchvision.transforms.Compose( self.augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms augmentation_transforms
...@@ -49,7 +51,7 @@ class Alexnet(pl.LightningModule): ...@@ -49,7 +51,7 @@ class Alexnet(pl.LightningModule):
if not pretrained: if not pretrained:
weights = None weights = None
else: else:
logger.info("Loading pretrained model weights") logger.info(f"Loading pretrained {self.name} model weights")
weights = models.AlexNet_Weights.DEFAULT weights = models.AlexNet_Weights.DEFAULT
self.model_ft = models.alexnet(weights=weights) self.model_ft = models.alexnet(weights=weights)
...@@ -81,47 +83,60 @@ class Alexnet(pl.LightningModule): ...@@ -81,47 +83,60 @@ class Alexnet(pl.LightningModule):
from .normalizer import make_imagenet_normalizer from .normalizer import make_imagenet_normalizer
logger.warning( logger.warning(
"ImageNet pre-trained densenet model - NOT " f"ImageNet pre-trained {self.name} model - NOT "
"computing z-norm factors from training data. " f"computing z-norm factors from train dataloader. "
"Using preset factors from torchvision." f"Using preset factors from torchvision."
) )
self.normalizer = make_imagenet_normalizer() self.normalizer = make_imagenet_normalizer()
else: else:
from .normalizer import make_z_normalizer from .normalizer import make_z_normalizer
logger.info( logger.info(
"Uninitialised densenet model - " f"Uninitialised {self.name} model - "
"computing z-norm factors from training data." f"computing z-norm factors from train dataloader."
) )
self.normalizer = make_z_normalizer(dataloader) self.normalizer = make_z_normalizer(dataloader)
def set_bce_loss_weights(self, datamodule): def balance_losses_by_class(
"""Reweights loss weights if BCEWithLogitsLoss is used. self, train_dataloader: DataLoader, valid_dataloader: DataLoader
):
"""Reweights loss weights if possible.
Parameters Parameters
---------- ----------
datamodule: train_dataloader
A datamodule implementing train_dataloader() and val_dataloader() The data loader to use for training
valid_dataloader
The data loader to use for validation
Raises
------
RuntimeError
If train or validation losses are not of type
:py:class:`torch.nn.BCEWithLogitsLoss`.
""" """
from ..data.dataset import _get_positive_weights from .loss_weights import get_label_weights
if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss): if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training criterion.") logger.info("Reweighting BCEWithLogitsLoss training criterion.")
train_positive_weights = _get_positive_weights( weights = get_label_weights(train_dataloader)
datamodule.train_dataloader() self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
) else:
self.criterion = torch.nn.BCEWithLogitsLoss( raise RuntimeError(
pos_weight=train_positive_weights "Training loss is not BCEWithLogitsLoss - dunno how to balance"
) )
if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss): if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation criterion.") logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
validation_positive_weights = _get_positive_weights( weights = get_label_weights(valid_dataloader)
datamodule.val_dataloader()["validation"] self.criterion_valid = torch.nn.BCEWithLogitsLoss(weights)
) else:
self.criterion_valid = torch.nn.BCEWithLogitsLoss( raise RuntimeError(
pos_weight=validation_positive_weights "Validation loss is not BCEWithLogitsLoss - dunno how to balance"
) )
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
...@@ -172,11 +187,6 @@ class Alexnet(pl.LightningModule): ...@@ -172,11 +187,6 @@ class Alexnet(pl.LightningModule):
outputs = self(images) outputs = self(images)
probabilities = torch.sigmoid(outputs) probabilities = torch.sigmoid(outputs)
# necessary check for HED architecture that uses several outputs
# for loss calculation instead of just the last concatfuse block
if isinstance(outputs, list):
outputs = outputs[-1]
return names[0], torch.flatten(probabilities), torch.flatten(labels) return names[0], torch.flatten(probabilities), torch.flatten(labels)
def configure_optimizers(self): def configure_optimizers(self):
......
...@@ -6,17 +6,24 @@ import logging ...@@ -6,17 +6,24 @@ import logging
import lightning.pytorch as pl import lightning.pytorch as pl
import torch import torch
import torch.nn as nn import torch.nn
import torch.utils.data
import torchvision.models as models import torchvision.models as models
import torchvision.transforms import torchvision.transforms
from ..data.typing import DataLoader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Densenet(pl.LightningModule): class Densenet(pl.LightningModule):
"""Densenet module. """Densenet-121 module.
Parameters
----------
Note: only usable with a normalized dataset criterion
A dictionary containing the criteria for the
""" """
def __init__( def __init__(
...@@ -30,7 +37,7 @@ class Densenet(pl.LightningModule): ...@@ -30,7 +37,7 @@ class Densenet(pl.LightningModule):
): ):
super().__init__() super().__init__()
self.name = "Densenet" self.name = "densenet-121"
self.augmentation_transforms = torchvision.transforms.Compose( self.augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms augmentation_transforms
...@@ -42,21 +49,20 @@ class Densenet(pl.LightningModule): ...@@ -42,21 +49,20 @@ class Densenet(pl.LightningModule):
self.optimizer = optimizer self.optimizer = optimizer
self.optimizer_configs = optimizer_configs self.optimizer_configs = optimizer_configs
self.normalizer = None
self.pretrained = pretrained self.pretrained = pretrained
# Load pretrained model # Load pretrained model
if not pretrained: if not pretrained:
weights = None weights = None
else: else:
logger.info("Loading pretrained model weights") logger.info(f"Loading pretrained {self.name} model weights")
weights = models.DenseNet121_Weights.DEFAULT weights = models.DenseNet121_Weights.DEFAULT
self.model_ft = models.densenet121(weights=weights) self.model_ft = models.densenet121(weights=weights)
# Adapt output features # Adapt output features
self.model_ft.classifier = nn.Sequential( self.model_ft.classifier = torch.nn.Sequential(
nn.Linear(1024, 256), nn.Linear(256, 1) torch.nn.Linear(1024, 256), torch.nn.Linear(256, 1)
) )
def forward(self, x): def forward(self, x):
...@@ -82,47 +88,62 @@ class Densenet(pl.LightningModule): ...@@ -82,47 +88,62 @@ class Densenet(pl.LightningModule):
from .normalizer import make_imagenet_normalizer from .normalizer import make_imagenet_normalizer
logger.warning( logger.warning(
"ImageNet pre-trained densenet model - NOT " f"ImageNet pre-trained {self.name} model - NOT "
"computing z-norm factors from training data. " f"computing z-norm factors from train dataloader. "
"Using preset factors from torchvision." f"Using preset factors from torchvision."
) )
self.normalizer = make_imagenet_normalizer() self.normalizer = make_imagenet_normalizer()
else: else:
from .normalizer import make_z_normalizer from .normalizer import make_z_normalizer
logger.info( logger.info(
"Uninitialised densenet model - " f"Uninitialised {self.name} model - "
"computing z-norm factors from training data." f"computing z-norm factors from train dataloader."
) )
self.normalizer = make_z_normalizer(dataloader) self.normalizer = make_z_normalizer(dataloader)
def set_bce_loss_weights(self, datamodule): def balance_losses_by_class(
"""Reweights loss weights if BCEWithLogitsLoss is used. self,
train_dataloader: DataLoader,
valid_dataloader: dict[str, DataLoader],
):
"""Reweights loss weights if possible.
Parameters Parameters
---------- ----------
datamodule: train_dataloader
A datamodule implementing train_dataloader() and val_dataloader() The data loader to use for training
valid_dataloader
The data loaders to use for each of the validation sets
Raises
------
RuntimeError
If train or validation losses are not of type
:py:class:`torch.nn.BCEWithLogitsLoss`.
""" """
from ..data.dataset import _get_positive_weights from .loss_weights import get_label_weights
if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss): if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training criterion.") logger.info("Reweighting BCEWithLogitsLoss training criterion.")
train_positive_weights = _get_positive_weights( weights = get_label_weights(train_dataloader)
datamodule.train_dataloader() self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=weights)
) else:
self.criterion = torch.nn.BCEWithLogitsLoss( raise RuntimeError(
pos_weight=train_positive_weights "Training loss is not BCEWithLogitsLoss - dunno how to balance"
) )
if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss): if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation criterion.") logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
validation_positive_weights = _get_positive_weights( weights = get_label_weights(valid_dataloader)
datamodule.val_dataloader()["validation"] self.criterion_valid = torch.nn.BCEWithLogitsLoss(weights)
) else:
self.criterion_valid = torch.nn.BCEWithLogitsLoss( raise RuntimeError(
pos_weight=validation_positive_weights "Validation loss is not BCEWithLogitsLoss - dunno how to balance"
) )
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
...@@ -173,11 +194,6 @@ class Densenet(pl.LightningModule): ...@@ -173,11 +194,6 @@ class Densenet(pl.LightningModule):
outputs = self(images) outputs = self(images)
probabilities = torch.sigmoid(outputs) probabilities = torch.sigmoid(outputs)
# necessary check for HED architecture that uses several outputs
# for loss calculation instead of just the last concatfuse block
if isinstance(outputs, list):
outputs = outputs[-1]
return names[0], torch.flatten(probabilities), torch.flatten(labels) return names[0], torch.flatten(probabilities), torch.flatten(labels)
def configure_optimizers(self): def configure_optimizers(self):
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import torch
import torch.utils.data
logger = logging.getLogger(__name__)
def get_label_weights(
dataloader: torch.utils.data.DataLoader,
) -> torch.Tensor:
"""Computes the weights of each class of a DataLoader.
This function inputs a pytorch DataLoader and computes the ratio between
number of negative and positive samples (scalar). The weight can be used
to adjust minimisation criteria to in cases there is a huge data imbalance.
If
It returns a vector with weights (inverse counts) for each label.
Parameters
----------
dataloader
A DataLoader from which to compute the positive weights. Entries must
be a dictionary which must contain a ``label`` key.
Returns
-------
positive_weights
the positive weight of each class in the dataset given as input
"""
targets = torch.tensor(
[sample for batch in dataloader for sample in batch[1]["label"]]
)
# Binary labels
if len(list(targets.shape)) == 1:
class_sample_count = [
float((targets == t).sum().item())
for t in torch.unique(targets, sorted=True)
]
# Divide negatives by positives
positive_weights = torch.tensor(
[class_sample_count[0] / class_sample_count[1]]
).reshape(-1)
# Multiclass labels
else:
class_sample_count = torch.sum(targets, dim=0)
negative_class_sample_count = (
torch.full((targets.size()[1],), float(targets.size()[0]))
- class_sample_count
)
positive_weights = negative_class_sample_count / (
class_sample_count + negative_class_sample_count
)
return positive_weights
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
import torch.nn import torch.nn
import torch.utils.data import torch.utils.data
import torchvision.transforms import torchvision.transforms
import tqdm
def make_z_normalizer( def make_z_normalizer(
...@@ -42,7 +43,7 @@ def make_z_normalizer( ...@@ -42,7 +43,7 @@ def make_z_normalizer(
num_images = 0 num_images = 0
# Evaluates mean and standard deviation # Evaluates mean and standard deviation
for batch in dataloader: for batch in tqdm.tqdm(dataloader, unit="batch"):
data = batch[0] data = batch[0]
data = data.view(data.size(0), data.size(1), -1) data = data.view(data.size(0), data.size(1), -1)
......
...@@ -3,94 +3,139 @@ ...@@ -3,94 +3,139 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import logging import logging
import typing
import lightning.pytorch as pl import lightning.pytorch as pl
import torch import torch
import torch.nn as nn import torch.nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim.optimizer
import torch.utils.data import torch.utils.data
import torchvision.transforms import torchvision.transforms
from ..data.typing import TransformSequence
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PASA(pl.LightningModule): class Pasa(pl.LightningModule):
"""PASA module. """Implementation of CNN by Pasa.
Simple CNN for classification based on paper by [PASA-2019]_.
Parameters
----------
train_loss
The loss to be used during the training.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
Based on paper by [PASA-2019]_. .. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
optimizer_type
The type of optimizer to use for training
optimizer_arguments
Arguments to the optimizer after ``params``.
augmentation_transforms
An optional sequence of torch modules containing transforms to be
applied on the input **before** it is fed into the network.
""" """
def __init__( def __init__(
self, self,
criterion, train_loss: torch.nn.Module,
criterion_valid, validation_loss: torch.nn.Module | None,
optimizer, optimizer_type: type[torch.optim.Optimizer],
optimizer_configs, optimizer_arguments: dict[str, typing.Any],
augmentation_transforms, augmentation_transforms: TransformSequence = [],
): ):
super().__init__() super().__init__()
self.name = "pasa" self.name = "pasa"
self.augmentation_transforms = torchvision.transforms.Compose( self._train_loss = train_loss
augmentation_transforms self._validation_loss = (
validation_loss if validation_loss is not None else train_loss
) )
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
self.criterion = criterion self._augmentation_transforms = torchvision.transforms.Compose(
self.criterion_valid = criterion_valid augmentation_transforms
)
self.optimizer = optimizer
self.optimizer_configs = optimizer_configs
self.normalizer = None
# First convolution block # First convolution block
self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1)) self.fc1 = torch.nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1)) self.fc2 = torch.nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
self.fc3 = nn.Conv2d(1, 16, (1, 1), (4, 4)) self.fc3 = torch.nn.Conv2d(1, 16, (1, 1), (4, 4))
self.batchNorm2d_4 = nn.BatchNorm2d(4) self.batchNorm2d_4 = torch.nn.BatchNorm2d(4)
self.batchNorm2d_16 = nn.BatchNorm2d(16) self.batchNorm2d_16 = torch.nn.BatchNorm2d(16)
self.batchNorm2d_16_2 = nn.BatchNorm2d(16) self.batchNorm2d_16_2 = torch.nn.BatchNorm2d(16)
# Second convolution block # Second convolution block
self.fc4 = nn.Conv2d(16, 24, (3, 3), (1, 1), (1, 1)) self.fc4 = torch.nn.Conv2d(16, 24, (3, 3), (1, 1), (1, 1))
self.fc5 = nn.Conv2d(24, 32, (3, 3), (1, 1), (1, 1)) self.fc5 = torch.nn.Conv2d(24, 32, (3, 3), (1, 1), (1, 1))
self.fc6 = nn.Conv2d(16, 32, (1, 1), (1, 1)) # Original stride (2, 2) self.fc6 = torch.nn.Conv2d(
16, 32, (1, 1), (1, 1)
) # Original stride (2, 2)
self.batchNorm2d_24 = nn.BatchNorm2d(24) self.batchNorm2d_24 = torch.nn.BatchNorm2d(24)
self.batchNorm2d_32 = nn.BatchNorm2d(32) self.batchNorm2d_32 = torch.nn.BatchNorm2d(32)
self.batchNorm2d_32_2 = nn.BatchNorm2d(32) self.batchNorm2d_32_2 = torch.nn.BatchNorm2d(32)
# Third convolution block # Third convolution block
self.fc7 = nn.Conv2d(32, 40, (3, 3), (1, 1), (1, 1)) self.fc7 = torch.nn.Conv2d(32, 40, (3, 3), (1, 1), (1, 1))
self.fc8 = nn.Conv2d(40, 48, (3, 3), (1, 1), (1, 1)) self.fc8 = torch.nn.Conv2d(40, 48, (3, 3), (1, 1), (1, 1))
self.fc9 = nn.Conv2d(32, 48, (1, 1), (1, 1)) # Original stride (2, 2) self.fc9 = torch.nn.Conv2d(
32, 48, (1, 1), (1, 1)
) # Original stride (2, 2)
self.batchNorm2d_40 = nn.BatchNorm2d(40) self.batchNorm2d_40 = torch.nn.BatchNorm2d(40)
self.batchNorm2d_48 = nn.BatchNorm2d(48) self.batchNorm2d_48 = torch.nn.BatchNorm2d(48)
self.batchNorm2d_48_2 = nn.BatchNorm2d(48) self.batchNorm2d_48_2 = torch.nn.BatchNorm2d(48)
# Fourth convolution block # Fourth convolution block
self.fc10 = nn.Conv2d(48, 56, (3, 3), (1, 1), (1, 1)) self.fc10 = torch.nn.Conv2d(48, 56, (3, 3), (1, 1), (1, 1))
self.fc11 = nn.Conv2d(56, 64, (3, 3), (1, 1), (1, 1)) self.fc11 = torch.nn.Conv2d(56, 64, (3, 3), (1, 1), (1, 1))
self.fc12 = nn.Conv2d(48, 64, (1, 1), (1, 1)) # Original stride (2, 2) self.fc12 = torch.nn.Conv2d(
48, 64, (1, 1), (1, 1)
) # Original stride (2, 2)
self.batchNorm2d_56 = nn.BatchNorm2d(56) self.batchNorm2d_56 = torch.nn.BatchNorm2d(56)
self.batchNorm2d_64 = nn.BatchNorm2d(64) self.batchNorm2d_64 = torch.nn.BatchNorm2d(64)
self.batchNorm2d_64_2 = nn.BatchNorm2d(64) self.batchNorm2d_64_2 = torch.nn.BatchNorm2d(64)
# Fifth convolution block # Fifth convolution block
self.fc13 = nn.Conv2d(64, 72, (3, 3), (1, 1), (1, 1)) self.fc13 = torch.nn.Conv2d(64, 72, (3, 3), (1, 1), (1, 1))
self.fc14 = nn.Conv2d(72, 80, (3, 3), (1, 1), (1, 1)) self.fc14 = torch.nn.Conv2d(72, 80, (3, 3), (1, 1), (1, 1))
self.fc15 = nn.Conv2d(64, 80, (1, 1), (1, 1)) # Original stride (2, 2) self.fc15 = torch.nn.Conv2d(
64, 80, (1, 1), (1, 1)
) # Original stride (2, 2)
self.batchNorm2d_72 = nn.BatchNorm2d(72) self.batchNorm2d_72 = torch.nn.BatchNorm2d(72)
self.batchNorm2d_80 = nn.BatchNorm2d(80) self.batchNorm2d_80 = torch.nn.BatchNorm2d(80)
self.batchNorm2d_80_2 = nn.BatchNorm2d(80) self.batchNorm2d_80_2 = torch.nn.BatchNorm2d(80)
self.pool2d = nn.MaxPool2d((3, 3), (2, 2)) # Pool after conv. block self.pool2d = torch.nn.MaxPool2d(
self.dense = nn.Linear(80, 1) # Fully connected layer (3, 3), (2, 2)
) # Pool after conv. block
self.dense = torch.nn.Linear(80, 1) # Fully connected layer
def forward(self, x): def forward(self, x):
x = self.normalizer(x) # type: ignore x = self.normalizer(x) # type: ignore
...@@ -141,51 +186,22 @@ class PASA(pl.LightningModule): ...@@ -141,51 +186,22 @@ class PASA(pl.LightningModule):
return x return x
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initializes the normalizer for the current model. """Initializes the input normalizer for the current model.
Parameters Parameters
---------- ----------
dataloader: :py:class:`torch.utils.data.DataLoader` dataloader
A torch Dataloader from which to compute the mean and std A torch Dataloader from which to compute the mean and std
""" """
from .normalizer import make_z_normalizer from .normalizer import make_z_normalizer
logger.info( logger.info(
"Uninitialised densenet model - " f"Uninitialised {self.name} model - "
"computing z-norm factors from training data." f"computing z-norm factors from train dataloader."
) )
self.normalizer = make_z_normalizer(dataloader) self.normalizer = make_z_normalizer(dataloader)
def set_bce_loss_weights(self, datamodule):
"""Reweights loss weights if BCEWithLogitsLoss is used.
Parameters
----------
datamodule:
A datamodule implementing train_dataloader() and val_dataloader()
"""
from ..data.dataset import _get_positive_weights
if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training criterion.")
train_positive_weights = _get_positive_weights(
datamodule.train_dataloader()
)
self.criterion = torch.nn.BCEWithLogitsLoss(
pos_weight=train_positive_weights
)
if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
validation_positive_weights = _get_positive_weights(
datamodule.val_dataloader()["validation"]
)
self.criterion_valid = torch.nn.BCEWithLogitsLoss(
pos_weight=validation_positive_weights
)
def training_step(self, batch, _): def training_step(self, batch, _):
images = batch[0] images = batch[0]
labels = batch[1]["label"] labels = batch[1]["label"]
...@@ -197,15 +213,13 @@ class PASA(pl.LightningModule): ...@@ -197,15 +213,13 @@ class PASA(pl.LightningModule):
# Forward pass on the network # Forward pass on the network
augmented_images = [ augmented_images = [
self.augmentation_transforms(img).to(self.device) for img in images self._augmentation_transforms(img).to(self.device) for img in images
] ]
# Combine list of augmented images back into a tensor # Combine list of augmented images back into a tensor
augmented_images = torch.cat(augmented_images, 0).view(images.shape) augmented_images = torch.cat(augmented_images, 0).view(images.shape)
outputs = self(augmented_images) outputs = self(augmented_images)
training_loss = self.criterion(outputs, labels.double()) return self._train_loss(outputs, labels.float())
return {"loss": training_loss}
def validation_step(self, batch, batch_idx, dataloader_idx=0): def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[0] images = batch[0]
...@@ -219,12 +233,7 @@ class PASA(pl.LightningModule): ...@@ -219,12 +233,7 @@ class PASA(pl.LightningModule):
# data forwarding on the existing network # data forwarding on the existing network
outputs = self(images) outputs = self(images)
validation_loss = self.criterion_valid(outputs, labels.double()) return self._validation_loss(outputs, labels.float())
if dataloader_idx == 0:
return {"validation_loss": validation_loss}
else:
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
images = batch[0] images = batch[0]
...@@ -234,40 +243,13 @@ class PASA(pl.LightningModule): ...@@ -234,40 +243,13 @@ class PASA(pl.LightningModule):
outputs = self(images) outputs = self(images)
probabilities = torch.sigmoid(outputs) probabilities = torch.sigmoid(outputs)
# necessary check for HED architecture that uses several outputs return (
# for loss calculation instead of just the last concatfuse block
if isinstance(outputs, list):
outputs = outputs[-1]
results = (
names[0], names[0],
torch.flatten(probabilities), torch.flatten(probabilities),
torch.flatten(labels), torch.flatten(labels),
) )
return results
# {
# f"dataloader_{dataloader_idx}_predictions": (
# names[0],
# torch.flatten(probabilities),
# torch.flatten(labels),
# )
# }
# def on_predict_epoch_end(self):
# retval = defaultdict(list)
# for dataloader_name, predictions in self.predictions_cache.items():
# for prediction in predictions:
# retval[dataloader_name]["name"].append(prediction[0])
# retval[dataloader_name]["prediction"].append(prediction[1])
# retval[dataloader_name]["label"].append(prediction[2])
# Need to cache predictions in the predict step, then reorder by key
# Clear prediction dict
# raise NotImplementedError
def configure_optimizers(self): def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), **self.optimizer_configs) return self._optimizer_type(
return optimizer self.parameters(), **self._optimizer_arguments
)
import glob
import logging
import os
import sys
import time
import pkg_resources
logger = logging.getLogger(__name__)
def save_sh_command(output_dir):
"""Records command-line to reproduce this experiment.
This function can record the current command-line used to call the script
being run. It creates an executable ``bash`` script setting up the current
working directory and activating a conda environment, if needed. It
records further information on the date and time the script was run and the
version of the package.
Parameters
----------
output_folder : str
Path leading to the directory where the commands to reproduce the current
run will be recorded. A subdirectory will be created each time this function
is called to match lightning's versioning convention for loggers.
"""
cmd_config_dir = os.path.join(output_dir, "cmd_line_configs")
cmd_config_versions = glob.glob(os.path.join(cmd_config_dir, "version_*"))
if len(cmd_config_versions) > 0:
latest_cmd_config_version = max(
[
int(config.split("version_")[-1])
for config in cmd_config_versions
]
)
current_cmd_config_version = str(latest_cmd_config_version + 1)
else:
current_cmd_config_version = "0"
destfile = os.path.join(
cmd_config_dir,
f"version_{current_cmd_config_version}",
"cmd_line_config.txt",
)
if os.path.exists(destfile):
logger.info(f"Not overwriting existing file '{destfile}'")
return
logger.info(f"Writing command-line for reproduction at '{destfile}'...")
os.makedirs(os.path.dirname(destfile), exist_ok=True)
with open(destfile, "w") as f:
f.write("#!/usr/bin/env sh\n")
f.write(f"# date: {time.asctime()}\n")
version = pkg_resources.require("ptbench")[0].version
f.write(f"# version: {version} (deepdraw)\n")
f.write(f"# platform: {sys.platform}\n")
f.write("\n")
args = []
for k in sys.argv:
if " " in k:
args.append(f'"{k}"')
else:
args.append(k)
if os.environ.get("CONDA_DEFAULT_ENV") is not None:
f.write(f"#conda activate {os.environ['CONDA_DEFAULT_ENV']}\n")
f.write(f"#cd {os.path.realpath(os.curdir)}\n")
f.write(" ".join(args) + "\n")
os.chmod(destfile, 0o755)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment