From 7f86dd1f2a552ea847d365b17c31b8044ecab183 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Mon, 19 Jun 2023 09:57:23 +0200 Subject: [PATCH] DataModule instanciation directly inside config file --- .../configs/datasets/shenzhen/default.py | 16 +++++++-- src/ptbench/data/base_datamodule.py | 34 ++++++++++++++----- src/ptbench/scripts/train.py | 18 ++-------- 3 files changed, 41 insertions(+), 27 deletions(-) diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py index 6dc59200..967d40e9 100644 --- a/src/ptbench/configs/datasets/shenzhen/default.py +++ b/src/ptbench/configs/datasets/shenzhen/default.py @@ -38,6 +38,7 @@ class DefaultModule(BaseDataModule): self.cache_samples = cache_samples self.has_setup_fit = False + self.has_setup_predict = False def setup(self, stage: str): if self.cache_samples: @@ -51,7 +52,7 @@ class DefaultModule(BaseDataModule): ) samples_loader = _delayed_loader - self.json_dataset = JSONDataset( + json_dataset = JSONDataset( protocols=_protocols, fieldnames=("data", "label"), loader=samples_loader, @@ -62,8 +63,17 @@ class DefaultModule(BaseDataModule): self.train_dataset, self.validation_dataset, self.extra_validation_datasets, - ) = return_subsets(self.json_dataset, "default", stage) + ) = return_subsets(json_dataset, "default", stage) self.has_setup_fit = True + if not self.has_setup_predict and stage == "predict": + ( + self.train_dataset, + self.validation_dataset, + self.extra_validation_datasets, + ) = return_subsets(json_dataset, "default", stage) + + self.has_setup_predict = True + -datamodule = DefaultModule +datamodule = DefaultModule() diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py index 5e656d42..1c51a105 100644 --- a/src/ptbench/data/base_datamodule.py +++ b/src/ptbench/data/base_datamodule.py @@ -18,13 +18,15 @@ class BaseDataModule(pl.LightningDataModule): self, train_batch_size=1, predict_batch_size=1, + batch_chunk_count=1, drop_incomplete_batch=False, - multiproc_kwargs=None, + multiproc_kwargs={}, ): super().__init__() self.train_batch_size = train_batch_size self.predict_batch_size = predict_batch_size + self.batch_chunk_count = batch_chunk_count self.drop_incomplete_batch = drop_incomplete_batch self.pin_memory = ( @@ -47,7 +49,7 @@ class BaseDataModule(pl.LightningDataModule): return DataLoader( self.train_dataset, - batch_size=self.train_batch_size, + batch_size=self.compute_chunk_size(self.train_batch_size), drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, sampler=train_sampler, @@ -59,7 +61,7 @@ class BaseDataModule(pl.LightningDataModule): val_loader = DataLoader( dataset=self.validation_dataset, - batch_size=self.train_batch_size, + batch_size=self.compute_chunk_size(self.train_batch_size), shuffle=False, drop_last=False, pin_memory=self.pin_memory, @@ -86,12 +88,26 @@ class BaseDataModule(pl.LightningDataModule): return loaders_dict def predict_dataloader(self): - return DataLoader( - dataset=self.predict_dataset, - batch_size=self.predict_batch_size, - shuffle=False, - pin_memory=self.pin_memory, - ) + loaders_dict = {} + + loaders_dict["train_dataloader"] = self.train_dataloader() + for k, v in self.val_dataloader().items(): + loaders_dict[k] = v + + return loaders_dict + + def compute_chunk_size(self, batch_size): + batch_chunk_size = batch_size + if batch_size % self.batch_chunk_count != 0: + # batch_size must be divisible by batch_chunk_count. + raise RuntimeError( + f"--batch-size ({batch_size}) must be divisible by " + f"--batch-chunk-size ({self.batch_chunk_count})." + ) + else: + batch_chunk_size = batch_size // self.batch_chunk_count + + return batch_chunk_size def get_dataset_from_module(module, stage, **module_args): diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index c59c81a3..3606c9f9 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -288,22 +288,10 @@ def train( "multiprocessing_context" ] = multiprocessing.get_context("spawn") - batch_chunk_size = batch_size - if batch_size % batch_chunk_count != 0: - # batch_size must be divisible by batch_chunk_count. - raise RuntimeError( - f"--batch-size ({batch_size}) must be divisible by " - f"--batch-chunk-size ({batch_chunk_count})." - ) - else: - batch_chunk_size = batch_size // batch_chunk_count + datamodule.train_batch_size = batch_size + datamodule.batch_chunk_count = batch_chunk_count + datamodule.multiproc_kwargs = multiproc_kwargs - datamodule = datamodule( - train_batch_size=batch_chunk_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - cache_samples=cache_samples, - ) # Manually calling these as we need to access some values to reweight the criterion datamodule.prepare_data() datamodule.setup(stage="fit") -- GitLab