Skip to content
Snippets Groups Projects
Commit 2e62f3d7 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.augmentations] Remove elastic deformation parallelisation

parent 6f23727a
Branches
Tags
1 merge request!50Remove elastic deformation parallelisation
Pipeline #88084 passed
...@@ -14,7 +14,6 @@ input image. ...@@ -14,7 +14,6 @@ input image.
import functools import functools
import logging import logging
import multiprocessing.pool
import random import random
import typing import typing
...@@ -131,7 +130,6 @@ def _elastic_deformation_on_batch( ...@@ -131,7 +130,6 @@ def _elastic_deformation_on_batch(
spline_order: int = 1, spline_order: int = 1,
mode: str = "nearest", mode: str = "nearest",
p: float = 1.0, p: float = 1.0,
pool: multiprocessing.pool.Pool | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Perform elastic deformation on a batch of images. """Perform elastic deformation on a batch of images.
...@@ -157,8 +155,6 @@ def _elastic_deformation_on_batch( ...@@ -157,8 +155,6 @@ def _elastic_deformation_on_batch(
p p
Probability that this transformation will be applied. Meaningful when Probability that this transformation will be applied. Meaningful when
using it as a data augmentation technique. using it as a data augmentation technique.
pool
The multiprocessing pool to use.
Returns Returns
------- -------
...@@ -178,11 +174,7 @@ def _elastic_deformation_on_batch( ...@@ -178,11 +174,7 @@ def _elastic_deformation_on_batch(
# if a mp pool is available, do it in parallel # if a mp pool is available, do it in parallel
augmented_images: typing.Any augmented_images: typing.Any
if pool is None: augmented_images = map(partial, batch.cpu())
augmented_images = map(partial, batch.cpu())
else:
augmented_images = pool.imap(partial, batch.cpu())
return torch.stack(list(augmented_images)) return torch.stack(list(augmented_images))
...@@ -217,13 +209,6 @@ class ElasticDeformation: ...@@ -217,13 +209,6 @@ class ElasticDeformation:
p p
Probability that this transformation will be applied. Meaningful when Probability that this transformation will be applied. Meaningful when
using it as a data augmentation technique. using it as a data augmentation technique.
parallel
Use multiprocessing for processing batches of data: if set to -1
(default), disables multiprocessing. If set to -2, then enable
auto-tune (use the minimum value between the first batch size and total
number of processing cores). Set to 0 to enable as many processes as
processing cores available in the system. Set to >= 1 to enable that
many processes.
""" """
def __init__( def __init__(
...@@ -233,14 +218,12 @@ class ElasticDeformation: ...@@ -233,14 +218,12 @@ class ElasticDeformation:
spline_order: int = 1, spline_order: int = 1,
mode: str = "nearest", mode: str = "nearest",
p: float = 1.0, p: float = 1.0,
parallel: int = -2,
): ):
self.alpha: float = alpha self.alpha: float = alpha
self.sigma: float = sigma self.sigma: float = sigma
self.spline_order: int = spline_order self.spline_order: int = spline_order
self.mode: str = mode self.mode: str = mode
self.p: float = p self.p: float = p
self.parallel = parallel
def __str__(self) -> str: def __str__(self) -> str:
parameters = [ parameters = [
...@@ -249,50 +232,11 @@ class ElasticDeformation: ...@@ -249,50 +232,11 @@ class ElasticDeformation:
f"spline_order={self.spline_order}", f"spline_order={self.spline_order}",
f"mode={self.mode}", f"mode={self.mode}",
f"p={self.p}", f"p={self.p}",
f"parallel={self.parallel}",
] ]
return f"{type(self).__name__}({', '.join(parameters)})" return f"{type(self).__name__}({', '.join(parameters)})"
@property
def parallel(self) -> int:
"""Use multiprocessing for data augmentation.
If set to -1 (default), disables multiprocessing. If set to -2,
then enable auto-tune (use the minimum value between the first
batch size and total number of processing cores). Set to 0 to
enable as many processes as processing cores available in the
system. Set to >= 1 to enable that many processes.
Returns
-------
int
The multiprocessing type.
"""
return self._parallel
@parallel.setter
def parallel(self, value):
self._parallel = value
if value >= 0:
instances = value or multiprocessing.cpu_count()
logger.info(
f"Applying data-augmentation using {instances} processes...",
)
self._mp_pool = multiprocessing.get_context("spawn").Pool(instances)
else:
self._mp_pool = None
def __call__(self, img: torch.Tensor) -> torch.Tensor: def __call__(self, img: torch.Tensor) -> torch.Tensor:
if len(img.shape) == 4: if len(img.shape) == 4:
if self._mp_pool is None and self._parallel == -2:
# auto-tunning on first batch
instances = min(img.shape[0], multiprocessing.cpu_count())
self._mp_pool = multiprocessing.get_context("spawn").Pool(
instances,
)
return _elastic_deformation_on_batch( return _elastic_deformation_on_batch(
img, img,
self.alpha, self.alpha,
...@@ -300,7 +244,6 @@ class ElasticDeformation: ...@@ -300,7 +244,6 @@ class ElasticDeformation:
self.spline_order, self.spline_order,
self.mode, self.mode,
self.p, self.p,
self._mp_pool,
).to(img.device) ).to(img.device)
if len(img.shape) == 3: if len(img.shape) == 3:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment