Skip to content
Snippets Groups Projects
Commit ea69e431 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[model] Model takes loss type and arguments during instanciation

parent 1f0e3938
No related branches found
No related tags found
1 merge request!38Replace sampler balancing by loss balancing
......@@ -15,8 +15,8 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.alexnet import Alexnet
model = Alexnet(
train_loss=BCEWithLogitsLoss(),
validation_loss=BCEWithLogitsLoss(),
train_loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
optimizer_type=SGD,
optimizer_arguments=dict(lr=0.01, momentum=0.1),
augmentation_transforms=[ElasticDeformation(p=0.8)],
......
......@@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.alexnet import Alexnet
model = Alexnet(
train_loss=BCEWithLogitsLoss(),
validation_loss=BCEWithLogitsLoss(),
train_loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
optimizer_type=SGD,
optimizer_arguments=dict(lr=0.01, momentum=0.1),
augmentation_transforms=[ElasticDeformation(p=0.8)],
......
......@@ -15,8 +15,8 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet
model = Densenet(
train_loss=BCEWithLogitsLoss(),
validation_loss=BCEWithLogitsLoss(),
train_loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)],
......
......@@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet
model = Densenet(
train_loss=BCEWithLogitsLoss(),
validation_loss=BCEWithLogitsLoss(),
train_loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)],
......
......@@ -16,8 +16,8 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet
model = Densenet(
train_loss=BCEWithLogitsLoss(),
validation_loss=BCEWithLogitsLoss(),
train_loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)],
......
......@@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.pasa import Pasa
model = Pasa(
train_loss=BCEWithLogitsLoss(),
validation_loss=BCEWithLogitsLoss(),
train_loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=8e-5),
augmentation_transforms=[ElasticDeformation(p=0.8)],
......
......@@ -72,6 +72,8 @@ def run(
output_folder.mkdir(parents=True, exist_ok=True)
model.configure_losses()
from .loggers import CustomTensorboardLogger
log_dir = "logs"
......
......@@ -27,14 +27,16 @@ class Alexnet(Model):
Parameters
----------
train_loss
train_loss_type
The loss to be used during the training.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss
train_loss_arguments
Arguments to the training loss.
validation_loss_type
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
......@@ -43,6 +45,8 @@ class Alexnet(Model):
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss_arguments
Arguments to the validation loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
......@@ -59,8 +63,10 @@ class Alexnet(Model):
def __init__(
self,
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
validation_loss: torch.nn.Module | None = None,
train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
train_loss_arguments: dict[str, typing.Any] = {},
validation_loss_type: torch.nn.Module | None = None,
validation_loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
......@@ -68,8 +74,10 @@ class Alexnet(Model):
num_classes: int = 1,
):
super().__init__(
train_loss,
validation_loss,
train_loss_type,
train_loss_arguments,
validation_loss_type,
validation_loss_arguments,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
......@@ -166,7 +174,7 @@ class Alexnet(Model):
# data forwarding on the existing network
outputs = self(images)
return self._validation_loss[dataloader_idx](outputs, labels.float())
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
......
......@@ -25,14 +25,16 @@ class Densenet(Model):
Parameters
----------
train_loss
train_loss_type
The loss to be used during the training.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss
train_loss_arguments
Arguments to the training loss.
validation_loss_type
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
......@@ -41,6 +43,8 @@ class Densenet(Model):
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss_arguments
Arguments to the validation loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
......@@ -59,8 +63,10 @@ class Densenet(Model):
def __init__(
self,
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
validation_loss: torch.nn.Module | None = None,
train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
train_loss_arguments: dict[str, typing.Any] = {},
validation_loss_type: torch.nn.Module | None = None,
validation_loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
......@@ -69,8 +75,10 @@ class Densenet(Model):
num_classes: int = 1,
):
super().__init__(
train_loss,
validation_loss,
train_loss_type,
train_loss_arguments,
validation_loss_type,
validation_loss_arguments,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
......@@ -164,7 +172,7 @@ class Densenet(Model):
# data forwarding on the existing network
outputs = self(images)
return self._validation_loss[dataloader_idx](outputs, labels.float())
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
......
......@@ -24,14 +24,16 @@ class Model(pl.LightningModule):
Parameters
----------
train_loss
train_loss_type
The loss to be used during the training.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss
train_loss_arguments
Arguments to the training loss.
validation_loss_type
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
......@@ -40,6 +42,8 @@ class Model(pl.LightningModule):
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss_arguments
Arguments to the validation loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
......@@ -53,8 +57,10 @@ class Model(pl.LightningModule):
def __init__(
self,
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
validation_loss: torch.nn.Module | None = None,
train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
train_loss_arguments: dict[str, typing.Any] = {},
validation_loss_type: torch.nn.Module | None = None,
validation_loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
......@@ -67,10 +73,13 @@ class Model(pl.LightningModule):
self.model_transforms: TransformSequence = []
self._train_loss = train_loss
self._validation_loss = [
(validation_loss if validation_loss is not None else train_loss)
]
self._train_loss_type = train_loss_type
self._train_loss_arguments = train_loss_arguments
self._train_loss = None
self._validation_loss_type = validation_loss_type
self._validation_loss_arguments = validation_loss_arguments
self.validation_loss = None
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
......@@ -138,6 +147,12 @@ class Model(pl.LightningModule):
def predict_step(self, batch, batch_idx, dataloader_idx=0):
raise NotImplementedError
def configure_losses(self):
self._train_loss = self._train_loss_type(**self._train_loss_arguments)
self._validation_loss = self._validation_loss_type(
**self._validation_loss_arguments
)
def configure_optimizers(self):
return self._optimizer_type(
self.parameters(),
......@@ -153,44 +168,30 @@ class Model(pl.LightningModule):
Instance of a datamodule.
"""
logger.info(f"Balancing training loss function {self._train_loss}.")
logger.info(
f"Balancing training loss function {self._train_loss_type}."
)
try:
getattr(self._train_loss, "pos_weight")
getattr(self._train_loss_type(), "pos_weight")
except AttributeError:
logger.warning(
"Training loss does not posess a 'pos_weight' attribute and will not be balanced."
)
else:
train_weights = get_positive_weights(datamodule.train_dataloader())
setattr(self._train_loss, "pos_weight", train_weights)
self._train_loss_arguments["pos_weight"] = train_weights
logger.info(
f"Balancing validation loss function {self._validation_loss[0]}."
f"Balancing validation loss function {self._validation_loss_type}."
)
try:
getattr(self._validation_loss[0], "pos_weight")
getattr(self._validation_loss_type(), "pos_weight")
except AttributeError:
logger.warning(
"Validation loss does not posess a 'pos_weight' attribute and will not be balanced."
)
else:
# If multiple validation DataLoaders are used, each one will need to have a loss
# that is balanced for that DataLoader
new_validation_losses = []
loss_class = self._validation_loss[0].__class__
datamodule_validation_keys = datamodule.val_dataset_keys()
logger.info(
f"Found {len(datamodule_validation_keys)} keys in the validation datamodule. A balanced loss will be created for each key."
validation_weights = get_positive_weights(
datamodule.val_dataloader()["validation"]
)
for val_dataset_key in datamodule_validation_keys:
validation_weights = get_positive_weights(
datamodule.val_dataloader()[val_dataset_key]
)
new_validation_losses.append(
loss_class(pos_weight=validation_weights)
)
self._validation_loss = new_validation_losses
self._validation_loss_arguments["pos_weight"] = validation_weights
......@@ -30,14 +30,16 @@ class Pasa(Model):
Parameters
----------
train_loss
train_loss_type
The loss to be used during the training.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss
train_loss_arguments
Arguments to the training loss.
validation_loss_type
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
......@@ -46,6 +48,8 @@ class Pasa(Model):
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss_arguments
Arguments to the validation loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
......@@ -59,16 +63,20 @@ class Pasa(Model):
def __init__(
self,
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(),
validation_loss: torch.nn.Module | None = None,
train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
train_loss_arguments: dict[str, typing.Any] = {},
validation_loss_type: torch.nn.Module | None = None,
validation_loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
):
super().__init__(
train_loss,
validation_loss,
train_loss_type,
train_loss_arguments,
validation_loss_type,
validation_loss_arguments,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
......@@ -233,7 +241,7 @@ class Pasa(Model):
# data forwarding on the existing network
outputs = self(images)
return self._validation_loss[dataloader_idx](outputs, labels.float())
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment