diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py
index d109e3cc090e445eedd81fd37bcad2306fec1024..651a783f75c983749d086fbe0e95f0ba5358a0d5 100644
--- a/src/mednet/models/model.py
+++ b/src/mednet/models/model.py
@@ -161,6 +161,36 @@ class Model(pl.LightningModule):
             **self._optimizer_arguments,
         )
 
+    def to(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Self:
+        """Move model, augmentations and losses to specified device.
+
+        Refer to the method :py:meth:`torch.nn.Module.to` for details.
+
+        Parameters
+        ----------
+        *args
+            Parameter forwarded to the underlying implementations.
+        **kwargs
+            Parameter forwarded to the underlying implementations.
+
+        Returns
+        -------
+            Self.
+        """
+
+        super().to(*args, **kwargs)
+
+        self._augmentation_transforms = torchvision.transforms.Compose(
+            [k.to(*args, **kwargs) for k in self._augmentation_transforms.transforms]
+        )
+
+        assert self._train_loss is not None
+        self._train_loss.to(*args, **kwargs)
+        assert self._validation_loss is not None
+        self._validation_loss.to(*args, **kwargs)
+
+        return self
+
     def balance_losses(self, datamodule) -> None:
         """Balance the loss based on the distribution of targets in the datamodule, if the loss supports it (contains a 'pos_weight' attribute).