Skip to content
Snippets Groups Projects
Commit 222a3762 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[train] Add support for multiple validation dataloaders

During validation logging, lightning seems to append "/dataloader_idx_n"
to the key we define if multiple dataloaders are used.
parent 51b5d6f7
No related branches found
No related tags found
1 merge request!38Replace sampler balancing by loss balancing
...@@ -758,7 +758,7 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -758,7 +758,7 @@ class ConcatDataModule(lightning.LightningDataModule):
else: else:
self._datasets[name] = _ConcatDataset(datasets) self._datasets[name] = _ConcatDataset(datasets)
def _val_dataset_keys(self) -> list[str]: def val_dataset_keys(self) -> list[str]:
"""Return list of validation dataset names. """Return list of validation dataset names.
Returns Returns
...@@ -796,11 +796,11 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -796,11 +796,11 @@ class ConcatDataModule(lightning.LightningDataModule):
""" """
if stage == "fit": if stage == "fit":
for k in ["train"] + self._val_dataset_keys(): for k in ["train"] + self.val_dataset_keys():
self._setup_dataset(k) self._setup_dataset(k)
elif stage == "validate": elif stage == "validate":
for k in self._val_dataset_keys(): for k in self.val_dataset_keys():
self._setup_dataset(k) self._setup_dataset(k)
elif stage == "test": elif stage == "test":
...@@ -889,7 +889,7 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -889,7 +889,7 @@ class ConcatDataModule(lightning.LightningDataModule):
self._datasets[k], self._datasets[k],
**validation_loader_opts, **validation_loader_opts,
) )
for k in self._val_dataset_keys() for k in self.val_dataset_keys()
} }
def test_dataloader(self) -> dict[str, DataLoader]: def test_dataloader(self) -> dict[str, DataLoader]:
......
...@@ -362,13 +362,8 @@ class LoggingCallback(lightning.pytorch.Callback): ...@@ -362,13 +362,8 @@ class LoggingCallback(lightning.pytorch.Callback):
out which dataset was used for this validation epoch. out which dataset was used for this validation epoch.
""" """
if dataloader_idx == 0:
key = "loss/validation"
else:
key = f"loss/validation-{dataloader_idx}"
pl_module.log( pl_module.log(
key, "loss/validation",
outputs.item(), outputs.item(),
prog_bar=False, prog_bar=False,
on_step=False, on_step=False,
......
...@@ -91,6 +91,10 @@ def run( ...@@ -91,6 +91,10 @@ def run(
main_pid=os.getpid(), main_pid=os.getpid(),
) )
monitor_key = "loss/validation"
if len(datamodule.val_dataset_keys()) > 1:
monitor_key = "loss/validation/dataloader_idx_0"
# This checkpointer will operate at the end of every validation epoch # This checkpointer will operate at the end of every validation epoch
# (which happens at each checkpoint period), it will then save the lowest # (which happens at each checkpoint period), it will then save the lowest
# validation loss model observed. It will also save the last trained model # validation loss model observed. It will also save the last trained model
...@@ -98,7 +102,7 @@ def run( ...@@ -98,7 +102,7 @@ def run(
dirpath=output_folder, dirpath=output_folder,
filename=CHECKPOINT_ALIASES["best"], filename=CHECKPOINT_ALIASES["best"],
save_last=True, # will (re)create the last trained model, at every iteration save_last=True, # will (re)create the last trained model, at every iteration
monitor="loss/validation", monitor=monitor_key,
mode="min", mode="min",
save_on_train_epoch_end=True, save_on_train_epoch_end=True,
every_n_epochs=validation_period, # frequency at which it checks the "monitor" every_n_epochs=validation_period, # frequency at which it checks the "monitor"
......
...@@ -166,8 +166,7 @@ class Alexnet(Model): ...@@ -166,8 +166,7 @@ class Alexnet(Model):
# data forwarding on the existing network # data forwarding on the existing network
outputs = self(images) outputs = self(images)
return self._validation_loss[dataloader_idx](outputs, labels.float())
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0]) outputs = self(batch[0])
......
...@@ -164,7 +164,7 @@ class Densenet(Model): ...@@ -164,7 +164,7 @@ class Densenet(Model):
# data forwarding on the existing network # data forwarding on the existing network
outputs = self(images) outputs = self(images)
return self._validation_loss(outputs, labels.float()) return self._validation_loss[dataloader_idx](outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0]) outputs = self(batch[0])
......
...@@ -68,9 +68,10 @@ class Model(pl.LightningModule): ...@@ -68,9 +68,10 @@ class Model(pl.LightningModule):
self.model_transforms: TransformSequence = [] self.model_transforms: TransformSequence = []
self._train_loss = train_loss self._train_loss = train_loss
self._validation_loss = ( self._validation_loss = [
validation_loss if validation_loss is not None else train_loss (validation_loss if validation_loss is not None else train_loss)
) ]
self._optimizer_type = optimizer_type self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments self._optimizer_arguments = optimizer_arguments
...@@ -163,16 +164,32 @@ class Model(pl.LightningModule): ...@@ -163,16 +164,32 @@ class Model(pl.LightningModule):
setattr(self._train_loss, "pos_weight", train_weights) setattr(self._train_loss, "pos_weight", train_weights)
logger.info( logger.info(
f"Balancing validation loss function {self._validation_loss}." f"Balancing validation loss function {self._validation_loss[0]}."
) )
try: try:
getattr(self._validation_loss, "pos_weight") getattr(self._validation_loss[0], "pos_weight")
except AttributeError: except AttributeError:
logger.warning( logger.warning(
"Validation loss does not posess a 'pos_weight' attribute and will not be balanced." "Validation loss does not posess a 'pos_weight' attribute and will not be balanced."
) )
else: else:
validation_weights = _get_label_weights( # If multiple validation DataLoaders are used, each one will need to have a loss
datamodule.val_dataloader()["validation"] # that is balanced for that DataLoader
new_validation_losses = []
loss_class = self._validation_loss[0].__class__
datamodule_validation_keys = datamodule.val_dataset_keys()
logger.info(
f"Found {len(datamodule_validation_keys)} keys in the validation datamodule. A balanced loss will be created for each key."
) )
setattr(self._validation_loss, "pos_weight", validation_weights)
for val_dataset_key in datamodule_validation_keys:
validation_weights = _get_label_weights(
datamodule.val_dataloader()[val_dataset_key]
)
new_validation_losses.append(
loss_class(pos_weight=validation_weights)
)
self._validation_loss = new_validation_losses
...@@ -233,8 +233,7 @@ class Pasa(Model): ...@@ -233,8 +233,7 @@ class Pasa(Model):
# data forwarding on the existing network # data forwarding on the existing network
outputs = self(images) outputs = self(images)
return self._validation_loss[dataloader_idx](outputs, labels.float())
return self._validation_loss(outputs, labels.float())
def predict_step(self, batch, batch_idx, dataloader_idx=0): def predict_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch[0]) outputs = self(batch[0])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment