diff --git a/src/ptbench/configs/models/densenet.py b/src/ptbench/configs/models/densenet.py index 6759490854fd05ac8d8d3a9eec5b1494e0cfb0f2..5d612b2a18146ba306b419a808936e9c3c7042f7 100644 --- a/src/ptbench/configs/models/densenet.py +++ b/src/ptbench/configs/models/densenet.py @@ -6,20 +6,30 @@ from torch import empty from torch.nn import BCEWithLogitsLoss +from torch.optim import Adam from ...models.densenet import Densenet -# config -optimizer_configs = {"lr": 0.0001} - # optimizer -optimizer = "Adam" +optimizer = Adam +optimizer_configs = {"lr": 0.0001} # criterion criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) +from ...data.transforms import ElasticDeformation + +augmentation_transforms = [ + ElasticDeformation(p=0.8), +] + # model model = Densenet( - criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False + criterion, + criterion_valid, + optimizer, + optimizer_configs, + pretrained=False, + augmentation_transforms=augmentation_transforms, ) diff --git a/src/ptbench/data/shenzhen/rgb.py b/src/ptbench/data/shenzhen/rgb.py index 2d93c07edb4c0824caa8149737ec42783966ad4c..f45f601d6e84a04e4610319d1f27631ff32ee69c 100644 --- a/src/ptbench/data/shenzhen/rgb.py +++ b/src/ptbench/data/shenzhen/rgb.py @@ -2,81 +2,40 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Shenzhen dataset for TB detection (cross validation fold 0, RGB) +"""Shenzhen datamodule for computer-aided diagnosis (default protocol) -* Split reference: first 80% of TB and healthy CXR for "train", rest for "test" -* This configuration resolution: 512 x 512 (default) -* See :py:mod:`ptbench.data.shenzhen` for dataset details -""" - -from clapper.logging import setup - -from ....data import return_subsets -from ....data.base_datamodule import BaseDataModule -from ....data.dataset import JSONProtocol -from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols +See :py:mod:`ptbench.data.shenzhen` for dataset details. -logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") - - -class DefaultModule(BaseDataModule): - def __init__( - self, - train_batch_size=1, - predict_batch_size=1, - drop_incomplete_batch=False, - cache_samples=False, - multiproc_kwargs=None, - data_transforms=[], - model_transforms=[], - train_transforms=[], - ): - super().__init__( - train_batch_size=train_batch_size, - predict_batch_size=predict_batch_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - ) - - self.cache_samples = cache_samples - self.has_setup_fit = False +This configuration: +* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms` +* augmentations: elastic deformation (probability = 80%) +* output image resolution: 512x512 pixels +""" - self.data_transforms = data_transforms - self.model_transforms = model_transforms - self.train_transforms = train_transforms +import importlib.resources - """[ - transforms.ToPILImage(), - transforms.Lambda(lambda x: x.convert("RGB")), - transforms.ToTensor(), - ]""" +from torchvision import transforms - def setup(self, stage: str): - if self.cache_samples: - logger.info( - "Argument cache_samples set to True. Samples will be loaded in memory." - ) - samples_loader = _cached_loader - else: - logger.info( - "Argument cache_samples set to False. Samples will be loaded at runtime." - ) - samples_loader = _delayed_loader +from ..datamodule import CachingDataModule +from ..split import JSONDatabaseSplit +from .raw_data_loader import raw_data_loader - self.json_protocol = JSONProtocol( - protocols=_protocols, - fieldnames=("data", "label"), - loader=samples_loader, - post_transforms=self.post_transforms, +datamodule = CachingDataModule( + database_split=JSONDatabaseSplit( + importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath( + "default.json.bz2" ) - - if not self.has_setup_fit and stage == "fit": - ( - self.train_dataset, - self.validation_dataset, - self.extra_validation_datasets, - ) = return_subsets(self.json_protocol, "default", stage) - self.has_setup_fit = True - - -datamodule = DefaultModule + ), + raw_data_loader=raw_data_loader, + cache_samples=False, + # train_sampler: typing.Optional[torch.utils.data.Sampler] = None, + model_transforms=[ + transforms.ToPILImage(), + transforms.Lambda(lambda x: x.convert("RGB")), + transforms.ToTensor(), + ], + # batch_size = 1, + # batch_chunk_count = 1, + # drop_incomplete_batch = False, + # parallel = -1, +) diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index e61d7dece53bce91df406048145382b86cfa5f81..ae866e720ba4df0be292bd2ee3f23878a714818a 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -8,6 +8,7 @@ import lightning.pytorch as pl import torch import torch.nn as nn import torchvision.models as models +import torchvision.transforms logger = logging.getLogger(__name__) @@ -20,25 +21,37 @@ class Densenet(pl.LightningModule): def __init__( self, - criterion, - criterion_valid, - optimizer, - optimizer_configs, + criterion=None, + criterion_valid=None, + optimizer=None, + optimizer_configs=None, pretrained=False, - nb_channels=3, + augmentation_transforms=[], ): super().__init__() - # Saves all hyper parameters declared on __init__ into ``self.hparams`. - # You can access those by their name, like `self.hparams.optimizer` - self.save_hyperparameters(ignore=["criterion", "criterion_valid"]) - self.name = "Densenet" + self.augmentation_transforms = torchvision.transforms.Compose( + augmentation_transforms + ) + + self.criterion = criterion + self.criterion_valid = criterion_valid + + self.optimizer = optimizer + self.optimizer_configs = optimizer_configs + self.normalizer = None + self.pretrained = pretrained # Load pretrained model - weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT + if not pretrained: + weights = None + else: + logger.info("Loading pretrained model weights") + weights = models.DenseNet121_Weights.DEFAULT + self.model_ft = models.densenet121(weights=weights) # Adapt output features @@ -52,17 +65,24 @@ class Densenet(pl.LightningModule): return x - def set_normalizer(self, dataloader): - """TODO: Write this function to set the Normalizer + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: + """Initializes the normalizer for the current model. This function is NOOP if ``pretrained = True`` (normalizer set to imagenet weights, during contruction). + + Parameters + ---------- + + dataloader: :py:class:`torch.utils.data.DataLoader` + A torch Dataloader from which to compute the mean and std. + Will not be used if the model is pretrained. """ if self.pretrained: from .normalizer import make_imagenet_normalizer logger.warning( - "ImageNet pre-trained densenet model - NOT" + "ImageNet pre-trained densenet model - NOT " "computing z-norm factors from training data. " "Using preset factors from torchvision." ) @@ -76,9 +96,38 @@ class Densenet(pl.LightningModule): ) self.normalizer = make_z_normalizer(dataloader) + def set_bce_loss_weights(self, datamodule): + """Reweights loss weights if BCEWithLogitsLoss is used. + + Parameters + ---------- + + datamodule: + A datamodule implementing train_dataloader() and val_dataloader() + """ + from ..data.dataset import _get_positive_weights + + if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss training criterion.") + train_positive_weights = _get_positive_weights( + datamodule.train_dataloader() + ) + self.criterion = torch.nn.BCEWithLogitsLoss( + pos_weight=train_positive_weights + ) + + if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss): + logger.info("Reweighting BCEWithLogitsLoss validation criterion.") + validation_positive_weights = _get_positive_weights( + datamodule.val_dataloader()["validation"] + ) + self.criterion_valid = torch.nn.BCEWithLogitsLoss( + pos_weight=validation_positive_weights + ) + def training_step(self, batch, batch_idx): - images = batch[1] - labels = batch[2] + images = batch[0] + labels = batch[1]["label"] # Increase label dimension if too low # Allows single and multiclass usage @@ -86,17 +135,20 @@ class Densenet(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - outputs = self(images) + augmented_images = [ + self.augmentation_transforms(img).to(self.device) for img in images + ] + # Combine list of augmented images back into a tensor + augmented_images = torch.cat(augmented_images, 0).view(images.shape) + outputs = self(augmented_images) - # Manually move criterion to selected device, since not part of the model. - self.hparams.criterion = self.hparams.criterion.to(self.device) - training_loss = self.hparams.criterion(outputs, labels.float()) + training_loss = self.criterion(outputs, labels.float()) return {"loss": training_loss} def validation_step(self, batch, batch_idx, dataloader_idx=0): - images = batch[1] - labels = batch[2] + images = batch[0] + labels = batch[1]["label"] # Increase label dimension if too low # Allows single and multiclass usage @@ -106,11 +158,7 @@ class Densenet(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - # Manually move criterion to selected device, since not part of the model. - self.hparams.criterion_valid = self.hparams.criterion_valid.to( - self.device - ) - validation_loss = self.hparams.criterion_valid(outputs, labels.float()) + validation_loss = self.criterion_valid(outputs, labels.float()) if dataloader_idx == 0: return {"validation_loss": validation_loss} @@ -118,8 +166,9 @@ class Densenet(pl.LightningModule): return {f"extra_validation_loss_{dataloader_idx}": validation_loss} def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): - names = batch[0] - images = batch[1] + images = batch[0] + labels = batch[1]["label"] + names = batch[1]["name"] outputs = self(images) probabilities = torch.sigmoid(outputs) @@ -129,12 +178,8 @@ class Densenet(pl.LightningModule): if isinstance(outputs, list): outputs = outputs[-1] - return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) + return names[0], torch.flatten(probabilities), torch.flatten(labels) def configure_optimizers(self): - # Dynamically instantiates the optimizer given the configs - optimizer = getattr(torch.optim, self.hparams.optimizer)( - self.parameters(), **self.hparams.optimizer_configs - ) - + optimizer = self.optimizer(self.parameters(), **self.optimizer_configs) return optimizer