diff --git a/src/ptbench/data/base_datamodule.py b/src/ptbench/data/base_datamodule.py
index 5cfa1620d14a186ec034821b95c3d1a2a0fdbb8b..5dfbbd7e6e0b00df0d1a6208a7bb70351b1cb417 100644
--- a/src/ptbench/data/base_datamodule.py
+++ b/src/ptbench/data/base_datamodule.py
@@ -60,7 +60,9 @@ class BaseDataModule(pl.LightningDataModule):
 
         return DataLoader(
             self.train_dataset,
-            batch_size=self._compute_chunk_size(self.batch_size),
+            batch_size=self._compute_chunk_size(
+                self.batch_size, self.batch_chunk_count
+            ),
             drop_last=self.drop_incomplete_batch,
             pin_memory=self.pin_memory,
             sampler=train_sampler,
@@ -74,7 +76,9 @@ class BaseDataModule(pl.LightningDataModule):
 
         val_loader = DataLoader(
             dataset=self.validation_dataset,
-            batch_size=self._compute_chunk_size(self.batch_size),
+            batch_size=self._compute_chunk_size(
+                self.batch_size, self.batch_chunk_count
+            ),
             shuffle=False,
             drop_last=False,
             pin_memory=self.pin_memory,
@@ -87,7 +91,9 @@ class BaseDataModule(pl.LightningDataModule):
             for set_idx, extra_set in enumerate(self.extra_validation_datasets):
                 extra_val_loader = DataLoader(
                     dataset=extra_set,
-                    batch_size=self._compute_chunk_size(self.batch_size),
+                    batch_size=self._compute_chunk_size(
+                        self.batch_size, self.batch_chunk_count
+                    ),
                     shuffle=False,
                     drop_last=False,
                     pin_memory=self.pin_memory,
@@ -109,16 +115,16 @@ class BaseDataModule(pl.LightningDataModule):
 
         return loaders_dict
 
-    def _compute_chunk_size(self, batch_size):
+    def _compute_chunk_size(self, batch_size, chunk_count):
         batch_chunk_size = batch_size
-        if batch_size % self.batch_chunk_count != 0:
+        if batch_size % chunk_count != 0:
             # batch_size must be divisible by batch_chunk_count.
             raise RuntimeError(
                 f"--batch-size ({batch_size}) must be divisible by "
-                f"--batch-chunk-size ({self.batch_chunk_count})."
+                f"--batch-chunk-size ({chunk_count})."
             )
         else:
-            batch_chunk_size = batch_size // self.batch_chunk_count
+            batch_chunk_size = batch_size // chunk_count
 
         return batch_chunk_size
 
diff --git a/src/ptbench/data/shenzhen/default.py b/src/ptbench/data/shenzhen/default.py
index f471d426e20e25c3517050b68f54650ce3074314..bf75eacb36b83f121e53ecb19ab97edfdd77d9ab 100644
--- a/src/ptbench/data/shenzhen/default.py
+++ b/src/ptbench/data/shenzhen/default.py
@@ -68,6 +68,7 @@ class DefaultModule(BaseDataModule):
                 self._build_transforms(is_train=True),
                 cache_samples=self._cache_samples,
             )
+
             self.validation_dataset = TBDataset(
                 json_protocol,
                 self._protocol,
diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py
index 962a761cce0b2a82bd6f41a60c92c49333f59da8..c78d52ea806310c087ecda943b269ec4129a4d11 100644
--- a/src/ptbench/engine/callbacks.py
+++ b/src/ptbench/engine/callbacks.py
@@ -56,20 +56,56 @@ class LoggingCallback(Callback):
         )
         current_time = time.time() - self.start_training_time
 
-        self.log("total_time", current_time)
-        self.log("eta", eta_seconds)
-        self.log("loss", numpy.average(self.training_loss))
-        self.log("learning_rate", pl_module.hparams["optimizer_configs"]["lr"])
-        self.log("validation_loss", numpy.average(self.validation_loss))
-
-        if len(self.extra_validation_loss) > 0:
-            for (
-                extra_valid_loss_key,
-                extra_valid_loss_values,
-            ) in self.extra_validation_loss.items:
-                self.log(
-                    extra_valid_loss_key, numpy.average(extra_valid_loss_values)
-                )
+        def _compute_batch_loss(losses, num_chunks):
+            # When accumulating gradients, partial losses need to be summed per batch before averaging
+            if num_chunks != 1:
+                # The loss we get is scaled by the number of accumulation steps
+                losses = numpy.multiply(losses, num_chunks)
+
+                if len(losses) % num_chunks > 0:
+                    num_splits = (len(losses) // num_chunks) + 1
+                else:
+                    num_splits = len(losses) // num_chunks
+
+                batched_losses = numpy.array_split(losses, num_splits)
+
+                summed_batch_losses = []
+
+                for b in batched_losses:
+                    summed_batch_losses.append(numpy.average(b))
+
+                return summed_batch_losses
+
+            # No gradient accumulation, we already have the batch losses
+            else:
+                return losses
+
+        # Do not log during sanity check as results are not relevant
+        if not trainer.sanity_checking:
+            # We get partial loses when using gradient accumulation
+            self.training_loss = _compute_batch_loss(
+                self.training_loss, trainer.accumulate_grad_batches
+            )
+            self.validation_loss = _compute_batch_loss(
+                self.validation_loss, trainer.accumulate_grad_batches
+            )
+
+            self.log("total_time", current_time)
+            self.log("eta", eta_seconds)
+            self.log("loss", numpy.average(self.training_loss))
+            self.log(
+                "learning_rate", pl_module.hparams["optimizer_configs"]["lr"]
+            )
+            self.log("validation_loss", numpy.sum(self.validation_loss))
+
+            if len(self.extra_validation_loss) > 0:
+                for (
+                    extra_valid_loss_key,
+                    extra_valid_loss_values,
+                ) in self.extra_validation_loss.items:
+                    self.log(
+                        extra_valid_loss_key, numpy.sum(extra_valid_loss_values)
+                    )
 
         queue_retries = 0
         # In case the resource monitor takes longer to fetch data from the queue, we wait
@@ -91,8 +127,10 @@ class LoggingCallback(Callback):
 
         assert self.resource_monitor.q.empty()
 
-        for metric_name, metric_value in self.resource_monitor.data:
-            self.log(metric_name, float(metric_value))
+        # Do not log during sanity check as results are not relevant
+        if not trainer.sanity_checking:
+            for metric_name, metric_value in self.resource_monitor.data:
+                self.log(metric_name, float(metric_value))
 
         self.resource_monitor.data = None