Skip to content
Snippets Groups Projects
Commit f45e6ea5 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.trainer] Always remove tqdm bars; Only use tqdm if tty is connected

parent 4aa25bef
No related branches found
No related tags found
1 merge request!12Streamlining
...@@ -80,8 +80,15 @@ def do_train( ...@@ -80,8 +80,15 @@ def do_train(
logger.info(f"Truncating {logfile_name} - training is restarting...") logger.info(f"Truncating {logfile_name} - training is restarting...")
os.unlink(logfile_name) os.unlink(logfile_name)
logfile_fields = ("epoch", "total-time", "eta", "average-loss", logfile_fields = (
"median-loss", "learning-rate", "gpu-memory-megabytes") "epoch",
"total-time",
"eta",
"average-loss",
"median-loss",
"learning-rate",
"gpu-memory-megabytes",
)
with open(logfile_name, "a+", newline="") as logfile: with open(logfile_name, "a+", newline="") as logfile:
logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields) logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
...@@ -105,7 +112,10 @@ def do_train( ...@@ -105,7 +112,10 @@ def do_train(
# Epoch time # Epoch time
start_epoch_time = time.time() start_epoch_time = time.time()
for samples in tqdm(data_loader): # progress bar only on interactive jobs
for samples in tqdm(
data_loader, desc="batches", leave=False, disable=None
):
images = samples[1].to(device) images = samples[1].to(device)
ground_truths = samples[2].to(device) ground_truths = samples[2].to(device)
...@@ -126,7 +136,7 @@ def do_train( ...@@ -126,7 +136,7 @@ def do_train(
if checkpoint_period and (epoch % checkpoint_period == 0): if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save("model_{:03d}".format(epoch), **arguments) checkpointer.save("model_{:03d}".format(epoch), **arguments)
if epoch == max_epoch: if epoch >= max_epoch:
checkpointer.save("model_final", **arguments) checkpointer.save("model_final", **arguments)
# computes ETA (estimated time-of-arrival; end of training) taking # computes ETA (estimated time-of-arrival; end of training) taking
...@@ -136,26 +146,31 @@ def do_train( ...@@ -136,26 +146,31 @@ def do_train(
current_time = time.time() - start_training_time current_time = time.time() - start_training_time
logdata = ( logdata = (
("epoch", f"{epoch}"), ("epoch", f"{epoch}"),
("total-time", (
f"{datetime.timedelta(seconds=int(current_time))}"), "total-time",
("eta", f"{datetime.timedelta(seconds=int(current_time))}",
f"{datetime.timedelta(seconds=int(eta_seconds))}"), ),
("average-loss", f"{losses.avg:.6f}"), ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
("median-loss", f"{losses.median:.6f}"), ("average-loss", f"{losses.avg:.6f}"),
("learning-rate", ("median-loss", f"{losses.median:.6f}"),
f"{optimizer.param_groups[0]['lr']:.6f}"), ("learning-rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
("gpu-memory-megabytes", (
f"{torch.cuda.max_memory_allocated()/(1024.0*1024.0)}" \ "gpu-memory-megabytes",
if torch.cuda.is_available() else "0.0"), f"{torch.cuda.max_memory_allocated()/(1024.0*1024.0)}"
) if torch.cuda.is_available()
else "0.0",
),
)
logwriter.writerow(dict(k for k in logdata)) logwriter.writerow(dict(k for k in logdata))
logger.info("|".join([f"{k}: {v}" for (k,v) in logdata])) logger.info("|".join([f"{k}: {v}" for (k, v) in logdata]))
logger.info("End of training.") logger.info("End of training")
total_training_time = time.time() - start_training_time total_training_time = time.time() - start_training_time
logger.info(f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)") logger.info(
f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)"
)
# plots a version of the CSV trainlog into a PDF # plots a version of the CSV trainlog into a PDF
logdf = pd.read_csv(logfile_name, header=0, names=logfile_fields) logdf = pd.read_csv(logfile_name, header=0, names=logfile_fields)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment