From 89a313e563db25b8cd80c5046fe1356f2a3aed04 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Fri, 7 Jun 2024 12:21:51 +0200 Subject: [PATCH] [models.model] Implement to() method to move all computing to the designated accelerator device, including losses and augmentations --- src/mednet/models/model.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py index d109e3cc..651a783f 100644 --- a/src/mednet/models/model.py +++ b/src/mednet/models/model.py @@ -161,6 +161,36 @@ class Model(pl.LightningModule): **self._optimizer_arguments, ) + def to(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Self: + """Move model, augmentations and losses to specified device. + + Refer to the method :py:meth:`torch.nn.Module.to` for details. + + Parameters + ---------- + *args + Parameter forwarded to the underlying implementations. + **kwargs + Parameter forwarded to the underlying implementations. + + Returns + ------- + Self. + """ + + super().to(*args, **kwargs) + + self._augmentation_transforms = torchvision.transforms.Compose( + [k.to(*args, **kwargs) for k in self._augmentation_transforms.transforms] + ) + + assert self._train_loss is not None + self._train_loss.to(*args, **kwargs) + assert self._validation_loss is not None + self._validation_loss.to(*args, **kwargs) + + return self + def balance_losses(self, datamodule) -> None: """Balance the loss based on the distribution of targets in the datamodule, if the loss supports it (contains a 'pos_weight' attribute). -- GitLab