diff --git a/src/ptbench/data/mc_ch_in_11k/fold_1_rgb.py b/src/ptbench/data/mc_ch_in_11k/fold_1_rgb.py index 32a94a5d68567ab4361686ac11118330e0b911b0..1fdea4b18d3721053845cf0bf7c830cd2bf6fed6 100644 --- a/src/ptbench/data/mc_ch_in_11k/fold_1_rgb.py +++ b/src/ptbench/data/mc_ch_in_11k/fold_1_rgb.py @@ -2,7 +2,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Aggregated dataset composed of Montgomery and Shenzhen datasets.""" +"""Aggregated dataset composed of Montgomery, Shenzhen, Indian and the default +TBX11K-simplified datasets (cross validation fold 1, RGB)""" from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset @@ -12,6 +13,7 @@ from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..indian.fold_1_rgb import datamodule as indian_datamodule from ..montgomery.fold_1_rgb import datamodule as mc_datamodule from ..shenzhen.fold_1_rgb import datamodule as ch_datamodule +from ..tbx11k_simplified.fold_1_rgb import datamodule as tbx11k_datamodule logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -51,23 +53,41 @@ class DefaultModule(BaseDataModule): indian = get_dataset_from_module( indian_datamodule, stage, **module_args ) + tbx11k = get_dataset_from_module( + tbx11k_datamodule, stage, **module_args + ) # Combine datasets self.dataset = {} self.dataset["__train__"] = ConcatDataset( - [mc["__train__"], ch["__train__"], indian["__train__"]] + [ + mc["__train__"], + ch["__train__"], + indian["__train__"], + tbx11k["__train__"], + ] ) self.dataset["train"] = ConcatDataset( - [mc["train"], ch["train"], indian["train"]] + [mc["train"], ch["train"], indian["train"], tbx11k["train"]] ) self.dataset["__valid__"] = ConcatDataset( - [mc["__valid__"], ch["__valid__"], indian["__valid__"]] + [ + mc["__valid__"], + ch["__valid__"], + indian["__valid__"], + tbx11k["__valid__"], + ] ) self.dataset["validation"] = ConcatDataset( - [mc["validation"], ch["validation"], indian["validation"]] + [ + mc["validation"], + ch["validation"], + indian["validation"], + tbx11k["validation"], + ] ) self.dataset["test"] = ConcatDataset( - [mc["test"], ch["test"], indian["test"]] + [mc["test"], ch["test"], indian["test"], tbx11k["test"]] ) (