Skip to content
Snippets Groups Projects
Commit 2e62f3d7 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.augmentations] Remove elastic deformation parallelisation

parent 6f23727a
No related branches found
No related tags found
1 merge request!50Remove elastic deformation parallelisation
Pipeline #88084 passed
......@@ -14,7 +14,6 @@ input image.
import functools
import logging
import multiprocessing.pool
import random
import typing
......@@ -131,7 +130,6 @@ def _elastic_deformation_on_batch(
spline_order: int = 1,
mode: str = "nearest",
p: float = 1.0,
pool: multiprocessing.pool.Pool | None = None,
) -> torch.Tensor:
"""Perform elastic deformation on a batch of images.
......@@ -157,8 +155,6 @@ def _elastic_deformation_on_batch(
p
Probability that this transformation will be applied. Meaningful when
using it as a data augmentation technique.
pool
The multiprocessing pool to use.
Returns
-------
......@@ -178,11 +174,7 @@ def _elastic_deformation_on_batch(
# if a mp pool is available, do it in parallel
augmented_images: typing.Any
if pool is None:
augmented_images = map(partial, batch.cpu())
else:
augmented_images = pool.imap(partial, batch.cpu())
augmented_images = map(partial, batch.cpu())
return torch.stack(list(augmented_images))
......@@ -217,13 +209,6 @@ class ElasticDeformation:
p
Probability that this transformation will be applied. Meaningful when
using it as a data augmentation technique.
parallel
Use multiprocessing for processing batches of data: if set to -1
(default), disables multiprocessing. If set to -2, then enable
auto-tune (use the minimum value between the first batch size and total
number of processing cores). Set to 0 to enable as many processes as
processing cores available in the system. Set to >= 1 to enable that
many processes.
"""
def __init__(
......@@ -233,14 +218,12 @@ class ElasticDeformation:
spline_order: int = 1,
mode: str = "nearest",
p: float = 1.0,
parallel: int = -2,
):
self.alpha: float = alpha
self.sigma: float = sigma
self.spline_order: int = spline_order
self.mode: str = mode
self.p: float = p
self.parallel = parallel
def __str__(self) -> str:
parameters = [
......@@ -249,50 +232,11 @@ class ElasticDeformation:
f"spline_order={self.spline_order}",
f"mode={self.mode}",
f"p={self.p}",
f"parallel={self.parallel}",
]
return f"{type(self).__name__}({', '.join(parameters)})"
@property
def parallel(self) -> int:
"""Use multiprocessing for data augmentation.
If set to -1 (default), disables multiprocessing. If set to -2,
then enable auto-tune (use the minimum value between the first
batch size and total number of processing cores). Set to 0 to
enable as many processes as processing cores available in the
system. Set to >= 1 to enable that many processes.
Returns
-------
int
The multiprocessing type.
"""
return self._parallel
@parallel.setter
def parallel(self, value):
self._parallel = value
if value >= 0:
instances = value or multiprocessing.cpu_count()
logger.info(
f"Applying data-augmentation using {instances} processes...",
)
self._mp_pool = multiprocessing.get_context("spawn").Pool(instances)
else:
self._mp_pool = None
def __call__(self, img: torch.Tensor) -> torch.Tensor:
if len(img.shape) == 4:
if self._mp_pool is None and self._parallel == -2:
# auto-tunning on first batch
instances = min(img.shape[0], multiprocessing.cpu_count())
self._mp_pool = multiprocessing.get_context("spawn").Pool(
instances,
)
return _elastic_deformation_on_batch(
img,
self.alpha,
......@@ -300,7 +244,6 @@ class ElasticDeformation:
self.spline_order,
self.mode,
self.p,
self._mp_pool,
).to(img.device)
if len(img.shape) == 3:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment