From 507a62075d1288fbf9f30c39eee5ece482ed7207 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Fri, 12 May 2023 13:52:31 +0200
Subject: [PATCH] Move criterion to selected device

As the criterion is not part of the model but instead a hyperparameter
due to the use of configuration files, it is not moved to the GPU if
selected as a device. We therefore manually move the criterion to the proper
device, which is bad practice when using lightning but works.
---
 src/ptbench/models/alexnet.py             |  7 +++++++
 src/ptbench/models/densenet.py            |  7 +++++++
 src/ptbench/models/densenet_rs.py         |  7 +++++++
 src/ptbench/models/logistic_regression.py |  7 +++++++
 src/ptbench/models/pasa.py                | 11 +++++++++--
 src/ptbench/models/signs_to_tb.py         |  7 +++++++
 6 files changed, 44 insertions(+), 2 deletions(-)

diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index 10ecfc72..e871a982 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -60,6 +60,8 @@ class Alexnet(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
+        # Manually move criterion to selected device, since not part of the model.
+        self.hparams.criterion = self.hparams.criterion.to(self.device)
         training_loss = self.hparams.criterion(outputs, labels.float())
 
         return {"loss": training_loss}
@@ -75,6 +77,11 @@ class Alexnet(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
+
+        # Manually move criterion to selected device, since not part of the model.
+        self.hparams.criterion_valid = self.hparams.criterion_valid.to(
+            self.device
+        )
         validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
         return {"validation_loss": validation_loss}
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index 77cbc0a8..ea6e623c 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -60,6 +60,8 @@ class Densenet(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
+        # Manually move criterion to selected device, since not part of the model.
+        self.hparams.criterion = self.hparams.criterion.to(self.device)
         training_loss = self.hparams.criterion(outputs, labels.float())
 
         return {"loss": training_loss}
@@ -75,6 +77,11 @@ class Densenet(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
+
+        # Manually move criterion to selected device, since not part of the model.
+        self.hparams.criterion_valid = self.hparams.criterion_valid.to(
+            self.device
+        )
         validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
         return {"validation_loss": validation_loss}
diff --git a/src/ptbench/models/densenet_rs.py b/src/ptbench/models/densenet_rs.py
index 6e5a3df4..a9d69e27 100644
--- a/src/ptbench/models/densenet_rs.py
+++ b/src/ptbench/models/densenet_rs.py
@@ -54,6 +54,8 @@ class DensenetRS(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
+        # Manually move criterion to selected device, since not part of the model.
+        self.hparams.criterion = self.hparams.criterion.to(self.device)
         training_loss = self.hparams.criterion(outputs, labels.float())
 
         return {"loss": training_loss}
@@ -69,6 +71,11 @@ class DensenetRS(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
+
+        # Manually move criterion to selected device, since not part of the model.
+        self.hparams.criterion_valid = self.hparams.criterion_valid.to(
+            self.device
+        )
         validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
         return {"validation_loss": validation_loss}
diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py
index 485a3967..6efd2a25 100644
--- a/src/ptbench/models/logistic_regression.py
+++ b/src/ptbench/models/logistic_regression.py
@@ -43,6 +43,8 @@ class LogisticRegression(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
+        # Manually move criterion to selected device, since not part of the model.
+        self.hparams.criterion = self.hparams.criterion.to(self.device)
         training_loss = self.hparams.criterion(outputs, labels.float())
 
         return {"loss": training_loss}
@@ -58,6 +60,11 @@ class LogisticRegression(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
+
+        # Manually move criterion to selected device, since not part of the model.
+        self.hparams.criterion_valid = self.hparams.criterion_valid.to(
+            self.device
+        )
         validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
         return {"validation_loss": validation_loss}
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 78dff8e3..125867bd 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -135,7 +135,9 @@ class PASA(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.hparams.criterion(outputs, labels.float())
+        # Manually move criterion to selected device, since not part of the model.
+        self.hparams.criterion = self.hparams.criterion.to(self.device)
+        training_loss = self.hparams.criterion(outputs, labels.double())
 
         return {"loss": training_loss}
 
@@ -150,7 +152,12 @@ class PASA(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.hparams.criterion_valid(outputs, labels.float())
+
+        # Manually move criterion to selected device, since not part of the model.
+        self.hparams.criterion_valid = self.hparams.criterion_valid.to(
+            self.device
+        )
+        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py
index 9267e777..aa228645 100644
--- a/src/ptbench/models/signs_to_tb.py
+++ b/src/ptbench/models/signs_to_tb.py
@@ -50,6 +50,8 @@ class SignsToTB(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
+        # Manually move criterion to selected device, since not part of the model.
+        self.hparams.criterion = self.hparams.criterion.to(self.device)
         training_loss = self.hparams.criterion(outputs, labels.float())
 
         return {"loss": training_loss}
@@ -65,6 +67,11 @@ class SignsToTB(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
+
+        # Manually move criterion to selected device, since not part of the model.
+        self.hparams.criterion_valid = self.hparams.criterion_valid.to(
+            self.device
+        )
         validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
         return {"validation_loss": validation_loss}
-- 
GitLab