Skip to content
Snippets Groups Projects
Commit 748acd57 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.montgomery_shenzhen] Create first concatenated datamodule

parent 978f04e8
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
Showing
with 78 additions and 808 deletions
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of Montgomery and Shenzhen datasets."""
from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets
from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.default import datamodule as mc_datamodule
from ..shenzhen.default import datamodule as ch_datamodule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.drop_incomplete_batch = drop_incomplete_batch
self.multiproc_kwargs = multiproc_kwargs
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
# Instantiate other datamodules and get their datasets
module_args = {
"train_batch_size": self.train_batch_size,
"predict_batch_size": self.predict_batch_size,
"drop_incomplete_batch": self.drop_incomplete_batch,
"multiproc_kwargs": self.multiproc_kwargs,
}
mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
# Combine datasets
self.dataset = {}
self.dataset["__train__"] = ConcatDataset(
[mc["__train__"], ch["__train__"]]
)
self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
self.dataset["__valid__"] = ConcatDataset(
[mc["__valid__"], ch["__valid__"]]
)
self.dataset["validation"] = ConcatDataset(
[mc["validation"], ch["validation"]]
)
self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
validation fold 0)"""
from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets
from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_0 import datamodule as mc_datamodule
from ..shenzhen.fold_0 import datamodule as ch_datamodule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.drop_incomplete_batch = drop_incomplete_batch
self.multiproc_kwargs = multiproc_kwargs
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
# Instantiate other datamodules and get their datasets
module_args = {
"train_batch_size": self.train_batch_size,
"predict_batch_size": self.predict_batch_size,
"drop_incomplete_batch": self.drop_incomplete_batch,
"multiproc_kwargs": self.multiproc_kwargs,
}
mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
# Combine datasets
self.dataset = {}
self.dataset["__train__"] = ConcatDataset(
[mc["__train__"], ch["__train__"]]
)
self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
self.dataset["__valid__"] = ConcatDataset(
[mc["__valid__"], ch["__valid__"]]
)
self.dataset["validation"] = ConcatDataset(
[mc["validation"], ch["validation"]]
)
self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
validation fold 1)"""
from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets
from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_1 import datamodule as mc_datamodule
from ..shenzhen.fold_1 import datamodule as ch_datamodule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.drop_incomplete_batch = drop_incomplete_batch
self.multiproc_kwargs = multiproc_kwargs
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
# Instantiate other datamodules and get their datasets
module_args = {
"train_batch_size": self.train_batch_size,
"predict_batch_size": self.predict_batch_size,
"drop_incomplete_batch": self.drop_incomplete_batch,
"multiproc_kwargs": self.multiproc_kwargs,
}
mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
# Combine datasets
self.dataset = {}
self.dataset["__train__"] = ConcatDataset(
[mc["__train__"], ch["__train__"]]
)
self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
self.dataset["__valid__"] = ConcatDataset(
[mc["__valid__"], ch["__valid__"]]
)
self.dataset["validation"] = ConcatDataset(
[mc["validation"], ch["validation"]]
)
self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
validation fold 2)"""
from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets
from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_2 import datamodule as mc_datamodule
from ..shenzhen.fold_2 import datamodule as ch_datamodule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.drop_incomplete_batch = drop_incomplete_batch
self.multiproc_kwargs = multiproc_kwargs
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
# Instantiate other datamodules and get their datasets
module_args = {
"train_batch_size": self.train_batch_size,
"predict_batch_size": self.predict_batch_size,
"drop_incomplete_batch": self.drop_incomplete_batch,
"multiproc_kwargs": self.multiproc_kwargs,
}
mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
# Combine datasets
self.dataset = {}
self.dataset["__train__"] = ConcatDataset(
[mc["__train__"], ch["__train__"]]
)
self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
self.dataset["__valid__"] = ConcatDataset(
[mc["__valid__"], ch["__valid__"]]
)
self.dataset["validation"] = ConcatDataset(
[mc["validation"], ch["validation"]]
)
self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
validation fold 3)"""
from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets
from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_3 import datamodule as mc_datamodule
from ..shenzhen.fold_3 import datamodule as ch_datamodule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.drop_incomplete_batch = drop_incomplete_batch
self.multiproc_kwargs = multiproc_kwargs
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
# Instantiate other datamodules and get their datasets
module_args = {
"train_batch_size": self.train_batch_size,
"predict_batch_size": self.predict_batch_size,
"drop_incomplete_batch": self.drop_incomplete_batch,
"multiproc_kwargs": self.multiproc_kwargs,
}
mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
# Combine datasets
self.dataset = {}
self.dataset["__train__"] = ConcatDataset(
[mc["__train__"], ch["__train__"]]
)
self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
self.dataset["__valid__"] = ConcatDataset(
[mc["__valid__"], ch["__valid__"]]
)
self.dataset["validation"] = ConcatDataset(
[mc["validation"], ch["validation"]]
)
self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
validation fold 4)"""
from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets
from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_4 import datamodule as mc_datamodule
from ..shenzhen.fold_4 import datamodule as ch_datamodule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.drop_incomplete_batch = drop_incomplete_batch
self.multiproc_kwargs = multiproc_kwargs
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
# Instantiate other datamodules and get their datasets
module_args = {
"train_batch_size": self.train_batch_size,
"predict_batch_size": self.predict_batch_size,
"drop_incomplete_batch": self.drop_incomplete_batch,
"multiproc_kwargs": self.multiproc_kwargs,
}
mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
# Combine datasets
self.dataset = {}
self.dataset["__train__"] = ConcatDataset(
[mc["__train__"], ch["__train__"]]
)
self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
self.dataset["__valid__"] = ConcatDataset(
[mc["__valid__"], ch["__valid__"]]
)
self.dataset["validation"] = ConcatDataset(
[mc["validation"], ch["validation"]]
)
self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
validation fold 5)"""
from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets
from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_5 import datamodule as mc_datamodule
from ..shenzhen.fold_5 import datamodule as ch_datamodule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.drop_incomplete_batch = drop_incomplete_batch
self.multiproc_kwargs = multiproc_kwargs
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
# Instantiate other datamodules and get their datasets
module_args = {
"train_batch_size": self.train_batch_size,
"predict_batch_size": self.predict_batch_size,
"drop_incomplete_batch": self.drop_incomplete_batch,
"multiproc_kwargs": self.multiproc_kwargs,
}
mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
# Combine datasets
self.dataset = {}
self.dataset["__train__"] = ConcatDataset(
[mc["__train__"], ch["__train__"]]
)
self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
self.dataset["__valid__"] = ConcatDataset(
[mc["__valid__"], ch["__valid__"]]
)
self.dataset["validation"] = ConcatDataset(
[mc["validation"], ch["validation"]]
)
self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
validation fold 6)"""
from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets
from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_6 import datamodule as mc_datamodule
from ..shenzhen.fold_6 import datamodule as ch_datamodule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.drop_incomplete_batch = drop_incomplete_batch
self.multiproc_kwargs = multiproc_kwargs
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
# Instantiate other datamodules and get their datasets
module_args = {
"train_batch_size": self.train_batch_size,
"predict_batch_size": self.predict_batch_size,
"drop_incomplete_batch": self.drop_incomplete_batch,
"multiproc_kwargs": self.multiproc_kwargs,
}
mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
# Combine datasets
self.dataset = {}
self.dataset["__train__"] = ConcatDataset(
[mc["__train__"], ch["__train__"]]
)
self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
self.dataset["__valid__"] = ConcatDataset(
[mc["__valid__"], ch["__valid__"]]
)
self.dataset["validation"] = ConcatDataset(
[mc["validation"], ch["validation"]]
)
self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
validation fold 7)"""
from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets
from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_7 import datamodule as mc_datamodule
from ..shenzhen.fold_7 import datamodule as ch_datamodule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.drop_incomplete_batch = drop_incomplete_batch
self.multiproc_kwargs = multiproc_kwargs
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
# Instantiate other datamodules and get their datasets
module_args = {
"train_batch_size": self.train_batch_size,
"predict_batch_size": self.predict_batch_size,
"drop_incomplete_batch": self.drop_incomplete_batch,
"multiproc_kwargs": self.multiproc_kwargs,
}
mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
# Combine datasets
self.dataset = {}
self.dataset["__train__"] = ConcatDataset(
[mc["__train__"], ch["__train__"]]
)
self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
self.dataset["__valid__"] = ConcatDataset(
[mc["__valid__"], ch["__valid__"]]
)
self.dataset["validation"] = ConcatDataset(
[mc["validation"], ch["validation"]]
)
self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
validation fold 8)"""
from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets
from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_8 import datamodule as mc_datamodule
from ..shenzhen.fold_8 import datamodule as ch_datamodule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.drop_incomplete_batch = drop_incomplete_batch
self.multiproc_kwargs = multiproc_kwargs
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
# Instantiate other datamodules and get their datasets
module_args = {
"train_batch_size": self.train_batch_size,
"predict_batch_size": self.predict_batch_size,
"drop_incomplete_batch": self.drop_incomplete_batch,
"multiproc_kwargs": self.multiproc_kwargs,
}
mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
# Combine datasets
self.dataset = {}
self.dataset["__train__"] = ConcatDataset(
[mc["__train__"], ch["__train__"]]
)
self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
self.dataset["__valid__"] = ConcatDataset(
[mc["__valid__"], ch["__valid__"]]
)
self.dataset["validation"] = ConcatDataset(
[mc["validation"], ch["validation"]]
)
self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Aggregated dataset composed of Montgomery and Shenzhen datasets (cross
validation fold 9)"""
from clapper.logging import setup
from torch.utils.data.dataset import ConcatDataset
from .. import return_subsets
from ..base_datamodule import BaseDataModule, get_dataset_from_module
from ..montgomery.fold_9 import datamodule as mc_datamodule
from ..shenzhen.fold_9 import datamodule as ch_datamodule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.drop_incomplete_batch = drop_incomplete_batch
self.multiproc_kwargs = multiproc_kwargs
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
# Instantiate other datamodules and get their datasets
module_args = {
"train_batch_size": self.train_batch_size,
"predict_batch_size": self.predict_batch_size,
"drop_incomplete_batch": self.drop_incomplete_batch,
"multiproc_kwargs": self.multiproc_kwargs,
}
mc = get_dataset_from_module(mc_datamodule, stage, **module_args)
ch = get_dataset_from_module(ch_datamodule, stage, **module_args)
# Combine datasets
self.dataset = {}
self.dataset["__train__"] = ConcatDataset(
[mc["__train__"], ch["__train__"]]
)
self.dataset["train"] = ConcatDataset([mc["train"], ch["train"]])
self.dataset["__valid__"] = ConcatDataset(
[mc["__valid__"], ch["__valid__"]]
)
self.dataset["validation"] = ConcatDataset(
[mc["validation"], ch["validation"]]
)
self.dataset["test"] = ConcatDataset([mc["test"], ch["test"]])
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
# Copyright © 2022 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
from ..datamodule import ConcatDataModule
from ..montgomery.datamodule import RawDataLoader as MontgomeryLoader
from ..montgomery.datamodule import make_split as make_montgomery_split
from ..shenzhen.datamodule import RawDataLoader as ShenzhenLoader
from ..shenzhen.datamodule import make_split as make_shenzhen_split
class DataModule(ConcatDataModule):
"""Aggregated datamodule composed of Montgomery and Shenzhen datasets."""
def __init__(self, split_filename: str):
montgomery_loader = MontgomeryLoader()
montgomery_split = make_montgomery_split("default.json")
shenzen_loader = ShenzhenLoader()
shenzen_split = make_shenzhen_split("default.json")
super().__init__(
splits={
"train": [
(montgomery_split["train"], montgomery_loader),
(shenzen_split["train"], shenzen_loader),
],
"validation": [
(montgomery_split["validation"], montgomery_loader),
(shenzen_split["validation"], shenzen_loader),
],
"test": [
(montgomery_split["test"], montgomery_loader),
(shenzen_split["test"], shenzen_loader),
],
}
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
from .datamodule import DataModule
datamodule = DataModule("default.json")
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
from .datamodule import DataModule
datamodule = DataModule("fold_0.json")
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
from .datamodule import DataModule
datamodule = DataModule("fold_1.json")
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
from .datamodule import DataModule
datamodule = DataModule("fold_2.json")
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
from .datamodule import DataModule
datamodule = DataModule("fold_3.json")
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
from .datamodule import DataModule
datamodule = DataModule("fold_4.json")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment