diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index 8b24227753e704281e5c00e979dfbf2192760675..74c07b71ab33627438367cfa008016f4d02976dc 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -74,7 +74,7 @@ class Alexnet(pl.LightningModule): # Forward pass on the network outputs = self(images) - training_loss = self.hparams.criterion(outputs, labels.double()) + training_loss = self.hparams.criterion(outputs, labels.float()) return {"loss": training_loss} @@ -89,7 +89,7 @@ class Alexnet(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - validation_loss = self.hparams.criterion_valid(outputs, labels.double()) + validation_loss = self.hparams.criterion_valid(outputs, labels.float()) return {"validation_loss": validation_loss} diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index f5c58ad68affdabe6c005616599979afb994834c..31abf44d7d8ad27dbc29d6161fa3b29a7de9aa22 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -75,7 +75,7 @@ class Densenet(pl.LightningModule): # Forward pass on the network outputs = self(images) - training_loss = self.hparams.criterion(outputs, labels.double()) + training_loss = self.hparams.criterion(outputs, labels.float()) return {"loss": training_loss} @@ -90,7 +90,7 @@ class Densenet(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - validation_loss = self.hparams.criterion_valid(outputs, labels.double()) + validation_loss = self.hparams.criterion_valid(outputs, labels.float()) return {"validation_loss": validation_loss} diff --git a/src/ptbench/models/densenet_rs.py b/src/ptbench/models/densenet_rs.py index 97cb9bdace46772cd39a829b864eab2116aa1429..557d34a9f5d48914078647e9304c213a16589031 100644 --- a/src/ptbench/models/densenet_rs.py +++ b/src/ptbench/models/densenet_rs.py @@ -69,7 +69,7 @@ class DensenetRS(pl.LightningModule): # Forward pass on the network outputs = self(images) - training_loss = self.hparams.criterion(outputs, labels.double()) + training_loss = self.hparams.criterion(outputs, labels.float()) return {"loss": training_loss} @@ -84,7 +84,7 @@ class DensenetRS(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - validation_loss = self.hparams.criterion_valid(outputs, labels.double()) + validation_loss = self.hparams.criterion_valid(outputs, labels.float()) return {"validation_loss": validation_loss} diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py index deda25aaac71289aa210cb88c4817d5344f57f38..d53f8df044de14205d16227da5fa1347a23fd391 100644 --- a/src/ptbench/models/logistic_regression.py +++ b/src/ptbench/models/logistic_regression.py @@ -57,7 +57,7 @@ class LogisticRegression(pl.LightningModule): # Forward pass on the network outputs = self(images) - training_loss = self.hparams.criterion(outputs, labels.double()) + training_loss = self.hparams.criterion(outputs, labels.float()) return {"loss": training_loss} @@ -72,7 +72,7 @@ class LogisticRegression(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - validation_loss = self.hparams.criterion_valid(outputs, labels.double()) + validation_loss = self.hparams.criterion_valid(outputs, labels.float()) return {"validation_loss": validation_loss} diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 155aa7d89775868710bfbdc3f0d7a7f9f49df699..76aab5a4270df455268bd6c7fc4edba4e88e42b5 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -166,7 +166,7 @@ class PASA(pl.LightningModule): # Forward pass on the network outputs = self(images) - training_loss = self.hparams.criterion(outputs, labels.double()) + training_loss = self.hparams.criterion(outputs, labels.float()) return {"loss": training_loss} @@ -181,7 +181,7 @@ class PASA(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - validation_loss = self.hparams.criterion_valid(outputs, labels.double()) + validation_loss = self.hparams.criterion_valid(outputs, labels.float()) return {"validation_loss": validation_loss} diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py index 9db39f7a967fd1bfa9c04d158120e7b866ed5a88..f88707e95452e6e83a32cb08d39b094d486f5528 100644 --- a/src/ptbench/models/signs_to_tb.py +++ b/src/ptbench/models/signs_to_tb.py @@ -64,7 +64,7 @@ class SignsToTB(pl.LightningModule): # Forward pass on the network outputs = self(images) - training_loss = self.hparams.criterion(outputs, labels.double()) + training_loss = self.hparams.criterion(outputs, labels.float()) return {"loss": training_loss} @@ -79,7 +79,7 @@ class SignsToTB(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - validation_loss = self.hparams.criterion_valid(outputs, labels.double()) + validation_loss = self.hparams.criterion_valid(outputs, labels.float()) return {"validation_loss": validation_loss}