Skip to content
Snippets Groups Projects
Commit 507a6207 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Move criterion to selected device

As the criterion is not part of the model but instead a hyperparameter
due to the use of configuration files, it is not moved to the GPU if
selected as a device. We therefore manually move the criterion to the proper
device, which is bad practice when using lightning but works.
parent e26fe458
No related branches found
No related tags found
1 merge request!4Moved code to lightning
Pipeline #73426 passed
......@@ -60,6 +60,8 @@ class Alexnet(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion = self.hparams.criterion.to(self.device)
training_loss = self.hparams.criterion(outputs, labels.float())
return {"loss": training_loss}
......@@ -75,6 +77,11 @@ class Alexnet(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion_valid = self.hparams.criterion_valid.to(
self.device
)
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss}
......
......@@ -60,6 +60,8 @@ class Densenet(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion = self.hparams.criterion.to(self.device)
training_loss = self.hparams.criterion(outputs, labels.float())
return {"loss": training_loss}
......@@ -75,6 +77,11 @@ class Densenet(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion_valid = self.hparams.criterion_valid.to(
self.device
)
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss}
......
......@@ -54,6 +54,8 @@ class DensenetRS(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion = self.hparams.criterion.to(self.device)
training_loss = self.hparams.criterion(outputs, labels.float())
return {"loss": training_loss}
......@@ -69,6 +71,11 @@ class DensenetRS(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion_valid = self.hparams.criterion_valid.to(
self.device
)
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss}
......
......@@ -43,6 +43,8 @@ class LogisticRegression(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion = self.hparams.criterion.to(self.device)
training_loss = self.hparams.criterion(outputs, labels.float())
return {"loss": training_loss}
......@@ -58,6 +60,11 @@ class LogisticRegression(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion_valid = self.hparams.criterion_valid.to(
self.device
)
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss}
......
......@@ -135,7 +135,9 @@ class PASA(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
training_loss = self.hparams.criterion(outputs, labels.float())
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion = self.hparams.criterion.to(self.device)
training_loss = self.hparams.criterion(outputs, labels.double())
return {"loss": training_loss}
......@@ -150,7 +152,12 @@ class PASA(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion_valid = self.hparams.criterion_valid.to(
self.device
)
validation_loss = self.hparams.criterion_valid(outputs, labels.double())
return {"validation_loss": validation_loss}
......
......@@ -50,6 +50,8 @@ class SignsToTB(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion = self.hparams.criterion.to(self.device)
training_loss = self.hparams.criterion(outputs, labels.float())
return {"loss": training_loss}
......@@ -65,6 +67,11 @@ class SignsToTB(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
# Manually move criterion to selected device, since not part of the model.
self.hparams.criterion_valid = self.hparams.criterion_valid.to(
self.device
)
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment