From 89a313e563db25b8cd80c5046fe1356f2a3aed04 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Fri, 7 Jun 2024 12:21:51 +0200
Subject: [PATCH] [models.model] Implement to() method to move all computing to
 the designated accelerator device, including losses and augmentations

---
 src/mednet/models/model.py | 30 ++++++++++++++++++++++++++++++
 1 file changed, 30 insertions(+)

diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py
index d109e3cc..651a783f 100644
--- a/src/mednet/models/model.py
+++ b/src/mednet/models/model.py
@@ -161,6 +161,36 @@ class Model(pl.LightningModule):
             **self._optimizer_arguments,
         )
 
+    def to(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Self:
+        """Move model, augmentations and losses to specified device.
+
+        Refer to the method :py:meth:`torch.nn.Module.to` for details.
+
+        Parameters
+        ----------
+        *args
+            Parameter forwarded to the underlying implementations.
+        **kwargs
+            Parameter forwarded to the underlying implementations.
+
+        Returns
+        -------
+            Self.
+        """
+
+        super().to(*args, **kwargs)
+
+        self._augmentation_transforms = torchvision.transforms.Compose(
+            [k.to(*args, **kwargs) for k in self._augmentation_transforms.transforms]
+        )
+
+        assert self._train_loss is not None
+        self._train_loss.to(*args, **kwargs)
+        assert self._validation_loss is not None
+        self._validation_loss.to(*args, **kwargs)
+
+        return self
+
     def balance_losses(self, datamodule) -> None:
         """Balance the loss based on the distribution of targets in the datamodule, if the loss supports it (contains a 'pos_weight' attribute).
 
-- 
GitLab