From 5f2b57a9c475b688e84bcf2a39b0647fcfe9e537 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Tue, 27 Feb 2024 13:12:13 +0100
Subject: [PATCH] [config.data] Implement attributes for database and split
 name for issue #60

---
 src/mednet/config/data/hivtb/datamodule.py        |  2 ++
 src/mednet/config/data/indian/datamodule.py       |  3 +++
 src/mednet/config/data/montgomery/datamodule.py   |  2 ++
 .../config/data/montgomery_shenzhen/datamodule.py |  7 ++++++-
 .../data/montgomery_shenzhen_indian/datamodule.py | 11 ++++++++---
 .../datamodule.py                                 | 10 +++++++++-
 .../datamodule.py                                 | 10 +++++++++-
 src/mednet/config/data/nih_cxr14/datamodule.py    |  2 ++
 .../config/data/nih_cxr14_padchest/datamodule.py  | 11 ++++++++++-
 src/mednet/config/data/padchest/datamodule.py     |  2 ++
 src/mednet/config/data/shenzhen/datamodule.py     |  2 ++
 src/mednet/config/data/tbpoc/datamodule.py        |  2 ++
 src/mednet/config/data/tbx11k/datamodule.py       |  2 ++
 src/mednet/data/datamodule.py                     | 15 ++++++++++++++-
 14 files changed, 73 insertions(+), 8 deletions(-)

diff --git a/src/mednet/config/data/hivtb/datamodule.py b/src/mednet/config/data/hivtb/datamodule.py
index 68a7b7a3..cae64f2c 100644
--- a/src/mednet/config/data/hivtb/datamodule.py
+++ b/src/mednet/config/data/hivtb/datamodule.py
@@ -141,4 +141,6 @@ class DataModule(CachingDataModule):
         super().__init__(
             database_split=make_split(split_filename),
             raw_data_loader=RawDataLoader(),
+            database_name=__package__.split(".")[-1],
+            split_name=os.path.splitext(split_filename)[0],
         )
diff --git a/src/mednet/config/data/indian/datamodule.py b/src/mednet/config/data/indian/datamodule.py
index 08a50722..2fc0567b 100644
--- a/src/mednet/config/data/indian/datamodule.py
+++ b/src/mednet/config/data/indian/datamodule.py
@@ -7,6 +7,7 @@ Database reference: [INDIAN-2013]_
 """
 
 import importlib.resources
+import os
 
 from ....config.data.shenzhen.datamodule import RawDataLoader
 from ....data.datamodule import CachingDataModule
@@ -82,4 +83,6 @@ class DataModule(CachingDataModule):
             raw_data_loader=RawDataLoader(
                 config_variable=CONFIGURATION_KEY_DATADIR
             ),
+            database_name=__package__.split(".")[-1],
+            split_name=os.path.splitext(split_filename)[0],
         )
diff --git a/src/mednet/config/data/montgomery/datamodule.py b/src/mednet/config/data/montgomery/datamodule.py
index 86e9fdb7..5ed7fa50 100644
--- a/src/mednet/config/data/montgomery/datamodule.py
+++ b/src/mednet/config/data/montgomery/datamodule.py
@@ -143,4 +143,6 @@ class DataModule(CachingDataModule):
         super().__init__(
             database_split=make_split(split_filename),
             raw_data_loader=RawDataLoader(),
+            database_name=__package__.split(".")[-1],
+            split_name=os.path.splitext(split_filename)[0],
         )
diff --git a/src/mednet/config/data/montgomery_shenzhen/datamodule.py b/src/mednet/config/data/montgomery_shenzhen/datamodule.py
index fa83fdde..6df353ad 100644
--- a/src/mednet/config/data/montgomery_shenzhen/datamodule.py
+++ b/src/mednet/config/data/montgomery_shenzhen/datamodule.py
@@ -1,6 +1,9 @@
 # Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+"""Aggregated DataModule composed of Montgomery and Shenzhen databases."""
+
+import os
 
 from ....data.datamodule import ConcatDataModule
 from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader
@@ -38,5 +41,7 @@ class DataModule(ConcatDataModule):
                     (montgomery_split["test"], montgomery_loader),
                     (shenzhen_split["test"], shenzhen_loader),
                 ],
-            }
+            },
+            database_name=__package__.split(".")[-1],
+            split_name=os.path.splitext(split_filename)[0],
         )
diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py
index 676fa8ef..0a0d497f 100644
--- a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py
+++ b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py
@@ -1,7 +1,9 @@
 # Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
-"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets."""
+"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian databases."""
+
+import os
 
 from ....data.datamodule import ConcatDataModule
 from ..indian.datamodule import RawDataLoader as IndianLoader
@@ -13,7 +15,8 @@ from ..shenzhen.datamodule import make_split as make_shenzhen_split
 
 
 class DataModule(ConcatDataModule):
-    """Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets.
+    """Aggregated DataModule composed of Montgomery, Shenzhen and Indian
+    datasets.
 
     Parameters
     ----------
@@ -46,5 +49,7 @@ class DataModule(ConcatDataModule):
                     (shenzhen_split["test"], shenzhen_loader),
                     (indian_split["test"], indian_loader),
                 ],
-            }
+            },
+            database_name=__package__.split(".")[-1],
+            split_name=os.path.splitext(split_filename)[0],
         )
diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py
index 2876af8f..5442c875 100644
--- a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py
+++ b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py
@@ -3,6 +3,8 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 """Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and PadChest datasets."""
 
+import os
+
 from ....data.datamodule import ConcatDataModule
 from ..indian.datamodule import RawDataLoader as IndianLoader
 from ..indian.datamodule import make_split as make_indian_split
@@ -57,5 +59,11 @@ class DataModule(ConcatDataModule):
                     (indian_split["test"], indian_loader),
                     (padchest_split["test"], padchest_loader),
                 ],
-            }
+            },
+            database_name=__package__.split(".")[-1],
+            split_name=(
+                os.path.splitext(split_filename)[0]
+                + "+"
+                + os.path.splitext(padchest_split_filename)[0]
+            ),
         )
diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py
index 8dd83198..ff0c3844 100644
--- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py
+++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py
@@ -3,6 +3,8 @@
 # SPDX-License-Identifier: GPL-3.0-or-later
 """Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets."""
 
+import os
+
 from ....data.datamodule import ConcatDataModule
 from ..indian.datamodule import RawDataLoader as IndianLoader
 from ..indian.datamodule import make_split as make_indian_split
@@ -56,5 +58,11 @@ class DataModule(ConcatDataModule):
                     (indian_split["test"], indian_loader),
                     (tbx11k_split["test"], tbx11k_loader),
                 ],
-            }
+            },
+            database_name=__package__.split(".")[-1],
+            split_name=(
+                os.path.splitext(split_filename)[0]
+                + "+"
+                + os.path.splitext(tbx11k_split_filename)[0]
+            ),
         )
diff --git a/src/mednet/config/data/nih_cxr14/datamodule.py b/src/mednet/config/data/nih_cxr14/datamodule.py
index 5967ee63..26596b74 100644
--- a/src/mednet/config/data/nih_cxr14/datamodule.py
+++ b/src/mednet/config/data/nih_cxr14/datamodule.py
@@ -192,4 +192,6 @@ class DataModule(CachingDataModule):
         super().__init__(
             database_split=make_split(split_filename),
             raw_data_loader=RawDataLoader(),
+            database_name=__package__.split(".")[-1],
+            split_name=os.path.splitext(split_filename)[0],
         )
diff --git a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py
index 2c793c79..6cc38340 100644
--- a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py
+++ b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py
@@ -1,6 +1,9 @@
 # Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
 #
 # SPDX-License-Identifier: GPL-3.0-or-later
+"""Aggregated DataModule composed of NIH-CXR-14 and PadChest databases."""
+
+import os
 
 from ....data.datamodule import ConcatDataModule
 from ..nih_cxr14.datamodule import RawDataLoader as CXR14Loader
@@ -42,5 +45,11 @@ class DataModule(ConcatDataModule):
                     # there is no test set on padchest
                     # (padchest_split["test"], padchest_loader),
                 ],
-            }
+            },
+            database_name=__package__.split(".")[-1],
+            split_name=(
+                os.path.splitext(cxr14_split_filename)[0]
+                + "+"
+                + os.path.splitext(padchest_split_filename)[0]
+            ),
         )
diff --git a/src/mednet/config/data/padchest/datamodule.py b/src/mednet/config/data/padchest/datamodule.py
index 778d505f..d146fc03 100644
--- a/src/mednet/config/data/padchest/datamodule.py
+++ b/src/mednet/config/data/padchest/datamodule.py
@@ -341,4 +341,6 @@ class DataModule(CachingDataModule):
         super().__init__(
             database_split=make_split(split_filename),
             raw_data_loader=RawDataLoader(),
+            database_name=__package__.split(".")[-1],
+            split_name=os.path.splitext(split_filename)[0],
         )
diff --git a/src/mednet/config/data/shenzhen/datamodule.py b/src/mednet/config/data/shenzhen/datamodule.py
index 6853ebe5..81e48f9b 100644
--- a/src/mednet/config/data/shenzhen/datamodule.py
+++ b/src/mednet/config/data/shenzhen/datamodule.py
@@ -155,4 +155,6 @@ class DataModule(CachingDataModule):
         super().__init__(
             database_split=make_split(split_filename),
             raw_data_loader=RawDataLoader(),
+            database_name=__package__.split(".")[-1],
+            split_name=os.path.splitext(split_filename)[0],
         )
diff --git a/src/mednet/config/data/tbpoc/datamodule.py b/src/mednet/config/data/tbpoc/datamodule.py
index 67846f6c..14b09e7f 100644
--- a/src/mednet/config/data/tbpoc/datamodule.py
+++ b/src/mednet/config/data/tbpoc/datamodule.py
@@ -136,4 +136,6 @@ class DataModule(CachingDataModule):
         super().__init__(
             database_split=make_split(split_filename),
             raw_data_loader=RawDataLoader(),
+            database_name=__package__.split(".")[-1],
+            split_name=os.path.splitext(split_filename)[0],
         )
diff --git a/src/mednet/config/data/tbx11k/datamodule.py b/src/mednet/config/data/tbx11k/datamodule.py
index 9c76bb5a..1735607e 100644
--- a/src/mednet/config/data/tbx11k/datamodule.py
+++ b/src/mednet/config/data/tbx11k/datamodule.py
@@ -355,4 +355,6 @@ class DataModule(CachingDataModule):
         super().__init__(
             database_split=make_split(split_filename),
             raw_data_loader=RawDataLoader(),
+            database_name=__package__.split(".")[-1],
+            split_name=os.path.splitext(split_filename)[0],
         )
diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py
index 4f79857b..71fead5a 100644
--- a/src/mednet/data/datamodule.py
+++ b/src/mednet/data/datamodule.py
@@ -458,6 +458,12 @@ class ConcatDataModule(lightning.LightningDataModule):
         Entries named ``monitor-...`` will be considered extra datasets that do
         not influence any early stop criteria during training, and are just
         monitored beyond the ``validation`` dataset.
+    database_name
+        The name of the database, or aggregated database containing the
+        raw-samples served by this data module.
+    split_name
+        The name of the split used to group the samples into the various
+        datasets for training, validation and testing.
     cache_samples
         If set, then issue raw data loading during ``prepare_data()``, and
         serves samples from CPU memory.  Otherwise, loads samples from disk on
@@ -510,6 +516,8 @@ class ConcatDataModule(lightning.LightningDataModule):
     def __init__(
         self,
         splits: ConcatDatabaseSplit,
+        database_name: str = "",
+        split_name: str = "",
         cache_samples: bool = False,
         balance_sampler_by_class: bool = False,
         batch_size: int = 1,
@@ -522,10 +530,15 @@ class ConcatDataModule(lightning.LightningDataModule):
         self.set_chunk_size(batch_size, batch_chunk_count)
 
         self.splits = splits
+        self.database_name = database_name
+        self.split_name = split_name
 
         for dataset_name, split_loaders in splits.items():
             count = sum([len(k) for k, _ in split_loaders])
-            logger.info(f"Dataset `{dataset_name}` contains {count} samples")
+            logger.info(
+                f"Dataset `{dataset_name}` (`{database_name}`/`{split_name}`) "
+                f"contains {count} samples"
+            )
 
         self.cache_samples = cache_samples
         self._train_sampler = None
-- 
GitLab