diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 6670a3099bd4f855f89eba0001380421a08e176f..633ebd62c476b7674943a4cfb97f79109eae1e15 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -95,7 +95,11 @@ class _DelayedLoadingDataset(torch.utils.data.Dataset): ): self.split = split self.raw_data_loader = raw_data_loader - self.transform = torchvision.transforms.Compose(*transforms) + # Cannot unpack empty list + if len(transforms) > 0: + self.transform = torchvision.transforms.Compose([*transforms]) + else: + self.transform = torchvision.transforms.Compose([]) def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]: tensor, metadata = self.raw_data_loader(self.split[key]) diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 52bd27f07b16131deea9e3d0aa11c8c1cf294f88..bc8b3a9ddb192e432a5345c96b3847379e58aa9d 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -57,7 +57,10 @@ class Densenet(pl.LightningModule): if self.pretrained: from .normalizer import TorchVisionNormalizer - self.normalizer = TorchVisionNormalizer(..., ...) + self.normalizer = TorchVisionNormalizer( + torch.Tensor([0.485, 0.456, 0.406]), + torch.Tensor([0.229, 0.224, 0.225]), + ) else: from .normalizer import get_znorm_normalizer diff --git a/src/ptbench/models/normalizer.py b/src/ptbench/models/normalizer.py index 10320ab1a7ef2aa3bd2a7b1b43566a21da865150..b9ba7eb3d81de5c61349f5d4b0cbcf9d8fee9d7f 100644 --- a/src/ptbench/models/normalizer.py +++ b/src/ptbench/models/normalizer.py @@ -23,12 +23,17 @@ class TorchVisionNormalizer(torch.nn.Module): def __init__(self, subtract: torch.Tensor, divide: torch.Tensor): super().__init__() - assert len(subtract) == len(divide), "TODO" - assert len(subtract) in (1, 3), "TODO" - self.subtract = subtract - self.divided = divide - subtract = torch.zeros(len(subtract.shape))[None, :, None, None] - divide = torch.ones(len(divide.shape))[None, :, None, None] + if len(subtract) != len(divide): + raise ValueError( + "Lengths of 'subtract' and 'divide' tensors should be the same." + ) + if len(subtract) not in (1, 3): + raise ValueError( + "Length of 'subtract' tensor should be either 1 or 3, depending on the number of color channels." + ) + + subtract = torch.as_tensor(subtract)[None, :, None, None] + divide = torch.as_tensor(divide)[None, :, None, None] self.register_buffer("subtract", subtract) self.register_buffer("divide", divide) self.name = "torchvision-normalizer" @@ -41,13 +46,35 @@ class TorchVisionNormalizer(torch.nn.Module): def get_znorm_normalizer( dataloader: torch.utils.data.DataLoader, ) -> TorchVisionNormalizer: - # TODO: Fix this function to use unaugmented training set - # TODO: This function is only applicable IFF we are not fine-tuning (ie. - # model does not re-use weights from imagenet training!) - # TODO: Add type hints - # TODO: Add documentation + """Returns a normalizer with the mean and std computed from a dataloader's + unaugmented training set. + + Parameters + ---------- + + dataloader: + A torch Dataloader from which to compute the mean and std + + Returns + ------- + An initialized TorchVisionNormalizer + """ + + mean = 0.0 + var = 0.0 + num_images = 0 + + for batch in dataloader: + data = batch[0] + data = data.view(data.size(0), data.size(1), -1) + + num_images += data.size(0) + mean += data.mean(2).sum(0) + var += data.var(2).sum(0) - # 1 extract mean/std from dataloader + mean /= num_images + var /= num_images + std = torch.sqrt(var) - # 2 return TorchVisionNormalizer(mean, std) - pass + normalizer = TorchVisionNormalizer(mean, std) + return normalizer diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index d14ea5d4f745a4f237c3cc96337d862013318188..cca494f141ca625caa0d4872802aa266f6e70a0f 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -8,8 +8,6 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.data -from .normalizer import TorchVisionNormalizer - class PASA(pl.LightningModule): """PASA module. @@ -30,7 +28,7 @@ class PASA(pl.LightningModule): self.name = "pasa" - self.normalizer = TorchVisionNormalizer(nb_channels=1) + self.normalizer = None # First convolution block self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1)) @@ -81,6 +79,11 @@ class PASA(pl.LightningModule): self.dense = nn.Linear(80, 1) # Fully connected layer def forward(self, x): + if self.normalizer is None: + raise TypeError( + "The normalizer has not been initialized. Make sure to call set_normalizer() after creation of the model." + ) + x = self.normalizer(x) # First convolution block @@ -129,14 +132,21 @@ class PASA(pl.LightningModule): return x def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: - """TODO: Write this function documentation""" + """Initializes the normalizer for the current model. + + Parameters + ---------- + + dataloader: + A torch Dataloader from which to compute the mean and std + """ from .normalizer import get_znorm_normalizer self.normalizer = get_znorm_normalizer(dataloader) def training_step(self, batch, batch_idx): - images = batch["data"] - labels = batch["label"] + images = batch[0] + labels = batch[1]["label"] # Increase label dimension if too low # Allows single and multiclass usage @@ -153,8 +163,8 @@ class PASA(pl.LightningModule): return {"loss": training_loss} def validation_step(self, batch, batch_idx, dataloader_idx=0): - images = batch["data"] - labels = batch["label"] + images = batch[0] + labels = batch[1]["label"] # Increase label dimension if too low # Allows single and multiclass usage @@ -176,9 +186,9 @@ class PASA(pl.LightningModule): return {f"extra_validation_loss_{dataloader_idx}": validation_loss} def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): - names = batch["name"] - images = batch["data"] - labels = batch["label"] + images = batch[0] + labels = batch[1]["label"] + names = batch[1]["names"] outputs = self(images) probabilities = torch.sigmoid(outputs) diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 4d2a226b5b0b479b3f84c748d24aebd43d8d6dea..eeb5b8696ca9a72f9828021d37e8420d56fdf1d3 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -263,26 +263,21 @@ def train( import torch.cuda import torch.nn - from ..data.dataset import normalize_data, reweight_BCEWithLogitsLoss + from ..data.dataset import reweight_BCEWithLogitsLoss from ..engine.trainer import run seed_everything(seed) checkpoint_file = get_checkpoint(output_folder, resume_from) - datamodule.update_module_properties( - batch_size=batch_size, - batch_chunk_count=batch_chunk_count, - drop_incomplete_batch=drop_incomplete_batch, - cache_samples=cache_samples, - parallel=parallel, - ) + datamodule.set_chunk_size(batch_size, batch_chunk_count) + datamodule.parallel = parallel datamodule.prepare_data() datamodule.setup(stage="fit") reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid) - normalize_data(normalization, model, datamodule) + model.set_normalizer(datamodule.unaugmented_train_dataloader()) arguments = {} arguments["max_epoch"] = epochs