From eb4848f60cd9d7d37572e6e070f73206a558ccb7 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 26 Jun 2024 13:48:42 +0200
Subject: [PATCH] [typing] Redefine Sample type in each lib

---
 .../classification/config/data/montgomery/datamodule.py  | 2 +-
 .../classification/config/data/nih_cxr14/datamodule.py   | 2 +-
 .../classification/config/data/padchest/datamodule.py    | 2 +-
 .../classification/config/data/shenzhen/datamodule.py    | 2 +-
 .../libs/classification/config/data/tbpoc/datamodule.py  | 2 +-
 .../libs/classification/config/data/tbx11k/datamodule.py | 2 +-
 src/mednet/libs/classification/data/typing.py            | 5 ++++-
 .../libs/classification/engine/saliency/completeness.py  | 2 +-
 src/mednet/libs/common/data/typing.py                    | 4 +++-
 .../libs/segmentation/config/data/avdrive/datamodule.py  | 2 +-
 .../libs/segmentation/config/data/chasedb1/datamodule.py | 2 +-
 .../libs/segmentation/config/data/cxr8/datamodule.py     | 2 +-
 .../libs/segmentation/config/data/drhagis/datamodule.py  | 2 +-
 .../libs/segmentation/config/data/drionsdb/datamodule.py | 2 +-
 .../segmentation/config/data/drishtigs1/datamodule.py    | 2 +-
 .../libs/segmentation/config/data/drive/datamodule.py    | 2 +-
 .../libs/segmentation/config/data/hrf/datamodule.py      | 2 +-
 .../libs/segmentation/config/data/iostar/datamodule.py   | 2 +-
 .../libs/segmentation/config/data/jsrt/datamodule.py     | 2 +-
 .../segmentation/config/data/montgomery/datamodule.py    | 2 +-
 .../libs/segmentation/config/data/refuge/datamodule.py   | 2 +-
 .../libs/segmentation/config/data/rimoner3/datamodule.py | 2 +-
 .../libs/segmentation/config/data/shenzhen/datamodule.py | 2 +-
 .../libs/segmentation/config/data/stare/datamodule.py    | 2 +-
 src/mednet/libs/segmentation/data/typing.py              | 9 ++++++++-
 25 files changed, 37 insertions(+), 25 deletions(-)

diff --git a/src/mednet/libs/classification/config/data/montgomery/datamodule.py b/src/mednet/libs/classification/config/data/montgomery/datamodule.py
index 4a61df7d..52f23320 100644
--- a/src/mednet/libs/classification/config/data/montgomery/datamodule.py
+++ b/src/mednet/libs/classification/config/data/montgomery/datamodule.py
@@ -13,10 +13,10 @@ import PIL.Image
 from mednet.libs.classification.data.typing import (
     ClassificationRawDataLoader as _ClassificationRawDataLoader,
 )
+from mednet.libs.classification.data.typing import Sample
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.image_utils import remove_black_borders
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.utils.rc import load_rc
 from torchvision.transforms.functional import to_tensor
 
diff --git a/src/mednet/libs/classification/config/data/nih_cxr14/datamodule.py b/src/mednet/libs/classification/config/data/nih_cxr14/datamodule.py
index 7712b4fe..719a78ad 100644
--- a/src/mednet/libs/classification/config/data/nih_cxr14/datamodule.py
+++ b/src/mednet/libs/classification/config/data/nih_cxr14/datamodule.py
@@ -13,9 +13,9 @@ import PIL.Image
 from mednet.libs.classification.data.typing import (
     ClassificationRawDataLoader as _ClassificationRawDataLoader,
 )
+from mednet.libs.classification.data.typing import Sample
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.utils.rc import load_rc
 from torchvision.transforms.functional import to_tensor
 
diff --git a/src/mednet/libs/classification/config/data/padchest/datamodule.py b/src/mednet/libs/classification/config/data/padchest/datamodule.py
index cc9bd734..20f6e221 100644
--- a/src/mednet/libs/classification/config/data/padchest/datamodule.py
+++ b/src/mednet/libs/classification/config/data/padchest/datamodule.py
@@ -14,10 +14,10 @@ import PIL.Image
 from mednet.libs.classification.data.typing import (
     ClassificationRawDataLoader as _ClassificationRawDataLoader,
 )
+from mednet.libs.classification.data.typing import Sample
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.image_utils import remove_black_borders
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.utils.rc import load_rc
 from torchvision.transforms.functional import to_tensor
 
diff --git a/src/mednet/libs/classification/config/data/shenzhen/datamodule.py b/src/mednet/libs/classification/config/data/shenzhen/datamodule.py
index ba2ccc5c..255d70b0 100644
--- a/src/mednet/libs/classification/config/data/shenzhen/datamodule.py
+++ b/src/mednet/libs/classification/config/data/shenzhen/datamodule.py
@@ -13,10 +13,10 @@ import PIL.Image
 from mednet.libs.classification.data.typing import (
     ClassificationRawDataLoader as _ClassificationRawDataLoader,
 )
+from mednet.libs.classification.data.typing import Sample
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.image_utils import remove_black_borders
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.utils.rc import load_rc
 from torchvision.transforms.functional import to_tensor
 
diff --git a/src/mednet/libs/classification/config/data/tbpoc/datamodule.py b/src/mednet/libs/classification/config/data/tbpoc/datamodule.py
index 8f46b3be..7d9c9bd8 100644
--- a/src/mednet/libs/classification/config/data/tbpoc/datamodule.py
+++ b/src/mednet/libs/classification/config/data/tbpoc/datamodule.py
@@ -9,10 +9,10 @@ import PIL.Image
 from mednet.libs.classification.data.typing import (
     ClassificationRawDataLoader as _ClassificationRawDataLoader,
 )
+from mednet.libs.classification.data.typing import Sample
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.image_utils import remove_black_borders
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.utils.rc import load_rc
 from torchvision.transforms.functional import to_tensor
 
diff --git a/src/mednet/libs/classification/config/data/tbx11k/datamodule.py b/src/mednet/libs/classification/config/data/tbx11k/datamodule.py
index 9cef49b1..f120a2af 100644
--- a/src/mednet/libs/classification/config/data/tbx11k/datamodule.py
+++ b/src/mednet/libs/classification/config/data/tbx11k/datamodule.py
@@ -13,9 +13,9 @@ import typing_extensions
 from mednet.libs.classification.data.typing import (
     ClassificationRawDataLoader as _ClassificationRawDataLoader,
 )
+from mednet.libs.classification.data.typing import Sample
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.utils.rc import load_rc
 from torch.utils.data._utils.collate import default_collate_fn_map
 from torchvision.transforms.functional import to_tensor
diff --git a/src/mednet/libs/classification/data/typing.py b/src/mednet/libs/classification/data/typing.py
index 0d18220e..10dc375b 100644
--- a/src/mednet/libs/classification/data/typing.py
+++ b/src/mednet/libs/classification/data/typing.py
@@ -1,6 +1,9 @@
 import typing
 
-from mednet.libs.common.data.typing import RawDataLoader, Sample
+import torch
+from mednet.libs.common.data.typing import RawDataLoader
+
+Sample: typing.TypeAlias = tuple[torch.Tensor, typing.Mapping[str, typing.Any]]
 
 
 class ClassificationRawDataLoader(RawDataLoader):
diff --git a/src/mednet/libs/classification/engine/saliency/completeness.py b/src/mednet/libs/classification/engine/saliency/completeness.py
index 2eac5429..f632afee 100644
--- a/src/mednet/libs/classification/engine/saliency/completeness.py
+++ b/src/mednet/libs/classification/engine/saliency/completeness.py
@@ -11,7 +11,7 @@ import lightning.pytorch
 import numpy as np
 import torch
 import tqdm
-from mednet.libs.common.data.typing import Sample
+from mednet.libs.classification.data.typing import Sample
 from mednet.libs.common.engine.device import DeviceManager
 from pytorch_grad_cam.metrics.road import (
     ROADLeastRelevantFirstAverage,
diff --git a/src/mednet/libs/common/data/typing.py b/src/mednet/libs/common/data/typing.py
index 521c8026..455cb1b3 100644
--- a/src/mednet/libs/common/data/typing.py
+++ b/src/mednet/libs/common/data/typing.py
@@ -9,7 +9,9 @@ import typing
 import torch
 import torch.utils.data
 
-Sample: typing.TypeAlias = tuple[torch.Tensor, typing.Mapping[str, typing.Any]]
+Sample: typing.TypeAlias = tuple[
+    torch.Tensor | typing.Mapping[str, torch.Tensor], typing.Mapping[str, typing.Any]
+]
 """Definition of a sample.
 
 First parameter
diff --git a/src/mednet/libs/segmentation/config/data/avdrive/datamodule.py b/src/mednet/libs/segmentation/config/data/avdrive/datamodule.py
index 83156a2c..5567d27c 100644
--- a/src/mednet/libs/segmentation/config/data/avdrive/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/avdrive/datamodule.py
@@ -9,9 +9,9 @@ import pathlib
 import PIL.Image
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.models.transforms import crop_image_to_mask
 from mednet.libs.common.utils.rc import load_rc
+from mednet.libs.segmentation.data.typing import Sample
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
diff --git a/src/mednet/libs/segmentation/config/data/chasedb1/datamodule.py b/src/mednet/libs/segmentation/config/data/chasedb1/datamodule.py
index 0cfab7ee..16a12166 100644
--- a/src/mednet/libs/segmentation/config/data/chasedb1/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/chasedb1/datamodule.py
@@ -10,9 +10,9 @@ import PIL.Image
 import pkg_resources
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.models.transforms import crop_image_to_mask
 from mednet.libs.common.utils.rc import load_rc
+from mednet.libs.segmentation.data.typing import Sample
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
diff --git a/src/mednet/libs/segmentation/config/data/cxr8/datamodule.py b/src/mednet/libs/segmentation/config/data/cxr8/datamodule.py
index acc6c83b..fee56bf0 100644
--- a/src/mednet/libs/segmentation/config/data/cxr8/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/cxr8/datamodule.py
@@ -11,8 +11,8 @@ import PIL.Image
 import torch
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.utils.rc import load_rc
+from mednet.libs.segmentation.data.typing import Sample
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
diff --git a/src/mednet/libs/segmentation/config/data/drhagis/datamodule.py b/src/mednet/libs/segmentation/config/data/drhagis/datamodule.py
index 153c0abe..1b0807dd 100644
--- a/src/mednet/libs/segmentation/config/data/drhagis/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/drhagis/datamodule.py
@@ -9,9 +9,9 @@ import pathlib
 import PIL.Image
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.models.transforms import crop_image_to_mask
 from mednet.libs.common.utils.rc import load_rc
+from mednet.libs.segmentation.data.typing import Sample
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
diff --git a/src/mednet/libs/segmentation/config/data/drionsdb/datamodule.py b/src/mednet/libs/segmentation/config/data/drionsdb/datamodule.py
index 41cc7fb8..5e7076af 100644
--- a/src/mednet/libs/segmentation/config/data/drionsdb/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/drionsdb/datamodule.py
@@ -11,9 +11,9 @@ import PIL.Image
 import pkg_resources
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.models.transforms import crop_image_to_mask
 from mednet.libs.common.utils.rc import load_rc
+from mednet.libs.segmentation.data.typing import Sample
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
diff --git a/src/mednet/libs/segmentation/config/data/drishtigs1/datamodule.py b/src/mednet/libs/segmentation/config/data/drishtigs1/datamodule.py
index 1c65e40d..2ddc1b83 100644
--- a/src/mednet/libs/segmentation/config/data/drishtigs1/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/drishtigs1/datamodule.py
@@ -10,9 +10,9 @@ import PIL.Image
 import pkg_resources
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.models.transforms import crop_image_to_mask
 from mednet.libs.common.utils.rc import load_rc
+from mednet.libs.segmentation.data.typing import Sample
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
diff --git a/src/mednet/libs/segmentation/config/data/drive/datamodule.py b/src/mednet/libs/segmentation/config/data/drive/datamodule.py
index 17629343..01639da4 100644
--- a/src/mednet/libs/segmentation/config/data/drive/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/drive/datamodule.py
@@ -9,9 +9,9 @@ import pathlib
 import PIL.Image
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.models.transforms import crop_image_to_mask
 from mednet.libs.common.utils.rc import load_rc
+from mednet.libs.segmentation.data.typing import Sample
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
diff --git a/src/mednet/libs/segmentation/config/data/hrf/datamodule.py b/src/mednet/libs/segmentation/config/data/hrf/datamodule.py
index 811ac24f..ed7ac302 100644
--- a/src/mednet/libs/segmentation/config/data/hrf/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/hrf/datamodule.py
@@ -9,9 +9,9 @@ import pathlib
 import PIL.Image
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.models.transforms import crop_image_to_mask
 from mednet.libs.common.utils.rc import load_rc
+from mednet.libs.segmentation.data.typing import Sample
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
diff --git a/src/mednet/libs/segmentation/config/data/iostar/datamodule.py b/src/mednet/libs/segmentation/config/data/iostar/datamodule.py
index ff67de9e..60283259 100644
--- a/src/mednet/libs/segmentation/config/data/iostar/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/iostar/datamodule.py
@@ -9,9 +9,9 @@ import pathlib
 import PIL.Image
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.models.transforms import crop_image_to_mask
 from mednet.libs.common.utils.rc import load_rc
+from mednet.libs.segmentation.data.typing import Sample
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
diff --git a/src/mednet/libs/segmentation/config/data/jsrt/datamodule.py b/src/mednet/libs/segmentation/config/data/jsrt/datamodule.py
index 60f185b5..ada65db9 100644
--- a/src/mednet/libs/segmentation/config/data/jsrt/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/jsrt/datamodule.py
@@ -12,8 +12,8 @@ import skimage.exposure
 import torch
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.utils.rc import load_rc
+from mednet.libs.segmentation.data.typing import Sample
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
diff --git a/src/mednet/libs/segmentation/config/data/montgomery/datamodule.py b/src/mednet/libs/segmentation/config/data/montgomery/datamodule.py
index 5792d941..663f9371 100644
--- a/src/mednet/libs/segmentation/config/data/montgomery/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/montgomery/datamodule.py
@@ -12,8 +12,8 @@ import pkg_resources
 import torch
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.utils.rc import load_rc
+from mednet.libs.segmentation.data.typing import Sample
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
diff --git a/src/mednet/libs/segmentation/config/data/refuge/datamodule.py b/src/mednet/libs/segmentation/config/data/refuge/datamodule.py
index a7bcd24d..696a204e 100644
--- a/src/mednet/libs/segmentation/config/data/refuge/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/refuge/datamodule.py
@@ -10,9 +10,9 @@ import PIL.Image
 import pkg_resources
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.models.transforms import crop_image_to_mask
 from mednet.libs.common.utils.rc import load_rc
+from mednet.libs.segmentation.data.typing import Sample
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
diff --git a/src/mednet/libs/segmentation/config/data/rimoner3/datamodule.py b/src/mednet/libs/segmentation/config/data/rimoner3/datamodule.py
index e4e9e581..81574441 100644
--- a/src/mednet/libs/segmentation/config/data/rimoner3/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/rimoner3/datamodule.py
@@ -10,8 +10,8 @@ import PIL.Image
 import pkg_resources
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.utils.rc import load_rc
+from mednet.libs.segmentation.data.typing import Sample
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
diff --git a/src/mednet/libs/segmentation/config/data/shenzhen/datamodule.py b/src/mednet/libs/segmentation/config/data/shenzhen/datamodule.py
index f0be322a..a65f3239 100644
--- a/src/mednet/libs/segmentation/config/data/shenzhen/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/shenzhen/datamodule.py
@@ -10,9 +10,9 @@ import PIL.Image
 import torch
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.models.transforms import crop_image_to_mask
 from mednet.libs.common.utils.rc import load_rc
+from mednet.libs.segmentation.data.typing import Sample
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
diff --git a/src/mednet/libs/segmentation/config/data/stare/datamodule.py b/src/mednet/libs/segmentation/config/data/stare/datamodule.py
index 9c824782..06bceb50 100644
--- a/src/mednet/libs/segmentation/config/data/stare/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/stare/datamodule.py
@@ -10,9 +10,9 @@ import PIL.Image
 import pkg_resources
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
-from mednet.libs.common.data.typing import Sample
 from mednet.libs.common.models.transforms import crop_image_to_mask
 from mednet.libs.common.utils.rc import load_rc
+from mednet.libs.segmentation.data.typing import Sample
 from mednet.libs.segmentation.data.typing import (
     SegmentationRawDataLoader as _SegmentationRawDataLoader,
 )
diff --git a/src/mednet/libs/segmentation/data/typing.py b/src/mednet/libs/segmentation/data/typing.py
index 626464fd..af27de40 100644
--- a/src/mednet/libs/segmentation/data/typing.py
+++ b/src/mednet/libs/segmentation/data/typing.py
@@ -1,4 +1,11 @@
-from mednet.libs.common.data.typing import RawDataLoader, Sample
+import typing
+
+import torch
+from mednet.libs.common.data.typing import RawDataLoader
+
+Sample: typing.TypeAlias = tuple[
+    typing.Mapping[str, torch.Tensor], typing.Mapping[str, typing.Any]
+]
 
 
 class SegmentationRawDataLoader(RawDataLoader):
-- 
GitLab