Skip to content
Snippets Groups Projects
Commit 55f4eac1 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Convert labels to float instead of double

parent 4964077c
No related branches found
No related tags found
1 merge request!4Moved code to lightning
Pipeline #73160 failed
......@@ -74,7 +74,7 @@ class Alexnet(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
training_loss = self.hparams.criterion(outputs, labels.double())
training_loss = self.hparams.criterion(outputs, labels.float())
return {"loss": training_loss}
......@@ -89,7 +89,7 @@ class Alexnet(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.hparams.criterion_valid(outputs, labels.double())
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss}
......
......@@ -75,7 +75,7 @@ class Densenet(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
training_loss = self.hparams.criterion(outputs, labels.double())
training_loss = self.hparams.criterion(outputs, labels.float())
return {"loss": training_loss}
......@@ -90,7 +90,7 @@ class Densenet(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.hparams.criterion_valid(outputs, labels.double())
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss}
......
......@@ -69,7 +69,7 @@ class DensenetRS(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
training_loss = self.hparams.criterion(outputs, labels.double())
training_loss = self.hparams.criterion(outputs, labels.float())
return {"loss": training_loss}
......@@ -84,7 +84,7 @@ class DensenetRS(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.hparams.criterion_valid(outputs, labels.double())
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss}
......
......@@ -57,7 +57,7 @@ class LogisticRegression(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
training_loss = self.hparams.criterion(outputs, labels.double())
training_loss = self.hparams.criterion(outputs, labels.float())
return {"loss": training_loss}
......@@ -72,7 +72,7 @@ class LogisticRegression(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.hparams.criterion_valid(outputs, labels.double())
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss}
......
......@@ -166,7 +166,7 @@ class PASA(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
training_loss = self.hparams.criterion(outputs, labels.double())
training_loss = self.hparams.criterion(outputs, labels.float())
return {"loss": training_loss}
......@@ -181,7 +181,7 @@ class PASA(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.hparams.criterion_valid(outputs, labels.double())
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss}
......
......@@ -64,7 +64,7 @@ class SignsToTB(pl.LightningModule):
# Forward pass on the network
outputs = self(images)
training_loss = self.hparams.criterion(outputs, labels.double())
training_loss = self.hparams.criterion(outputs, labels.float())
return {"loss": training_loss}
......@@ -79,7 +79,7 @@ class SignsToTB(pl.LightningModule):
# data forwarding on the existing network
outputs = self(images)
validation_loss = self.hparams.criterion_valid(outputs, labels.double())
validation_loss = self.hparams.criterion_valid(outputs, labels.float())
return {"validation_loss": validation_loss}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment