Skip to content
Snippets Groups Projects
Commit 113d553f authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Function to instantiate DataModule and retrieve dataset from it (DRY)

parent 0ba35d8c
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
Showing
with 184 additions and 398 deletions
...@@ -92,3 +92,16 @@ class BaseDataModule(pl.LightningDataModule): ...@@ -92,3 +92,16 @@ class BaseDataModule(pl.LightningDataModule):
shuffle=False, shuffle=False,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
def get_dataset_from_module(module, stage, **module_args):
"""Instantiates a DataModule and retrieves the corresponding dataset.
Useful when combining multiple datasets.
"""
module_instance = module(**module_args)
module_instance.prepare_data()
module_instance.setup(stage=stage)
dataset = module_instance.dataset
return dataset
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.default import datamodule as mc_datamodule from ..montgomery.default import datamodule as mc_datamodule
from ..shenzhen.default import datamodule as ch_datamodule from ..shenzhen.default import datamodule as ch_datamodule
...@@ -37,27 +37,16 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,16 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_0 import datamodule as mc_datamodule from ..montgomery.fold_0 import datamodule as mc_datamodule
from ..shenzhen.fold_0 import datamodule as ch_datamodule from ..shenzhen.fold_0 import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_0_rgb import datamodule as mc_datamodule from ..montgomery.fold_0_rgb import datamodule as mc_datamodule
from ..shenzhen.fold_0_rgb import datamodule as ch_datamodule from ..shenzhen.fold_0_rgb import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_1 import datamodule as mc_datamodule from ..montgomery.fold_1 import datamodule as mc_datamodule
from ..shenzhen.fold_1 import datamodule as ch_datamodule from ..shenzhen.fold_1 import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_1_rgb import datamodule as mc_datamodule from ..montgomery.fold_1_rgb import datamodule as mc_datamodule
from ..shenzhen.fold_1_rgb import datamodule as ch_datamodule from ..shenzhen.fold_1_rgb import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_2 import datamodule as mc_datamodule from ..montgomery.fold_2 import datamodule as mc_datamodule
from ..shenzhen.fold_2 import datamodule as ch_datamodule from ..shenzhen.fold_2 import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_2_rgb import datamodule as mc_datamodule from ..montgomery.fold_2_rgb import datamodule as mc_datamodule
from ..shenzhen.fold_2_rgb import datamodule as ch_datamodule from ..shenzhen.fold_2_rgb import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_3 import datamodule as mc_datamodule from ..montgomery.fold_3 import datamodule as mc_datamodule
from ..shenzhen.fold_3 import datamodule as ch_datamodule from ..shenzhen.fold_3 import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_3_rgb import datamodule as mc_datamodule from ..montgomery.fold_3_rgb import datamodule as mc_datamodule
from ..shenzhen.fold_3_rgb import datamodule as ch_datamodule from ..shenzhen.fold_3_rgb import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_4 import datamodule as mc_datamodule from ..montgomery.fold_4 import datamodule as mc_datamodule
from ..shenzhen.fold_4 import datamodule as ch_datamodule from ..shenzhen.fold_4 import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_4_rgb import datamodule as mc_datamodule from ..montgomery.fold_4_rgb import datamodule as mc_datamodule
from ..shenzhen.fold_4_rgb import datamodule as ch_datamodule from ..shenzhen.fold_4_rgb import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_5 import datamodule as mc_datamodule from ..montgomery.fold_5 import datamodule as mc_datamodule
from ..shenzhen.fold_5 import datamodule as ch_datamodule from ..shenzhen.fold_5 import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_5_rgb import datamodule as mc_datamodule from ..montgomery.fold_5_rgb import datamodule as mc_datamodule
from ..shenzhen.fold_5_rgb import datamodule as ch_datamodule from ..shenzhen.fold_5_rgb import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_6 import datamodule as mc_datamodule from ..montgomery.fold_6 import datamodule as mc_datamodule
from ..shenzhen.fold_6 import datamodule as ch_datamodule from ..shenzhen.fold_6 import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_6_rgb import datamodule as mc_datamodule from ..montgomery.fold_6_rgb import datamodule as mc_datamodule
from ..shenzhen.fold_6_rgb import datamodule as ch_datamodule from ..shenzhen.fold_6_rgb import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_7 import datamodule as mc_datamodule from ..montgomery.fold_7 import datamodule as mc_datamodule
from ..shenzhen.fold_7 import datamodule as ch_datamodule from ..shenzhen.fold_7 import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_7_rgb import datamodule as mc_datamodule from ..montgomery.fold_7_rgb import datamodule as mc_datamodule
from ..shenzhen.fold_7_rgb import datamodule as ch_datamodule from ..shenzhen.fold_7_rgb import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_8 import datamodule as mc_datamodule from ..montgomery.fold_8 import datamodule as mc_datamodule
from ..shenzhen.fold_8 import datamodule as ch_datamodule from ..shenzhen.fold_8 import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
...@@ -8,7 +8,7 @@ from clapper.logging import setup ...@@ -8,7 +8,7 @@ from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets from .. import return_subsets
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_8_rgb import datamodule as mc_datamodule from ..montgomery.fold_8_rgb import datamodule as mc_datamodule
from ..shenzhen.fold_8_rgb import datamodule as ch_datamodule from ..shenzhen.fold_8_rgb import datamodule as ch_datamodule
...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule): ...@@ -37,27 +37,15 @@ class DefaultModule(BaseDataModule):
def setup(self, stage: str): def setup(self, stage: str):
# Instantiate other datamodules and get their datasets # Instantiate other datamodules and get their datasets
mc_module = mc_datamodule( module_args = {
train_batch_size=self.train_batch_size, "train_batch_size": self.train_batch_size,
predict_batch_size=self.predict_batch_size, "predict_batch_size": self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch, "drop_incomplete_batch": self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs, "multiproc_kwargs": self.multiproc_kwargs,
) }
mc_module.prepare_data()
mc_module.setup(stage=stage)
mc = mc_module.dataset
ch_module = ch_datamodule(
train_batch_size=self.train_batch_size,
predict_batch_size=self.predict_batch_size,
drop_incomplete_batch=self.drop_incomplete_batch,
multiproc_kwargs=self.multiproc_kwargs,
)
ch_module.prepare_data() mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch_module.setup(stage=stage) ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
ch = ch_module.dataset
# Combine datasets # Combine datasets
self.dataset = {} self.dataset = {}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment