diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py index d109e3cc090e445eedd81fd37bcad2306fec1024..651a783f75c983749d086fbe0e95f0ba5358a0d5 100644 --- a/src/mednet/models/model.py +++ b/src/mednet/models/model.py @@ -161,6 +161,36 @@ class Model(pl.LightningModule): **self._optimizer_arguments, ) + def to(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Self: + """Move model, augmentations and losses to specified device. + + Refer to the method :py:meth:`torch.nn.Module.to` for details. + + Parameters + ---------- + *args + Parameter forwarded to the underlying implementations. + **kwargs + Parameter forwarded to the underlying implementations. + + Returns + ------- + Self. + """ + + super().to(*args, **kwargs) + + self._augmentation_transforms = torchvision.transforms.Compose( + [k.to(*args, **kwargs) for k in self._augmentation_transforms.transforms] + ) + + assert self._train_loss is not None + self._train_loss.to(*args, **kwargs) + assert self._validation_loss is not None + self._validation_loss.to(*args, **kwargs) + + return self + def balance_losses(self, datamodule) -> None: """Balance the loss based on the distribution of targets in the datamodule, if the loss supports it (contains a 'pos_weight' attribute).