From b0138f2fc36fecff0c995c752216780e1be3452a Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Mon, 5 Jun 2023 16:01:38 +0200 Subject: [PATCH] Function to instantiate DataModule and retrieve dataset from it (DRY) --- src/ptbench/data/base_datamodule.py | 13 ++++++++++++ src/ptbench/data/mc_ch/default.py | 29 +++++++++------------------ src/ptbench/data/mc_ch/fold_0.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_0_rgb.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_1.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_1_rgb.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_2.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_2_rgb.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_3.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_3_rgb.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_4.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_4_rgb.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_5.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_5_rgb.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_6.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_6_rgb.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_7.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_7_rgb.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_8.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_8_rgb.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_9.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/fold_9_rgb.py | 30 +++++++++------------------- src/ptbench/data/mc_ch/rgb.py | 30 +++++++++------------------- 23 files changed, 211 insertions(+), 461 deletions(-) diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py index fb0970f0..5e656d42 100644 --- a/src/ptbench/data/base_datamodule.py +++ b/src/ptbench/data/base_datamodule.py @@ -92,3 +92,16 @@ class BaseDataModule(pl.LightningDataModule): shuffle=False, pin_memory=self.pin_memory, ) + + +def get_dataset_from_module(module, stage, **module_args): + """Instantiates a DataModule and retrieves the corresponding dataset. + + Useful when combining multiple datasets. + """ + module_instance = module(**module_args) + module_instance.prepare_data() + module_instance.setup(stage=stage) + dataset = module_instance.dataset + + return dataset diff --git a/src/ptbench/data/mc_ch/default.py b/src/ptbench/data/mc_ch/default.py index 0af1d686..cf901fdb 100644 --- a/src/ptbench/data/mc_ch/default.py +++ b/src/ptbench/data/mc_ch/default.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.default import datamodule as mc_datamodule from ..shenzhen.default import datamodule as ch_datamodule @@ -37,27 +37,16 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_0.py b/src/ptbench/data/mc_ch/fold_0.py index e151c0e0..eeb8f512 100644 --- a/src/ptbench/data/mc_ch/fold_0.py +++ b/src/ptbench/data/mc_ch/fold_0.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_0 import datamodule as mc_datamodule from ..shenzhen.fold_0 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_0_rgb.py b/src/ptbench/data/mc_ch/fold_0_rgb.py index 56502b24..6c8c5aeb 100644 --- a/src/ptbench/data/mc_ch/fold_0_rgb.py +++ b/src/ptbench/data/mc_ch/fold_0_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_0_rgb import datamodule as mc_datamodule from ..shenzhen.fold_0_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_1.py b/src/ptbench/data/mc_ch/fold_1.py index 732513da..b6bcb0f7 100644 --- a/src/ptbench/data/mc_ch/fold_1.py +++ b/src/ptbench/data/mc_ch/fold_1.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_1 import datamodule as mc_datamodule from ..shenzhen.fold_1 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_1_rgb.py b/src/ptbench/data/mc_ch/fold_1_rgb.py index 6cfbcb36..25709510 100644 --- a/src/ptbench/data/mc_ch/fold_1_rgb.py +++ b/src/ptbench/data/mc_ch/fold_1_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_1_rgb import datamodule as mc_datamodule from ..shenzhen.fold_1_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_2.py b/src/ptbench/data/mc_ch/fold_2.py index 1d4ac5b5..e3ac99ec 100644 --- a/src/ptbench/data/mc_ch/fold_2.py +++ b/src/ptbench/data/mc_ch/fold_2.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_2 import datamodule as mc_datamodule from ..shenzhen.fold_2 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_2_rgb.py b/src/ptbench/data/mc_ch/fold_2_rgb.py index eec98dca..1bbce20d 100644 --- a/src/ptbench/data/mc_ch/fold_2_rgb.py +++ b/src/ptbench/data/mc_ch/fold_2_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_2_rgb import datamodule as mc_datamodule from ..shenzhen.fold_2_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_3.py b/src/ptbench/data/mc_ch/fold_3.py index b97b5e94..ed58cac7 100644 --- a/src/ptbench/data/mc_ch/fold_3.py +++ b/src/ptbench/data/mc_ch/fold_3.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_3 import datamodule as mc_datamodule from ..shenzhen.fold_3 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_3_rgb.py b/src/ptbench/data/mc_ch/fold_3_rgb.py index e380dc33..7d9401f0 100644 --- a/src/ptbench/data/mc_ch/fold_3_rgb.py +++ b/src/ptbench/data/mc_ch/fold_3_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_3_rgb import datamodule as mc_datamodule from ..shenzhen.fold_3_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_4.py b/src/ptbench/data/mc_ch/fold_4.py index 8cd906b5..761d730e 100644 --- a/src/ptbench/data/mc_ch/fold_4.py +++ b/src/ptbench/data/mc_ch/fold_4.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_4 import datamodule as mc_datamodule from ..shenzhen.fold_4 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_4_rgb.py b/src/ptbench/data/mc_ch/fold_4_rgb.py index 7ba0ecfe..3f1a1476 100644 --- a/src/ptbench/data/mc_ch/fold_4_rgb.py +++ b/src/ptbench/data/mc_ch/fold_4_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_4_rgb import datamodule as mc_datamodule from ..shenzhen.fold_4_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_5.py b/src/ptbench/data/mc_ch/fold_5.py index 3f20a33b..01a8f31e 100644 --- a/src/ptbench/data/mc_ch/fold_5.py +++ b/src/ptbench/data/mc_ch/fold_5.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_5 import datamodule as mc_datamodule from ..shenzhen.fold_5 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_5_rgb.py b/src/ptbench/data/mc_ch/fold_5_rgb.py index 61159ed6..cf66986f 100644 --- a/src/ptbench/data/mc_ch/fold_5_rgb.py +++ b/src/ptbench/data/mc_ch/fold_5_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_5_rgb import datamodule as mc_datamodule from ..shenzhen.fold_5_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_6.py b/src/ptbench/data/mc_ch/fold_6.py index 74137910..5b488da9 100644 --- a/src/ptbench/data/mc_ch/fold_6.py +++ b/src/ptbench/data/mc_ch/fold_6.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_6 import datamodule as mc_datamodule from ..shenzhen.fold_6 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_6_rgb.py b/src/ptbench/data/mc_ch/fold_6_rgb.py index 79abe09b..ecadb9bf 100644 --- a/src/ptbench/data/mc_ch/fold_6_rgb.py +++ b/src/ptbench/data/mc_ch/fold_6_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_6_rgb import datamodule as mc_datamodule from ..shenzhen.fold_6_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_7.py b/src/ptbench/data/mc_ch/fold_7.py index a94621e6..6514fd1a 100644 --- a/src/ptbench/data/mc_ch/fold_7.py +++ b/src/ptbench/data/mc_ch/fold_7.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_7 import datamodule as mc_datamodule from ..shenzhen.fold_7 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_7_rgb.py b/src/ptbench/data/mc_ch/fold_7_rgb.py index 90b866e1..9123adf3 100644 --- a/src/ptbench/data/mc_ch/fold_7_rgb.py +++ b/src/ptbench/data/mc_ch/fold_7_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_7_rgb import datamodule as mc_datamodule from ..shenzhen.fold_7_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_8.py b/src/ptbench/data/mc_ch/fold_8.py index aa52bc81..e6050636 100644 --- a/src/ptbench/data/mc_ch/fold_8.py +++ b/src/ptbench/data/mc_ch/fold_8.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_8 import datamodule as mc_datamodule from ..shenzhen.fold_8 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_8_rgb.py b/src/ptbench/data/mc_ch/fold_8_rgb.py index 3df1838d..4ff54baa 100644 --- a/src/ptbench/data/mc_ch/fold_8_rgb.py +++ b/src/ptbench/data/mc_ch/fold_8_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_8_rgb import datamodule as mc_datamodule from ..shenzhen.fold_8_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_9.py b/src/ptbench/data/mc_ch/fold_9.py index 4bb4a5a3..cdd69ba7 100644 --- a/src/ptbench/data/mc_ch/fold_9.py +++ b/src/ptbench/data/mc_ch/fold_9.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_9 import datamodule as mc_datamodule from ..shenzhen.fold_9 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_9_rgb.py b/src/ptbench/data/mc_ch/fold_9_rgb.py index a07ffce4..465e6c3d 100644 --- a/src/ptbench/data/mc_ch/fold_9_rgb.py +++ b/src/ptbench/data/mc_ch/fold_9_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_9_rgb import datamodule as mc_datamodule from ..shenzhen.fold_9_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/rgb.py b/src/ptbench/data/mc_ch/rgb.py index 05407eae..4eee3c86 100644 --- a/src/ptbench/data/mc_ch/rgb.py +++ b/src/ptbench/data/mc_ch/rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.rgb import datamodule as mc_datamodule from ..shenzhen.rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} -- GitLab