diff --git a/pyproject.toml b/pyproject.toml index b304379d7ffa335552d650e9965246fe3c4ef029..7c9f19488a8aa34e72d1faf2be868a255710654c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -465,8 +465,8 @@ hivtb_rs_f7 = "ptbench.configs.datasets.hivtb_RS.fold_7" hivtb_rs_f8 = "ptbench.configs.datasets.hivtb_RS.fold_8" hivtb_rs_f9 = "ptbench.configs.datasets.hivtb_RS.fold_9" # montgomery-shenzhen-indian-padchest aggregated dataset -mc_ch_in_pc = "ptbench.configs.datasets.mc_ch_in_pc.default" -mc_ch_in_pc_rgb = "ptbench.configs.datasets.mc_ch_in_pc.rgb" +mc_ch_in_pc = "ptbench.data.mc_ch_in_pc.default" +mc_ch_in_pc_rgb = "ptbench.data.mc_ch_in_pc.rgb" # extended montgomery-shenzhen-indian-padchest aggregated dataset # (with radiological signs) mc_ch_in_pc_rs = "ptbench.configs.datasets.mc_ch_in_pc_RS.default" diff --git a/src/ptbench/configs/datasets/mc_ch_in_pc/__init__.py b/src/ptbench/configs/datasets/mc_ch_in_pc/__init__.py deleted file mode 100644 index 9a4a1815ae7daafd9812623a8fbb08ac9abd114d..0000000000000000000000000000000000000000 --- a/src/ptbench/configs/datasets/mc_ch_in_pc/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from torch.utils.data.dataset import ConcatDataset - - -def _maker(protocol): - if protocol == "default": - from ..indian import default as indian - from ..montgomery import default as mc - from ..padchest import tb_idiap as pc - from ..shenzhen import default as ch - elif protocol == "rgb": - from ..indian import rgb as indian - from ..montgomery import rgb as mc - from ..padchest import tb_idiap_rgb as pc - from ..shenzhen import rgb as ch - - mc = mc.dataset - ch = ch.dataset - indian = indian.dataset - pc = pc.dataset - - dataset = {} - dataset["__train__"] = ConcatDataset( - [mc["__train__"], ch["__train__"], indian["__train__"], pc["__train__"]] - ) - dataset["train"] = ConcatDataset( - [mc["train"], ch["train"], indian["train"], pc["train"]] - ) - dataset["__valid__"] = ConcatDataset( - [mc["__valid__"], ch["__valid__"], indian["__valid__"], pc["__valid__"]] - ) - dataset["test"] = ConcatDataset( - [mc["test"], ch["test"], indian["test"], pc["test"]] - ) - - return dataset diff --git a/src/ptbench/configs/datasets/mc_ch_in_pc/default.py b/src/ptbench/configs/datasets/mc_ch_in_pc/default.py deleted file mode 100644 index 8b75d5f4cdde3fcdfef5a1c4d7dfa220d490aa44..0000000000000000000000000000000000000000 --- a/src/ptbench/configs/datasets/mc_ch_in_pc/default.py +++ /dev/null @@ -1,10 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and Padchest -datasets.""" - -from . import _maker - -dataset = _maker("default") diff --git a/src/ptbench/configs/datasets/mc_ch_in_pc/rgb.py b/src/ptbench/configs/datasets/mc_ch_in_pc/rgb.py deleted file mode 100644 index 0da28daecb170e392b7f65e3c143e2bff9c0adb0..0000000000000000000000000000000000000000 --- a/src/ptbench/configs/datasets/mc_ch_in_pc/rgb.py +++ /dev/null @@ -1,10 +0,0 @@ -# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> -# -# SPDX-License-Identifier: GPL-3.0-or-later - -"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and Padchest -(RGB) datasets.""" - -from . import _maker - -dataset = _maker("rgb") diff --git a/src/ptbench/data/mc_ch_in_pc/__init__.py b/src/ptbench/data/mc_ch_in_pc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..662d5c1326651b4d9f48d47bc4b503df23d17216 --- /dev/null +++ b/src/ptbench/data/mc_ch_in_pc/__init__.py @@ -0,0 +1,3 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/src/ptbench/data/mc_ch_in_pc/default.py b/src/ptbench/data/mc_ch_in_pc/default.py new file mode 100644 index 0000000000000000000000000000000000000000..0b8be8f44d10fb859dcb7c044a6646c1375fda5a --- /dev/null +++ b/src/ptbench/data/mc_ch_in_pc/default.py @@ -0,0 +1,91 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and Padchest +datasets.""" + +from clapper.logging import setup +from torch.utils.data.dataset import ConcatDataset + +from .. import return_subsets +from ..base_datamodule import BaseDataModule, get_dataset_from_module +from ..indian.default import datamodule as indian_datamodule +from ..montgomery.default import datamodule as mc_datamodule +from ..padchest.tb_idiap import datamodule as pc_datamodule +from ..shenzhen.default import datamodule as ch_datamodule + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +class DefaultModule(BaseDataModule): + def __init__( + self, + train_batch_size=1, + predict_batch_size=1, + drop_incomplete_batch=False, + multiproc_kwargs=None, + ): + self.train_batch_size = train_batch_size + self.predict_batch_size = predict_batch_size + self.drop_incomplete_batch = drop_incomplete_batch + self.multiproc_kwargs = multiproc_kwargs + + super().__init__( + train_batch_size=train_batch_size, + predict_batch_size=predict_batch_size, + drop_incomplete_batch=drop_incomplete_batch, + multiproc_kwargs=multiproc_kwargs, + ) + + def setup(self, stage: str): + # Instantiate other datamodules and get their datasets + + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } + + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) + indian = get_dataset_from_module( + indian_datamodule, stage, **module_args + ) + pc = get_dataset_from_module(pc_datamodule, stage, **module_args) + + # Combine datasets + self.dataset = {} + self.dataset["__train__"] = ConcatDataset( + [ + mc["__train__"], + ch["__train__"], + indian["__train__"], + pc["__train__"], + ] + ) + self.dataset["train"] = ConcatDataset( + [mc["train"], ch["train"], indian["train"], pc["train"]] + ) + self.dataset["__valid__"] = ConcatDataset( + [ + mc["__valid__"], + ch["__valid__"], + indian["__valid__"], + pc["__valid__"], + ] + ) + self.dataset["test"] = ConcatDataset( + [mc["test"], ch["test"], indian["test"], pc["test"]] + ) + + ( + self.train_dataset, + self.validation_dataset, + self.extra_validation_datasets, + self.predict_dataset, + ) = return_subsets(self.dataset) + + +datamodule = DefaultModule diff --git a/src/ptbench/data/mc_ch_in_pc/rgb.py b/src/ptbench/data/mc_ch_in_pc/rgb.py new file mode 100644 index 0000000000000000000000000000000000000000..47b6845d00b5330a4ca140d26e0ae2d2ef357fe4 --- /dev/null +++ b/src/ptbench/data/mc_ch_in_pc/rgb.py @@ -0,0 +1,91 @@ +# Copyright © 2022 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and Padchest +(RGB) datasets.""" + +from clapper.logging import setup +from torch.utils.data.dataset import ConcatDataset + +from .. import return_subsets +from ..base_datamodule import BaseDataModule, get_dataset_from_module +from ..indian.rgb import datamodule as indian_datamodule +from ..montgomery.rgb import datamodule as mc_datamodule +from ..padchest.tb_idiap_rgb import datamodule as pc_datamodule +from ..shenzhen.rgb import datamodule as ch_datamodule + +logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") + + +class DefaultModule(BaseDataModule): + def __init__( + self, + train_batch_size=1, + predict_batch_size=1, + drop_incomplete_batch=False, + multiproc_kwargs=None, + ): + self.train_batch_size = train_batch_size + self.predict_batch_size = predict_batch_size + self.drop_incomplete_batch = drop_incomplete_batch + self.multiproc_kwargs = multiproc_kwargs + + super().__init__( + train_batch_size=train_batch_size, + predict_batch_size=predict_batch_size, + drop_incomplete_batch=drop_incomplete_batch, + multiproc_kwargs=multiproc_kwargs, + ) + + def setup(self, stage: str): + # Instantiate other datamodules and get their datasets + + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } + + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) + indian = get_dataset_from_module( + indian_datamodule, stage, **module_args + ) + pc = get_dataset_from_module(pc_datamodule, stage, **module_args) + + # Combine datasets + self.dataset = {} + self.dataset["__train__"] = ConcatDataset( + [ + mc["__train__"], + ch["__train__"], + indian["__train__"], + pc["__train__"], + ] + ) + self.dataset["train"] = ConcatDataset( + [mc["train"], ch["train"], indian["train"], pc["train"]] + ) + self.dataset["__valid__"] = ConcatDataset( + [ + mc["__valid__"], + ch["__valid__"], + indian["__valid__"], + pc["__valid__"], + ] + ) + self.dataset["test"] = ConcatDataset( + [mc["test"], ch["test"], indian["test"], pc["test"]] + ) + + ( + self.train_dataset, + self.validation_dataset, + self.extra_validation_datasets, + self.predict_dataset, + ) = return_subsets(self.dataset) + + +datamodule = DefaultModule