diff --git a/src/mednet/config/data/hivtb/datamodule.py b/src/mednet/config/data/hivtb/datamodule.py index 68a7b7a3bdfea302ecd41bc750355b354b85c5f2..cae64f2c53b2fda0992aea1f662307ddab27d616 100644 --- a/src/mednet/config/data/hivtb/datamodule.py +++ b/src/mednet/config/data/hivtb/datamodule.py @@ -141,4 +141,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/indian/datamodule.py b/src/mednet/config/data/indian/datamodule.py index 08a507223436be949b4ae03de1c2671060697c1c..2fc0567bbbefd65ec4c96fc19592c1ec1db7215c 100644 --- a/src/mednet/config/data/indian/datamodule.py +++ b/src/mednet/config/data/indian/datamodule.py @@ -7,6 +7,7 @@ Database reference: [INDIAN-2013]_ """ import importlib.resources +import os from ....config.data.shenzhen.datamodule import RawDataLoader from ....data.datamodule import CachingDataModule @@ -82,4 +83,6 @@ class DataModule(CachingDataModule): raw_data_loader=RawDataLoader( config_variable=CONFIGURATION_KEY_DATADIR ), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/montgomery/datamodule.py b/src/mednet/config/data/montgomery/datamodule.py index 86e9fdb7058c3e25351212b93be7a8167c9ecc03..5ed7fa50e325810f3b8523a883aa5f9f8c1c301b 100644 --- a/src/mednet/config/data/montgomery/datamodule.py +++ b/src/mednet/config/data/montgomery/datamodule.py @@ -143,4 +143,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/montgomery_shenzhen/datamodule.py b/src/mednet/config/data/montgomery_shenzhen/datamodule.py index fa83fdde5165e24c0d08b0b9c123b87bd98e3805..6df353ad70d701d8585be9b46dc4ea540adccdb4 100644 --- a/src/mednet/config/data/montgomery_shenzhen/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen/datamodule.py @@ -1,6 +1,9 @@ # Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later +"""Aggregated DataModule composed of Montgomery and Shenzhen databases.""" + +import os from ....data.datamodule import ConcatDataModule from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader @@ -38,5 +41,7 @@ class DataModule(ConcatDataModule): (montgomery_split["test"], montgomery_loader), (shenzhen_split["test"], shenzhen_loader), ], - } + }, + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py index 676fa8ef96d26b794341dee5c5cd09caf55d06d0..0a0d497f14b70e23c5f0208904ddbb830f7171b4 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian/datamodule.py @@ -1,7 +1,9 @@ # Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets.""" +"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian databases.""" + +import os from ....data.datamodule import ConcatDataModule from ..indian.datamodule import RawDataLoader as IndianLoader @@ -13,7 +15,8 @@ from ..shenzhen.datamodule import make_split as make_shenzhen_split class DataModule(ConcatDataModule): - """Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets. + """Aggregated DataModule composed of Montgomery, Shenzhen and Indian + datasets. Parameters ---------- @@ -46,5 +49,7 @@ class DataModule(ConcatDataModule): (shenzhen_split["test"], shenzhen_loader), (indian_split["test"], indian_loader), ], - } + }, + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py index 2876af8f0335acd87313c02ca18ad8b47dd21bcc..5442c875a63b6bbcd0587d138dc4ee3d2e23fea7 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_padchest/datamodule.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and PadChest datasets.""" +import os + from ....data.datamodule import ConcatDataModule from ..indian.datamodule import RawDataLoader as IndianLoader from ..indian.datamodule import make_split as make_indian_split @@ -57,5 +59,11 @@ class DataModule(ConcatDataModule): (indian_split["test"], indian_loader), (padchest_split["test"], padchest_loader), ], - } + }, + database_name=__package__.split(".")[-1], + split_name=( + os.path.splitext(split_filename)[0] + + "+" + + os.path.splitext(padchest_split_filename)[0] + ), ) diff --git a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py index 8dd831981465af7afa276d67287de5c09008bc36..ff0c3844e89f0606c0681eac7f7ab3995b4a6afc 100644 --- a/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py +++ b/src/mednet/config/data/montgomery_shenzhen_indian_tbx11k/datamodule.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later """Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets.""" +import os + from ....data.datamodule import ConcatDataModule from ..indian.datamodule import RawDataLoader as IndianLoader from ..indian.datamodule import make_split as make_indian_split @@ -56,5 +58,11 @@ class DataModule(ConcatDataModule): (indian_split["test"], indian_loader), (tbx11k_split["test"], tbx11k_loader), ], - } + }, + database_name=__package__.split(".")[-1], + split_name=( + os.path.splitext(split_filename)[0] + + "+" + + os.path.splitext(tbx11k_split_filename)[0] + ), ) diff --git a/src/mednet/config/data/nih_cxr14/datamodule.py b/src/mednet/config/data/nih_cxr14/datamodule.py index 5967ee63152c6e6dac0aa9b1fb6bd13d3f24a438..26596b74c427e712d7ef3b2b141d4508d1eae395 100644 --- a/src/mednet/config/data/nih_cxr14/datamodule.py +++ b/src/mednet/config/data/nih_cxr14/datamodule.py @@ -192,4 +192,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py index 2c793c7980156f14f0ab4946652bd1f627cd22c0..6cc383405093db4adca141ad6001bab9572fefa9 100644 --- a/src/mednet/config/data/nih_cxr14_padchest/datamodule.py +++ b/src/mednet/config/data/nih_cxr14_padchest/datamodule.py @@ -1,6 +1,9 @@ # Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # # SPDX-License-Identifier: GPL-3.0-or-later +"""Aggregated DataModule composed of NIH-CXR-14 and PadChest databases.""" + +import os from ....data.datamodule import ConcatDataModule from ..nih_cxr14.datamodule import RawDataLoader as CXR14Loader @@ -42,5 +45,11 @@ class DataModule(ConcatDataModule): # there is no test set on padchest # (padchest_split["test"], padchest_loader), ], - } + }, + database_name=__package__.split(".")[-1], + split_name=( + os.path.splitext(cxr14_split_filename)[0] + + "+" + + os.path.splitext(padchest_split_filename)[0] + ), ) diff --git a/src/mednet/config/data/padchest/datamodule.py b/src/mednet/config/data/padchest/datamodule.py index 778d505f96813f28b348b9a251d91f4213871d77..d146fc0357659134151084d765945b96a8f8a305 100644 --- a/src/mednet/config/data/padchest/datamodule.py +++ b/src/mednet/config/data/padchest/datamodule.py @@ -341,4 +341,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/shenzhen/datamodule.py b/src/mednet/config/data/shenzhen/datamodule.py index 6853ebe56718c38db9e1e1d72a2c08ded11627dd..81e48f9b70dc887a97a312a0d8062afc809682e2 100644 --- a/src/mednet/config/data/shenzhen/datamodule.py +++ b/src/mednet/config/data/shenzhen/datamodule.py @@ -155,4 +155,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/tbpoc/datamodule.py b/src/mednet/config/data/tbpoc/datamodule.py index 67846f6c7c0842c18f24d83a4018beb1ccc908b9..14b09e7f23e7c72779f43f7f9bb8503ca9954414 100644 --- a/src/mednet/config/data/tbpoc/datamodule.py +++ b/src/mednet/config/data/tbpoc/datamodule.py @@ -136,4 +136,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/config/data/tbx11k/datamodule.py b/src/mednet/config/data/tbx11k/datamodule.py index 9c76bb5af431793e2668ef4a181e4e98b28e5f25..1735607ee72cb7f1298bca1e1adca57bb5d41d6a 100644 --- a/src/mednet/config/data/tbx11k/datamodule.py +++ b/src/mednet/config/data/tbx11k/datamodule.py @@ -355,4 +355,6 @@ class DataModule(CachingDataModule): super().__init__( database_split=make_split(split_filename), raw_data_loader=RawDataLoader(), + database_name=__package__.split(".")[-1], + split_name=os.path.splitext(split_filename)[0], ) diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py index 4f79857b88eeb7a7fada2bf40f29ebb9dcfe7209..71fead5a2d275b2aa33f28dcfc9c894ac687a2b2 100644 --- a/src/mednet/data/datamodule.py +++ b/src/mednet/data/datamodule.py @@ -458,6 +458,12 @@ class ConcatDataModule(lightning.LightningDataModule): Entries named ``monitor-...`` will be considered extra datasets that do not influence any early stop criteria during training, and are just monitored beyond the ``validation`` dataset. + database_name + The name of the database, or aggregated database containing the + raw-samples served by this data module. + split_name + The name of the split used to group the samples into the various + datasets for training, validation and testing. cache_samples If set, then issue raw data loading during ``prepare_data()``, and serves samples from CPU memory. Otherwise, loads samples from disk on @@ -510,6 +516,8 @@ class ConcatDataModule(lightning.LightningDataModule): def __init__( self, splits: ConcatDatabaseSplit, + database_name: str = "", + split_name: str = "", cache_samples: bool = False, balance_sampler_by_class: bool = False, batch_size: int = 1, @@ -522,10 +530,15 @@ class ConcatDataModule(lightning.LightningDataModule): self.set_chunk_size(batch_size, batch_chunk_count) self.splits = splits + self.database_name = database_name + self.split_name = split_name for dataset_name, split_loaders in splits.items(): count = sum([len(k) for k, _ in split_loaders]) - logger.info(f"Dataset `{dataset_name}` contains {count} samples") + logger.info( + f"Dataset `{dataset_name}` (`{database_name}`/`{split_name}`) " + f"contains {count} samples" + ) self.cache_samples = cache_samples self._train_sampler = None