diff --git a/src/mednet/data/augmentations.py b/src/mednet/data/augmentations.py index 47fb0fac3b6b9439270e73e9a6793ec3514569d6..fcf8887d06b23bdf7f4f187651e9041b43d5b38e 100644 --- a/src/mednet/data/augmentations.py +++ b/src/mednet/data/augmentations.py @@ -14,7 +14,6 @@ input image. import functools import logging -import multiprocessing.pool import random import typing @@ -131,7 +130,6 @@ def _elastic_deformation_on_batch( spline_order: int = 1, mode: str = "nearest", p: float = 1.0, - pool: multiprocessing.pool.Pool | None = None, ) -> torch.Tensor: """Perform elastic deformation on a batch of images. @@ -157,8 +155,6 @@ def _elastic_deformation_on_batch( p Probability that this transformation will be applied. Meaningful when using it as a data augmentation technique. - pool - The multiprocessing pool to use. Returns ------- @@ -178,11 +174,7 @@ def _elastic_deformation_on_batch( # if a mp pool is available, do it in parallel augmented_images: typing.Any - if pool is None: - augmented_images = map(partial, batch.cpu()) - else: - augmented_images = pool.imap(partial, batch.cpu()) - + augmented_images = map(partial, batch.cpu()) return torch.stack(list(augmented_images)) @@ -217,13 +209,6 @@ class ElasticDeformation: p Probability that this transformation will be applied. Meaningful when using it as a data augmentation technique. - parallel - Use multiprocessing for processing batches of data: if set to -1 - (default), disables multiprocessing. If set to -2, then enable - auto-tune (use the minimum value between the first batch size and total - number of processing cores). Set to 0 to enable as many processes as - processing cores available in the system. Set to >= 1 to enable that - many processes. """ def __init__( @@ -233,14 +218,12 @@ class ElasticDeformation: spline_order: int = 1, mode: str = "nearest", p: float = 1.0, - parallel: int = -2, ): self.alpha: float = alpha self.sigma: float = sigma self.spline_order: int = spline_order self.mode: str = mode self.p: float = p - self.parallel = parallel def __str__(self) -> str: parameters = [ @@ -249,50 +232,11 @@ class ElasticDeformation: f"spline_order={self.spline_order}", f"mode={self.mode}", f"p={self.p}", - f"parallel={self.parallel}", ] return f"{type(self).__name__}({', '.join(parameters)})" - @property - def parallel(self) -> int: - """Use multiprocessing for data augmentation. - - If set to -1 (default), disables multiprocessing. If set to -2, - then enable auto-tune (use the minimum value between the first - batch size and total number of processing cores). Set to 0 to - enable as many processes as processing cores available in the - system. Set to >= 1 to enable that many processes. - - Returns - ------- - int - The multiprocessing type. - """ - - return self._parallel - - @parallel.setter - def parallel(self, value): - self._parallel = value - - if value >= 0: - instances = value or multiprocessing.cpu_count() - logger.info( - f"Applying data-augmentation using {instances} processes...", - ) - self._mp_pool = multiprocessing.get_context("spawn").Pool(instances) - else: - self._mp_pool = None - def __call__(self, img: torch.Tensor) -> torch.Tensor: if len(img.shape) == 4: - if self._mp_pool is None and self._parallel == -2: - # auto-tunning on first batch - instances = min(img.shape[0], multiprocessing.cpu_count()) - self._mp_pool = multiprocessing.get_context("spawn").Pool( - instances, - ) - return _elastic_deformation_on_batch( img, self.alpha, @@ -300,7 +244,6 @@ class ElasticDeformation: self.spline_order, self.mode, self.p, - self._mp_pool, ).to(img.device) if len(img.shape) == 3: