Skip to content
Snippets Groups Projects
Commit 89a313e5 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models.model] Implement to() method to move all computing to the designated...

[models.model] Implement to() method to move all computing to the designated accelerator device, including losses and augmentations
parent 4a09305d
No related branches found
No related tags found
1 merge request!49Optimise device allocation for all attributes used in the training loop
Pipeline #87958 failed
...@@ -161,6 +161,36 @@ class Model(pl.LightningModule): ...@@ -161,6 +161,36 @@ class Model(pl.LightningModule):
**self._optimizer_arguments, **self._optimizer_arguments,
) )
def to(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Self:
"""Move model, augmentations and losses to specified device.
Refer to the method :py:meth:`torch.nn.Module.to` for details.
Parameters
----------
*args
Parameter forwarded to the underlying implementations.
**kwargs
Parameter forwarded to the underlying implementations.
Returns
-------
Self.
"""
super().to(*args, **kwargs)
self._augmentation_transforms = torchvision.transforms.Compose(
[k.to(*args, **kwargs) for k in self._augmentation_transforms.transforms]
)
assert self._train_loss is not None
self._train_loss.to(*args, **kwargs)
assert self._validation_loss is not None
self._validation_loss.to(*args, **kwargs)
return self
def balance_losses(self, datamodule) -> None: def balance_losses(self, datamodule) -> None:
"""Balance the loss based on the distribution of targets in the datamodule, if the loss supports it (contains a 'pos_weight' attribute). """Balance the loss based on the distribution of targets in the datamodule, if the loss supports it (contains a 'pos_weight' attribute).
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment