Skip to content
Snippets Groups Projects
Commit b61397ae authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Setup datamodule only once per stage, only get needed subsets

parent 62c0c872
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -37,6 +37,7 @@ class DefaultModule(BaseDataModule): ...@@ -37,6 +37,7 @@ class DefaultModule(BaseDataModule):
) )
self.cache_samples = cache_samples self.cache_samples = cache_samples
self.has_setup_fit = False
def setup(self, stage: str): def setup(self, stage: str):
if self.cache_samples: if self.cache_samples:
...@@ -56,12 +57,13 @@ class DefaultModule(BaseDataModule): ...@@ -56,12 +57,13 @@ class DefaultModule(BaseDataModule):
loader=samples_loader, loader=samples_loader,
) )
( if not self.has_setup_fit and stage == "fit":
self.train_dataset, (
self.validation_dataset, self.train_dataset,
self.extra_validation_datasets, self.validation_dataset,
self.predict_dataset, self.extra_validation_datasets,
) = return_subsets(self.json_dataset, "default") ) = return_subsets(self.json_dataset, "default", stage)
self.has_setup_fit = True
datamodule = DefaultModule datamodule = DefaultModule
...@@ -37,6 +37,7 @@ class DefaultModule(BaseDataModule): ...@@ -37,6 +37,7 @@ class DefaultModule(BaseDataModule):
) )
self.cache_samples = cache_samples self.cache_samples = cache_samples
self.has_setup_fit = False
self.post_transforms = [ self.post_transforms = [
transforms.ToPILImage(), transforms.ToPILImage(),
...@@ -63,12 +64,13 @@ class DefaultModule(BaseDataModule): ...@@ -63,12 +64,13 @@ class DefaultModule(BaseDataModule):
post_transforms=self.post_transforms, post_transforms=self.post_transforms,
) )
( if not self.has_setup_fit and stage == "fit":
self.train_dataset, (
self.validation_dataset, self.train_dataset,
self.extra_validation_datasets, self.validation_dataset,
self.predict_dataset, self.extra_validation_datasets,
) = return_subsets(self.json_dataset, "default") ) = return_subsets(self.json_dataset, "default", stage)
self.has_setup_fit = True
datamodule = DefaultModule datamodule = DefaultModule
...@@ -303,49 +303,54 @@ def get_positive_weights(dataset): ...@@ -303,49 +303,54 @@ def get_positive_weights(dataset):
return positive_weights return positive_weights
def return_subsets(dataset, protocol): def return_subsets(dataset, protocol, stage):
train_dataset = None train_set = None
validation_dataset = None valid_set = None
extra_validation_datasets = None extra_valid_sets = None
predict_dataset = None
subsets = dataset.subsets(protocol) subsets = dataset.subsets(protocol)
if "train" in subsets.keys():
train_dataset = SampleListDataset(subsets["train"], [])
if "validation" in subsets.keys(): def get_train_subset():
validation_dataset = SampleListDataset(subsets["validation"], []) if "train" in subsets.keys():
else: nonlocal train_set
logger.warning( train_set = SampleListDataset(subsets["train"], [])
"No validation dataset found, using training set instead."
) def get_valid_subset():
validation_dataset = train_dataset if "validation" in subsets.keys():
nonlocal valid_set
if "__extra_valid__" in subsets.keys(): valid_set = SampleListDataset(subsets["validation"], [])
if not isinstance(subsets["__extra_valid__"], list): else:
raise RuntimeError( logger.warning(
f"If present, dataset['__extra_valid__'] must be a list, " "No validation dataset found, using training set instead."
f"but you passed a {type(subsets['__extra_valid__'])}, "
f"which is invalid."
) )
logger.info( if train_set is None:
f"Found {len(subsets['__extra_valid__'])} extra validation " get_train_subset()
f"set(s) to be tracked during training"
) valid_set = train_set
logger.info(
"Extra validation sets are NOT used for model checkpointing!" def get_extra_valid_subset():
) if "__extra_valid__" in subsets.keys():
extra_validation_datasets = SampleListDataset( if not isinstance(subsets["__extra_valid__"], list):
subsets["__extra_valid__"], [] raise RuntimeError(
) f"If present, dataset['__extra_valid__'] must be a list, "
else: f"but you passed a {type(subsets['__extra_valid__'])}, "
extra_validation_datasets = None f"which is invalid."
)
logger.info(
f"Found {len(subsets['__extra_valid__'])} extra validation "
f"set(s) to be tracked during training"
)
logger.info(
"Extra validation sets are NOT used for model checkpointing!"
)
nonlocal extra_valid_sets
extra_valid_sets = SampleListDataset(subsets["__extra_valid__"], [])
predict_dataset = subsets if stage == "fit":
get_train_subset()
get_valid_subset()
get_extra_valid_subset()
return ( return train_set, valid_set, extra_valid_sets
train_dataset, else:
validation_dataset, raise ValueError(f"Stage {stage} is unknown.")
extra_validation_datasets,
predict_dataset,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment