Skip to content
Snippets Groups Projects
Commit 7f86dd1f authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

DataModule instanciation directly inside config file

parent 1b773088
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -38,6 +38,7 @@ class DefaultModule(BaseDataModule):
self.cache_samples = cache_samples
self.has_setup_fit = False
self.has_setup_predict = False
def setup(self, stage: str):
if self.cache_samples:
......@@ -51,7 +52,7 @@ class DefaultModule(BaseDataModule):
)
samples_loader = _delayed_loader
self.json_dataset = JSONDataset(
json_dataset = JSONDataset(
protocols=_protocols,
fieldnames=("data", "label"),
loader=samples_loader,
......@@ -62,8 +63,17 @@ class DefaultModule(BaseDataModule):
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
) = return_subsets(self.json_dataset, "default", stage)
) = return_subsets(json_dataset, "default", stage)
self.has_setup_fit = True
if not self.has_setup_predict and stage == "predict":
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
) = return_subsets(json_dataset, "default", stage)
self.has_setup_predict = True
datamodule = DefaultModule
datamodule = DefaultModule()
......@@ -18,13 +18,15 @@ class BaseDataModule(pl.LightningDataModule):
self,
train_batch_size=1,
predict_batch_size=1,
batch_chunk_count=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
multiproc_kwargs={},
):
super().__init__()
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
self.batch_chunk_count = batch_chunk_count
self.drop_incomplete_batch = drop_incomplete_batch
self.pin_memory = (
......@@ -47,7 +49,7 @@ class BaseDataModule(pl.LightningDataModule):
return DataLoader(
self.train_dataset,
batch_size=self.train_batch_size,
batch_size=self.compute_chunk_size(self.train_batch_size),
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
sampler=train_sampler,
......@@ -59,7 +61,7 @@ class BaseDataModule(pl.LightningDataModule):
val_loader = DataLoader(
dataset=self.validation_dataset,
batch_size=self.train_batch_size,
batch_size=self.compute_chunk_size(self.train_batch_size),
shuffle=False,
drop_last=False,
pin_memory=self.pin_memory,
......@@ -86,12 +88,26 @@ class BaseDataModule(pl.LightningDataModule):
return loaders_dict
def predict_dataloader(self):
return DataLoader(
dataset=self.predict_dataset,
batch_size=self.predict_batch_size,
shuffle=False,
pin_memory=self.pin_memory,
)
loaders_dict = {}
loaders_dict["train_dataloader"] = self.train_dataloader()
for k, v in self.val_dataloader().items():
loaders_dict[k] = v
return loaders_dict
def compute_chunk_size(self, batch_size):
batch_chunk_size = batch_size
if batch_size % self.batch_chunk_count != 0:
# batch_size must be divisible by batch_chunk_count.
raise RuntimeError(
f"--batch-size ({batch_size}) must be divisible by "
f"--batch-chunk-size ({self.batch_chunk_count})."
)
else:
batch_chunk_size = batch_size // self.batch_chunk_count
return batch_chunk_size
def get_dataset_from_module(module, stage, **module_args):
......
......@@ -288,22 +288,10 @@ def train(
"multiprocessing_context"
] = multiprocessing.get_context("spawn")
batch_chunk_size = batch_size
if batch_size % batch_chunk_count != 0:
# batch_size must be divisible by batch_chunk_count.
raise RuntimeError(
f"--batch-size ({batch_size}) must be divisible by "
f"--batch-chunk-size ({batch_chunk_count})."
)
else:
batch_chunk_size = batch_size // batch_chunk_count
datamodule.train_batch_size = batch_size
datamodule.batch_chunk_count = batch_chunk_count
datamodule.multiproc_kwargs = multiproc_kwargs
datamodule = datamodule(
train_batch_size=batch_chunk_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
cache_samples=cache_samples,
)
# Manually calling these as we need to access some values to reweight the criterion
datamodule.prepare_data()
datamodule.setup(stage="fit")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment