From 252156bb93b71ca967dba3b90da126177287fc21 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Fri, 26 Apr 2024 10:17:53 +0200
Subject: [PATCH] [dataloader] Create specialized RawDataLoaders per task

---
 .gitignore                                    |  1 +
 .../config/data/hivtb/datamodule.py           | 12 ++++--
 .../config/data/indian/datamodule.py          | 12 +++---
 .../config/data/montgomery/datamodule.py      | 10 +++--
 .../data/montgomery_shenzhen/datamodule.py    |  6 ++-
 .../montgomery_shenzhen_indian/datamodule.py  |  8 ++--
 .../datamodule.py                             | 10 +++--
 .../datamodule.py                             | 10 +++--
 .../config/data/nih_cxr14/datamodule.py       | 12 ++++--
 .../data/nih_cxr14_padchest/datamodule.py     |  5 +--
 .../config/data/padchest/datamodule.py        | 12 ++++--
 .../config/data/shenzhen/datamodule.py        | 12 ++++--
 .../config/data/tbpoc/datamodule.py           | 12 ++++--
 .../config/data/tbx11k/datamodule.py          |  8 ++--
 .../libs/classification/data/__init__.py      |  0
 src/mednet/libs/classification/data/typing.py | 41 +++++++++++++++++++
 src/mednet/libs/common/data/typing.py         | 20 ---------
 .../config/data/drive/datamodule.py           |  8 ++--
 src/mednet/libs/segmentation/data/typing.py   | 21 ++++++++++
 19 files changed, 150 insertions(+), 70 deletions(-)
 create mode 100644 src/mednet/libs/classification/data/__init__.py
 create mode 100644 src/mednet/libs/classification/data/typing.py
 create mode 100644 src/mednet/libs/segmentation/data/typing.py

diff --git a/.gitignore b/.gitignore
index 2ba7edec..f8dcb218 100644
--- a/.gitignore
+++ b/.gitignore
@@ -22,3 +22,4 @@ cache/
 .venv
 .pixi
 results*/
+environment.yaml
diff --git a/src/mednet/libs/classification/config/data/hivtb/datamodule.py b/src/mednet/libs/classification/config/data/hivtb/datamodule.py
index 19083250..5d8170b7 100644
--- a/src/mednet/libs/classification/config/data/hivtb/datamodule.py
+++ b/src/mednet/libs/classification/config/data/hivtb/datamodule.py
@@ -10,11 +10,15 @@ import os
 import pathlib
 
 import PIL.Image
+from mednet.libs.classification.data.typing import (
+    ClassificationRawDataLoader as _ClassificationRawDataLoader,
+)
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.image_utils import remove_black_borders
+
 from mednet.libs.common.data.split import make_split
 from mednet.libs.common.data.typing import Sample
-from mednet.libs.common.data.typing import RawDataLoader as _BaseRawDataLoader
+
 from torchvision.transforms.functional import to_tensor
 
 from ....utils.rc import load_rc
@@ -24,7 +28,7 @@ CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
 database."""
 
 
-class RawDataLoader(_BaseRawDataLoader):
+class ClassificationRawDataLoader(_ClassificationRawDataLoader):
     """A specialized raw-data-loader for the HIV-TB dataset."""
 
     datadir: pathlib.Path
@@ -123,7 +127,7 @@ class DataModule(CachingDataModule):
         assert __package__ is not None
         super().__init__(
             database_split=make_split(__package__, split_filename),
-            raw_data_loader=RawDataLoader(),
-            database_name=__package__.rsplit(".", 1)[1],
+            raw_data_loader=_ClassificationRawDataLoader(),
+            database_name=__package__.rsplit(".", 1)[1],            
             split_name=pathlib.Path(split_filename).stem,
         )
diff --git a/src/mednet/libs/classification/config/data/indian/datamodule.py b/src/mednet/libs/classification/config/data/indian/datamodule.py
index 6b4f9c77..63ff00c2 100644
--- a/src/mednet/libs/classification/config/data/indian/datamodule.py
+++ b/src/mednet/libs/classification/config/data/indian/datamodule.py
@@ -11,8 +11,9 @@ import pathlib
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
 
-from ....config.data.shenzhen.datamodule import RawDataLoader
-
+from mednet.libs.classification.data.typing import (
+    ClassificationRawDataLoader as _ClassificationRawDataLoader,
+)
 CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
 """Key to search for in the configuration file for the root directory of this
 database."""
@@ -60,10 +61,11 @@ class DataModule(CachingDataModule):
     """
 
     def __init__(self, split_filename: str):
-        assert __package__
+        assert __package__ is not None
         super().__init__(
             database_split=make_split(__package__, split_filename),
-            raw_data_loader=RawDataLoader(config_variable=CONFIGURATION_KEY_DATADIR),
-            database_name=__package__.rsplit(".", 1)[1],
+            raw_data_loader=_ClassificationRawDataLoader(),
+            database_name=__package__.rsplit(".", 1)[1],            
             split_name=pathlib.Path(split_filename).stem,
         )
+
diff --git a/src/mednet/libs/classification/config/data/montgomery/datamodule.py b/src/mednet/libs/classification/config/data/montgomery/datamodule.py
index ecdd12a8..83753205 100644
--- a/src/mednet/libs/classification/config/data/montgomery/datamodule.py
+++ b/src/mednet/libs/classification/config/data/montgomery/datamodule.py
@@ -10,11 +10,15 @@ import os
 import pathlib
 
 import PIL.Image
+from mednet.libs.classification.data.typing import (
+    ClassificationRawDataLoader as _ClassificationRawDataLoader,
+)
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.image_utils import remove_black_borders
+
 from mednet.libs.common.data.split import make_split
 from mednet.libs.common.data.typing import Sample
-from mednet.libs.common.data.typing import RawDataLoader as _BaseRawDataLoader
+
 from torchvision.transforms.functional import to_tensor
 
 from ....utils.rc import load_rc
@@ -136,7 +140,7 @@ class DataModule(CachingDataModule):
         assert __package__ is not None
         super().__init__(
             database_split=make_split(__package__, split_filename),
-            raw_data_loader=RawDataLoader(),
-            database_name=__package__.rsplit(".", 1)[1],
+            raw_data_loader=_ClassificationRawDataLoader(),
+            database_name=__package__.rsplit(".", 1)[1],            
             split_name=pathlib.Path(split_filename).stem,
         )
diff --git a/src/mednet/libs/classification/config/data/montgomery_shenzhen/datamodule.py b/src/mednet/libs/classification/config/data/montgomery_shenzhen/datamodule.py
index f3bf9814..cce90f7a 100644
--- a/src/mednet/libs/classification/config/data/montgomery_shenzhen/datamodule.py
+++ b/src/mednet/libs/classification/config/data/montgomery_shenzhen/datamodule.py
@@ -8,8 +8,10 @@ import pathlib
 from mednet.libs.common.data.datamodule import ConcatDataModule
 from mednet.libs.common.data.split import make_split
 
-from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader
-from ..shenzhen.datamodule import RawDataLoader as ShenzhenLoader
+from ..montgomery.datamodule import (
+    ClassificationRawDataLoader as MontgomeryLoader,
+)
+from ..shenzhen.datamodule import ClassificationRawDataLoader as ShenzhenLoader
 
 
 class DataModule(ConcatDataModule):
diff --git a/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian/datamodule.py b/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian/datamodule.py
index 73b6d765..b09fa338 100644
--- a/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian/datamodule.py
+++ b/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian/datamodule.py
@@ -12,9 +12,11 @@ from mednet.libs.common.data.split import make_split
 
 from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR
 from ..indian.datamodule import DataModule as IndianDataModule
-from ..indian.datamodule import RawDataLoader as IndianLoader
-from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader
-from ..shenzhen.datamodule import RawDataLoader as ShenzhenLoader
+from ..indian.datamodule import ClassificationRawDataLoader as IndianLoader
+from ..montgomery.datamodule import (
+    ClassificationRawDataLoader as MontgomeryLoader,
+)
+from ..shenzhen.datamodule import ClassificationRawDataLoader as ShenzhenLoader
 
 
 class DataModule(ConcatDataModule):
diff --git a/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian_padchest/datamodule.py b/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian_padchest/datamodule.py
index 942326b2..5daef812 100644
--- a/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian_padchest/datamodule.py
+++ b/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian_padchest/datamodule.py
@@ -11,11 +11,13 @@ from mednet.libs.common.data.datamodule import ConcatDataModule
 from mednet.libs.common.data.split import make_split
 
 from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR
+from ..indian.datamodule import ClassificationRawDataLoader as IndianLoader
 from ..indian.datamodule import DataModule as IndianDataModule
-from ..indian.datamodule import RawDataLoader as IndianLoader
-from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader
-from ..padchest.datamodule import RawDataLoader as PadchestLoader
-from ..shenzhen.datamodule import RawDataLoader as ShenzhenLoader
+from ..montgomery.datamodule import (
+    ClassificationRawDataLoader as MontgomeryLoader,
+)
+from ..padchest.datamodule import ClassificationRawDataLoader as PadchestLoader
+from ..shenzhen.datamodule import ClassificationRawDataLoader as ShenzhenLoader
 
 
 class DataModule(ConcatDataModule):
diff --git a/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py b/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py
index daf575f7..53294527 100644
--- a/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py
+++ b/src/mednet/libs/classification/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py
@@ -12,10 +12,12 @@ from mednet.libs.common.data.split import make_split
 
 from ..indian.datamodule import CONFIGURATION_KEY_DATADIR as INDIAN_KEY_DATADIR
 from ..indian.datamodule import DataModule as IndianDataModule
-from ..indian.datamodule import RawDataLoader as IndianLoader
-from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader
-from ..shenzhen.datamodule import RawDataLoader as ShenzhenLoader
-from ..tbx11k.datamodule import RawDataLoader as TBX11kLoader
+from ..indian.datamodule import ClassificationRawDataLoader as IndianLoader
+from ..montgomery.datamodule import (
+    ClassificationRawDataLoader as MontgomeryLoader,
+)
+from ..shenzhen.datamodule import ClassificationRawDataLoader as ShenzhenLoader
+from ..tbx11k.datamodule import ClassificationRawDataLoader as TBX11kLoader
 
 
 class DataModule(ConcatDataModule):
diff --git a/src/mednet/libs/classification/config/data/nih_cxr14/datamodule.py b/src/mednet/libs/classification/config/data/nih_cxr14/datamodule.py
index 10b02268..900dcc9d 100644
--- a/src/mednet/libs/classification/config/data/nih_cxr14/datamodule.py
+++ b/src/mednet/libs/classification/config/data/nih_cxr14/datamodule.py
@@ -10,10 +10,14 @@ import os
 import pathlib
 
 import PIL.Image
+from mednet.libs.classification.data.typing import (
+    ClassificationRawDataLoader as _ClassificationRawDataLoader,
+)
 from mednet.libs.common.data.datamodule import CachingDataModule
+
 from mednet.libs.common.data.split import make_split
 from mednet.libs.common.data.typing import Sample
-from mednet.libs.common.data.typing import RawDataLoader as _BaseRawDataLoader
+
 from torchvision.transforms.functional import to_tensor
 
 from ....utils.rc import load_rc
@@ -35,7 +39,7 @@ different folder structure, that was adapted to Idiap's requirements
 """
 
 
-class RawDataLoader(_BaseRawDataLoader):
+class ClassificationRawDataLoader(_ClassificationRawDataLoader):
     """A specialized raw-data-loader for the NIH CXR-14 dataset."""
 
     datadir: pathlib.Path
@@ -172,7 +176,7 @@ class DataModule(CachingDataModule):
         assert __package__ is not None
         super().__init__(
             database_split=make_split(__package__, split_filename),
-            raw_data_loader=RawDataLoader(),
-            database_name=__package__.rsplit(".", 1)[1],
+            raw_data_loader=_ClassificationRawDataLoader(),
+            database_name=__package__.rsplit(".", 1)[1],            
             split_name=pathlib.Path(split_filename).stem,
         )
diff --git a/src/mednet/libs/classification/config/data/nih_cxr14_padchest/datamodule.py b/src/mednet/libs/classification/config/data/nih_cxr14_padchest/datamodule.py
index beeb3d96..d5a44ce6 100644
--- a/src/mednet/libs/classification/config/data/nih_cxr14_padchest/datamodule.py
+++ b/src/mednet/libs/classification/config/data/nih_cxr14_padchest/datamodule.py
@@ -8,9 +8,8 @@ import pathlib
 from mednet.libs.common.data.datamodule import ConcatDataModule
 from mednet.libs.common.data.split import make_split
 
-from ..nih_cxr14.datamodule import RawDataLoader as CXR14Loader
-from ..padchest.datamodule import RawDataLoader as PadchestLoader
-
+from ..nih_cxr14.datamodule import ClassificationRawDataLoader as CXR14Loader
+from ..padchest.datamodule import ClassificationRawDataLoader as PadchestLoader
 
 class DataModule(ConcatDataModule):
     """Aggregated dataset composed of NIH CXR14 relabeld and PadChest
diff --git a/src/mednet/libs/classification/config/data/padchest/datamodule.py b/src/mednet/libs/classification/config/data/padchest/datamodule.py
index 8ea0eb7c..0fed49ab 100644
--- a/src/mednet/libs/classification/config/data/padchest/datamodule.py
+++ b/src/mednet/libs/classification/config/data/padchest/datamodule.py
@@ -11,11 +11,15 @@ import pathlib
 
 import numpy
 import PIL.Image
+from mednet.libs.classification.data.typing import (
+    ClassificationRawDataLoader as _ClassificationRawDataLoader,
+)
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.image_utils import remove_black_borders
+
 from mednet.libs.common.data.split import make_split
 from mednet.libs.common.data.typing import Sample
-from mednet.libs.common.data.typing import RawDataLoader as _BaseRawDataLoader
+
 from torchvision.transforms.functional import to_tensor
 
 from ....utils.rc import load_rc
@@ -25,7 +29,7 @@ CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
 database."""
 
 
-class RawDataLoader(_BaseRawDataLoader):
+class ClassificationRawDataLoader(_ClassificationRawDataLoader):
     """A specialized raw-data-loader for the PadChest dataset."""
 
     datadir: pathlib.Path
@@ -325,7 +329,7 @@ class DataModule(CachingDataModule):
         assert __package__ is not None
         super().__init__(
             database_split=make_split(__package__, split_filename),
-            raw_data_loader=RawDataLoader(),
-            database_name=__package__.rsplit(".", 1)[1],
+            raw_data_loader=_ClassificationRawDataLoader(),
+            database_name=__package__.rsplit(".", 1)[1],            
             split_name=pathlib.Path(split_filename).stem,
         )
diff --git a/src/mednet/libs/classification/config/data/shenzhen/datamodule.py b/src/mednet/libs/classification/config/data/shenzhen/datamodule.py
index f68851b6..7ab4d171 100644
--- a/src/mednet/libs/classification/config/data/shenzhen/datamodule.py
+++ b/src/mednet/libs/classification/config/data/shenzhen/datamodule.py
@@ -10,11 +10,15 @@ import os
 import pathlib
 
 import PIL.Image
+from mednet.libs.classification.data.typing import (
+    ClassificationRawDataLoader as _ClassificationRawDataLoader,
+)
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.image_utils import remove_black_borders
+
 from mednet.libs.common.data.split import make_split
 from mednet.libs.common.data.typing import Sample
-from mednet.libs.common.data.typing import RawDataLoader as _BaseRawDataLoader
+
 from torchvision.transforms.functional import to_tensor
 
 from ....utils.rc import load_rc
@@ -24,7 +28,7 @@ CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
 database."""
 
 
-class RawDataLoader(_BaseRawDataLoader):
+class ClassificationRawDataLoader(_ClassificationRawDataLoader):
     """A specialized raw-data-loader for the Shenzhen dataset.
 
     Parameters
@@ -135,7 +139,7 @@ class DataModule(CachingDataModule):
         assert __package__ is not None
         super().__init__(
             database_split=make_split(__package__, split_filename),
-            raw_data_loader=RawDataLoader(),
-            database_name=__package__.rsplit(".", 1)[1],
+            raw_data_loader=_ClassificationRawDataLoader(),
+            database_name=__package__.rsplit(".", 1)[1],            
             split_name=pathlib.Path(split_filename).stem,
         )
diff --git a/src/mednet/libs/classification/config/data/tbpoc/datamodule.py b/src/mednet/libs/classification/config/data/tbpoc/datamodule.py
index 122f6686..2ba9f688 100644
--- a/src/mednet/libs/classification/config/data/tbpoc/datamodule.py
+++ b/src/mednet/libs/classification/config/data/tbpoc/datamodule.py
@@ -6,11 +6,15 @@ import os
 import pathlib
 
 import PIL.Image
+from mednet.libs.classification.data.typing import (
+    ClassificationRawDataLoader as _ClassificationRawDataLoader,
+)
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.image_utils import remove_black_borders
+
 from mednet.libs.common.data.split import make_split
 from mednet.libs.common.data.typing import Sample
-from mednet.libs.common.data.typing import RawDataLoader as _BaseRawDataLoader
+
 from torchvision.transforms.functional import to_tensor
 
 from ....utils.rc import load_rc
@@ -20,7 +24,7 @@ CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
 database."""
 
 
-class RawDataLoader(_BaseRawDataLoader):
+class ClassificationRawDataLoader(_ClassificationRawDataLoader):
     """A specialized raw-data-loader for the Shenzen dataset."""
 
     datadir: pathlib.Path
@@ -121,7 +125,7 @@ class DataModule(CachingDataModule):
         assert __package__ is not None
         super().__init__(
             database_split=make_split(__package__, split_filename),
-            raw_data_loader=RawDataLoader(),
-            database_name=__package__.rsplit(".", 1)[1],
+            raw_data_loader=_ClassificationRawDataLoader(),
+            database_name=__package__.rsplit(".", 1)[1],            
             split_name=pathlib.Path(split_filename).stem,
         )
diff --git a/src/mednet/libs/classification/config/data/tbx11k/datamodule.py b/src/mednet/libs/classification/config/data/tbx11k/datamodule.py
index b633d8e9..fb860307 100644
--- a/src/mednet/libs/classification/config/data/tbx11k/datamodule.py
+++ b/src/mednet/libs/classification/config/data/tbx11k/datamodule.py
@@ -10,10 +10,12 @@ import typing
 
 import PIL.Image
 import typing_extensions
+from mednet.libs.classification.data.typing import (
+    ClassificationRawDataLoader as _ClassificationRawDataLoader,
+)
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import make_split
 from mednet.libs.common.data.typing import Sample
-from mednet.libs.common.data.typing import RawDataLoader as _BaseRawDataLoader
 from torch.utils.data._utils.collate import default_collate_fn_map
 from torchvision.transforms.functional import to_tensor
 
@@ -167,7 +169,7 @@ finding locations, as described above.
 """
 
 
-class RawDataLoader(_BaseRawDataLoader):
+class ClassificationRawDataLoader(_ClassificationRawDataLoader):
     """A specialized raw-data-loader for the TBX11k dataset.
 
     Parameters
@@ -377,7 +379,7 @@ class DataModule(CachingDataModule):
         assert __package__ is not None
         super().__init__(
             database_split=make_split(__package__, split_filename),
-            raw_data_loader=RawDataLoader(ignore_bboxes=ignore_bboxes),
+            raw_data_loader=ClassificationRawDataLoader(ignore_bboxes=ignore_bboxes),
             database_name=__package__.rsplit(".", 1)[1],
             split_name=pathlib.Path(split_filename).stem,
         )
diff --git a/src/mednet/libs/classification/data/__init__.py b/src/mednet/libs/classification/data/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/mednet/libs/classification/data/typing.py b/src/mednet/libs/classification/data/typing.py
new file mode 100644
index 00000000..8cd21405
--- /dev/null
+++ b/src/mednet/libs/classification/data/typing.py
@@ -0,0 +1,41 @@
+import typing
+
+from mednet.libs.common.data.typing import RawDataLoader, Sample
+
+
+class ClassificationRawDataLoader(RawDataLoader):
+    """A loader object can load samples and labels from storage for classification tasks."""
+
+    def __init__(self):
+        super().__init__()
+
+    def sample(self, _: typing.Any) -> Sample:
+        """Load whole samples from media.
+
+        Parameters
+        ----------
+         _
+            Information about the sample to load. Implementation dependent.
+        """
+
+        raise NotImplementedError("You must implement the `sample()` method")
+
+    def label(self, k: typing.Any) -> int | list[int]:
+        """Load only sample label from media.
+
+        If you do not override this implementation, then, by default,
+        this method will call :py:meth:`sample` to load the whole sample
+        and extract the label.
+
+        Parameters
+        ----------
+        k
+            The sample to load. This is implementation-dependent.
+
+        Returns
+        -------
+        int | list[int]
+            The label corresponding to the specified sample.
+        """
+
+        return self.sample(k)[1]["label"]
diff --git a/src/mednet/libs/common/data/typing.py b/src/mednet/libs/common/data/typing.py
index 93bbfc97..48cce9e7 100644
--- a/src/mednet/libs/common/data/typing.py
+++ b/src/mednet/libs/common/data/typing.py
@@ -35,26 +35,6 @@ class RawDataLoader:
 
         raise NotImplementedError("You must implement the `sample()` method")
 
-    def label(self, k: typing.Any) -> int | list[int]:
-        """Load only sample label from media.
-
-        If you do not override this implementation, then, by default,
-        this method will call :py:meth:`sample` to load the whole sample
-        and extract the label.
-
-        Parameters
-        ----------
-        k
-            The sample to load. This is implementation-dependent.
-
-        Returns
-        -------
-        int | list[int]
-            The label corresponding to the specified sample.
-        """
-
-        return self.sample(k)[1]["label"]
-
 
 Transform: typing.TypeAlias = typing.Callable[[torch.Tensor], torch.Tensor]
 """A callable that transforms tensors into (other) tensors.
diff --git a/src/mednet/libs/segmentation/config/data/drive/datamodule.py b/src/mednet/libs/segmentation/config/data/drive/datamodule.py
index b61941a5..2ef7d1cb 100644
--- a/src/mednet/libs/segmentation/config/data/drive/datamodule.py
+++ b/src/mednet/libs/segmentation/config/data/drive/datamodule.py
@@ -11,7 +11,9 @@ import PIL.Image
 from mednet.libs.common.data.datamodule import CachingDataModule
 from mednet.libs.common.data.split import JSONDatabaseSplit
 from mednet.libs.common.data.typing import DatabaseSplit, Sample
-from mednet.libs.common.data.typing import RawDataLoader as _BaseRawDataLoader
+from mednet.libs.segmentation.data.typing import (
+    SegmentationRawDataLoader as _SegmentationRawDataLoader,
+)
 from torchvision import tv_tensors
 from torchvision.transforms.functional import to_tensor
 
@@ -22,7 +24,7 @@ CONFIGURATION_KEY_DATADIR = "datadir." + (__name__.rsplit(".", 2)[-2])
 database."""
 
 
-class RawDataLoader(_BaseRawDataLoader):
+class SegmentationRawDataLoader(_SegmentationRawDataLoader):
     """A specialized raw-data-loader for the Montgomery dataset."""
 
     datadir: str
@@ -116,5 +118,5 @@ class DataModule(CachingDataModule):
     def __init__(self, split_filename: str):
         super().__init__(
             database_split=make_split(split_filename),
-            raw_data_loader=RawDataLoader(),
+            raw_data_loader=SegmentationRawDataLoader(),
         )
diff --git a/src/mednet/libs/segmentation/data/typing.py b/src/mednet/libs/segmentation/data/typing.py
new file mode 100644
index 00000000..e1d95cdb
--- /dev/null
+++ b/src/mednet/libs/segmentation/data/typing.py
@@ -0,0 +1,21 @@
+import typing
+
+from mednet.libs.common.data.typing import RawDataLoader, Sample
+
+
+class SegmentationRawDataLoader(RawDataLoader):
+    """A loader object can load samples and labels from storage for classification tasks."""
+
+    def __init__(self):
+        super().__init__()
+
+    def sample(self, _: typing.Any) -> Sample:
+        """Load whole samples from media.
+
+        Parameters
+        ----------
+         _
+            Information about the sample to load. Implementation dependent.
+        """
+
+        raise NotImplementedError("You must implement the `sample()` method")
-- 
GitLab