From 812e23df2ea020e080f559366ee034ce32096ccf Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Fri, 7 Jun 2024 20:13:08 +0200
Subject: [PATCH] [models.model] Only move loss to device iff a loss is
 available (e.g. not during prediction)

---
 src/mednet/models/model.py | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py
index 651a783f..4308e26e 100644
--- a/src/mednet/models/model.py
+++ b/src/mednet/models/model.py
@@ -65,7 +65,7 @@ class Model(pl.LightningModule):
         self._train_loss = None
         self._train_loss_arguments = loss_arguments
 
-        self.validation_loss = None
+        self._validation_loss = None
         self._validation_loss_arguments = loss_arguments
 
         self._optimizer_type = optimizer_type
@@ -184,10 +184,10 @@ class Model(pl.LightningModule):
             [k.to(*args, **kwargs) for k in self._augmentation_transforms.transforms]
         )
 
-        assert self._train_loss is not None
-        self._train_loss.to(*args, **kwargs)
-        assert self._validation_loss is not None
-        self._validation_loss.to(*args, **kwargs)
+        if self._train_loss is not None:
+            self._train_loss.to(*args, **kwargs)
+        if self._validation_loss is not None:
+            self._validation_loss.to(*args, **kwargs)
 
         return self
 
-- 
GitLab