diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py index 6dc59200b2c751d7e304393a9b28abee18d40fa5..967d40e9358c8f946f876199cac56a921a89606f 100644 --- a/src/ptbench/configs/datasets/shenzhen/default.py +++ b/src/ptbench/configs/datasets/shenzhen/default.py @@ -38,6 +38,7 @@ class DefaultModule(BaseDataModule): self.cache_samples = cache_samples self.has_setup_fit = False + self.has_setup_predict = False def setup(self, stage: str): if self.cache_samples: @@ -51,7 +52,7 @@ class DefaultModule(BaseDataModule): ) samples_loader = _delayed_loader - self.json_dataset = JSONDataset( + json_dataset = JSONDataset( protocols=_protocols, fieldnames=("data", "label"), loader=samples_loader, @@ -62,8 +63,17 @@ class DefaultModule(BaseDataModule): self.train_dataset, self.validation_dataset, self.extra_validation_datasets, - ) = return_subsets(self.json_dataset, "default", stage) + ) = return_subsets(json_dataset, "default", stage) self.has_setup_fit = True + if not self.has_setup_predict and stage == "predict": + ( + self.train_dataset, + self.validation_dataset, + self.extra_validation_datasets, + ) = return_subsets(json_dataset, "default", stage) + + self.has_setup_predict = True + -datamodule = DefaultModule +datamodule = DefaultModule() diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py index 5e656d428c6ebe64dde7858f368149481b8e1c2e..1c51a1055dc226efc72292a3e439041c67a6d37b 100644 --- a/src/ptbench/data/base_datamodule.py +++ b/src/ptbench/data/base_datamodule.py @@ -18,13 +18,15 @@ class BaseDataModule(pl.LightningDataModule): self, train_batch_size=1, predict_batch_size=1, + batch_chunk_count=1, drop_incomplete_batch=False, - multiproc_kwargs=None, + multiproc_kwargs={}, ): super().__init__() self.train_batch_size = train_batch_size self.predict_batch_size = predict_batch_size + self.batch_chunk_count = batch_chunk_count self.drop_incomplete_batch = drop_incomplete_batch self.pin_memory = ( @@ -47,7 +49,7 @@ class BaseDataModule(pl.LightningDataModule): return DataLoader( self.train_dataset, - batch_size=self.train_batch_size, + batch_size=self.compute_chunk_size(self.train_batch_size), drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, sampler=train_sampler, @@ -59,7 +61,7 @@ class BaseDataModule(pl.LightningDataModule): val_loader = DataLoader( dataset=self.validation_dataset, - batch_size=self.train_batch_size, + batch_size=self.compute_chunk_size(self.train_batch_size), shuffle=False, drop_last=False, pin_memory=self.pin_memory, @@ -86,12 +88,26 @@ class BaseDataModule(pl.LightningDataModule): return loaders_dict def predict_dataloader(self): - return DataLoader( - dataset=self.predict_dataset, - batch_size=self.predict_batch_size, - shuffle=False, - pin_memory=self.pin_memory, - ) + loaders_dict = {} + + loaders_dict["train_dataloader"] = self.train_dataloader() + for k, v in self.val_dataloader().items(): + loaders_dict[k] = v + + return loaders_dict + + def compute_chunk_size(self, batch_size): + batch_chunk_size = batch_size + if batch_size % self.batch_chunk_count != 0: + # batch_size must be divisible by batch_chunk_count. + raise RuntimeError( + f"--batch-size ({batch_size}) must be divisible by " + f"--batch-chunk-size ({self.batch_chunk_count})." + ) + else: + batch_chunk_size = batch_size // self.batch_chunk_count + + return batch_chunk_size def get_dataset_from_module(module, stage, **module_args): diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index c59c81a3bc8f325ad20d883d58b8fa81bd51bbb7..3606c9f9d2c0179f171acc4b3dcba7987242047e 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -288,22 +288,10 @@ def train( "multiprocessing_context" ] = multiprocessing.get_context("spawn") - batch_chunk_size = batch_size - if batch_size % batch_chunk_count != 0: - # batch_size must be divisible by batch_chunk_count. - raise RuntimeError( - f"--batch-size ({batch_size}) must be divisible by " - f"--batch-chunk-size ({batch_chunk_count})." - ) - else: - batch_chunk_size = batch_size // batch_chunk_count + datamodule.train_batch_size = batch_size + datamodule.batch_chunk_count = batch_chunk_count + datamodule.multiproc_kwargs = multiproc_kwargs - datamodule = datamodule( - train_batch_size=batch_chunk_size, - drop_incomplete_batch=drop_incomplete_batch, - multiproc_kwargs=multiproc_kwargs, - cache_samples=cache_samples, - ) # Manually calling these as we need to access some values to reweight the criterion datamodule.prepare_data() datamodule.setup(stage="fit")