Skip to content
Snippets Groups Projects
Commit 27363e33 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Fixed batch loss logging when using gradient accumulation

parent e4ab0fca
No related branches found
No related tags found
No related merge requests found
......@@ -60,7 +60,9 @@ class BaseDataModule(pl.LightningDataModule):
return DataLoader(
self.train_dataset,
batch_size=self._compute_chunk_size(self.batch_size),
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
sampler=train_sampler,
......@@ -74,7 +76,9 @@ class BaseDataModule(pl.LightningDataModule):
val_loader = DataLoader(
dataset=self.validation_dataset,
batch_size=self._compute_chunk_size(self.batch_size),
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
shuffle=False,
drop_last=False,
pin_memory=self.pin_memory,
......@@ -87,7 +91,9 @@ class BaseDataModule(pl.LightningDataModule):
for set_idx, extra_set in enumerate(self.extra_validation_datasets):
extra_val_loader = DataLoader(
dataset=extra_set,
batch_size=self._compute_chunk_size(self.batch_size),
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
shuffle=False,
drop_last=False,
pin_memory=self.pin_memory,
......@@ -109,16 +115,16 @@ class BaseDataModule(pl.LightningDataModule):
return loaders_dict
def _compute_chunk_size(self, batch_size):
def _compute_chunk_size(self, batch_size, chunk_count):
batch_chunk_size = batch_size
if batch_size % self.batch_chunk_count != 0:
if batch_size % chunk_count != 0:
# batch_size must be divisible by batch_chunk_count.
raise RuntimeError(
f"--batch-size ({batch_size}) must be divisible by "
f"--batch-chunk-size ({self.batch_chunk_count})."
f"--batch-chunk-size ({chunk_count})."
)
else:
batch_chunk_size = batch_size // self.batch_chunk_count
batch_chunk_size = batch_size // chunk_count
return batch_chunk_size
......
......@@ -68,6 +68,7 @@ class DefaultModule(BaseDataModule):
self._build_transforms(is_train=True),
cache_samples=self._cache_samples,
)
self.validation_dataset = TBDataset(
json_protocol,
self._protocol,
......
......@@ -56,20 +56,56 @@ class LoggingCallback(Callback):
)
current_time = time.time() - self.start_training_time
self.log("total_time", current_time)
self.log("eta", eta_seconds)
self.log("loss", numpy.average(self.training_loss))
self.log("learning_rate", pl_module.hparams["optimizer_configs"]["lr"])
self.log("validation_loss", numpy.average(self.validation_loss))
if len(self.extra_validation_loss) > 0:
for (
extra_valid_loss_key,
extra_valid_loss_values,
) in self.extra_validation_loss.items:
self.log(
extra_valid_loss_key, numpy.average(extra_valid_loss_values)
)
def _compute_batch_loss(losses, num_chunks):
# When accumulating gradients, partial losses need to be summed per batch before averaging
if num_chunks != 1:
# The loss we get is scaled by the number of accumulation steps
losses = numpy.multiply(losses, num_chunks)
if len(losses) % num_chunks > 0:
num_splits = (len(losses) // num_chunks) + 1
else:
num_splits = len(losses) // num_chunks
batched_losses = numpy.array_split(losses, num_splits)
summed_batch_losses = []
for b in batched_losses:
summed_batch_losses.append(numpy.average(b))
return summed_batch_losses
# No gradient accumulation, we already have the batch losses
else:
return losses
# Do not log during sanity check as results are not relevant
if not trainer.sanity_checking:
# We get partial loses when using gradient accumulation
self.training_loss = _compute_batch_loss(
self.training_loss, trainer.accumulate_grad_batches
)
self.validation_loss = _compute_batch_loss(
self.validation_loss, trainer.accumulate_grad_batches
)
self.log("total_time", current_time)
self.log("eta", eta_seconds)
self.log("loss", numpy.average(self.training_loss))
self.log(
"learning_rate", pl_module.hparams["optimizer_configs"]["lr"]
)
self.log("validation_loss", numpy.sum(self.validation_loss))
if len(self.extra_validation_loss) > 0:
for (
extra_valid_loss_key,
extra_valid_loss_values,
) in self.extra_validation_loss.items:
self.log(
extra_valid_loss_key, numpy.sum(extra_valid_loss_values)
)
queue_retries = 0
# In case the resource monitor takes longer to fetch data from the queue, we wait
......@@ -91,8 +127,10 @@ class LoggingCallback(Callback):
assert self.resource_monitor.q.empty()
for metric_name, metric_value in self.resource_monitor.data:
self.log(metric_name, float(metric_value))
# Do not log during sanity check as results are not relevant
if not trainer.sanity_checking:
for metric_name, metric_value in self.resource_monitor.data:
self.log(metric_name, float(metric_value))
self.resource_monitor.data = None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment