Skip to content
Snippets Groups Projects
Commit 6290c1bd authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Moved densenet to lightning

parent a61efdb4
No related branches found
No related tags found
1 merge request!4Moved code to lightning
......@@ -4,19 +4,22 @@
"""DenseNet."""
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.densenet import build_densenet
from ...models.densenet import Densenet
# config
lr = 0.0001
# model
model = build_densenet(pretrained=False)
optimizer_configs = {"lr": 0.0001}
# optimizer
optimizer = Adam(model.parameters(), lr=lr)
optimizer = "Adam"
# criterion
criterion = BCEWithLogitsLoss()
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
# model
model = Densenet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False
)
......@@ -2,23 +2,40 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
from collections import OrderedDict
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision.models as models
from .normalizer import TorchVisionNormalizer
class Densenet(nn.Module):
class Densenet(pl.LightningModule):
"""Densenet module.
Note: only usable with a normalized dataset
"""
def __init__(self, pretrained=False):
def __init__(
self,
criterion,
criterion_valid,
optimizer,
optimizer_params,
pretrained=False,
nb_channels=3,
):
super().__init__()
self.save_hyperparameters()
self.name = "Densenet"
self.criterion = criterion
self.criterion_valid = criterion_valid
self.normalizer = TorchVisionNormalizer(nb_channels=nb_channels)
# Load pretrained model
weights = None if not pretrained else models.DenseNet121_Weights.DEFAULT
self.model_ft = models.densenet121(weights=weights)
......@@ -43,23 +60,48 @@ class Densenet(nn.Module):
tensor : :py:class:`torch.Tensor`
"""
return self.model_ft(x)
x = self.normalizer(x)
def build_densenet(pretrained=False, nb_channels=3):
"""Build Densenet CNN.
x = self.model_ft(x)
Returns
-------
return x
module : :py:class:`torch.nn.Module`
"""
model = Densenet(pretrained=pretrained)
model = [
("normalizer", TorchVisionNormalizer(nb_channels=nb_channels)),
("model", model),
]
model = nn.Sequential(OrderedDict(model))
model.name = "Densenet"
return model
def training_step(self, batch, batch_idx):
images = batch[1]
labels = batch[2]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network
outputs = self(images)
training_loss = self.criterion(outputs, labels.double())
return {"loss": training_loss}
def validation_step(self, batch, batch_idx):
images = batch[1]
labels = batch[2]
# Increase label dimension if too low
# Allows single and multiclass usage
if labels.ndim == 1:
labels = torch.reshape(labels, (labels.shape[0], 1))
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.criterion_valid(outputs, labels.double())
return {"validation_loss": validation_loss}
def configure_optimizers(self):
# Dynamically instantiates the optimizer given the configs
optimizer = getattr(torch.optim, self.hparams.optimizer)(
self.parameters(), **self.hparams.optimizer_params
)
return optimizer
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment