From 85390375ea2766d7bae4601ff4efb03a638e3fae Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Wed, 24 Jan 2024 17:42:17 +0100
Subject: [PATCH] [config.data.*] Remove center_crop from all raw data loaders
 to avoid lung cropping (closes #23)

---
 src/mednet/config/data/hivtb/datamodule.py      | 3 +--
 src/mednet/config/data/montgomery/datamodule.py | 3 +--
 src/mednet/config/data/padchest/datamodule.py   | 3 +--
 src/mednet/config/data/shenzhen/datamodule.py   | 3 +--
 src/mednet/config/data/tbpoc/datamodule.py      | 3 +--
 5 files changed, 5 insertions(+), 10 deletions(-)

diff --git a/src/mednet/config/data/hivtb/datamodule.py b/src/mednet/config/data/hivtb/datamodule.py
index d5bf8103..206a3ec6 100644
--- a/src/mednet/config/data/hivtb/datamodule.py
+++ b/src/mednet/config/data/hivtb/datamodule.py
@@ -12,7 +12,7 @@ import os
 
 import PIL.Image
 
-from torchvision.transforms.functional import center_crop, to_tensor
+from torchvision.transforms.functional import to_tensor
 
 from ....data.datamodule import CachingDataModule
 from ....data.image_utils import remove_black_borders
@@ -59,7 +59,6 @@ class RawDataLoader(_BaseRawDataLoader):
         )
         image = remove_black_borders(image)
         tensor = to_tensor(image)
-        tensor = center_crop(tensor, min(*tensor.shape[1:]))
 
         # use the code below to view generated images
         # from torchvision.transforms.functional import to_pil_image
diff --git a/src/mednet/config/data/montgomery/datamodule.py b/src/mednet/config/data/montgomery/datamodule.py
index ec1b7c14..ff8d21d7 100644
--- a/src/mednet/config/data/montgomery/datamodule.py
+++ b/src/mednet/config/data/montgomery/datamodule.py
@@ -11,7 +11,7 @@ import os
 
 import PIL.Image
 
-from torchvision.transforms.functional import center_crop, to_tensor
+from torchvision.transforms.functional import to_tensor
 
 from ....data.datamodule import CachingDataModule
 from ....data.image_utils import remove_black_borders
@@ -58,7 +58,6 @@ class RawDataLoader(_BaseRawDataLoader):
         image = PIL.Image.open(os.path.join(self.datadir, sample[0]))
         image = remove_black_borders(image)
         tensor = to_tensor(image)
-        tensor = center_crop(tensor, min(*tensor.shape[1:]))
 
         # use the code below to view generated images
         # from torchvision.transforms.functional import to_pil_image
diff --git a/src/mednet/config/data/padchest/datamodule.py b/src/mednet/config/data/padchest/datamodule.py
index a065dece..436193e8 100644
--- a/src/mednet/config/data/padchest/datamodule.py
+++ b/src/mednet/config/data/padchest/datamodule.py
@@ -12,7 +12,7 @@ import os
 import numpy
 import PIL.Image
 
-from torchvision.transforms.functional import center_crop, to_tensor
+from torchvision.transforms.functional import to_tensor
 
 from ....data.datamodule import CachingDataModule
 from ....data.image_utils import remove_black_borders
@@ -60,7 +60,6 @@ class RawDataLoader(_BaseRawDataLoader):
         image = remove_black_borders(image)
         array = numpy.array(image).astype(numpy.float32) / 65535
         tensor = to_tensor(array)
-        tensor = center_crop(tensor, min(*tensor.shape[1:]))
 
         # use the code below to view generated images
         # from torchvision.transforms.functional import to_pil_image
diff --git a/src/mednet/config/data/shenzhen/datamodule.py b/src/mednet/config/data/shenzhen/datamodule.py
index 64091794..9fe82d47 100644
--- a/src/mednet/config/data/shenzhen/datamodule.py
+++ b/src/mednet/config/data/shenzhen/datamodule.py
@@ -11,7 +11,7 @@ import os
 
 import PIL.Image
 
-from torchvision.transforms.functional import center_crop, to_tensor
+from torchvision.transforms.functional import to_tensor
 
 from ....data.datamodule import CachingDataModule
 from ....data.image_utils import remove_black_borders
@@ -62,7 +62,6 @@ class RawDataLoader(_BaseRawDataLoader):
         )
         image = remove_black_borders(image)
         tensor = to_tensor(image)
-        tensor = center_crop(tensor, min(*tensor.shape[1:]))
 
         # use the code below to view generated images
         # from torchvision.transforms.functional import to_pil_image
diff --git a/src/mednet/config/data/tbpoc/datamodule.py b/src/mednet/config/data/tbpoc/datamodule.py
index ffe59568..c0339747 100644
--- a/src/mednet/config/data/tbpoc/datamodule.py
+++ b/src/mednet/config/data/tbpoc/datamodule.py
@@ -7,7 +7,7 @@ import os
 
 import PIL.Image
 
-from torchvision.transforms.functional import center_crop, to_tensor
+from torchvision.transforms.functional import to_tensor
 
 from ....data.datamodule import CachingDataModule
 from ....data.image_utils import remove_black_borders
@@ -54,7 +54,6 @@ class RawDataLoader(_BaseRawDataLoader):
         image = PIL.Image.open(os.path.join(self.datadir, sample[0]))
         image = remove_black_borders(image)
         tensor = to_tensor(image)
-        tensor = center_crop(tensor, min(*tensor.shape[1:]))
 
         # use the code below to view generated images
         # from torchvision.transforms.functional import to_pil_image
-- 
GitLab