diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index 10ecfc7215ecd0903d848812669d8ae4debf9ff6..e871a982ea393c919aab6c819ef2dd2bb70fcc96 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -60,6 +60,8 @@ class Alexnet(pl.LightningModule): # Forward pass on the network outputs = self(images) + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion = self.hparams.criterion.to(self.device) training_loss = self.hparams.criterion(outputs, labels.float()) return {"loss": training_loss} @@ -75,6 +77,11 @@ class Alexnet(pl.LightningModule): # data forwarding on the existing network outputs = self(images) + + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion_valid = self.hparams.criterion_valid.to( + self.device + ) validation_loss = self.hparams.criterion_valid(outputs, labels.float()) return {"validation_loss": validation_loss} diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 77cbc0a8d3443a940489cc7f482655cd183c0cc5..ea6e623c3a9cdfc3f9d6896ea77a681f7a2f5cc7 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -60,6 +60,8 @@ class Densenet(pl.LightningModule): # Forward pass on the network outputs = self(images) + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion = self.hparams.criterion.to(self.device) training_loss = self.hparams.criterion(outputs, labels.float()) return {"loss": training_loss} @@ -75,6 +77,11 @@ class Densenet(pl.LightningModule): # data forwarding on the existing network outputs = self(images) + + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion_valid = self.hparams.criterion_valid.to( + self.device + ) validation_loss = self.hparams.criterion_valid(outputs, labels.float()) return {"validation_loss": validation_loss} diff --git a/src/ptbench/models/densenet_rs.py b/src/ptbench/models/densenet_rs.py index 6e5a3df4f46db4d4b2fb70dfcdf55b1b2a7decfa..a9d69e27928d5ec9a3d525d1a043370deeacb119 100644 --- a/src/ptbench/models/densenet_rs.py +++ b/src/ptbench/models/densenet_rs.py @@ -54,6 +54,8 @@ class DensenetRS(pl.LightningModule): # Forward pass on the network outputs = self(images) + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion = self.hparams.criterion.to(self.device) training_loss = self.hparams.criterion(outputs, labels.float()) return {"loss": training_loss} @@ -69,6 +71,11 @@ class DensenetRS(pl.LightningModule): # data forwarding on the existing network outputs = self(images) + + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion_valid = self.hparams.criterion_valid.to( + self.device + ) validation_loss = self.hparams.criterion_valid(outputs, labels.float()) return {"validation_loss": validation_loss} diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py index 485a396760facfc3bf84fbbaf1b66db40f327aa3..6efd2a25c9726d5aeb081ae2a7ed22192b9befcd 100644 --- a/src/ptbench/models/logistic_regression.py +++ b/src/ptbench/models/logistic_regression.py @@ -43,6 +43,8 @@ class LogisticRegression(pl.LightningModule): # Forward pass on the network outputs = self(images) + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion = self.hparams.criterion.to(self.device) training_loss = self.hparams.criterion(outputs, labels.float()) return {"loss": training_loss} @@ -58,6 +60,11 @@ class LogisticRegression(pl.LightningModule): # data forwarding on the existing network outputs = self(images) + + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion_valid = self.hparams.criterion_valid.to( + self.device + ) validation_loss = self.hparams.criterion_valid(outputs, labels.float()) return {"validation_loss": validation_loss} diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 78dff8e3fb402b4cccdec93f046096f9e7a8f2cb..125867bda1aa4f6e2317708cc5010d9120518f46 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -135,7 +135,9 @@ class PASA(pl.LightningModule): # Forward pass on the network outputs = self(images) - training_loss = self.hparams.criterion(outputs, labels.float()) + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion = self.hparams.criterion.to(self.device) + training_loss = self.hparams.criterion(outputs, labels.double()) return {"loss": training_loss} @@ -150,7 +152,12 @@ class PASA(pl.LightningModule): # data forwarding on the existing network outputs = self(images) - validation_loss = self.hparams.criterion_valid(outputs, labels.float()) + + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion_valid = self.hparams.criterion_valid.to( + self.device + ) + validation_loss = self.hparams.criterion_valid(outputs, labels.double()) return {"validation_loss": validation_loss} diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py index 9267e7778dacc46d58972f016f79e25e0b47366a..aa22864558aec7340cfd53b7e3b9622e72980e8a 100644 --- a/src/ptbench/models/signs_to_tb.py +++ b/src/ptbench/models/signs_to_tb.py @@ -50,6 +50,8 @@ class SignsToTB(pl.LightningModule): # Forward pass on the network outputs = self(images) + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion = self.hparams.criterion.to(self.device) training_loss = self.hparams.criterion(outputs, labels.float()) return {"loss": training_loss} @@ -65,6 +67,11 @@ class SignsToTB(pl.LightningModule): # data forwarding on the existing network outputs = self(images) + + # Manually move criterion to selected device, since not part of the model. + self.hparams.criterion_valid = self.hparams.criterion_valid.to( + self.device + ) validation_loss = self.hparams.criterion_valid(outputs, labels.float()) return {"validation_loss": validation_loss}