GPU-offloading is quite sub-optimal
I think I just discovered why our GPU utilisation is far from optimal: both losses (and now data augmentations since !48 (merged)) were never allocated on the right accelerator device.
The following extra piece of code, added to the central Model
class seems to fix the problem:
diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py
index d109e3c..f9cad35 100644
--- a/src/mednet/models/model.py
+++ b/src/mednet/models/model.py
@@ -161,6 +161,22 @@ class Model(pl.LightningModule):
**self._optimizer_arguments,
)
+ def to(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Self:
+ """Move model, augmentations and losses to specified device."""
+
+ super().to(*args, **kwargs)
+
+ self._augmentation_transforms = torchvision.transforms.Compose([
+ k.to(*args, **kwargs) for k in self._augmentation_transforms.transforms
+ ])
+
+ assert self._train_loss is not None
+ self._train_loss.to(*args, **kwargs)
+ assert self._validation_loss is not None
+ self._validation_loss.to(*args, **kwargs)
+
+ return self
+
def balance_losses(self, datamodule) -> None:
"""Balance the loss based on the distribution of targets in the datamodule, if the loss supports it (contains a 'pos_weight' attribute).
I ran the following tests on my M1-Ultra MacBook:
# Tests without augmentations
\rm -rf results && time mednet train -vv montgomery pasa --batch-size=8 --cache-samples --parallel=6 --epochs=100 --monitoring-interval=1 --device=cpu
...
306.25s user 126.45s system 387% cpu 1:51.79 total
\rm -rf results && time mednet train -vv montgomery pasa --batch-size=8 --cache-samples --parallel=6 --epochs=100 --monitoring-interval=1 --device=mps
112.14s user 38.24s system 245% cpu 1:01.18 total
# Tests with a lot of augmentations
rm -rf results && time mednet train -vv montgomery pasa hflip-jitter-affine --batch-size=8 --cache-samples --parallel=6 --epochs=100 --monitoring-interval=1 --device=cpu
...
349.79s user 157.25s system 400% cpu 2:06.50 total
rm -rf results && time mednet train -vv montgomery pasa hflip-jitter-affine --batch-size=8 --cache-samples --parallel=6 --epochs=100 --monitoring-interval=1 --device=mps
...
115.87s user 43.17s system 219% cpu 1:12.39 total
We probably never realised this because our implementation of the elastic transformation is done on the CPU and very sub-optimal. However, when using torchvision available transforms, the time on my laptop goes down to a third of the original processing time! CPU utilisation also goes down indicating the processing is off-loaded to the GPU. I'll execute more tests, but this seems to considerably improve the situation.