Skip to content
Snippets Groups Projects
Commit 4ccba39c authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Functional densenet model

parent 2baa8e0b
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
...@@ -6,20 +6,30 @@ ...@@ -6,20 +6,30 @@
from torch import empty from torch import empty
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.densenet import Densenet from ...models.densenet import Densenet
# config
optimizer_configs = {"lr": 0.0001}
# optimizer # optimizer
optimizer = "Adam" optimizer = Adam
optimizer_configs = {"lr": 0.0001}
# criterion # criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model # model
model = Densenet( model = Densenet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=False,
augmentation_transforms=augmentation_transforms,
) )
...@@ -2,81 +2,40 @@ ...@@ -2,81 +2,40 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for TB detection (cross validation fold 0, RGB) """Shenzhen datamodule for computer-aided diagnosis (default protocol)
* Split reference: first 80% of TB and healthy CXR for "train", rest for "test" See :py:mod:`ptbench.data.shenzhen` for dataset details.
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.shenzhen` for dataset details
"""
from clapper.logging import setup
from ....data import return_subsets
from ....data.base_datamodule import BaseDataModule
from ....data.dataset import JSONProtocol
from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") This configuration:
* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms`
* augmentations: elastic deformation (probability = 80%)
class DefaultModule(BaseDataModule): * output image resolution: 512x512 pixels
def __init__( """
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
cache_samples=False,
multiproc_kwargs=None,
data_transforms=[],
model_transforms=[],
train_transforms=[],
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
self.cache_samples = cache_samples
self.has_setup_fit = False
self.data_transforms = data_transforms import importlib.resources
self.model_transforms = model_transforms
self.train_transforms = train_transforms
"""[ from torchvision import transforms
transforms.ToPILImage(),
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
]"""
def setup(self, stage: str): from ..datamodule import CachingDataModule
if self.cache_samples: from ..split import JSONDatabaseSplit
logger.info( from .raw_data_loader import raw_data_loader
"Argument cache_samples set to True. Samples will be loaded in memory."
)
samples_loader = _cached_loader
else:
logger.info(
"Argument cache_samples set to False. Samples will be loaded at runtime."
)
samples_loader = _delayed_loader
self.json_protocol = JSONProtocol( datamodule = CachingDataModule(
protocols=_protocols, database_split=JSONDatabaseSplit(
fieldnames=("data", "label"), importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
loader=samples_loader, "default.json.bz2"
post_transforms=self.post_transforms,
) )
),
if not self.has_setup_fit and stage == "fit": raw_data_loader=raw_data_loader,
( cache_samples=False,
self.train_dataset, # train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
self.validation_dataset, model_transforms=[
self.extra_validation_datasets, transforms.ToPILImage(),
) = return_subsets(self.json_protocol, "default", stage) transforms.Lambda(lambda x: x.convert("RGB")),
self.has_setup_fit = True transforms.ToTensor(),
],
# batch_size = 1,
datamodule = DefaultModule # batch_chunk_count = 1,
# drop_incomplete_batch = False,
# parallel = -1,
)
...@@ -8,6 +8,7 @@ import lightning.pytorch as pl ...@@ -8,6 +8,7 @@ import lightning.pytorch as pl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision.models as models import torchvision.models as models
import torchvision.transforms
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -20,25 +21,37 @@ class Densenet(pl.LightningModule): ...@@ -20,25 +21,37 @@ class Densenet(pl.LightningModule):
def __init__( def __init__(
self, self,
criterion, criterion=None,
criterion_valid, criterion_valid=None,
optimizer, optimizer=None,
optimizer_configs, optimizer_configs=None,
pretrained=False, pretrained=False,
nb_channels=3, augmentation_transforms=[],
): ):
super().__init__() super().__init__()
# Saves all hyper parameters declared on __init__ into ``self.hparams`.
# You can access those by their name, like `self.hparams.optimizer`
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.name = "Densenet" self.name = "Densenet"
self.augmentation_transforms = torchvision.transforms.Compose(
augmentation_transforms
)
self.criterion = criterion
self.criterion_valid = criterion_valid
self.optimizer = optimizer
self.optimizer_configs = optimizer_configs
self.normalizer = None self.normalizer = None
self.pretrained = pretrained
# Load pretrained model # Load pretrained model
weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT if not pretrained:
weights = None
else:
logger.info("Loading pretrained model weights")
weights = models.DenseNet121_Weights.DEFAULT
self.model_ft = models.densenet121(weights=weights) self.model_ft = models.densenet121(weights=weights)
# Adapt output features # Adapt output features
...@@ -52,17 +65,24 @@ class Densenet(pl.LightningModule): ...@@ -52,17 +65,24 @@ class Densenet(pl.LightningModule):
return x return x
def set_normalizer(self, dataloader): def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""TODO: Write this function to set the Normalizer """Initializes the normalizer for the current model.
This function is NOOP if ``pretrained = True`` (normalizer set to This function is NOOP if ``pretrained = True`` (normalizer set to
imagenet weights, during contruction). imagenet weights, during contruction).
Parameters
----------
dataloader: :py:class:`torch.utils.data.DataLoader`
A torch Dataloader from which to compute the mean and std.
Will not be used if the model is pretrained.
""" """
if self.pretrained: if self.pretrained:
from .normalizer import make_imagenet_normalizer from .normalizer import make_imagenet_normalizer
logger.warning( logger.warning(
"ImageNet pre-trained densenet model - NOT" "ImageNet pre-trained densenet model - NOT "
"computing z-norm factors from training data. " "computing z-norm factors from training data. "
"Using preset factors from torchvision." "Using preset factors from torchvision."
) )
...@@ -76,9 +96,38 @@ class Densenet(pl.LightningModule): ...@@ -76,9 +96,38 @@ class Densenet(pl.LightningModule):
) )
self.normalizer = make_z_normalizer(dataloader) self.normalizer = make_z_normalizer(dataloader)
def set_bce_loss_weights(self, datamodule):
"""Reweights loss weights if BCEWithLogitsLoss is used.
Parameters
----------
datamodule:
A datamodule implementing train_dataloader() and val_dataloader()
"""
from ..data.dataset import _get_positive_weights
if isinstance(self.criterion, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss training criterion.")
train_positive_weights = _get_positive_weights(
datamodule.train_dataloader()
)
self.criterion = torch.nn.BCEWithLogitsLoss(
pos_weight=train_positive_weights
)
if isinstance(self.criterion_valid, torch.nn.BCEWithLogitsLoss):
logger.info("Reweighting BCEWithLogitsLoss validation criterion.")
validation_positive_weights = _get_positive_weights(
datamodule.val_dataloader()["validation"]
)
self.criterion_valid = torch.nn.BCEWithLogitsLoss(
pos_weight=validation_positive_weights
)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[1] images = batch[0]
labels = batch[2] labels = batch[1]["label"]
# Increase label dimension if too low # Increase label dimension if too low
# Allows single and multiclass usage # Allows single and multiclass usage
...@@ -86,17 +135,20 @@ class Densenet(pl.LightningModule): ...@@ -86,17 +135,20 @@ class Densenet(pl.LightningModule):
labels = torch.reshape(labels, (labels.shape[0], 1)) labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network # Forward pass on the network
outputs = self(images) augmented_images = [
self.augmentation_transforms(img).to(self.device) for img in images
]
# Combine list of augmented images back into a tensor
augmented_images = torch.cat(augmented_images, 0).view(images.shape)
outputs = self(augmented_images)
# Manually move criterion to selected device, since not part of the model. training_loss = self.criterion(outputs, labels.float())
self.hparams.criterion = self.hparams.criterion.to(self.device)
training_loss = self.hparams.criterion(outputs, labels.float())
return {"loss": training_loss} return {"loss": training_loss}
def validation_step(self, batch, batch_idx, dataloader_idx=0): def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[1] images = batch[0]
labels = batch[2] labels = batch[1]["label"]
# Increase label dimension if too low # Increase label dimension if too low
# Allows single and multiclass usage # Allows single and multiclass usage
...@@ -106,11 +158,7 @@ class Densenet(pl.LightningModule): ...@@ -106,11 +158,7 @@ class Densenet(pl.LightningModule):
# data forwarding on the existing network # data forwarding on the existing network
outputs = self(images) outputs = self(images)
# Manually move criterion to selected device, since not part of the model. validation_loss = self.criterion_valid(outputs, labels.float())
self.hparams.criterion_valid = self.hparams.criterion_valid.to(
self.device
)
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
if dataloader_idx == 0: if dataloader_idx == 0:
return {"validation_loss": validation_loss} return {"validation_loss": validation_loss}
...@@ -118,8 +166,9 @@ class Densenet(pl.LightningModule): ...@@ -118,8 +166,9 @@ class Densenet(pl.LightningModule):
return {f"extra_validation_loss_{dataloader_idx}": validation_loss} return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch[0] images = batch[0]
images = batch[1] labels = batch[1]["label"]
names = batch[1]["name"]
outputs = self(images) outputs = self(images)
probabilities = torch.sigmoid(outputs) probabilities = torch.sigmoid(outputs)
...@@ -129,12 +178,8 @@ class Densenet(pl.LightningModule): ...@@ -129,12 +178,8 @@ class Densenet(pl.LightningModule):
if isinstance(outputs, list): if isinstance(outputs, list):
outputs = outputs[-1] outputs = outputs[-1]
return names[0], torch.flatten(probabilities), torch.flatten(batch[2]) return names[0], torch.flatten(probabilities), torch.flatten(labels)
def configure_optimizers(self): def configure_optimizers(self):
# Dynamically instantiates the optimizer given the configs optimizer = self.optimizer(self.parameters(), **self.optimizer_configs)
optimizer = getattr(torch.optim, self.hparams.optimizer)(
self.parameters(), **self.hparams.optimizer_configs
)
return optimizer return optimizer
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment