We might be able to use torchvision's ElasticTransform instead of our custom ElasticDeformation implementation for data augmentation. It would be more robust and work on batches.
transform_default_parameters is the torchvision implementation with default parameters of alpha=50.0, sigma=5.0 and bilinear interpolation
transform_defined_parameters is the torchvision implementation with similar parameters as our custom implementation: alpha=1000.0, sigma=30.0, nearest interpolation
The important comparison is between transform_custom and transform_defined_parameters. The result is not identical but the effect is similar. It also seems the torch implementation has some aliasing.
Using the torchvision implementation in an environment built from the dev-profile on a machine with cuda support fails with
RuntimeError: [enforce fail at alloc_cpu.cpp:75] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 60902342656 bytes. Error code 12 (Cannot allocate memory).
Full traceback
Traceback (most recent call last): File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/bin/ptbench", line 8, in <module> sys.exit(cli()) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/click/core.py", line 1130, in __call__ return self.main(*args, **kwargs) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/click/core.py", line 1055, in main rv = self.invoke(ctx) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/click/core.py", line 1657, in invoke return _process_result(sub_ctx.command.invoke(sub_ctx)) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/click/core.py", line 1404, in invoke return ctx.invoke(self.callback, **ctx.params) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/click/core.py", line 760, in invoke return __callback(*args, **kwargs) File "/remote/idiap.svm/user.active/dcarron/ptbench/src/ptbench/scripts/train.py", line 283, in train run( File "/remote/idiap.svm/user.active/dcarron/ptbench/src/ptbench/engine/trainer.py", line 282, in run _ = trainer.fit(model, datamodule, ckpt_path=checkpoint) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 531, in fit call._call_and_handle_interrupt( File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 42, in _call_and_handle_interrupt return trainer_fn(*args, **kwargs) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 570, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 975, in _run results = self._run_stage() File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1018, in _run_stage self.fit_loop.run() File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 201, in run self.advance() File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 354, in advance self.epoch_loop.run(self._data_fetcher) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 133, in run self.advance(data_fetcher) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 218, in advance batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 185, in run self._optimizer_step(kwargs.get("batch_idx", 0), closure) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 260, in _optimizer_step call._call_lightning_module_hook( File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 140, in _call_lightning_module_hook output = fn(*args, **kwargs) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1256, in optimizer_step optimizer.step(closure=optimizer_closure) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py", line 155, in step step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 225, in optimizer_step return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py", line 114, in optimizer_step return optimizer.step(closure=closure, **kwargs) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/torch/optim/optimizer.py", line 280, in wrapper out = func(*args, **kwargs) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/torch/optim/optimizer.py", line 33, in _use_grad ret = func(self, *args, **kwargs) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/torch/optim/adam.py", line 121, in step loss = closure() File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py", line 101, in _wrap_closure closure_result = closure() File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 140, in __call__ self._result = self.closure(*args, **kwargs) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 126, in closure step_output = self._step_fn() File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 307, in _training_step training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values()) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 287, in _call_strategy_hook output = fn(*args, **kwargs) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 367, in training_step return self.model.training_step(*args, **kwargs) File "/remote/idiap.svm/user.active/dcarron/ptbench/src/ptbench/models/pasa.py", line 199, in training_step augmented_images = [ File "/remote/idiap.svm/user.active/dcarron/ptbench/src/ptbench/models/pasa.py", line 200, in <listcomp> self.augmentation_transforms(img).to(self.device) for img in images File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/torchvision/transforms/transforms.py", line 95, in __call__ img = t(img) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/torchvision/transforms/v2/_transform.py", line 40, in forward params = self._get_params( File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/torchvision/transforms/v2/_geometry.py", line 1083, in _get_params dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma)) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/torchvision/transforms/v2/functional/_misc.py", line 175, in gaussian_blur return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) File "/idiap/temp/dcarron/miniconda3-202011/envs/ptbench-gpu/lib/python3.10/site-packages/torchvision/transforms/v2/functional/_misc.py", line 140, in gaussian_blur_image_tensor output = conv2d(output, kernel, groups=shape[-3])RuntimeError: [enforce fail at alloc_cpu.cpp:75] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 60902342656 bytes. Error code 12 (Cannot allocate memory)
This happens both when running on CPU or GPU.
It works fine on a conda environment built without cuda support, albeit slow to compute.
Not sure if this is an issue with my particular environment, with our dev-profile requirements, or somewhere in torch.
The elastic deformation augmentation is the reason of slow execution time during training. The throughput on an MPS backend with and without Elastic Deformation is the following:
With only ElasticDeformation as data augmentation (p=0.8): 3 batches/second
With only ElasticDeformation as data augmentation (p=0.1): 11 batches/second
Without any data augmentation: 20 batches/second
For testing, we can either remove the deformation or reduce its probability to 10%, it will substantially speed-up testing without affecting the functionality.
I suspect that the reason for your issues @dcarron, is that we're bottlenecking the processing of transforms on the models.
All of the transforms should be able to handle batches of images. As of now, our own ElasticDeformation does not and, as a consequence, we break-down the batches into images on each model, then reassemble them back into an appropriate tensor. I wonder if we correctly implement ElasticDeformation to handle batches and remove that custom code, if your problem persists.
This issue may also account for the slowness you observed.
Please also note that, for now, ElasticDeformation is not scriptable therefore we have to handle it on the CPU during the training loop, which of course is a bottleneck. The point of having the augmentations close to the model was to be able to move them into the target torch device. However, if we can't, we need to figure out another way to "parallelise" it.
@dcarron: The commit above (66f7051a), implements a cleaner way to handle data augmentations, that will create a more solid base for the future:
I modified the ElasticDeformation to properly work with images or batches. Incidentally, I also think I made it do less copies, which sped it up a (tiny) bit.
I simplified all models to account for that factor, which BTW, makes them future-ready while allowing torch to maximally optimise throughput. I believe this should fix your reported issue above. (Please re-test your use-case above so we can be sure the issue is gone.)
I updated the test units concerning the elastic transformation and visually checked it is working with the test image.
As a bonus, I implemented (optional) parallelism (off by default), when ElasticDeformation has to handle batches. You can configure a number of workers when you create the transform:
ElasticDeformation(p=0.8,parallel=4)
We could make this "automatic" by waiting for the first batch to come, and configure the number of parallel workers to the the minimum between total number of cores and batch size. For now, it is not. If you pass -1, as in other cases, turns this feature completely off. Our configuration comes with the feature OFF by default.
With these changes, and running with parallel=4 on my machine, gives me a rough 2x speed boost. If I run on device=mps, I get a 3x speed boost. I'll test with the automated tuning of this flag tomorrow.
Very well. Let's leave this open and keep surveilling fixes on this. Did you report this on the pytorch bug tracker? We may be should, and then cross-reference here.
Not yet, I want to make sure it's not an issue with my environment first.
Also that comparison table is not accurate as I left p=0.8 and torchvision applies the transform to all images. I'll re-run experiments after I'm done with the tensorboard to csv conversion.
Here is a more correct table. Values are the times in seconds to train the first epoch, taken from epoch-duration-seconds/train in the logs.
"custom" is our implementation ElasticDeformation(p=1)
"torchvision" is torchvision.transforms.ElasticTransform(alpha=1000.0, sigma=30.0,interpolation=InterpolationMode.NEAREST). The parameters correspond to our defaults.
Comparing cached vs non-cached datasets is not relevant for this issue but I included the values anyway as it can be interesting.
torchvision
custom
pasa_cpu
226
226
pasa_cuda:0
?
222
pasa_cpu_cached
120
120
pasa_cuda:0_cached
?
116
alexnet_cpu
236
251
alexnet_cuda:0
?
212
alexnet_cpu_cached
130
145
alexnet_cuda:0_cached
?
106
It appears that training pasa (greyscale) takes the same amount of time using either implementation. Training alexnet (rgb) is ~15 seconds faster with the torchvision implementation.
I retested this after a few torchvision iterations. The transform provided by that package continues to be quite slow, for a similar effect. Here is the program I used that obtains the images below:
These 3 images contain a 512x512 reference black-and-white pattern, the result of running mender's ElasticDeformation implementation, and the result of running torch vision's ElasticTransform on the pattern. We use the program above.
The generation times for a 512x512 pixel pattern grid are: