Skip to content
Snippets Groups Projects
Commit 8769c327 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[model] Use a single type of loss for train and validation

parent 28a261b0
No related branches found
No related tags found
1 merge request!38Replace sampler balancing by loss balancing
...@@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation ...@@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.alexnet import Alexnet from mednet.models.alexnet import Alexnet
model = Alexnet( model = Alexnet(
train_loss_type=BCEWithLogitsLoss, loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
optimizer_type=SGD, optimizer_type=SGD,
optimizer_arguments=dict(lr=0.01, momentum=0.1), optimizer_arguments=dict(lr=0.01, momentum=0.1),
augmentation_transforms=[ElasticDeformation(p=0.8)], augmentation_transforms=[ElasticDeformation(p=0.8)],
......
...@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation ...@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.alexnet import Alexnet from mednet.models.alexnet import Alexnet
model = Alexnet( model = Alexnet(
train_loss_type=BCEWithLogitsLoss, loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
optimizer_type=SGD, optimizer_type=SGD,
optimizer_arguments=dict(lr=0.01, momentum=0.1), optimizer_arguments=dict(lr=0.01, momentum=0.1),
augmentation_transforms=[ElasticDeformation(p=0.8)], augmentation_transforms=[ElasticDeformation(p=0.8)],
......
...@@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation ...@@ -15,8 +15,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet from mednet.models.densenet import Densenet
model = Densenet( model = Densenet(
train_loss_type=BCEWithLogitsLoss, loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
optimizer_type=Adam, optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001), optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)], augmentation_transforms=[ElasticDeformation(p=0.2)],
......
...@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation ...@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet from mednet.models.densenet import Densenet
model = Densenet( model = Densenet(
train_loss_type=BCEWithLogitsLoss, loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
optimizer_type=Adam, optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001), optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)], augmentation_transforms=[ElasticDeformation(p=0.2)],
......
...@@ -16,8 +16,7 @@ from mednet.data.augmentations import ElasticDeformation ...@@ -16,8 +16,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet from mednet.models.densenet import Densenet
model = Densenet( model = Densenet(
train_loss_type=BCEWithLogitsLoss, loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
optimizer_type=Adam, optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001), optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)], augmentation_transforms=[ElasticDeformation(p=0.2)],
......
...@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation ...@@ -17,8 +17,7 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.pasa import Pasa from mednet.models.pasa import Pasa
model = Pasa( model = Pasa(
train_loss_type=BCEWithLogitsLoss, loss_type=BCEWithLogitsLoss,
validation_loss_type=BCEWithLogitsLoss,
optimizer_type=Adam, optimizer_type=Adam,
optimizer_arguments=dict(lr=8e-5), optimizer_arguments=dict(lr=8e-5),
augmentation_transforms=[ElasticDeformation(p=0.8)], augmentation_transforms=[ElasticDeformation(p=0.8)],
......
...@@ -27,26 +27,15 @@ class Alexnet(Model): ...@@ -27,26 +27,15 @@ class Alexnet(Model):
Parameters Parameters
---------- ----------
train_loss_type loss_type
The loss to be used during the training. The loss to be used for training and evaluation.
.. warning:: .. warning::
The loss should be set to always return batch averages (as opposed The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so. to the batch sum), as our logging system expects it so.
train_loss_arguments loss_arguments
Arguments to the training loss. Arguments to the loss.
validation_loss_type
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss_arguments
Arguments to the validation loss.
optimizer_type optimizer_type
The type of optimizer to use for training. The type of optimizer to use for training.
optimizer_arguments optimizer_arguments
...@@ -63,10 +52,8 @@ class Alexnet(Model): ...@@ -63,10 +52,8 @@ class Alexnet(Model):
def __init__( def __init__(
self, self,
train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
train_loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
validation_loss_type: torch.nn.Module | None = None,
validation_loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -74,10 +61,8 @@ class Alexnet(Model): ...@@ -74,10 +61,8 @@ class Alexnet(Model):
num_classes: int = 1, num_classes: int = 1,
): ):
super().__init__( super().__init__(
train_loss_type, loss_type,
train_loss_arguments, loss_arguments,
validation_loss_type,
validation_loss_arguments,
optimizer_type, optimizer_type,
optimizer_arguments, optimizer_arguments,
augmentation_transforms, augmentation_transforms,
......
...@@ -25,26 +25,15 @@ class Densenet(Model): ...@@ -25,26 +25,15 @@ class Densenet(Model):
Parameters Parameters
---------- ----------
train_loss_type loss_type
The loss to be used during the training. The loss to be used for training and evaluation.
.. warning:: .. warning::
The loss should be set to always return batch averages (as opposed The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so. to the batch sum), as our logging system expects it so.
train_loss_arguments loss_arguments
Arguments to the training loss. Arguments to the loss.
validation_loss_type
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss_arguments
Arguments to the validation loss.
optimizer_type optimizer_type
The type of optimizer to use for training. The type of optimizer to use for training.
optimizer_arguments optimizer_arguments
...@@ -63,10 +52,8 @@ class Densenet(Model): ...@@ -63,10 +52,8 @@ class Densenet(Model):
def __init__( def __init__(
self, self,
train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
train_loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
validation_loss_type: torch.nn.Module | None = None,
validation_loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -75,10 +62,8 @@ class Densenet(Model): ...@@ -75,10 +62,8 @@ class Densenet(Model):
num_classes: int = 1, num_classes: int = 1,
): ):
super().__init__( super().__init__(
train_loss_type, loss_type,
train_loss_arguments, loss_arguments,
validation_loss_type,
validation_loss_arguments,
optimizer_type, optimizer_type,
optimizer_arguments, optimizer_arguments,
augmentation_transforms, augmentation_transforms,
......
...@@ -24,26 +24,15 @@ class Model(pl.LightningModule): ...@@ -24,26 +24,15 @@ class Model(pl.LightningModule):
Parameters Parameters
---------- ----------
train_loss_type loss_type
The loss to be used during the training. The loss to be used for training and evaluation.
.. warning:: .. warning::
The loss should be set to always return batch averages (as opposed The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so. to the batch sum), as our logging system expects it so.
train_loss_arguments loss_arguments
Arguments to the training loss. Arguments to the loss.
validation_loss_type
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss_arguments
Arguments to the validation loss.
optimizer_type optimizer_type
The type of optimizer to use for training. The type of optimizer to use for training.
optimizer_arguments optimizer_arguments
...@@ -57,10 +46,8 @@ class Model(pl.LightningModule): ...@@ -57,10 +46,8 @@ class Model(pl.LightningModule):
def __init__( def __init__(
self, self,
train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
train_loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
validation_loss_type: torch.nn.Module | None = None,
validation_loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -73,13 +60,13 @@ class Model(pl.LightningModule): ...@@ -73,13 +60,13 @@ class Model(pl.LightningModule):
self.model_transforms: TransformSequence = [] self.model_transforms: TransformSequence = []
self._train_loss_type = train_loss_type self._loss_type = loss_type
self._train_loss_arguments = train_loss_arguments
self._train_loss = None self._train_loss = None
self._train_loss_arguments = loss_arguments
self._validation_loss_type = validation_loss_type
self._validation_loss_arguments = validation_loss_arguments
self.validation_loss = None self.validation_loss = None
self._validation_loss_arguments = loss_arguments
self._optimizer_type = optimizer_type self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments self._optimizer_arguments = optimizer_arguments
...@@ -148,8 +135,8 @@ class Model(pl.LightningModule): ...@@ -148,8 +135,8 @@ class Model(pl.LightningModule):
raise NotImplementedError raise NotImplementedError
def configure_losses(self): def configure_losses(self):
self._train_loss = self._train_loss_type(**self._train_loss_arguments) self._train_loss = self._loss_type(**self._train_loss_arguments)
self._validation_loss = self._validation_loss_type( self._validation_loss = self._loss_type(
**self._validation_loss_arguments **self._validation_loss_arguments
) )
...@@ -160,7 +147,7 @@ class Model(pl.LightningModule): ...@@ -160,7 +147,7 @@ class Model(pl.LightningModule):
) )
def balance_losses(self, datamodule) -> None: def balance_losses(self, datamodule) -> None:
"""Balance the loss based on the distribution of targets in the datamodule, if the loss function supports it. """Balance the loss based on the distribution of targets in the datamodule, if the loss supports it (contains a 'pos_weight' attribute).
Parameters Parameters
---------- ----------
...@@ -168,29 +155,18 @@ class Model(pl.LightningModule): ...@@ -168,29 +155,18 @@ class Model(pl.LightningModule):
Instance of a datamodule. Instance of a datamodule.
""" """
logger.info(
f"Balancing training loss function {self._train_loss_type}."
)
try: try:
getattr(self._train_loss_type(), "pos_weight") getattr(self._loss_type(), "pos_weight")
except AttributeError: except AttributeError:
logger.warning( logger.warning(
"Training loss does not posess a 'pos_weight' attribute and will not be balanced." f"Loss {self._loss_type} does not posess a 'pos_weight' attribute and will not be balanced."
) )
else: else:
logger.info(f"Balancing training loss {self._loss_type}.")
train_weights = get_positive_weights(datamodule.train_dataloader()) train_weights = get_positive_weights(datamodule.train_dataloader())
self._train_loss_arguments["pos_weight"] = train_weights self._train_loss_arguments["pos_weight"] = train_weights
logger.info( logger.info(f"Balancing validation loss {self._loss_type}.")
f"Balancing validation loss function {self._validation_loss_type}."
)
try:
getattr(self._validation_loss_type(), "pos_weight")
except AttributeError:
logger.warning(
"Validation loss does not posess a 'pos_weight' attribute and will not be balanced."
)
else:
validation_weights = get_positive_weights( validation_weights = get_positive_weights(
datamodule.val_dataloader()["validation"] datamodule.val_dataloader()["validation"]
) )
......
...@@ -30,26 +30,15 @@ class Pasa(Model): ...@@ -30,26 +30,15 @@ class Pasa(Model):
Parameters Parameters
---------- ----------
train_loss_type loss_type
The loss to be used during the training. The loss to be used for training and evaluation.
.. warning:: .. warning::
The loss should be set to always return batch averages (as opposed The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so. to the batch sum), as our logging system expects it so.
train_loss_arguments loss_arguments
Arguments to the training loss. Arguments to the loss.
validation_loss_type
The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be
used throughout.
.. warning::
The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so.
validation_loss_arguments
Arguments to the validation loss.
optimizer_type optimizer_type
The type of optimizer to use for training. The type of optimizer to use for training.
optimizer_arguments optimizer_arguments
...@@ -63,20 +52,16 @@ class Pasa(Model): ...@@ -63,20 +52,16 @@ class Pasa(Model):
def __init__( def __init__(
self, self,
train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss, loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
train_loss_arguments: dict[str, typing.Any] = {}, loss_arguments: dict[str, typing.Any] = {},
validation_loss_type: torch.nn.Module | None = None,
validation_loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
num_classes: int = 1, num_classes: int = 1,
): ):
super().__init__( super().__init__(
train_loss_type, loss_type,
train_loss_arguments, loss_arguments,
validation_loss_type,
validation_loss_arguments,
optimizer_type, optimizer_type,
optimizer_arguments, optimizer_arguments,
augmentation_transforms, augmentation_transforms,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment