diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py index 651a783f75c983749d086fbe0e95f0ba5358a0d5..4308e26ece9de2495777a0e6bac9a4d86b60407d 100644 --- a/src/mednet/models/model.py +++ b/src/mednet/models/model.py @@ -65,7 +65,7 @@ class Model(pl.LightningModule): self._train_loss = None self._train_loss_arguments = loss_arguments - self.validation_loss = None + self._validation_loss = None self._validation_loss_arguments = loss_arguments self._optimizer_type = optimizer_type @@ -184,10 +184,10 @@ class Model(pl.LightningModule): [k.to(*args, **kwargs) for k in self._augmentation_transforms.transforms] ) - assert self._train_loss is not None - self._train_loss.to(*args, **kwargs) - assert self._validation_loss is not None - self._validation_loss.to(*args, **kwargs) + if self._train_loss is not None: + self._train_loss.to(*args, **kwargs) + if self._validation_loss is not None: + self._validation_loss.to(*args, **kwargs) return self