diff --git a/src/ptbench/configs/datasets/nih_cxr14_re_pc/__init__.py b/src/ptbench/configs/datasets/nih_cxr14_re_pc/__init__.py deleted file mode 100644 index 7b6d0df211ab0b5bd52db6f3887470373603ed8f..0000000000000000000000000000000000000000 --- a/src/ptbench/configs/datasets/nih_cxr14_re_pc/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from torch.utils.data.dataset import ConcatDataset - - -def _maker(protocol): - if protocol == "idiap": - from ..nih_cxr14_re import default as nih_cxr14_re - from ..padchest import no_tb_idiap as padchest_no_tb - else: - raise RuntimeError(f"Unsupported protocol: {protocol}") - - nih_cxr14_re = nih_cxr14_re.dataset - padchest_no_tb = padchest_no_tb.dataset - - dataset = {} - dataset["__train__"] = ConcatDataset( - [nih_cxr14_re["__train__"], padchest_no_tb["__train__"]] - ) - dataset["train"] = ConcatDataset( - [nih_cxr14_re["train"], padchest_no_tb["train"]] - ) - dataset["__valid__"] = ConcatDataset( - [nih_cxr14_re["__valid__"], padchest_no_tb["__valid__"]] - ) - dataset["validation"] = ConcatDataset( - [nih_cxr14_re["validation"], padchest_no_tb["validation"]] - ) - dataset["test"] = nih_cxr14_re["test"] - - return dataset diff --git a/src/ptbench/configs/datasets/nih_cxr14_re_pc/idiap.py b/src/ptbench/configs/datasets/nih_cxr14_re_pc/idiap.py deleted file mode 100644 index 092d853953a86d4264cf424358254288d2e70d4e..0000000000000000000000000000000000000000 --- a/src/ptbench/configs/datasets/nih_cxr14_re_pc/idiap.py +++ /dev/null @@ -1,10 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Aggregated dataset composed of NIH CXR14 relabeld and PadChest (normalized) -datasets.""" - -from . import _maker - -dataset = _maker("idiap") diff --git a/src/ptbench/data/nih_cxr14_re_pc/__init__.py b/src/ptbench/data/nih_cxr14_re_pc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..84b9088ea60cbbf9ddee2fdf1bfc14203beda01f --- /dev/null +++ b/src/ptbench/data/nih_cxr14_re_pc/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/src/ptbench/data/nih_cxr14_re_pc/idiap.py b/src/ptbench/data/nih_cxr14_re_pc/idiap.py new file mode 100644 index 0000000000000000000000000000000000000000..2087235a971a268ab57e286090752f7b4a22ec09 --- /dev/null +++ b/src/ptbench/data/nih_cxr14_re_pc/idiap.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of NIH CXR14 relabeld and PadChest (normalized) +datasets.""" + +from clapper.logging import setup +from torch.utils.data.dataset import ConcatDataset + +from .. import return_subsets +from ..base_datamodule import BaseDataModule, get_dataset_from_module +from ..nih_cxr14_re.default import datamodule as nih_cxr14_re_datamodule +from ..padchest.no_tb_idiap import datamodule as padchest_no_tb_datamodule + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +class DefaultModule(BaseDataModule): + def __init__( + self, + train_batch_size=1, + predict_batch_size=1, + drop_incomplete_batch=False, + multiproc_kwargs=None, + ): + self.train_batch_size = train_batch_size + self.predict_batch_size = predict_batch_size + self.drop_incomplete_batch = drop_incomplete_batch + self.multiproc_kwargs = multiproc_kwargs + + super().__init__( + train_batch_size=train_batch_size, + predict_batch_size=predict_batch_size, + drop_incomplete_batch=drop_incomplete_batch, + multiproc_kwargs=multiproc_kwargs, + ) + + def setup(self, stage: str): + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } + + nih_cxr14_re = get_dataset_from_module( + nih_cxr14_re_datamodule, stage, **module_args + ) + padchest_no_tb = get_dataset_from_module( + padchest_no_tb_datamodule, stage, **module_args + ) + + self.dataset = {} + self.dataset["__train__"] = ConcatDataset( + [nih_cxr14_re["__train__"], padchest_no_tb["__train__"]] + ) + self.dataset["train"] = ConcatDataset( + [nih_cxr14_re["train"], padchest_no_tb["train"]] + ) + self.dataset["__valid__"] = ConcatDataset( + [nih_cxr14_re["__valid__"], padchest_no_tb["__valid__"]] + ) + self.dataset["validation"] = ConcatDataset( + [nih_cxr14_re["validation"], padchest_no_tb["validation"]] + ) + self.dataset["test"] = nih_cxr14_re["test"] + + ( + self.train_dataset, + self.validation_dataset, + self.extra_validation_datasets, + self.predict_dataset, + ) = return_subsets(self.dataset) + + +datamodule = DefaultModule