diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py index fb0970f0b2c1ebb1cfcae291abd626428f556bf6..5e656d428c6ebe64dde7858f368149481b8e1c2e 100644 --- a/src/ptbench/data/base_datamodule.py +++ b/src/ptbench/data/base_datamodule.py @@ -92,3 +92,16 @@ class BaseDataModule(pl.LightningDataModule): shuffle=False, pin_memory=self.pin_memory, ) + + +def get_dataset_from_module(module, stage, **module_args): + """Instantiates a DataModule and retrieves the corresponding dataset. + + Useful when combining multiple datasets. + """ + module_instance = module(**module_args) + module_instance.prepare_data() + module_instance.setup(stage=stage) + dataset = module_instance.dataset + + return dataset diff --git a/src/ptbench/data/mc_ch/default.py b/src/ptbench/data/mc_ch/default.py index 0af1d68682ec2b91ae8febf1adcebf22704f0982..cf901fdb41b14037c1827d364fedee1f70b1a665 100644 --- a/src/ptbench/data/mc_ch/default.py +++ b/src/ptbench/data/mc_ch/default.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.default import datamodule as mc_datamodule from ..shenzhen.default import datamodule as ch_datamodule @@ -37,27 +37,16 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_0.py b/src/ptbench/data/mc_ch/fold_0.py index e151c0e0eb0703b8a8b3bc83f68d8f020053cc14..eeb8f5128f56dc41f383ebd1d58b67b8cfd5ae27 100644 --- a/src/ptbench/data/mc_ch/fold_0.py +++ b/src/ptbench/data/mc_ch/fold_0.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_0 import datamodule as mc_datamodule from ..shenzhen.fold_0 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_0_rgb.py b/src/ptbench/data/mc_ch/fold_0_rgb.py index 56502b24651879a8106ddd6228eee70e1c4de205..6c8c5aebce6a1b972a8f418a66ff2d7cbdc5ae30 100644 --- a/src/ptbench/data/mc_ch/fold_0_rgb.py +++ b/src/ptbench/data/mc_ch/fold_0_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_0_rgb import datamodule as mc_datamodule from ..shenzhen.fold_0_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_1.py b/src/ptbench/data/mc_ch/fold_1.py index 732513da40421d271b86a087731a56f6e1bafbd5..b6bcb0f70e15097da82a2b8f824ee7178e5f1972 100644 --- a/src/ptbench/data/mc_ch/fold_1.py +++ b/src/ptbench/data/mc_ch/fold_1.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_1 import datamodule as mc_datamodule from ..shenzhen.fold_1 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_1_rgb.py b/src/ptbench/data/mc_ch/fold_1_rgb.py index 6cfbcb366f71146cc7f20c5d27ea1998cd209bc1..2570951009386568ac10cf4a7de705c54ea71a9b 100644 --- a/src/ptbench/data/mc_ch/fold_1_rgb.py +++ b/src/ptbench/data/mc_ch/fold_1_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_1_rgb import datamodule as mc_datamodule from ..shenzhen.fold_1_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_2.py b/src/ptbench/data/mc_ch/fold_2.py index 1d4ac5b58bb7d2500932061be0e9a252139e6274..e3ac99ece5ded748c6d1cd772846986e3ff1d9d0 100644 --- a/src/ptbench/data/mc_ch/fold_2.py +++ b/src/ptbench/data/mc_ch/fold_2.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_2 import datamodule as mc_datamodule from ..shenzhen.fold_2 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_2_rgb.py b/src/ptbench/data/mc_ch/fold_2_rgb.py index eec98dca20441bf30fb6ae8250303dd970b68148..1bbce20d34d7d2083fc4fb35effd6022b6028b28 100644 --- a/src/ptbench/data/mc_ch/fold_2_rgb.py +++ b/src/ptbench/data/mc_ch/fold_2_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_2_rgb import datamodule as mc_datamodule from ..shenzhen.fold_2_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_3.py b/src/ptbench/data/mc_ch/fold_3.py index b97b5e944b94797c518a56b8f0021b5b890766ec..ed58cac744caa2e7a30378c90c967d2406a30ef5 100644 --- a/src/ptbench/data/mc_ch/fold_3.py +++ b/src/ptbench/data/mc_ch/fold_3.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_3 import datamodule as mc_datamodule from ..shenzhen.fold_3 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_3_rgb.py b/src/ptbench/data/mc_ch/fold_3_rgb.py index e380dc33b3653632a3c8f7f5324ab2b0bfb737e2..7d9401f0ece0f316c958cebb2eed35f5ec0cabbc 100644 --- a/src/ptbench/data/mc_ch/fold_3_rgb.py +++ b/src/ptbench/data/mc_ch/fold_3_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_3_rgb import datamodule as mc_datamodule from ..shenzhen.fold_3_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_4.py b/src/ptbench/data/mc_ch/fold_4.py index 8cd906b53064644499f45a07f082e2f91acb76fc..761d730e0c4a00c230ec47be966d52bfce72c3bc 100644 --- a/src/ptbench/data/mc_ch/fold_4.py +++ b/src/ptbench/data/mc_ch/fold_4.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_4 import datamodule as mc_datamodule from ..shenzhen.fold_4 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_4_rgb.py b/src/ptbench/data/mc_ch/fold_4_rgb.py index 7ba0ecfebd957bf6778fd0c34731974ddec50434..3f1a147634fa41d06299765ec44aab2d1271e429 100644 --- a/src/ptbench/data/mc_ch/fold_4_rgb.py +++ b/src/ptbench/data/mc_ch/fold_4_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_4_rgb import datamodule as mc_datamodule from ..shenzhen.fold_4_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_5.py b/src/ptbench/data/mc_ch/fold_5.py index 3f20a33b633df0aba8b65e5849709634b96cce6e..01a8f31ebca9871ce5e5cee8e5c9106c86335362 100644 --- a/src/ptbench/data/mc_ch/fold_5.py +++ b/src/ptbench/data/mc_ch/fold_5.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_5 import datamodule as mc_datamodule from ..shenzhen.fold_5 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_5_rgb.py b/src/ptbench/data/mc_ch/fold_5_rgb.py index 61159ed6e37be4e898ac37d3d2b1cf4f71b5347a..cf66986fe9b48119f0a5510ad6d7547513542346 100644 --- a/src/ptbench/data/mc_ch/fold_5_rgb.py +++ b/src/ptbench/data/mc_ch/fold_5_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_5_rgb import datamodule as mc_datamodule from ..shenzhen.fold_5_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_6.py b/src/ptbench/data/mc_ch/fold_6.py index 7413791054f2d14948f3ee4f5f12817e172872a8..5b488da973f00e5a6718eac6117016ccfe3a359f 100644 --- a/src/ptbench/data/mc_ch/fold_6.py +++ b/src/ptbench/data/mc_ch/fold_6.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_6 import datamodule as mc_datamodule from ..shenzhen.fold_6 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_6_rgb.py b/src/ptbench/data/mc_ch/fold_6_rgb.py index 79abe09b973fe27a54af0cd1db6d855f1e93d2f4..ecadb9bfabfd1348281c63eb97339c5bad990f6a 100644 --- a/src/ptbench/data/mc_ch/fold_6_rgb.py +++ b/src/ptbench/data/mc_ch/fold_6_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_6_rgb import datamodule as mc_datamodule from ..shenzhen.fold_6_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_7.py b/src/ptbench/data/mc_ch/fold_7.py index a94621e61d846e596ff8422c8abce39c896ee028..6514fd1aba701fff1dcaaaf4791d6ae9b861de22 100644 --- a/src/ptbench/data/mc_ch/fold_7.py +++ b/src/ptbench/data/mc_ch/fold_7.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_7 import datamodule as mc_datamodule from ..shenzhen.fold_7 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_7_rgb.py b/src/ptbench/data/mc_ch/fold_7_rgb.py index 90b866e1d0b506e3d5d03a76472815fb6165294f..9123adf3676c08acbd7d32a1b2747ffd314dae38 100644 --- a/src/ptbench/data/mc_ch/fold_7_rgb.py +++ b/src/ptbench/data/mc_ch/fold_7_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_7_rgb import datamodule as mc_datamodule from ..shenzhen.fold_7_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_8.py b/src/ptbench/data/mc_ch/fold_8.py index aa52bc818fd9dc4daff55c6691ada122813fbfd8..e60506363bf508a39f3bb2c9806f8bd3bcc9e83c 100644 --- a/src/ptbench/data/mc_ch/fold_8.py +++ b/src/ptbench/data/mc_ch/fold_8.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_8 import datamodule as mc_datamodule from ..shenzhen.fold_8 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_8_rgb.py b/src/ptbench/data/mc_ch/fold_8_rgb.py index 3df1838d21c7c8c3cdb4fb686edad87f053e564f..4ff54baa2f5424993ed747edd4ba136da9324981 100644 --- a/src/ptbench/data/mc_ch/fold_8_rgb.py +++ b/src/ptbench/data/mc_ch/fold_8_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_8_rgb import datamodule as mc_datamodule from ..shenzhen.fold_8_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_9.py b/src/ptbench/data/mc_ch/fold_9.py index 4bb4a5a3c9b70a35d4cd7665bc5221ae3948419e..cdd69ba7deeeb2759cdabc1134379145d5537696 100644 --- a/src/ptbench/data/mc_ch/fold_9.py +++ b/src/ptbench/data/mc_ch/fold_9.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_9 import datamodule as mc_datamodule from ..shenzhen.fold_9 import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/fold_9_rgb.py b/src/ptbench/data/mc_ch/fold_9_rgb.py index a07ffce4a0b6d567e85eff06162a845a3f4c6bab..465e6c3d3d2c65eb367f7a97d7bb296a3be3dab1 100644 --- a/src/ptbench/data/mc_ch/fold_9_rgb.py +++ b/src/ptbench/data/mc_ch/fold_9_rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.fold_9_rgb import datamodule as mc_datamodule from ..shenzhen.fold_9_rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {} diff --git a/src/ptbench/data/mc_ch/rgb.py b/src/ptbench/data/mc_ch/rgb.py index 05407eae2ee826f5706770bbec3feb2e6aa9426f..4eee3c86f6150859a45ecd0e5e9d22dc45adc637 100644 --- a/src/ptbench/data/mc_ch/rgb.py +++ b/src/ptbench/data/mc_ch/rgb.py @@ -8,7 +8,7 @@ from clapper.logging import setup from torch.utils.data.dataset import ConcatDataset from .. import return_subsets -from ..base_datamodule import BaseDataModule +from ..base_datamodule import BaseDataModule, get_dataset_from_module from ..montgomery.rgb import datamodule as mc_datamodule from ..shenzhen.rgb import datamodule as ch_datamodule @@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): def setup(self, stage: str): # Instantiate other datamodules and get their datasets - mc_module = mc_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) - - mc_module.prepare_data() - mc_module.setup(stage=stage) - mc = mc_module.dataset - - ch_module = ch_datamodule( - train_batch_size=self.train_batch_size, - predict_batch_size=self.predict_batch_size, - drop_incomplete_batch=self.drop_incomplete_batch, - multiproc_kwargs=self.multiproc_kwargs, - ) + module_args = { + "train_batch_size": self.train_batch_size, + "predict_batch_size": self.predict_batch_size, + "drop_incomplete_batch": self.drop_incomplete_batch, + "multiproc_kwargs": self.multiproc_kwargs, + } - ch_module.prepare_data() - ch_module.setup(stage=stage) - ch = ch_module.dataset + mc = get_dataset_from_module(mc_datamodule, stage, **module_args) + ch = get_dataset_from_module(ch_datamodule, stage, **module_args) # Combine datasets self.dataset = {}