diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py index 5cfa1620d14a186ec034821b95c3d1a2a0fdbb8b..5dfbbd7e6e0b00df0d1a6208a7bb70351b1cb417 100644 --- a/src/ptbench/data/base_datamodule.py +++ b/src/ptbench/data/base_datamodule.py @@ -60,7 +60,9 @@ class BaseDataModule(pl.LightningDataModule): return DataLoader( self.train_dataset, - batch_size=self._compute_chunk_size(self.batch_size), + batch_size=self._compute_chunk_size( + self.batch_size, self.batch_chunk_count + ), drop_last=self.drop_incomplete_batch, pin_memory=self.pin_memory, sampler=train_sampler, @@ -74,7 +76,9 @@ class BaseDataModule(pl.LightningDataModule): val_loader = DataLoader( dataset=self.validation_dataset, - batch_size=self._compute_chunk_size(self.batch_size), + batch_size=self._compute_chunk_size( + self.batch_size, self.batch_chunk_count + ), shuffle=False, drop_last=False, pin_memory=self.pin_memory, @@ -87,7 +91,9 @@ class BaseDataModule(pl.LightningDataModule): for set_idx, extra_set in enumerate(self.extra_validation_datasets): extra_val_loader = DataLoader( dataset=extra_set, - batch_size=self._compute_chunk_size(self.batch_size), + batch_size=self._compute_chunk_size( + self.batch_size, self.batch_chunk_count + ), shuffle=False, drop_last=False, pin_memory=self.pin_memory, @@ -109,16 +115,16 @@ class BaseDataModule(pl.LightningDataModule): return loaders_dict - def _compute_chunk_size(self, batch_size): + def _compute_chunk_size(self, batch_size, chunk_count): batch_chunk_size = batch_size - if batch_size % self.batch_chunk_count != 0: + if batch_size % chunk_count != 0: # batch_size must be divisible by batch_chunk_count. raise RuntimeError( f"--batch-size ({batch_size}) must be divisible by " - f"--batch-chunk-size ({self.batch_chunk_count})." + f"--batch-chunk-size ({chunk_count})." ) else: - batch_chunk_size = batch_size // self.batch_chunk_count + batch_chunk_size = batch_size // chunk_count return batch_chunk_size diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py index f471d426e20e25c3517050b68f54650ce3074314..bf75eacb36b83f121e53ecb19ab97edfdd77d9ab 100644 --- a/src/ptbench/data/shenzhen/default.py +++ b/src/ptbench/data/shenzhen/default.py @@ -68,6 +68,7 @@ class DefaultModule(BaseDataModule): self._build_transforms(is_train=True), cache_samples=self._cache_samples, ) + self.validation_dataset = TBDataset( json_protocol, self._protocol, diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index 962a761cce0b2a82bd6f41a60c92c49333f59da8..c78d52ea806310c087ecda943b269ec4129a4d11 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -56,20 +56,56 @@ class LoggingCallback(Callback): ) current_time = time.time() - self.start_training_time - self.log("total_time", current_time) - self.log("eta", eta_seconds) - self.log("loss", numpy.average(self.training_loss)) - self.log("learning_rate", pl_module.hparams["optimizer_configs"]["lr"]) - self.log("validation_loss", numpy.average(self.validation_loss)) - - if len(self.extra_validation_loss) > 0: - for ( - extra_valid_loss_key, - extra_valid_loss_values, - ) in self.extra_validation_loss.items: - self.log( - extra_valid_loss_key, numpy.average(extra_valid_loss_values) - ) + def _compute_batch_loss(losses, num_chunks): + # When accumulating gradients, partial losses need to be summed per batch before averaging + if num_chunks != 1: + # The loss we get is scaled by the number of accumulation steps + losses = numpy.multiply(losses, num_chunks) + + if len(losses) % num_chunks > 0: + num_splits = (len(losses) // num_chunks) + 1 + else: + num_splits = len(losses) // num_chunks + + batched_losses = numpy.array_split(losses, num_splits) + + summed_batch_losses = [] + + for b in batched_losses: + summed_batch_losses.append(numpy.average(b)) + + return summed_batch_losses + + # No gradient accumulation, we already have the batch losses + else: + return losses + + # Do not log during sanity check as results are not relevant + if not trainer.sanity_checking: + # We get partial loses when using gradient accumulation + self.training_loss = _compute_batch_loss( + self.training_loss, trainer.accumulate_grad_batches + ) + self.validation_loss = _compute_batch_loss( + self.validation_loss, trainer.accumulate_grad_batches + ) + + self.log("total_time", current_time) + self.log("eta", eta_seconds) + self.log("loss", numpy.average(self.training_loss)) + self.log( + "learning_rate", pl_module.hparams["optimizer_configs"]["lr"] + ) + self.log("validation_loss", numpy.sum(self.validation_loss)) + + if len(self.extra_validation_loss) > 0: + for ( + extra_valid_loss_key, + extra_valid_loss_values, + ) in self.extra_validation_loss.items: + self.log( + extra_valid_loss_key, numpy.sum(extra_valid_loss_values) + ) queue_retries = 0 # In case the resource monitor takes longer to fetch data from the queue, we wait @@ -91,8 +127,10 @@ class LoggingCallback(Callback): assert self.resource_monitor.q.empty() - for metric_name, metric_value in self.resource_monitor.data: - self.log(metric_name, float(metric_value)) + # Do not log during sanity check as results are not relevant + if not trainer.sanity_checking: + for metric_name, metric_value in self.resource_monitor.data: + self.log(metric_name, float(metric_value)) self.resource_monitor.data = None