Skip to content
Snippets Groups Projects
Commit 51e2e13e authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models.densenet_rs] Remove outdated module (functionality is now incorporated...

[models.densenet_rs] Remove outdated module (functionality is now incorporated at stock densenet model)
parent 262e52a7
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
Pipeline #76712 canceled
...@@ -72,7 +72,6 @@ ptbench = "ptbench.scripts.cli:cli" ...@@ -72,7 +72,6 @@ ptbench = "ptbench.scripts.cli:cli"
pasa = "ptbench.configs.models.pasa" pasa = "ptbench.configs.models.pasa"
signs-to-tb = "ptbench.configs.models.signs_to_tb" signs-to-tb = "ptbench.configs.models.signs_to_tb"
logistic-regression = "ptbench.configs.models.logistic_regression" logistic-regression = "ptbench.configs.models.logistic_regression"
densenet-rs = "ptbench.configs.models_datasets.densenet_rs"
alexnet = "ptbench.configs.models.alexnet" alexnet = "ptbench.configs.models.alexnet"
alexnet-pretrained = "ptbench.configs.models.alexnet_pretrained" alexnet-pretrained = "ptbench.configs.models.alexnet_pretrained"
densenet = "ptbench.configs.models.densenet" densenet = "ptbench.configs.models.densenet"
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torchvision.models as models
from .normalizer import TorchVisionNormalizer
class DensenetRS(pl.LightningModule):
"""Densenet121 module for radiological extraction."""
def __init__(
self,
criterion,
criterion_valid,
optimizer,
optimizer_configs,
):
super().__init__()
self.save_hyperparameters(ignore=["criterion", "criterion_valid"])
self.name = "DensenetRS"
self.normalizer = TorchVisionNormalizer()
# Load pretrained model
self.model_ft = models.densenet121(
weights=models.DenseNet121_Weights.DEFAULT
)
# Adapt output features
num_ftrs = self.model_ft.classifier.in_features
self.model_ft.classifier = nn.Linear(num_ftrs, 14)
def forward(self, x):
x = self.normalizer(x)
x = self.model_ft(x)
return x
def training_step(self, batch, batch_idx):
images = batch[1]
labels = batch[2]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion = self.hparams.criterion.to(self.device)
training_loss = self.hparams.criterion(outputs, labels.float())
return {"loss": training_loss}
def validation_step(self, batch, batch_idx, dataloader_idx=0):
images = batch[1]
labels = batch[2]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# data forwarding on the existing network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion_valid = self.hparams.criterion_valid.to(
self.device
)
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
if dataloader_idx == 0:
return {"validation_loss": validation_loss}
else:
return {f"extra_validation_loss_{dataloader_idx}": validation_loss}
def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False):
names = batch[0]
images = batch[1]
outputs = self(images)
probabilities = torch.sigmoid(outputs)
# necessary check for HED architecture that uses several outputs
# for loss calculation instead of just the last concatfuse block
if isinstance(outputs, list):
outputs = outputs[-1]
return names[0], torch.flatten(probabilities), torch.flatten(batch[2])
def configure_optimizers(self):
# Dynamically instantiates the optimizer given the configs
optimizer = getattr(torch.optim, self.hparams.optimizer)(
filter(lambda p: p.requires_grad, self.model_ft.parameters()),
**self.hparams.optimizer_configs,
)
return optimizer
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment