From f45e6ea5e2efd7c8d344ad4ea145c8f6dc7c7814 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.anjos@idiap.ch> Date: Sun, 5 Apr 2020 10:56:19 +0200 Subject: [PATCH] [engine.trainer] Always remove tqdm bars; Only use tqdm if tty is connected --- bob/ip/binseg/engine/trainer.py | 55 +++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index 3201df17..b225f69c 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -80,8 +80,15 @@ def do_train( logger.info(f"Truncating {logfile_name} - training is restarting...") os.unlink(logfile_name) - logfile_fields = ("epoch", "total-time", "eta", "average-loss", - "median-loss", "learning-rate", "gpu-memory-megabytes") + logfile_fields = ( + "epoch", + "total-time", + "eta", + "average-loss", + "median-loss", + "learning-rate", + "gpu-memory-megabytes", + ) with open(logfile_name, "a+", newline="") as logfile: logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) @@ -105,7 +112,10 @@ def do_train( # Epoch time start_epoch_time = time.time() - for samples in tqdm(data_loader): + # progress bar only on interactive jobs + for samples in tqdm( + data_loader, desc="batches", leave=False, disable=None + ): images = samples[1].to(device) ground_truths = samples[2].to(device) @@ -126,7 +136,7 @@ def do_train( if checkpoint_period and (epoch % checkpoint_period == 0): checkpointer.save("model_{:03d}".format(epoch), **arguments) - if epoch == max_epoch: + if epoch >= max_epoch: checkpointer.save("model_final", **arguments) # computes ETA (estimated time-of-arrival; end of training) taking @@ -136,26 +146,31 @@ def do_train( current_time = time.time() - start_training_time logdata = ( - ("epoch", f"{epoch}"), - ("total-time", - f"{datetime.timedelta(seconds=int(current_time))}"), - ("eta", - f"{datetime.timedelta(seconds=int(eta_seconds))}"), - ("average-loss", f"{losses.avg:.6f}"), - ("median-loss", f"{losses.median:.6f}"), - ("learning-rate", - f"{optimizer.param_groups[0]['lr']:.6f}"), - ("gpu-memory-megabytes", - f"{torch.cuda.max_memory_allocated()/(1024.0*1024.0)}" \ - if torch.cuda.is_available() else "0.0"), - ) + ("epoch", f"{epoch}"), + ( + "total-time", + f"{datetime.timedelta(seconds=int(current_time))}", + ), + ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"), + ("average-loss", f"{losses.avg:.6f}"), + ("median-loss", f"{losses.median:.6f}"), + ("learning-rate", f"{optimizer.param_groups[0]['lr']:.6f}"), + ( + "gpu-memory-megabytes", + f"{torch.cuda.max_memory_allocated()/(1024.0*1024.0)}" + if torch.cuda.is_available() + else "0.0", + ), + ) logwriter.writerow(dict(k for k in logdata)) - logger.info("|".join([f"{k}: {v}" for (k,v) in logdata])) + logger.info("|".join([f"{k}: {v}" for (k, v) in logdata])) - logger.info("End of training.") + logger.info("End of training") total_training_time = time.time() - start_training_time - logger.info(f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)") + logger.info( + f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)" + ) # plots a version of the CSV trainlog into a PDF logdf = pd.read_csv(logfile_name, header=0, names=logfile_fields) -- GitLab