Skip to content
Snippets Groups Projects
Commit 8769c327 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[model] Use a single type of loss for train and validation

parent 28a261b0
No related branches found
No related tags found
1 merge request!38Replace sampler balancing by loss balancing
......@@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.alexnet import Alexnet
model = Alexnet(
train_loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
loss_type=BCEWithLogitsLoss,
optimizer_type=SGD,
optimizer_arguments=dict(lr=0.01, momentum=0.1),
augmentation_transforms=[ElasticDeformation(p=0.8)],
......
......@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.alexnet import Alexnet
model = Alexnet(
train_loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
loss_type=BCEWithLogitsLoss,
optimizer_type=SGD,
optimizer_arguments=dict(lr=0.01, momentum=0.1),
augmentation_transforms=[ElasticDeformation(p=0.8)],
......
......@@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet
model = Densenet(
train_loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)],
......
......@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet
model = Densenet(
train_loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)],
......
......@@ -16,8 +16,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet
model = Densenet(
train_loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)],
......
......@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.pasa import Pasa
model = Pasa(
train_loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
loss_type=BCEWithLogitsLoss,
optimizer_type=Adam,
optimizer_arguments=dict(lr=8e-5),
augmentation_transforms=[ElasticDeformation(p=0.8)],
......
......@@ -27,26 +27,15 @@ class Alexnet(Model):
Parameters
----------
train_loss_type
The loss to be used during the training.
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
train_loss_arguments
Arguments to the training loss.
validation_loss_type
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss_arguments
Arguments to the validation loss.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
......@@ -63,10 +52,8 @@ class Alexnet(Model):
def __init__(
self,
train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
train_loss_arguments: dict[str, typing.Any] = {},
validation_loss_type: torch.nn.Module | None = None,
validation_loss_arguments: dict[str, typing.Any] = {},
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
......@@ -74,10 +61,8 @@ class Alexnet(Model):
num_classes: int = 1,
):
super().__init__(
train_loss_type,
train_loss_arguments,
validation_loss_type,
validation_loss_arguments,
loss_type,
loss_arguments,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
......
......@@ -25,26 +25,15 @@ class Densenet(Model):
Parameters
----------
train_loss_type
The loss to be used during the training.
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
train_loss_arguments
Arguments to the training loss.
validation_loss_type
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss_arguments
Arguments to the validation loss.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
......@@ -63,10 +52,8 @@ class Densenet(Model):
def __init__(
self,
train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
train_loss_arguments: dict[str, typing.Any] = {},
validation_loss_type: torch.nn.Module | None = None,
validation_loss_arguments: dict[str, typing.Any] = {},
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
......@@ -75,10 +62,8 @@ class Densenet(Model):
num_classes: int = 1,
):
super().__init__(
train_loss_type,
train_loss_arguments,
validation_loss_type,
validation_loss_arguments,
loss_type,
loss_arguments,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
......
......@@ -24,26 +24,15 @@ class Model(pl.LightningModule):
Parameters
----------
train_loss_type
The loss to be used during the training.
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
train_loss_arguments
Arguments to the training loss.
validation_loss_type
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss_arguments
Arguments to the validation loss.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
......@@ -57,10 +46,8 @@ class Model(pl.LightningModule):
def __init__(
self,
train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
train_loss_arguments: dict[str, typing.Any] = {},
validation_loss_type: torch.nn.Module | None = None,
validation_loss_arguments: dict[str, typing.Any] = {},
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
......@@ -73,13 +60,13 @@ class Model(pl.LightningModule):
self.model_transforms: TransformSequence = []
self._train_loss_type = train_loss_type
self._train_loss_arguments = train_loss_arguments
self._loss_type = loss_type
self._train_loss = None
self._train_loss_arguments = loss_arguments
self._validation_loss_type = validation_loss_type
self._validation_loss_arguments = validation_loss_arguments
self.validation_loss = None
self._validation_loss_arguments = loss_arguments
self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments
......@@ -148,8 +135,8 @@ class Model(pl.LightningModule):
raise NotImplementedError
def configure_losses(self):
self._train_loss = self._train_loss_type(**self._train_loss_arguments)
self._validation_loss = self._validation_loss_type(
self._train_loss = self._loss_type(**self._train_loss_arguments)
self._validation_loss = self._loss_type(
**self._validation_loss_arguments
)
......@@ -160,7 +147,7 @@ class Model(pl.LightningModule):
)
def balance_losses(self, datamodule) -> None:
"""Balance the loss based on the distribution of targets in the datamodule, if the loss function supports it.
"""Balance the loss based on the distribution of targets in the datamodule, if the loss supports it (contains a 'pos_weight' attribute).
Parameters
----------
......@@ -168,29 +155,18 @@ class Model(pl.LightningModule):
Instance of a datamodule.
"""
logger.info(
f"Balancing training loss function {self._train_loss_type}."
)
try:
getattr(self._train_loss_type(), "pos_weight")
getattr(self._loss_type(), "pos_weight")
except AttributeError:
logger.warning(
"Training loss does not posess a 'pos_weight' attribute and will not be balanced."
f"Loss {self._loss_type} does not posess a 'pos_weight' attribute and will not be balanced."
)
else:
logger.info(f"Balancing training loss {self._loss_type}.")
train_weights = get_positive_weights(datamodule.train_dataloader())
self._train_loss_arguments["pos_weight"] = train_weights
logger.info(
f"Balancing validation loss function {self._validation_loss_type}."
)
try:
getattr(self._validation_loss_type(), "pos_weight")
except AttributeError:
logger.warning(
"Validation loss does not posess a 'pos_weight' attribute and will not be balanced."
)
else:
logger.info(f"Balancing validation loss {self._loss_type}.")
validation_weights = get_positive_weights(
datamodule.val_dataloader()["validation"]
)
......
......@@ -30,26 +30,15 @@ class Pasa(Model):
Parameters
----------
train_loss_type
The loss to be used during the training.
loss_type
The loss to be used for training and evaluation.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
train_loss_arguments
Arguments to the training loss.
validation_loss_type
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss_arguments
Arguments to the validation loss.
loss_arguments
Arguments to the loss.
optimizer_type
The type of optimizer to use for training.
optimizer_arguments
......@@ -63,20 +52,16 @@ class Pasa(Model):
def __init__(
self,
train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
train_loss_arguments: dict[str, typing.Any] = {},
validation_loss_type: torch.nn.Module | None = None,
validation_loss_arguments: dict[str, typing.Any] = {},
loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [],
num_classes: int = 1,
):
super().__init__(
train_loss_type,
train_loss_arguments,
validation_loss_type,
validation_loss_arguments,
loss_type,
loss_arguments,
optimizer_type,
optimizer_arguments,
augmentation_transforms,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment