From 5f2b57a9c475b688e84bcf2a39b0647fcfe9e537 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Tue, 27 Feb 2024 13:12:13 +0100 Subject: [PATCH] [config.data] Implement attributes for database and split name for issue #60 --- src/mednet/config/data/hivtb/datamodule.py | 2 ++ src/mednet/config/data/indian/datamodule.py | 3 +++ src/mednet/config/data/montgomery/datamodule.py | 2 ++ .../config/data/montgomery_shenzhen/datamodule.py | 7 ++++++- .../data/montgomery_shenzhen_indian/datamodule.py | 11 ++++++++--- .../datamodule.py | 10 +++++++++- .../datamodule.py | 10 +++++++++- src/mednet/config/data/nih_cxr14/datamodule.py | 2 ++ .../config/data/nih_cxr14_padchest/datamodule.py | 11 ++++++++++- src/mednet/config/data/padchest/datamodule.py | 2 ++ src/mednet/config/data/shenzhen/datamodule.py | 2 ++ src/mednet/config/data/tbpoc/datamodule.py | 2 ++ src/mednet/config/data/tbx11k/datamodule.py | 2 ++ src/mednet/data/datamodule.py | 15 ++++++++++++++- 14 files changed, 73 insertions(+), 8 deletions(-) diff --git a/src/mednet/config/data/hivtb/datamodule.py b/src/mednet/config/data/hivtb/datamodule.py index 68a7b7a3..cae64f2c 100644 --- a/src/mednet/config/data/hivtb/datamodule.py +++ b/src/mednet/config/data/hivtb/datamodule.py @@ -141,4 +141,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/indian/datamodule.py b/src/mednet/config/data/indian/datamodule.py index 08a50722..2fc0567b 100644 --- a/src/mednet/config/data/indian/datamodule.py +++ b/src/mednet/config/data/indian/datamodule.py @@ -7,6 +7,7 @@ Database reference: [INDIAN-2013]_ """ import importlib.resources +import os from ....config.data.shenzhen.datamodule import RawDataLoader from ....data.datamodule import CachingDataModule @@ -82,4 +83,6 @@ class DataModule(CachingDataModule): raw_data_loader=RawDataLoader( config_variable=CONFIGURATION_KEY_DATADIR ), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/montgomery/datamodule.py b/src/mednet/config/data/montgomery/datamodule.py index 86e9fdb7..5ed7fa50 100644 --- a/src/mednet/config/data/montgomery/datamodule.py +++ b/src/mednet/config/data/montgomery/datamodule.py @@ -143,4 +143,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/montgomery_shenzhen/datamodule.py b/src/mednet/config/data/montgomery_shenzhen/datamodule.py index fa83fdde..6df353ad 100644 --- a/src/mednet/config/data/montgomery_shenzhen/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen/datamodule.py @@ -1,6 +1,9 @@ # Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later +"""Aggregated DataModule composed of Montgomery and Shenzhen databases.""" + +import os from ....data.datamodule import ConcatDataModule from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader @@ -38,5 +41,7 @@ class DataModule(ConcatDataModule): (montgomery_split["test"], montgomery_loader), (shenzhen_split["test"], shenzhen_loader), ], - } + }, + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py index 676fa8ef..0a0d497f 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py @@ -1,7 +1,9 @@ # Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets.""" +"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian databases.""" + +import os from ....data.datamodule import ConcatDataModule from ..indian.datamodule import RawDataLoader as IndianLoader @@ -13,7 +15,8 @@ from ..shenzhen.datamodule import make_split as make_shenzhen_split class DataModule(ConcatDataModule): - """Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets. + """Aggregated DataModule composed of Montgomery, Shenzhen and Indian + datasets. Parameters ---------- @@ -46,5 +49,7 @@ class DataModule(ConcatDataModule): (shenzhen_split["test"], shenzhen_loader), (indian_split["test"], indian_loader), ], - } + }, + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py index 2876af8f..5442c875 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and PadChest datasets.""" +import os + from ....data.datamodule import ConcatDataModule from ..indian.datamodule import RawDataLoader as IndianLoader from ..indian.datamodule import make_split as make_indian_split @@ -57,5 +59,11 @@ class DataModule(ConcatDataModule): (indian_split["test"], indian_loader), (padchest_split["test"], padchest_loader), ], - } + }, + database_name=__package__.split(".")[-1], + split_name=( + os.path.splitext(split_filename)[0] + + "+" + + os.path.splitext(padchest_split_filename)[0] + ), ) diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py index 8dd83198..ff0c3844 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets.""" +import os + from ....data.datamodule import ConcatDataModule from ..indian.datamodule import RawDataLoader as IndianLoader from ..indian.datamodule import make_split as make_indian_split @@ -56,5 +58,11 @@ class DataModule(ConcatDataModule): (indian_split["test"], indian_loader), (tbx11k_split["test"], tbx11k_loader), ], - } + }, + database_name=__package__.split(".")[-1], + split_name=( + os.path.splitext(split_filename)[0] + + "+" + + os.path.splitext(tbx11k_split_filename)[0] + ), ) diff --git a/src/mednet/config/data/nih_cxr14/datamodule.py b/src/mednet/config/data/nih_cxr14/datamodule.py index 5967ee63..26596b74 100644 --- a/src/mednet/config/data/nih_cxr14/datamodule.py +++ b/src/mednet/config/data/nih_cxr14/datamodule.py @@ -192,4 +192,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py index 2c793c79..6cc38340 100644 --- a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py +++ b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py @@ -1,6 +1,9 @@ # Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later +"""Aggregated DataModule composed of NIH-CXR-14 and PadChest databases.""" + +import os from ....data.datamodule import ConcatDataModule from ..nih_cxr14.datamodule import RawDataLoader as CXR14Loader @@ -42,5 +45,11 @@ class DataModule(ConcatDataModule): # there is no test set on padchest # (padchest_split["test"], padchest_loader), ], - } + }, + database_name=__package__.split(".")[-1], + split_name=( + os.path.splitext(cxr14_split_filename)[0] + + "+" + + os.path.splitext(padchest_split_filename)[0] + ), ) diff --git a/src/mednet/config/data/padchest/datamodule.py b/src/mednet/config/data/padchest/datamodule.py index 778d505f..d146fc03 100644 --- a/src/mednet/config/data/padchest/datamodule.py +++ b/src/mednet/config/data/padchest/datamodule.py @@ -341,4 +341,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/shenzhen/datamodule.py b/src/mednet/config/data/shenzhen/datamodule.py index 6853ebe5..81e48f9b 100644 --- a/src/mednet/config/data/shenzhen/datamodule.py +++ b/src/mednet/config/data/shenzhen/datamodule.py @@ -155,4 +155,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/tbpoc/datamodule.py b/src/mednet/config/data/tbpoc/datamodule.py index 67846f6c..14b09e7f 100644 --- a/src/mednet/config/data/tbpoc/datamodule.py +++ b/src/mednet/config/data/tbpoc/datamodule.py @@ -136,4 +136,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/tbx11k/datamodule.py b/src/mednet/config/data/tbx11k/datamodule.py index 9c76bb5a..1735607e 100644 --- a/src/mednet/config/data/tbx11k/datamodule.py +++ b/src/mednet/config/data/tbx11k/datamodule.py @@ -355,4 +355,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py index 4f79857b..71fead5a 100644 --- a/src/mednet/data/datamodule.py +++ b/src/mednet/data/datamodule.py @@ -458,6 +458,12 @@ class ConcatDataModule(lightning.LightningDataModule): Entries named ``monitor-...`` will be considered extra datasets that do not influence any early stop criteria during training, and are just monitored beyond the ``validation`` dataset. + database_name + The name of the database, or aggregated database containing the + raw-samples served by this data module. + split_name + The name of the split used to group the samples into the various + datasets for training, validation and testing. cache_samples If set, then issue raw data loading during ``prepare_data()``, and serves samples from CPU memory. Otherwise, loads samples from disk on @@ -510,6 +516,8 @@ class ConcatDataModule(lightning.LightningDataModule): def __init__( self, splits: ConcatDatabaseSplit, + database_name: str = "", + split_name: str = "", cache_samples: bool = False, balance_sampler_by_class: bool = False, batch_size: int = 1, @@ -522,10 +530,15 @@ class ConcatDataModule(lightning.LightningDataModule): self.set_chunk_size(batch_size, batch_chunk_count) self.splits = splits + self.database_name = database_name + self.split_name = split_name for dataset_name, split_loaders in splits.items(): count = sum([len(k) for k, _ in split_loaders]) - logger.info(f"Dataset `{dataset_name}` contains {count} samples") + logger.info( + f"Dataset `{dataset_name}` (`{database_name}`/`{split_name}`) " + f"contains {count} samples" + ) self.cache_samples = cache_samples self._train_sampler = None -- GitLab