Skip to content
Snippets Groups Projects
Commit 376c7874 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Setup datamodule only once per stage, only get needed subsets

parent 0b40b241
No related branches found
No related tags found
No related merge requests found
......@@ -37,6 +37,7 @@ class DefaultModule(BaseDataModule):
)
self.cache_samples = cache_samples
self.has_setup_fit = False
def setup(self, stage: str):
if self.cache_samples:
......@@ -56,12 +57,13 @@ class DefaultModule(BaseDataModule):
loader=samples_loader,
)
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.json_dataset, "default")
if not self.has_setup_fit and stage == "fit":
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
) = return_subsets(self.json_dataset, "default", stage)
self.has_setup_fit = True
datamodule = DefaultModule
......@@ -37,6 +37,7 @@ class DefaultModule(BaseDataModule):
)
self.cache_samples = cache_samples
self.has_setup_fit = False
self.post_transforms = [
transforms.ToPILImage(),
......@@ -63,12 +64,13 @@ class DefaultModule(BaseDataModule):
post_transforms=self.post_transforms,
)
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.json_dataset, "default")
if not self.has_setup_fit and stage == "fit":
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
) = return_subsets(self.json_dataset, "default", stage)
self.has_setup_fit = True
datamodule = DefaultModule
......@@ -303,49 +303,54 @@ def get_positive_weights(dataset):
return positive_weights
def return_subsets(dataset, protocol):
train_dataset = None
validation_dataset = None
extra_validation_datasets = None
predict_dataset = None
def return_subsets(dataset, protocol, stage):
train_set = None
valid_set = None
extra_valid_sets = None
subsets = dataset.subsets(protocol)
if "train" in subsets.keys():
train_dataset = SampleListDataset(subsets["train"], [])
if "validation" in subsets.keys():
validation_dataset = SampleListDataset(subsets["validation"], [])
else:
logger.warning(
"No validation dataset found, using training set instead."
)
validation_dataset = train_dataset
if "__extra_valid__" in subsets.keys():
if not isinstance(subsets["__extra_valid__"], list):
raise RuntimeError(
f"If present, dataset['__extra_valid__'] must be a list, "
f"but you passed a {type(subsets['__extra_valid__'])}, "
f"which is invalid."
def get_train_subset():
if "train" in subsets.keys():
nonlocal train_set
train_set = SampleListDataset(subsets["train"], [])
def get_valid_subset():
if "validation" in subsets.keys():
nonlocal valid_set
valid_set = SampleListDataset(subsets["validation"], [])
else:
logger.warning(
"No validation dataset found, using training set instead."
)
logger.info(
f"Found {len(subsets['__extra_valid__'])} extra validation "
f"set(s) to be tracked during training"
)
logger.info(
"Extra validation sets are NOT used for model checkpointing!"
)
extra_validation_datasets = SampleListDataset(
subsets["__extra_valid__"], []
)
else:
extra_validation_datasets = None
if train_set is None:
get_train_subset()
valid_set = train_set
def get_extra_valid_subset():
if "__extra_valid__" in subsets.keys():
if not isinstance(subsets["__extra_valid__"], list):
raise RuntimeError(
f"If present, dataset['__extra_valid__'] must be a list, "
f"but you passed a {type(subsets['__extra_valid__'])}, "
f"which is invalid."
)
logger.info(
f"Found {len(subsets['__extra_valid__'])} extra validation "
f"set(s) to be tracked during training"
)
logger.info(
"Extra validation sets are NOT used for model checkpointing!"
)
nonlocal extra_valid_sets
extra_valid_sets = SampleListDataset(subsets["__extra_valid__"], [])
predict_dataset = subsets
if stage == "fit":
get_train_subset()
get_valid_subset()
get_extra_valid_subset()
return (
train_dataset,
validation_dataset,
extra_validation_datasets,
predict_dataset,
)
return train_set, valid_set, extra_valid_sets
else:
raise ValueError(f"Stage {stage} is unknown.")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment