Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • medai/software/mednet
1 result
Show changes
......@@ -2,12 +2,15 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms
from .normalizer import TorchVisionNormalizer
logger = logging.getLogger(__name__)
class Alexnet(pl.LightningModule):
......@@ -18,25 +21,38 @@ class Alexnet(pl.LightningModule):
def __init__(
self,
criterion,
criterion_valid,
optimizer,
optimizer_configs,
criterion=None,
criterion_valid=None,
optimizer=None,
optimizer_configs=None,
pretrained=False,
augmentation_transforms=[],
):
super().__init__()
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.name = "AlexNet"
# Load pretrained model
weights = (
None if pretrained is False else models.AlexNet_Weights.DEFAULT
self.augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms
)
self.model_ft = models.alexnet(weights=weights)
self.normalizer = TorchVisionNormalizer(nb_channels=1)
self.criterion = criterion
self.criterion_valid = criterion_valid
self.optimizer = optimizer
self.optimizer_configs = optimizer_configs
self.normalizer = None
self.pretrained = pretrained
# Load pretrained model
if not pretrained:
weights = None
else:
logger.info("Loading pretrained model weights")
weights = models.AlexNet_Weights.DEFAULT
self.model_ft = models.alexnet(weights=weights)
# Adapt output features
self.model_ft.classifier[4] = nn.Linear(4096, 512)
......@@ -48,9 +64,69 @@ class Alexnet(pl.LightningModule):
return x
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initializes the normalizer for the current model.
This function is NOOP if ``pretrained = True`` (normalizer set to
imagenet weights, during contruction).
Parameters
----------
dataloader: :py:class:`torch.utils.data.DataLoader`
A torch Dataloader from which to compute the mean and std.
Will not be used if the model is pretrained.
"""
if self.pretrained:
from .normalizer import make_imagenet_normalizer
logger.warning(
"ImageNet pre-trained densenet model - NOT "
"computing z-norm factors from training data. "
"Using preset factors from torchvision."
)
self.normalizer = make_imagenet_normalizer()
else:
from .normalizer import make_z_normalizer
logger.info(
"Uninitialised densenet model - "
"computing z-norm factors from training data."
)
self.normalizer = make_z_normalizer(dataloader)
def set_bce_loss_weights(self, datamodule):
"""Reweights loss weights if BCEWithLogitsLoss is used.
Parameters
----------
datamodule:
A datamodule implementing train_dataloader() and val_dataloader()
"""
from ..data.dataset import _get_positive_weights
if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training criterion.")
train_positive_weights = _get_positive_weights(
datamodule.train_dataloader()
)
self.criterion = torch.nn.BCEWithLogitsLoss(
pos_weight=train_positive_weights
)
if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
validation_positive_weights = _get_positive_weights(
datamodule.val_dataloader()["validation"]
)
self.criterion_valid = torch.nn.BCEWithLogitsLoss(
pos_weight=validation_positive_weights
)
def training_step(self, batch, batch_idx):
images = batch[1]
labels = batch[2]
images = batch[0]
labels = batch[1]["label"]
# Increase label dimension if too low
# Allows single and multiclass usage
......@@ -58,17 +134,20 @@ class Alexnet(pl.LightningModule):
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(images)
augmented_images = [
self.augmentation_transforms(img).to(self.device) for img in images
]
# Combine list of augmented images back into a tensor
augmented_images = torch.cat(augmented_images, 0).view(images.shape)
outputs = self(augmented_images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion = self.hparams.criterion.to(self.device)
training_loss = self.hparams.criterion(outputs, labels.float())
training_loss = self.criterion(outputs, labels.float())
return {"loss": training_loss}
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[1]
labels = batch[2]
images = batch[0]
labels = batch[1]["label"]
# Increase label dimension if too low
# Allows single and multiclass usage
......@@ -78,11 +157,7 @@ class Alexnet(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion_valid = self.hparams.criterion_valid.to(
self.device
)
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
validation_loss = self.criterion_valid(outputs, labels.float())
if dataloader_idx == 0:
return {"validation_loss": validation_loss}
......@@ -90,8 +165,9 @@ class Alexnet(pl.LightningModule):
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch[0]
images = batch[1]
images = batch[0]
labels = batch[1]["label"]
names = batch[1]["name"]
outputs = self(images)
probabilities = torch.sigmoid(outputs)
......@@ -101,11 +177,8 @@ class Alexnet(pl.LightningModule):
if isinstance(outputs, list):
outputs = outputs[-1]
return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
return names[0], torch.flatten(probabilities), torch.flatten(labels)
def configure_optimizers(self):
optimizer = getattr(torch.optim, self.hparams.optimizer)(
self.parameters(), **self.hparams.optimizer_configs
)
optimizer = self.optimizer(self.parameters(), **self.optimizer_configs)
return optimizer
......@@ -2,12 +2,15 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms
from .normalizer import TorchVisionNormalizer
logger = logging.getLogger(__name__)
class Densenet(pl.LightningModule):
......@@ -18,23 +21,37 @@ class Densenet(pl.LightningModule):
def __init__(
self,
criterion,
criterion_valid,
optimizer,
optimizer_configs,
criterion=None,
criterion_valid=None,
optimizer=None,
optimizer_configs=None,
pretrained=False,
nb_channels=3,
augmentation_transforms=[],
):
super().__init__()
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.name = "Densenet"
self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels)
self.augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms
)
self.criterion = criterion
self.criterion_valid = criterion_valid
self.optimizer = optimizer
self.optimizer_configs = optimizer_configs
self.normalizer = None
self.pretrained = pretrained
# Load pretrained model
weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT
if not pretrained:
weights = None
else:
logger.info("Loading pretrained model weights")
weights = models.DenseNet121_Weights.DEFAULT
self.model_ft = models.densenet121(weights=weights)
# Adapt output features
......@@ -48,9 +65,69 @@ class Densenet(pl.LightningModule):
return x
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initializes the normalizer for the current model.
This function is NOOP if ``pretrained = True`` (normalizer set to
imagenet weights, during contruction).
Parameters
----------
dataloader: :py:class:`torch.utils.data.DataLoader`
A torch Dataloader from which to compute the mean and std.
Will not be used if the model is pretrained.
"""
if self.pretrained:
from .normalizer import make_imagenet_normalizer
logger.warning(
"ImageNet pre-trained densenet model - NOT "
"computing z-norm factors from training data. "
"Using preset factors from torchvision."
)
self.normalizer = make_imagenet_normalizer()
else:
from .normalizer import make_z_normalizer
logger.info(
"Uninitialised densenet model - "
"computing z-norm factors from training data."
)
self.normalizer = make_z_normalizer(dataloader)
def set_bce_loss_weights(self, datamodule):
"""Reweights loss weights if BCEWithLogitsLoss is used.
Parameters
----------
datamodule:
A datamodule implementing train_dataloader() and val_dataloader()
"""
from ..data.dataset import _get_positive_weights
if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training criterion.")
train_positive_weights = _get_positive_weights(
datamodule.train_dataloader()
)
self.criterion = torch.nn.BCEWithLogitsLoss(
pos_weight=train_positive_weights
)
if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
validation_positive_weights = _get_positive_weights(
datamodule.val_dataloader()["validation"]
)
self.criterion_valid = torch.nn.BCEWithLogitsLoss(
pos_weight=validation_positive_weights
)
def training_step(self, batch, batch_idx):
images = batch[1]
labels = batch[2]
images = batch[0]
labels = batch[1]["label"]
# Increase label dimension if too low
# Allows single and multiclass usage
......@@ -58,17 +135,20 @@ class Densenet(pl.LightningModule):
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(images)
augmented_images = [
self.augmentation_transforms(img).to(self.device) for img in images
]
# Combine list of augmented images back into a tensor
augmented_images = torch.cat(augmented_images, 0).view(images.shape)
outputs = self(augmented_images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion = self.hparams.criterion.to(self.device)
training_loss = self.hparams.criterion(outputs, labels.float())
training_loss = self.criterion(outputs, labels.float())
return {"loss": training_loss}
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[1]
labels = batch[2]
images = batch[0]
labels = batch[1]["label"]
# Increase label dimension if too low
# Allows single and multiclass usage
......@@ -78,11 +158,7 @@ class Densenet(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion_valid = self.hparams.criterion_valid.to(
self.device
)
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
validation_loss = self.criterion_valid(outputs, labels.float())
if dataloader_idx == 0:
return {"validation_loss": validation_loss}
......@@ -90,8 +166,9 @@ class Densenet(pl.LightningModule):
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch[0]
images = batch[1]
images = batch[0]
labels = batch[1]["label"]
names = batch[1]["name"]
outputs = self(images)
probabilities = torch.sigmoid(outputs)
......@@ -101,12 +178,8 @@ class Densenet(pl.LightningModule):
if isinstance(outputs, list):
outputs = outputs[-1]
return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
return names[0], torch.flatten(probabilities), torch.flatten(labels)
def configure_optimizers(self):
# Dynamically instantiates the optimizer given the configs
optimizer = getattr(torch.optim, self.hparams.optimizer)(
self.parameters(), **self.hparams.optimizer_configs
)
optimizer = self.optimizer(self.parameters(), **self.optimizer_configs)
return optimizer
......@@ -2,37 +2,73 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""A network model that prefixes a z-normalization step to any other module."""
"""A network model that prefixes a subtract/divide step to any other module."""
import torch
import torch.nn
import torch.utils.data
import torchvision.transforms
class TorchVisionNormalizer(torch.nn.Module):
"""A simple normalizer that applies the standard torchvision normalization.
def make_z_normalizer(
dataloader: torch.utils.data.DataLoader,
) -> torchvision.transforms.Normalize:
"""Computes mean and standard deviation from a dataloader.
This function will input a dataloader, and compute the mean and standard
deviation by image channel. It will work for both monochromatic, and color
inputs with 2, 3 or more color planes.
This module does not learn.
Parameters
----------
nb_channels : :py:class:`int`, Optional
Number of images channels fed to the model
dataloader:
A torch Dataloader from which to compute the mean and std
Returns
-------
An initialized normalizer
"""
# Peek the number of channels of batches in the data loader
batch = next(iter(dataloader))
channels = batch[0].shape[1]
# Initialises accumulators
mean = torch.zeros(channels, dtype=batch[0].dtype)
var = torch.zeros(channels, dtype=batch[0].dtype)
num_images = 0
# Evaluates mean and standard deviation
for batch in dataloader:
data = batch[0]
data = data.view(data.size(0), data.size(1), -1)
num_images += data.size(0)
mean += data.mean(2).sum(0)
var += data.var(2).sum(0)
mean /= num_images
var /= num_images
std = torch.sqrt(var)
return torchvision.transforms.Normalize(mean, std)
def make_imagenet_normalizer() -> torchvision.transforms.Normalize:
"""Returns the stock ImageNet normalisation weights from torchvision.
The weights are wrapped in a torch module. This normalizer only works for
**RGB (color) images**.
Returns
-------
An initialized normalizer
"""
def __init__(self, nb_channels=3):
super().__init__()
mean = torch.zeros(nb_channels)[None, :, None, None]
std = torch.ones(nb_channels)[None, :, None, None]
self.register_buffer("mean", mean)
self.register_buffer("std", std)
self.name = "torchvision-normalizer"
def set_mean_std(self, mean, std):
mean = torch.as_tensor(mean)[None, :, None, None]
std = torch.as_tensor(std)[None, :, None, None]
self.register_buffer("mean", mean)
self.register_buffer("std", std)
def forward(self, inputs):
return inputs.sub(self.mean).div(self.std)
return torchvision.transforms.Normalize(
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
)
......@@ -2,12 +2,16 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.transforms
from .normalizer import TorchVisionNormalizer
logger = logging.getLogger(__name__)
class PASA(pl.LightningModule):
......@@ -22,14 +26,23 @@ class PASA(pl.LightningModule):
criterion_valid,
optimizer,
optimizer_configs,
augmentation_transforms,
):
super().__init__()
self.save_hyperparameters()
self.name = "pasa"
self.normalizer = TorchVisionNormalizer(nb_channels=1)
self.augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms
)
self.criterion = criterion
self.criterion_valid = criterion_valid
self.optimizer = optimizer
self.optimizer_configs = optimizer_configs
self.normalizer = None
# First convolution block
self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
......@@ -80,7 +93,7 @@ class PASA(pl.LightningModule):
self.dense = nn.Linear(80, 1) # Fully connected layer
def forward(self, x):
x = self.normalizer(x)
x = self.normalizer(x) # type: ignore
# First convolution block
_x = x
......@@ -127,9 +140,55 @@ class PASA(pl.LightningModule):
return x
def training_step(self, batch, batch_idx):
images = batch["data"]
labels = batch["label"]
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initializes the normalizer for the current model.
Parameters
----------
dataloader: :py:class:`torch.utils.data.DataLoader`
A torch Dataloader from which to compute the mean and std
"""
from .normalizer import make_z_normalizer
logger.info(
"Uninitialised densenet model - "
"computing z-norm factors from training data."
)
self.normalizer = make_z_normalizer(dataloader)
def set_bce_loss_weights(self, datamodule):
"""Reweights loss weights if BCEWithLogitsLoss is used.
Parameters
----------
datamodule:
A datamodule implementing train_dataloader() and val_dataloader()
"""
from ..data.dataset import _get_positive_weights
if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training criterion.")
train_positive_weights = _get_positive_weights(
datamodule.train_dataloader()
)
self.criterion = torch.nn.BCEWithLogitsLoss(
pos_weight=train_positive_weights
)
if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
validation_positive_weights = _get_positive_weights(
datamodule.val_dataloader()["validation"]
)
self.criterion_valid = torch.nn.BCEWithLogitsLoss(
pos_weight=validation_positive_weights
)
def training_step(self, batch, _):
images = batch[0]
labels = batch[1]["label"]
# Increase label dimension if too low
# Allows single and multiclass usage
......@@ -137,17 +196,20 @@ class PASA(pl.LightningModule):
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(images)
augmented_images = [
self.augmentation_transforms(img).to(self.device) for img in images
]
# Combine list of augmented images back into a tensor
augmented_images = torch.cat(augmented_images, 0).view(images.shape)
outputs = self(augmented_images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion = self.hparams.criterion.to(self.device)
training_loss = self.hparams.criterion(outputs, labels.double())
training_loss = self.criterion(outputs, labels.double())
return {"loss": training_loss}
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch["data"]
labels = batch["label"]
images = batch[0]
labels = batch[1]["label"]
# Increase label dimension if too low
# Allows single and multiclass usage
......@@ -157,11 +219,7 @@ class PASA(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion_valid = self.hparams.criterion_valid.to(
self.device
)
validation_loss = self.hparams.criterion_valid(outputs, labels.double())
validation_loss = self.criterion_valid(outputs, labels.double())
if dataloader_idx == 0:
return {"validation_loss": validation_loss}
......@@ -169,9 +227,9 @@ class PASA(pl.LightningModule):
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch["name"]
images = batch["data"]
labels = batch["label"]
images = batch[0]
labels = batch[1]["label"]
names = batch[1]["names"]
outputs = self(images)
probabilities = torch.sigmoid(outputs)
......@@ -211,9 +269,5 @@ class PASA(pl.LightningModule):
# raise NotImplementedError
def configure_optimizers(self):
# Dynamically instantiates the optimizer given the configs
optimizer = getattr(torch.optim, self.hparams.optimizer)(
self.parameters(), **self.hparams.optimizer_configs
)
optimizer = self.optimizer(self.parameters(), **self.optimizer_configs)
return optimizer
......@@ -6,9 +6,6 @@ import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup
from lightning.pytorch import seed_everything
from ..utils.checkpointer import get_checkpoint
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
......@@ -65,21 +62,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
required=True,
cls=ResourceOption,
)
@click.option(
"--criterion",
help="A loss function to compute the CNN error for every sample "
"respecting the PyTorch API for loss functions (see torch.nn.modules.loss)",
required=True,
cls=ResourceOption,
)
@click.option(
"--criterion-valid",
help="A specific loss function for the validation set to compute the CNN"
"error for every sample respecting the PyTorch API for loss functions"
"(see torch.nn.modules.loss)",
required=False,
cls=ResourceOption,
)
@click.option(
"--batch-size",
"-b",
......@@ -159,7 +141,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@click.option(
"--accelerator",
"-a",
help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)',
help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). '
"The device can also be specified (gpu:0)",
show_default=True,
required=True,
default="cpu",
......@@ -167,7 +150,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
)
@click.option(
"--cache-samples",
help="If set to True, loads the sample into memory, otherwise loads them at runtime.",
help="If set to True, loads the sample into memory, "
"otherwise loads them at runtime.",
required=True,
show_default=True,
default=False,
......@@ -196,16 +180,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
default=-1,
cls=ResourceOption,
)
@click.option(
"--normalization",
"-n",
help="Z-Normalization of input images: 'imagenet' for ImageNet parameters,"
" 'current' for parameters of the current trainset, "
"'none' for no normalization.",
required=False,
default="none",
cls=ResourceOption,
)
@click.option(
"--monitoring-interval",
"-I",
......@@ -224,7 +198,8 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
)
@click.option(
"--resume-from",
help="Which checkpoint to resume training from. Can be one of 'None', 'best', 'last', or a path to a model checkpoint.",
help="Which checkpoint to resume training from. If set, can be one of "
"`best`, `last`, or a path to a model checkpoint.",
type=str,
required=False,
default=None,
......@@ -238,20 +213,17 @@ def train(
batch_size,
batch_chunk_count,
drop_incomplete_batch,
criterion,
criterion_valid,
datamodule,
checkpoint_period,
accelerator,
cache_samples,
seed,
parallel,
normalization,
monitoring_interval,
resume_from,
**_,
):
"""Trains an CNN to perform tuberculosis detection.
"""Trains an CNN to perform image classification.
Training is performed for a configurable number of epochs, and
generates at least a final_model.pth. It may also generate a number
......@@ -263,26 +235,38 @@ def train(
import torch.cuda
import torch.nn
from ..data.dataset import normalize_data, reweight_BCEWithLogitsLoss
from lightning.pytorch import seed_everything
from ..engine.trainer import run
from ..utils.checkpointer import get_checkpoint
seed_everything(seed)
checkpoint_file = get_checkpoint(output_folder, resume_from)
datamodule.update_module_properties(
batch_size=batch_size,
batch_chunk_count=batch_chunk_count,
drop_incomplete_batch=drop_incomplete_batch,
cache_samples=cache_samples,
parallel=parallel,
)
# reset datamodule with user configurable options
datamodule.set_chunk_size(batch_size, batch_chunk_count)
datamodule.drop_incomplete_batch = drop_incomplete_batch
datamodule.cache_samples = cache_samples
datamodule.parallel = parallel
datamodule.prepare_data()
datamodule.setup(stage="fit")
reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid)
normalize_data(normalization, model, datamodule)
# Sets the model normalizer with the unaugmented-train-subset.
# this call may be a NOOP, if the model was pre-trained and expects
# different weights for the normalisation layer.
if hasattr(model, "set_normalizer"):
model.set_normalizer(datamodule.train_dataloader())
else:
logger.warning(
f"Model {model.name} has no 'set_normalizer' method. No normalization will be applied."
)
# Rebalances the loss criterion based on the relative proportion of class
# examples available in the training set. Also affects the validation loss
# if a validation set is available on the data module.
model.set_bce_loss_weights(datamodule)
arguments = {}
arguments["max_epoch"] = epochs
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import os
from typing import Union
import torch
from PIL.Image import Image
from torchvision import transforms
def save_image(img: Union[torch.Tensor, Image], filepath: str) -> None:
"""Saves a PIL image or a tensor as an image at the specified destination.
Parameters
----------
img:
A torch.Tensor or PIL.Image to save
filepath:
The file in which to save the image. The format is inferred from the file extension, or defaults to png if not specified.
"""
if isinstance(img, torch.Tensor):
img = transforms.ToPILImage()(img)
root, ext = os.path.splitext(filepath)
if len(ext) == 0:
filepath = filepath + ".png"
img.save(filepath)
import glob
import logging
import os
import sys
import time
import pkg_resources
logger = logging.getLogger(__name__)
def save_sh_command(output_dir):
"""Records command-line to reproduce this experiment.
This function can record the current command-line used to call the script
being run. It creates an executable ``bash`` script setting up the current
working directory and activating a conda environment, if needed. It
records further information on the date and time the script was run and the
version of the package.
Parameters
----------
output_folder : str
Path leading to the directory where the commands to reproduce the current
run will be recorded. A subdirectory will be created each time this function
is called to match lightning's versioning convention for loggers.
"""
cmd_config_dir = os.path.join(output_dir, "cmd_line_configs")
cmd_config_versions = glob.glob(os.path.join(cmd_config_dir, "version_*"))
if len(cmd_config_versions) > 0:
latest_cmd_config_version = max(
[
int(config.split("version_")[-1])
for config in cmd_config_versions
]
)
current_cmd_config_version = str(latest_cmd_config_version + 1)
else:
current_cmd_config_version = "0"
destfile = os.path.join(
cmd_config_dir,
f"version_{current_cmd_config_version}",
"cmd_line_config.txt",
)
if os.path.exists(destfile):
logger.info(f"Not overwriting existing file '{destfile}'")
return
logger.info(f"Writing command-line for reproduction at '{destfile}'...")
os.makedirs(os.path.dirname(destfile), exist_ok=True)
with open(destfile, "w") as f:
f.write("#!/usr/bin/env sh\n")
f.write(f"# date: {time.asctime()}\n")
version = pkg_resources.require("ptbench")[0].version
f.write(f"# version: {version} (deepdraw)\n")
f.write(f"# platform: {sys.platform}\n")
f.write("\n")
args = []
for k in sys.argv:
if " " in k:
args.append(f'"{k}"')
else:
args.append(k)
if os.environ.get("CONDA_DEFAULT_ENV") is not None:
f.write(f"#conda activate {os.environ['CONDA_DEFAULT_ENV']}\n")
f.write(f"#cd {os.path.realpath(os.curdir)}\n")
f.write(" ".join(args) + "\n")
os.chmod(destfile, 0o755)