diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py
index 651a783f75c983749d086fbe0e95f0ba5358a0d5..4308e26ece9de2495777a0e6bac9a4d86b60407d 100644
--- a/src/mednet/models/model.py
+++ b/src/mednet/models/model.py
@@ -65,7 +65,7 @@ class Model(pl.LightningModule):
         self._train_loss = None
         self._train_loss_arguments = loss_arguments
 
-        self.validation_loss = None
+        self._validation_loss = None
         self._validation_loss_arguments = loss_arguments
 
         self._optimizer_type = optimizer_type
@@ -184,10 +184,10 @@ class Model(pl.LightningModule):
             [k.to(*args, **kwargs) for k in self._augmentation_transforms.transforms]
         )
 
-        assert self._train_loss is not None
-        self._train_loss.to(*args, **kwargs)
-        assert self._validation_loss is not None
-        self._validation_loss.to(*args, **kwargs)
+        if self._train_loss is not None:
+            self._train_loss.to(*args, **kwargs)
+        if self._validation_loss is not None:
+            self._validation_loss.to(*args, **kwargs)
 
         return self