Skip to content
Snippets Groups Projects
Commit 5f2b57a9 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[config.data] Implement attributes for database and split name for issue #60

parent 02238129
No related branches found
No related tags found
1 merge request!24Implement fixes on evaluation (closes #20), and prepare for handling issue #60
Showing
with 73 additions and 8 deletions
......@@ -141,4 +141,6 @@ class DataModule(CachingDataModule):
super().__init__(
database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(),
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
)
......@@ -7,6 +7,7 @@ Database reference: [INDIAN-2013]_
"""
import importlib.resources
import os
from ....config.data.shenzhen.datamodule import RawDataLoader
from ....data.datamodule import CachingDataModule
......@@ -82,4 +83,6 @@ class DataModule(CachingDataModule):
raw_data_loader=RawDataLoader(
config_variable=CONFIGURATION_KEY_DATADIR
),
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
)
......@@ -143,4 +143,6 @@ class DataModule(CachingDataModule):
super().__init__(
database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(),
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
)
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated DataModule composed of Montgomery and Shenzhen databases."""
import os
from ....data.datamodule import ConcatDataModule
from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader
......@@ -38,5 +41,7 @@ class DataModule(ConcatDataModule):
(montgomery_split["test"], montgomery_loader),
(shenzhen_split["test"], shenzhen_loader),
],
}
},
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
)
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets."""
"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian databases."""
import os
from ....data.datamodule import ConcatDataModule
from ..indian.datamodule import RawDataLoader as IndianLoader
......@@ -13,7 +15,8 @@ from ..shenzhen.datamodule import make_split as make_shenzhen_split
class DataModule(ConcatDataModule):
"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets.
"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian
datasets.
Parameters
----------
......@@ -46,5 +49,7 @@ class DataModule(ConcatDataModule):
(shenzhen_split["test"], shenzhen_loader),
(indian_split["test"], indian_loader),
],
}
},
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
)
......@@ -3,6 +3,8 @@
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and PadChest datasets."""
import os
from ....data.datamodule import ConcatDataModule
from ..indian.datamodule import RawDataLoader as IndianLoader
from ..indian.datamodule import make_split as make_indian_split
......@@ -57,5 +59,11 @@ class DataModule(ConcatDataModule):
(indian_split["test"], indian_loader),
(padchest_split["test"], padchest_loader),
],
}
},
database_name=__package__.split(".")[-1],
split_name=(
os.path.splitext(split_filename)[0]
+ "+"
+ os.path.splitext(padchest_split_filename)[0]
),
)
......@@ -3,6 +3,8 @@
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets."""
import os
from ....data.datamodule import ConcatDataModule
from ..indian.datamodule import RawDataLoader as IndianLoader
from ..indian.datamodule import make_split as make_indian_split
......@@ -56,5 +58,11 @@ class DataModule(ConcatDataModule):
(indian_split["test"], indian_loader),
(tbx11k_split["test"], tbx11k_loader),
],
}
},
database_name=__package__.split(".")[-1],
split_name=(
os.path.splitext(split_filename)[0]
+ "+"
+ os.path.splitext(tbx11k_split_filename)[0]
),
)
......@@ -192,4 +192,6 @@ class DataModule(CachingDataModule):
super().__init__(
database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(),
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
)
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated DataModule composed of NIH-CXR-14 and PadChest databases."""
import os
from ....data.datamodule import ConcatDataModule
from ..nih_cxr14.datamodule import RawDataLoader as CXR14Loader
......@@ -42,5 +45,11 @@ class DataModule(ConcatDataModule):
# there is no test set on padchest
# (padchest_split["test"], padchest_loader),
],
}
},
database_name=__package__.split(".")[-1],
split_name=(
os.path.splitext(cxr14_split_filename)[0]
+ "+"
+ os.path.splitext(padchest_split_filename)[0]
),
)
......@@ -341,4 +341,6 @@ class DataModule(CachingDataModule):
super().__init__(
database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(),
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
)
......@@ -155,4 +155,6 @@ class DataModule(CachingDataModule):
super().__init__(
database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(),
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
)
......@@ -136,4 +136,6 @@ class DataModule(CachingDataModule):
super().__init__(
database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(),
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
)
......@@ -355,4 +355,6 @@ class DataModule(CachingDataModule):
super().__init__(
database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(),
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
)
......@@ -458,6 +458,12 @@ class ConcatDataModule(lightning.LightningDataModule):
Entries named ``monitor-...`` will be considered extra datasets that do
not influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset.
database_name
The name of the database, or aggregated database containing the
raw-samples served by this data module.
split_name
The name of the split used to group the samples into the various
datasets for training, validation and testing.
cache_samples
If set, then issue raw data loading during ``prepare_data()``, and
serves samples from CPU memory. Otherwise, loads samples from disk on
......@@ -510,6 +516,8 @@ class ConcatDataModule(lightning.LightningDataModule):
def __init__(
self,
splits: ConcatDatabaseSplit,
database_name: str = "",
split_name: str = "",
cache_samples: bool = False,
balance_sampler_by_class: bool = False,
batch_size: int = 1,
......@@ -522,10 +530,15 @@ class ConcatDataModule(lightning.LightningDataModule):
self.set_chunk_size(batch_size, batch_chunk_count)
self.splits = splits
self.database_name = database_name
self.split_name = split_name
for dataset_name, split_loaders in splits.items():
count = sum([len(k) for k, _ in split_loaders])
logger.info(f"Dataset `{dataset_name}` contains {count} samples")
logger.info(
f"Dataset `{dataset_name}` (`{database_name}`/`{split_name}`) "
f"contains {count} samples"
)
self.cache_samples = cache_samples
self._train_sampler = None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment