Skip to content
Snippets Groups Projects
Commit 812e23df authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models.model] Only move loss to device iff a loss is available (e.g. not during prediction)

parent 89a313e5
Branches
Tags
1 merge request!49Optimise device allocation for all attributes used in the training loop
......@@ -65,7 +65,7 @@ class Model(pl.LightningModule):
self._train_loss = None
self._train_loss_arguments = loss_arguments
self.validation_loss = None
self._validation_loss = None
self._validation_loss_arguments = loss_arguments
self._optimizer_type = optimizer_type
......@@ -184,10 +184,10 @@ class Model(pl.LightningModule):
[k.to(*args, **kwargs) for k in self._augmentation_transforms.transforms]
)
assert self._train_loss is not None
self._train_loss.to(*args, **kwargs)
assert self._validation_loss is not None
self._validation_loss.to(*args, **kwargs)
if self._train_loss is not None:
self._train_loss.to(*args, **kwargs)
if self._validation_loss is not None:
self._validation_loss.to(*args, **kwargs)
return self
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment