diff --git a/src/ptbench/configs/models/alexnet.py b/src/ptbench/configs/models/alexnet.py index 815226b517142438d77db25e69f6f4e173cee39c..0028e810ed34658b754ea0fe89b75bf64faae4a0 100644 --- a/src/ptbench/configs/models/alexnet.py +++ b/src/ptbench/configs/models/alexnet.py @@ -7,7 +7,7 @@ from torch.nn import BCEWithLogitsLoss from torch.optim import SGD -from ...data.transforms import ElasticDeformation +from ...data.augmentations import ElasticDeformation from ...models.alexnet import Alexnet model = Alexnet( diff --git a/src/ptbench/configs/models/alexnet_pretrained.py b/src/ptbench/configs/models/alexnet_pretrained.py index f968df50cda171cc94991febc511168d111517c9..9c772f42e935b2b995030e6d73fcc044fdcbcc76 100644 --- a/src/ptbench/configs/models/alexnet_pretrained.py +++ b/src/ptbench/configs/models/alexnet_pretrained.py @@ -7,7 +7,7 @@ from torch.nn import BCEWithLogitsLoss from torch.optim import SGD -from ...data.transforms import ElasticDeformation +from ...data.augmentations import ElasticDeformation from ...models.alexnet import Alexnet model = Alexnet( diff --git a/src/ptbench/configs/models/densenet.py b/src/ptbench/configs/models/densenet.py index 79f8f7dabc58746c1029bbc9760f10137801c202..5d453a6d4ff75514c0e7669a01b43be3e1aa3473 100644 --- a/src/ptbench/configs/models/densenet.py +++ b/src/ptbench/configs/models/densenet.py @@ -7,7 +7,7 @@ from torch.nn import BCEWithLogitsLoss from torch.optim import Adam -from ...data.transforms import ElasticDeformation +from ...data.augmentations import ElasticDeformation from ...models.densenet import Densenet model = Densenet( diff --git a/src/ptbench/configs/models/densenet_pretrained.py b/src/ptbench/configs/models/densenet_pretrained.py index 4bc4616c6de0a19134646a4ad1449c2920be9e50..49b0162f5fb1defd3bfc7b64bcc390c9e2bd0c27 100644 --- a/src/ptbench/configs/models/densenet_pretrained.py +++ b/src/ptbench/configs/models/densenet_pretrained.py @@ -7,7 +7,7 @@ from torch.nn import BCEWithLogitsLoss from torch.optim import Adam -from ...data.transforms import ElasticDeformation +from ...data.augmentations import ElasticDeformation from ...models.densenet import Densenet model = Densenet( diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py index d1e1b0a3ae8d9e3e32a7ec19a49e21f01bb694d9..b1a201e5b2d1f5a1b22135ee46ea148d8f655536 100644 --- a/src/ptbench/configs/models/pasa.py +++ b/src/ptbench/configs/models/pasa.py @@ -14,7 +14,7 @@ Reference: [PASA-2019]_ from torch.nn import BCEWithLogitsLoss from torch.optim import Adam -from ...data.transforms import ElasticDeformation +from ...data.augmentations import ElasticDeformation from ...models.pasa import Pasa model = Pasa( diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/augmentations.py similarity index 76% rename from src/ptbench/data/transforms.py rename to src/ptbench/data/augmentations.py index 4e61c4d5f8d46351f83ecf405ad0dcb3de8a4cee..a0e20c57458b78039c465bab4ec7c044b706b745 100644 --- a/src/ptbench/data/transforms.py +++ b/src/ptbench/data/augmentations.py @@ -69,6 +69,13 @@ def _elastic_deformation_on_image( p Probability that this transformation will be applied. Meaningful when using it as a data augmentation technique. + + + Returns + ------- + + tensor + A tensor on the CPU. """ if random.random() < p: @@ -122,7 +129,7 @@ def _elastic_deformation_on_image( ).reshape(img_shape) ) - # wraps numpy array as tensor (with no copy) + # wraps numpy array as tensor, move to destination device if need-be return torch.as_tensor(output) return img @@ -137,6 +144,47 @@ def _elastic_deformation_on_batch( p: float = 1.0, pool: multiprocessing.pool.Pool | None = None, ) -> torch.Tensor: + """Performs elastic deformation on a batch of images. + + This implementation is based on 2 scipy functions + (:py:func:`scipy.ndimage.gaussian_filter` and + :py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it + requires data is moved off the current running device and then back. + + + Parameters + ---------- + + img + The input image to apply elastic deformation at. This image should + always have this shape: ``[C, H, W]``. It should always represent a + tensor on the CPU. + + alpha + A multiplier for the gaussian filter outputs + + sigma + Standard deviation for Gaussian kernel. + + spline_order + The order of the spline interpolation, default is 1. The order has to + be in the range 0-5. + + mode + The mode parameter determines how the input array is extended beyond + its boundaries. + + p + Probability that this transformation will be applied. Meaningful when + using it as a data augmentation technique. + + + Returns + ------- + + tensor + A tensor on the CPU. + """ # transforms our custom functions into simpler callables partial = functools.partial( _elastic_deformation_on_image, @@ -154,7 +202,7 @@ def _elastic_deformation_on_batch( else: augmented_images = pool.imap(partial, batch.cpu()) - return torch.stack(list(augmented_images)).to(batch.device) + return torch.stack(list(augmented_images)) class ElasticDeformation: @@ -196,9 +244,11 @@ class ElasticDeformation: parallel Use multiprocessing for processing batches of data: if set to -1 - (default), disables multiprocessing. Set to 0 to enable as many - processes as processing cores as available in the system. Set to >= 1 - to enable that many processes. + (default), disables multiprocessing. If set to -2, then enable + auto-tune (use the minimum value between the first batch size and total + number of processing cores). Set to 0 to enable as many processes as + processing cores as available in the system. Set to >= 1 to enable that + many processes. """ def __init__( @@ -208,7 +258,7 @@ class ElasticDeformation: spline_order: int = 1, mode: str = "nearest", p: float = 1.0, - parallel: int = -1, + parallel: int = -2, ): self.alpha: float = alpha self.sigma: float = sigma @@ -221,10 +271,11 @@ class ElasticDeformation: def parallel(self): """Use multiprocessing for data augmentation. - If set to -1 (default), disables multiprocessing data - augmentation. Set to 0 to enable as many data loading instances - as processing cores as available in the system. Set to >= 1 to - enable that many multiprocessing instances for data loading. + If set to -1 (default), disables multiprocessing. If set to -2, + then enable auto-tune (use the minimum value between the first + batch size and total number of processing cores). Set to 0 to + enable as many processes as processing cores as available in the + system. Set to >= 1 to enable that many processes. """ return self._parallel @@ -243,6 +294,11 @@ class ElasticDeformation: def __call__(self, img: torch.Tensor) -> torch.Tensor: if len(img.shape) == 4: + if self._mp_pool is None and self._parallel == -2: + # auto-tunning on first batch + instances = min(img.shape[0], multiprocessing.cpu_count()) + self._mp_pool = multiprocessing.pool.Pool(instances) + return _elastic_deformation_on_batch( img, self.alpha, @@ -251,7 +307,8 @@ class ElasticDeformation: self.mode, self.p, self._mp_pool, - ) + ).to(img.device) + elif len(img.shape) == 3: return _elastic_deformation_on_image( img.cpu(), diff --git a/src/ptbench/data/hivtb/__init__.py b/src/ptbench/data/hivtb/__init__.py index b5d2753c0ba191b42bcaf8d66c104e36cb4dd839..88401da0f431df46325f0d122b1d4eac7a908d33 100644 --- a/src/ptbench/data/hivtb/__init__.py +++ b/src/ptbench/data/hivtb/__init__.py @@ -62,7 +62,7 @@ json_dataset = JSONDataset( def _maker(protocol, resize_size=512, cc_size=512, RGB=False): from torchvision import transforms - from ..transforms import ElasticDeformation, RemoveBlackBorders + from ..augmentations import ElasticDeformation, RemoveBlackBorders post_transforms = [] if RGB: diff --git a/src/ptbench/data/indian/__init__.py b/src/ptbench/data/indian/__init__.py index 72d7567f7087b4646b3e300b5dcf31161af4980f..5255c783daab1ca36b7f184d36339789a96ffd1d 100644 --- a/src/ptbench/data/indian/__init__.py +++ b/src/ptbench/data/indian/__init__.py @@ -62,7 +62,8 @@ json_dataset = JSONDataset( def _maker(protocol, resize_size=512, cc_size=512, RGB=False): from torchvision import transforms - from ..transforms import ElasticDeformation, RemoveBlackBorders + from ..augmentations import ElasticDeformation + from ..image_utils import RemoveBlackBorders post_transforms = [] if RGB: diff --git a/src/ptbench/data/tbpoc/__init__.py b/src/ptbench/data/tbpoc/__init__.py index 02cd488080873ff11feabf895571ca25af113fca..6108b2fba5cad8611a8d6bd6c38467ee8a869807 100644 --- a/src/ptbench/data/tbpoc/__init__.py +++ b/src/ptbench/data/tbpoc/__init__.py @@ -62,7 +62,8 @@ json_dataset = JSONDataset( def _maker(protocol, resize_size=512, cc_size=512, RGB=False): from torchvision import transforms - from ..transforms import ElasticDeformation, RemoveBlackBorders + from ..augmentations import ElasticDeformation + from ..image_utils import RemoveBlackBorders post_transforms = [] if RGB: diff --git a/src/ptbench/data/tbx11k_simplified/__init__.py b/src/ptbench/data/tbx11k_simplified/__init__.py index 3080b63a3427ddb92af82e88c896802d24b6fa24..7e66abc346f88faaeab7f0e61637b1767723aa3f 100644 --- a/src/ptbench/data/tbx11k_simplified/__init__.py +++ b/src/ptbench/data/tbx11k_simplified/__init__.py @@ -94,7 +94,7 @@ def _maker(protocol, RGB=False): from torchvision import transforms from .. import make_dataset - from ..transforms import ElasticDeformation + from ..augmentations import ElasticDeformation post_transforms = [] if RGB: diff --git a/src/ptbench/data/tbx11k_simplified_v2/__init__.py b/src/ptbench/data/tbx11k_simplified_v2/__init__.py index d6075e85d366510d8670af362a2b685167e3c8de..57323beed9475e3f0c0992fd8ecb42b6fd3bb1cb 100644 --- a/src/ptbench/data/tbx11k_simplified_v2/__init__.py +++ b/src/ptbench/data/tbx11k_simplified_v2/__init__.py @@ -94,7 +94,7 @@ def _maker(protocol, RGB=False): from torchvision import transforms from .. import make_dataset - from ..transforms import ElasticDeformation + from ..augmentations import ElasticDeformation post_transforms = [] if RGB: diff --git a/tests/test_tranforms.py b/tests/test_tranforms.py index 0dcdf5d64b473f4fccecb72db0baf01d1b67bbd1..c02c80d077e1e43f73451b01d32a9e9066067934 100644 --- a/tests/test_tranforms.py +++ b/tests/test_tranforms.py @@ -8,7 +8,7 @@ import numpy import PIL.Image import torchvision.transforms.functional as F -from ptbench.data.transforms import ElasticDeformation +from ptbench.data.augmentations import ElasticDeformation def test_elastic_deformation(datadir):