Skip to content
Snippets Groups Projects
Commit ea69e431 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[model] Model takes loss type and arguments during instanciation

parent 1f0e3938
No related branches found
No related tags found
1 merge request!38Replace sampler balancing by loss balancing
...@@ -15,8 +15,8 @@ from mednet.data.augmentations import ElasticDeformation ...@@ -15,8 +15,8 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.alexnet import Alexnet from mednet.models.alexnet import Alexnet
model = Alexnet( model = Alexnet(
train_loss=BCEWithLogitsLoss(), train_loss_type=BCEWithLogitsLoss,
validation_loss=BCEWithLogitsLoss(), validation_loss_type=BCEWithLogitsLoss,
optimizer_type=SGD, optimizer_type=SGD,
optimizer_arguments=dict(lr=0.01, momentum=0.1), optimizer_arguments=dict(lr=0.01, momentum=0.1),
augmentation_transforms=[ElasticDeformation(p=0.8)], augmentation_transforms=[ElasticDeformation(p=0.8)],
......
...@@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation ...@@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.alexnet import Alexnet from mednet.models.alexnet import Alexnet
model = Alexnet( model = Alexnet(
train_loss=BCEWithLogitsLoss(), train_loss_type=BCEWithLogitsLoss,
validation_loss=BCEWithLogitsLoss(), validation_loss_type=BCEWithLogitsLoss,
optimizer_type=SGD, optimizer_type=SGD,
optimizer_arguments=dict(lr=0.01, momentum=0.1), optimizer_arguments=dict(lr=0.01, momentum=0.1),
augmentation_transforms=[ElasticDeformation(p=0.8)], augmentation_transforms=[ElasticDeformation(p=0.8)],
......
...@@ -15,8 +15,8 @@ from mednet.data.augmentations import ElasticDeformation ...@@ -15,8 +15,8 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet from mednet.models.densenet import Densenet
model = Densenet( model = Densenet(
train_loss=BCEWithLogitsLoss(), train_loss_type=BCEWithLogitsLoss,
validation_loss=BCEWithLogitsLoss(), validation_loss_type=BCEWithLogitsLoss,
optimizer_type=Adam, optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001), optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)], augmentation_transforms=[ElasticDeformation(p=0.2)],
......
...@@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation ...@@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet from mednet.models.densenet import Densenet
model = Densenet( model = Densenet(
train_loss=BCEWithLogitsLoss(), train_loss_type=BCEWithLogitsLoss,
validation_loss=BCEWithLogitsLoss(), validation_loss_type=BCEWithLogitsLoss,
optimizer_type=Adam, optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001), optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)], augmentation_transforms=[ElasticDeformation(p=0.2)],
......
...@@ -16,8 +16,8 @@ from mednet.data.augmentations import ElasticDeformation ...@@ -16,8 +16,8 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet from mednet.models.densenet import Densenet
model = Densenet( model = Densenet(
train_loss=BCEWithLogitsLoss(), train_loss_type=BCEWithLogitsLoss,
validation_loss=BCEWithLogitsLoss(), validation_loss_type=BCEWithLogitsLoss,
optimizer_type=Adam, optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001), optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)], augmentation_transforms=[ElasticDeformation(p=0.2)],
......
...@@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation ...@@ -17,8 +17,8 @@ from mednet.data.augmentations import ElasticDeformation
from mednet.models.pasa import Pasa from mednet.models.pasa import Pasa
model = Pasa( model = Pasa(
train_loss=BCEWithLogitsLoss(), train_loss_type=BCEWithLogitsLoss,
validation_loss=BCEWithLogitsLoss(), validation_loss_type=BCEWithLogitsLoss,
optimizer_type=Adam, optimizer_type=Adam,
optimizer_arguments=dict(lr=8e-5), optimizer_arguments=dict(lr=8e-5),
augmentation_transforms=[ElasticDeformation(p=0.8)], augmentation_transforms=[ElasticDeformation(p=0.8)],
......
...@@ -72,6 +72,8 @@ def run( ...@@ -72,6 +72,8 @@ def run(
output_folder.mkdir(parents=True, exist_ok=True) output_folder.mkdir(parents=True, exist_ok=True)
model.configure_losses()
from .loggers import CustomTensorboardLogger from .loggers import CustomTensorboardLogger
log_dir = "logs" log_dir = "logs"
......
...@@ -27,14 +27,16 @@ class Alexnet(Model): ...@@ -27,14 +27,16 @@ class Alexnet(Model):
Parameters Parameters
---------- ----------
train_loss train_loss_type
The loss to be used during the training. The loss to be used during the training.
.. warning:: .. warning::
The loss should be set to always return batch averages (as opposed The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so. to the batch sum), as our logging system expects it so.
validation_loss train_loss_arguments
Arguments to the training loss.
validation_loss_type
The loss to be used for validation (may be different from the training The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be loss). If extra-validation sets are provided, the same loss will be
used throughout. used throughout.
...@@ -43,6 +45,8 @@ class Alexnet(Model): ...@@ -43,6 +45,8 @@ class Alexnet(Model):
The loss should be set to always return batch averages (as opposed The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so. to the batch sum), as our logging system expects it so.
validation_loss_arguments
Arguments to the validation loss.
optimizer_type optimizer_type
The type of optimizer to use for training. The type of optimizer to use for training.
optimizer_arguments optimizer_arguments
...@@ -59,8 +63,10 @@ class Alexnet(Model): ...@@ -59,8 +63,10 @@ class Alexnet(Model):
def __init__( def __init__(
self, self,
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
validation_loss: torch.nn.Module | None = None, train_loss_arguments: dict[str, typing.Any] = {},
validation_loss_type: torch.nn.Module | None = None,
validation_loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -68,8 +74,10 @@ class Alexnet(Model): ...@@ -68,8 +74,10 @@ class Alexnet(Model):
num_classes: int = 1, num_classes: int = 1,
): ):
super().__init__( super().__init__(
train_loss, train_loss_type,
validation_loss, train_loss_arguments,
validation_loss_type,
validation_loss_arguments,
optimizer_type, optimizer_type,
optimizer_arguments, optimizer_arguments,
augmentation_transforms, augmentation_transforms,
...@@ -166,7 +174,7 @@ class Alexnet(Model): ...@@ -166,7 +174,7 @@ class Alexnet(Model):
# data forwarding on the existing network # data forwarding on the existing network
outputs = self(images) outputs = self(images)
return self._validation_loss[dataloader_idx](outputs, labels.float()) return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0]) outputs = self(batch[0])
......
...@@ -25,14 +25,16 @@ class Densenet(Model): ...@@ -25,14 +25,16 @@ class Densenet(Model):
Parameters Parameters
---------- ----------
train_loss train_loss_type
The loss to be used during the training. The loss to be used during the training.
.. warning:: .. warning::
The loss should be set to always return batch averages (as opposed The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so. to the batch sum), as our logging system expects it so.
validation_loss train_loss_arguments
Arguments to the training loss.
validation_loss_type
The loss to be used for validation (may be different from the training The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be loss). If extra-validation sets are provided, the same loss will be
used throughout. used throughout.
...@@ -41,6 +43,8 @@ class Densenet(Model): ...@@ -41,6 +43,8 @@ class Densenet(Model):
The loss should be set to always return batch averages (as opposed The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so. to the batch sum), as our logging system expects it so.
validation_loss_arguments
Arguments to the validation loss.
optimizer_type optimizer_type
The type of optimizer to use for training. The type of optimizer to use for training.
optimizer_arguments optimizer_arguments
...@@ -59,8 +63,10 @@ class Densenet(Model): ...@@ -59,8 +63,10 @@ class Densenet(Model):
def __init__( def __init__(
self, self,
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
validation_loss: torch.nn.Module | None = None, train_loss_arguments: dict[str, typing.Any] = {},
validation_loss_type: torch.nn.Module | None = None,
validation_loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -69,8 +75,10 @@ class Densenet(Model): ...@@ -69,8 +75,10 @@ class Densenet(Model):
num_classes: int = 1, num_classes: int = 1,
): ):
super().__init__( super().__init__(
train_loss, train_loss_type,
validation_loss, train_loss_arguments,
validation_loss_type,
validation_loss_arguments,
optimizer_type, optimizer_type,
optimizer_arguments, optimizer_arguments,
augmentation_transforms, augmentation_transforms,
...@@ -164,7 +172,7 @@ class Densenet(Model): ...@@ -164,7 +172,7 @@ class Densenet(Model):
# data forwarding on the existing network # data forwarding on the existing network
outputs = self(images) outputs = self(images)
return self._validation_loss[dataloader_idx](outputs, labels.float()) return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0]) outputs = self(batch[0])
......
...@@ -24,14 +24,16 @@ class Model(pl.LightningModule): ...@@ -24,14 +24,16 @@ class Model(pl.LightningModule):
Parameters Parameters
---------- ----------
train_loss train_loss_type
The loss to be used during the training. The loss to be used during the training.
.. warning:: .. warning::
The loss should be set to always return batch averages (as opposed The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so. to the batch sum), as our logging system expects it so.
validation_loss train_loss_arguments
Arguments to the training loss.
validation_loss_type
The loss to be used for validation (may be different from the training The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be loss). If extra-validation sets are provided, the same loss will be
used throughout. used throughout.
...@@ -40,6 +42,8 @@ class Model(pl.LightningModule): ...@@ -40,6 +42,8 @@ class Model(pl.LightningModule):
The loss should be set to always return batch averages (as opposed The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so. to the batch sum), as our logging system expects it so.
validation_loss_arguments
Arguments to the validation loss.
optimizer_type optimizer_type
The type of optimizer to use for training. The type of optimizer to use for training.
optimizer_arguments optimizer_arguments
...@@ -53,8 +57,10 @@ class Model(pl.LightningModule): ...@@ -53,8 +57,10 @@ class Model(pl.LightningModule):
def __init__( def __init__(
self, self,
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
validation_loss: torch.nn.Module | None = None, train_loss_arguments: dict[str, typing.Any] = {},
validation_loss_type: torch.nn.Module | None = None,
validation_loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
...@@ -67,10 +73,13 @@ class Model(pl.LightningModule): ...@@ -67,10 +73,13 @@ class Model(pl.LightningModule):
self.model_transforms: TransformSequence = [] self.model_transforms: TransformSequence = []
self._train_loss = train_loss self._train_loss_type = train_loss_type
self._validation_loss = [ self._train_loss_arguments = train_loss_arguments
(validation_loss if validation_loss is not None else train_loss) self._train_loss = None
]
self._validation_loss_type = validation_loss_type
self._validation_loss_arguments = validation_loss_arguments
self.validation_loss = None
self._optimizer_type = optimizer_type self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments self._optimizer_arguments = optimizer_arguments
...@@ -138,6 +147,12 @@ class Model(pl.LightningModule): ...@@ -138,6 +147,12 @@ class Model(pl.LightningModule):
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
raise NotImplementedError raise NotImplementedError
def configure_losses(self):
self._train_loss = self._train_loss_type(**self._train_loss_arguments)
self._validation_loss = self._validation_loss_type(
**self._validation_loss_arguments
)
def configure_optimizers(self): def configure_optimizers(self):
return self._optimizer_type( return self._optimizer_type(
self.parameters(), self.parameters(),
...@@ -153,44 +168,30 @@ class Model(pl.LightningModule): ...@@ -153,44 +168,30 @@ class Model(pl.LightningModule):
Instance of a datamodule. Instance of a datamodule.
""" """
logger.info(f"Balancing training loss function {self._train_loss}.") logger.info(
f"Balancing training loss function {self._train_loss_type}."
)
try: try:
getattr(self._train_loss, "pos_weight") getattr(self._train_loss_type(), "pos_weight")
except AttributeError: except AttributeError:
logger.warning( logger.warning(
"Training loss does not posess a 'pos_weight' attribute and will not be balanced." "Training loss does not posess a 'pos_weight' attribute and will not be balanced."
) )
else: else:
train_weights = get_positive_weights(datamodule.train_dataloader()) train_weights = get_positive_weights(datamodule.train_dataloader())
setattr(self._train_loss, "pos_weight", train_weights) self._train_loss_arguments["pos_weight"] = train_weights
logger.info( logger.info(
f"Balancing validation loss function {self._validation_loss[0]}." f"Balancing validation loss function {self._validation_loss_type}."
) )
try: try:
getattr(self._validation_loss[0], "pos_weight") getattr(self._validation_loss_type(), "pos_weight")
except AttributeError: except AttributeError:
logger.warning( logger.warning(
"Validation loss does not posess a 'pos_weight' attribute and will not be balanced." "Validation loss does not posess a 'pos_weight' attribute and will not be balanced."
) )
else: else:
# If multiple validation DataLoaders are used, each one will need to have a loss validation_weights = get_positive_weights(
# that is balanced for that DataLoader datamodule.val_dataloader()["validation"]
new_validation_losses = []
loss_class = self._validation_loss[0].__class__
datamodule_validation_keys = datamodule.val_dataset_keys()
logger.info(
f"Found {len(datamodule_validation_keys)} keys in the validation datamodule. A balanced loss will be created for each key."
) )
self._validation_loss_arguments["pos_weight"] = validation_weights
for val_dataset_key in datamodule_validation_keys:
validation_weights = get_positive_weights(
datamodule.val_dataloader()[val_dataset_key]
)
new_validation_losses.append(
loss_class(pos_weight=validation_weights)
)
self._validation_loss = new_validation_losses
...@@ -30,14 +30,16 @@ class Pasa(Model): ...@@ -30,14 +30,16 @@ class Pasa(Model):
Parameters Parameters
---------- ----------
train_loss train_loss_type
The loss to be used during the training. The loss to be used during the training.
.. warning:: .. warning::
The loss should be set to always return batch averages (as opposed The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so. to the batch sum), as our logging system expects it so.
validation_loss train_loss_arguments
Arguments to the training loss.
validation_loss_type
The loss to be used for validation (may be different from the training The loss to be used for validation (may be different from the training
loss). If extra-validation sets are provided, the same loss will be loss). If extra-validation sets are provided, the same loss will be
used throughout. used throughout.
...@@ -46,6 +48,8 @@ class Pasa(Model): ...@@ -46,6 +48,8 @@ class Pasa(Model):
The loss should be set to always return batch averages (as opposed The loss should be set to always return batch averages (as opposed
to the batch sum), as our logging system expects it so. to the batch sum), as our logging system expects it so.
validation_loss_arguments
Arguments to the validation loss.
optimizer_type optimizer_type
The type of optimizer to use for training. The type of optimizer to use for training.
optimizer_arguments optimizer_arguments
...@@ -59,16 +63,20 @@ class Pasa(Model): ...@@ -59,16 +63,20 @@ class Pasa(Model):
def __init__( def __init__(
self, self,
train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), train_loss_type: torch.nn.Module = torch.nn.BCEWithLogitsLoss,
validation_loss: torch.nn.Module | None = None, train_loss_arguments: dict[str, typing.Any] = {},
validation_loss_type: torch.nn.Module | None = None,
validation_loss_arguments: dict[str, typing.Any] = {},
optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam,
optimizer_arguments: dict[str, typing.Any] = {}, optimizer_arguments: dict[str, typing.Any] = {},
augmentation_transforms: TransformSequence = [], augmentation_transforms: TransformSequence = [],
num_classes: int = 1, num_classes: int = 1,
): ):
super().__init__( super().__init__(
train_loss, train_loss_type,
validation_loss, train_loss_arguments,
validation_loss_type,
validation_loss_arguments,
optimizer_type, optimizer_type,
optimizer_arguments, optimizer_arguments,
augmentation_transforms, augmentation_transforms,
...@@ -233,7 +241,7 @@ class Pasa(Model): ...@@ -233,7 +241,7 @@ class Pasa(Model):
# data forwarding on the existing network # data forwarding on the existing network
outputs = self(images) outputs = self(images)
return self._validation_loss[dataloader_idx](outputs, labels.float()) return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0]) outputs = self(batch[0])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment