Skip to content
Snippets Groups Projects
Commit 5f2b57a9 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[config.data] Implement attributes for database and split name for issue #60

parent 02238129
No related branches found
No related tags found
1 merge request!24Implement fixes on evaluation (closes #20), and prepare for handling issue #60
Showing
with 73 additions and 8 deletions
...@@ -141,4 +141,6 @@ class DataModule(CachingDataModule): ...@@ -141,4 +141,6 @@ class DataModule(CachingDataModule):
super().__init__( super().__init__(
database_split=make_split(split_filename), database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(), raw_data_loader=RawDataLoader(),
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
) )
...@@ -7,6 +7,7 @@ Database reference: [INDIAN-2013]_ ...@@ -7,6 +7,7 @@ Database reference: [INDIAN-2013]_
""" """
import importlib.resources import importlib.resources
import os
from ....config.data.shenzhen.datamodule import RawDataLoader from ....config.data.shenzhen.datamodule import RawDataLoader
from ....data.datamodule import CachingDataModule from ....data.datamodule import CachingDataModule
...@@ -82,4 +83,6 @@ class DataModule(CachingDataModule): ...@@ -82,4 +83,6 @@ class DataModule(CachingDataModule):
raw_data_loader=RawDataLoader( raw_data_loader=RawDataLoader(
config_variable=CONFIGURATION_KEY_DATADIR config_variable=CONFIGURATION_KEY_DATADIR
), ),
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
) )
...@@ -143,4 +143,6 @@ class DataModule(CachingDataModule): ...@@ -143,4 +143,6 @@ class DataModule(CachingDataModule):
super().__init__( super().__init__(
database_split=make_split(split_filename), database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(), raw_data_loader=RawDataLoader(),
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
) )
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated DataModule composed of Montgomery and Shenzhen databases."""
import os
from ....data.datamodule import ConcatDataModule from ....data.datamodule import ConcatDataModule
from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader
...@@ -38,5 +41,7 @@ class DataModule(ConcatDataModule): ...@@ -38,5 +41,7 @@ class DataModule(ConcatDataModule):
(montgomery_split["test"], montgomery_loader), (montgomery_split["test"], montgomery_loader),
(shenzhen_split["test"], shenzhen_loader), (shenzhen_split["test"], shenzhen_loader),
], ],
} },
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
) )
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets.""" """Aggregated DataModule composed of Montgomery, Shenzhen and Indian databases."""
import os
from ....data.datamodule import ConcatDataModule from ....data.datamodule import ConcatDataModule
from ..indian.datamodule import RawDataLoader as IndianLoader from ..indian.datamodule import RawDataLoader as IndianLoader
...@@ -13,7 +15,8 @@ from ..shenzhen.datamodule import make_split as make_shenzhen_split ...@@ -13,7 +15,8 @@ from ..shenzhen.datamodule import make_split as make_shenzhen_split
class DataModule(ConcatDataModule): class DataModule(ConcatDataModule):
"""Aggregated DataModule composed of Montgomery, Shenzhen and Indian datasets. """Aggregated DataModule composed of Montgomery, Shenzhen and Indian
datasets.
Parameters Parameters
---------- ----------
...@@ -46,5 +49,7 @@ class DataModule(ConcatDataModule): ...@@ -46,5 +49,7 @@ class DataModule(ConcatDataModule):
(shenzhen_split["test"], shenzhen_loader), (shenzhen_split["test"], shenzhen_loader),
(indian_split["test"], indian_loader), (indian_split["test"], indian_loader),
], ],
} },
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
) )
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and PadChest datasets.""" """Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and PadChest datasets."""
import os
from ....data.datamodule import ConcatDataModule from ....data.datamodule import ConcatDataModule
from ..indian.datamodule import RawDataLoader as IndianLoader from ..indian.datamodule import RawDataLoader as IndianLoader
from ..indian.datamodule import make_split as make_indian_split from ..indian.datamodule import make_split as make_indian_split
...@@ -57,5 +59,11 @@ class DataModule(ConcatDataModule): ...@@ -57,5 +59,11 @@ class DataModule(ConcatDataModule):
(indian_split["test"], indian_loader), (indian_split["test"], indian_loader),
(padchest_split["test"], padchest_loader), (padchest_split["test"], padchest_loader),
], ],
} },
database_name=__package__.split(".")[-1],
split_name=(
os.path.splitext(split_filename)[0]
+ "+"
+ os.path.splitext(padchest_split_filename)[0]
),
) )
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets.""" """Aggregated DataModule composed of Montgomery, Shenzhen, Indian, and TBX11k datasets."""
import os
from ....data.datamodule import ConcatDataModule from ....data.datamodule import ConcatDataModule
from ..indian.datamodule import RawDataLoader as IndianLoader from ..indian.datamodule import RawDataLoader as IndianLoader
from ..indian.datamodule import make_split as make_indian_split from ..indian.datamodule import make_split as make_indian_split
...@@ -56,5 +58,11 @@ class DataModule(ConcatDataModule): ...@@ -56,5 +58,11 @@ class DataModule(ConcatDataModule):
(indian_split["test"], indian_loader), (indian_split["test"], indian_loader),
(tbx11k_split["test"], tbx11k_loader), (tbx11k_split["test"], tbx11k_loader),
], ],
} },
database_name=__package__.split(".")[-1],
split_name=(
os.path.splitext(split_filename)[0]
+ "+"
+ os.path.splitext(tbx11k_split_filename)[0]
),
) )
...@@ -192,4 +192,6 @@ class DataModule(CachingDataModule): ...@@ -192,4 +192,6 @@ class DataModule(CachingDataModule):
super().__init__( super().__init__(
database_split=make_split(split_filename), database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(), raw_data_loader=RawDataLoader(),
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
) )
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> # Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated DataModule composed of NIH-CXR-14 and PadChest databases."""
import os
from ....data.datamodule import ConcatDataModule from ....data.datamodule import ConcatDataModule
from ..nih_cxr14.datamodule import RawDataLoader as CXR14Loader from ..nih_cxr14.datamodule import RawDataLoader as CXR14Loader
...@@ -42,5 +45,11 @@ class DataModule(ConcatDataModule): ...@@ -42,5 +45,11 @@ class DataModule(ConcatDataModule):
# there is no test set on padchest # there is no test set on padchest
# (padchest_split["test"], padchest_loader), # (padchest_split["test"], padchest_loader),
], ],
} },
database_name=__package__.split(".")[-1],
split_name=(
os.path.splitext(cxr14_split_filename)[0]
+ "+"
+ os.path.splitext(padchest_split_filename)[0]
),
) )
...@@ -341,4 +341,6 @@ class DataModule(CachingDataModule): ...@@ -341,4 +341,6 @@ class DataModule(CachingDataModule):
super().__init__( super().__init__(
database_split=make_split(split_filename), database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(), raw_data_loader=RawDataLoader(),
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
) )
...@@ -155,4 +155,6 @@ class DataModule(CachingDataModule): ...@@ -155,4 +155,6 @@ class DataModule(CachingDataModule):
super().__init__( super().__init__(
database_split=make_split(split_filename), database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(), raw_data_loader=RawDataLoader(),
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
) )
...@@ -136,4 +136,6 @@ class DataModule(CachingDataModule): ...@@ -136,4 +136,6 @@ class DataModule(CachingDataModule):
super().__init__( super().__init__(
database_split=make_split(split_filename), database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(), raw_data_loader=RawDataLoader(),
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
) )
...@@ -355,4 +355,6 @@ class DataModule(CachingDataModule): ...@@ -355,4 +355,6 @@ class DataModule(CachingDataModule):
super().__init__( super().__init__(
database_split=make_split(split_filename), database_split=make_split(split_filename),
raw_data_loader=RawDataLoader(), raw_data_loader=RawDataLoader(),
database_name=__package__.split(".")[-1],
split_name=os.path.splitext(split_filename)[0],
) )
...@@ -458,6 +458,12 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -458,6 +458,12 @@ class ConcatDataModule(lightning.LightningDataModule):
Entries named ``monitor-...`` will be considered extra datasets that do Entries named ``monitor-...`` will be considered extra datasets that do
not influence any early stop criteria during training, and are just not influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset. monitored beyond the ``validation`` dataset.
database_name
The name of the database, or aggregated database containing the
raw-samples served by this data module.
split_name
The name of the split used to group the samples into the various
datasets for training, validation and testing.
cache_samples cache_samples
If set, then issue raw data loading during ``prepare_data()``, and If set, then issue raw data loading during ``prepare_data()``, and
serves samples from CPU memory. Otherwise, loads samples from disk on serves samples from CPU memory. Otherwise, loads samples from disk on
...@@ -510,6 +516,8 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -510,6 +516,8 @@ class ConcatDataModule(lightning.LightningDataModule):
def __init__( def __init__(
self, self,
splits: ConcatDatabaseSplit, splits: ConcatDatabaseSplit,
database_name: str = "",
split_name: str = "",
cache_samples: bool = False, cache_samples: bool = False,
balance_sampler_by_class: bool = False, balance_sampler_by_class: bool = False,
batch_size: int = 1, batch_size: int = 1,
...@@ -522,10 +530,15 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -522,10 +530,15 @@ class ConcatDataModule(lightning.LightningDataModule):
self.set_chunk_size(batch_size, batch_chunk_count) self.set_chunk_size(batch_size, batch_chunk_count)
self.splits = splits self.splits = splits
self.database_name = database_name
self.split_name = split_name
for dataset_name, split_loaders in splits.items(): for dataset_name, split_loaders in splits.items():
count = sum([len(k) for k, _ in split_loaders]) count = sum([len(k) for k, _ in split_loaders])
logger.info(f"Dataset `{dataset_name}` contains {count} samples") logger.info(
f"Dataset `{dataset_name}` (`{database_name}`/`{split_name}`) "
f"contains {count} samples"
)
self.cache_samples = cache_samples self.cache_samples = cache_samples
self._train_sampler = None self._train_sampler = None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment