From 55f4eac1cd0492234553f7716dc64aec5a6c041c Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Mon, 8 May 2023 15:31:07 +0200
Subject: [PATCH] Convert labels to float instead of double

---
 src/ptbench/models/alexnet.py             | 4 ++--
 src/ptbench/models/densenet.py            | 4 ++--
 src/ptbench/models/densenet_rs.py         | 4 ++--
 src/ptbench/models/logistic_regression.py | 4 ++--
 src/ptbench/models/pasa.py                | 4 ++--
 src/ptbench/models/signs_to_tb.py         | 4 ++--
 6 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index 8b242277..74c07b71 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -74,7 +74,7 @@ class Alexnet(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.hparams.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.float())
 
         return {"loss": training_loss}
 
@@ -89,7 +89,7 @@ class Alexnet(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index f5c58ad6..31abf44d 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -75,7 +75,7 @@ class Densenet(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.hparams.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.float())
 
         return {"loss": training_loss}
 
@@ -90,7 +90,7 @@ class Densenet(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/models/densenet_rs.py b/src/ptbench/models/densenet_rs.py
index 97cb9bda..557d34a9 100644
--- a/src/ptbench/models/densenet_rs.py
+++ b/src/ptbench/models/densenet_rs.py
@@ -69,7 +69,7 @@ class DensenetRS(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.hparams.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.float())
 
         return {"loss": training_loss}
 
@@ -84,7 +84,7 @@ class DensenetRS(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py
index deda25aa..d53f8df0 100644
--- a/src/ptbench/models/logistic_regression.py
+++ b/src/ptbench/models/logistic_regression.py
@@ -57,7 +57,7 @@ class LogisticRegression(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.hparams.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.float())
 
         return {"loss": training_loss}
 
@@ -72,7 +72,7 @@ class LogisticRegression(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 155aa7d8..76aab5a4 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -166,7 +166,7 @@ class PASA(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.hparams.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.float())
 
         return {"loss": training_loss}
 
@@ -181,7 +181,7 @@ class PASA(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
         return {"validation_loss": validation_loss}
 
diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py
index 9db39f7a..f88707e9 100644
--- a/src/ptbench/models/signs_to_tb.py
+++ b/src/ptbench/models/signs_to_tb.py
@@ -64,7 +64,7 @@ class SignsToTB(pl.LightningModule):
         # Forward pass on the network
         outputs = self(images)
 
-        training_loss = self.hparams.criterion(outputs, labels.double())
+        training_loss = self.hparams.criterion(outputs, labels.float())
 
         return {"loss": training_loss}
 
@@ -79,7 +79,7 @@ class SignsToTB(pl.LightningModule):
 
         # data forwarding on the existing network
         outputs = self(images)
-        validation_loss = self.hparams.criterion_valid(outputs, labels.double())
+        validation_loss = self.hparams.criterion_valid(outputs, labels.float())
 
         return {"validation_loss": validation_loss}
 
-- 
GitLab