Skip to content
Snippets Groups Projects
Commit 31ffe6f1 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.ssltrainer] Harmonize with engine.trainer

parent f34cb146
No related branches found
No related tags found
No related merge requests found
......@@ -4,21 +4,23 @@
import os
import csv
import time
import shutil
import datetime
import distutils.version
import numpy
import pandas
import torch
from tqdm import tqdm
from ..utils.measure import SmoothedValue
from ..utils.plot import loss_curve
from ..utils.summary import summary
from ..utils.resources import cpu_constants, gpu_constants, cpu_log, gpu_log
import logging
logger = logging.getLogger(__name__)
PYTORCH_GE_110 = (distutils.version.StrictVersion(torch.__version__) >= "1.1.0")
PYTORCH_GE_110 = distutils.version.StrictVersion(torch.__version__) >= "1.1.0"
def sharpen(x, T):
......@@ -48,7 +50,7 @@ def mix_up(alpha, input, target, unlabelled_input, unlabled_target):
list
"""
# TODO:
with torch.no_grad():
l = numpy.random.beta(alpha, alpha) # Eq (8)
l = max(l, 1 - l) # Eq (9)
......@@ -63,10 +65,12 @@ def mix_up(alpha, input, target, unlabelled_input, unlabled_target):
# Apply MixUp to unlabelled data and entries from W. Alg. 1 Line: 14
unlabelled_input_mixedup = (
l * unlabelled_input + (1 - l) * w_inputs[idx[: len(unlabelled_input)]]
l * unlabelled_input
+ (1 - l) * w_inputs[idx[: len(unlabelled_input)]]
)
unlabled_target_mixedup = (
l * unlabled_target + (1 - l) * w_targets[idx[: len(unlabled_target)]]
l * unlabled_target
+ (1 - l) * w_targets[idx[: len(unlabled_target)]]
)
return (
input_mixedup,
......@@ -176,6 +180,11 @@ def run(
"""
Fits an FCN model using semi-supervised learning and saves it to disk.
This method supports periodic checkpointing and the output of a
CSV-formatted log with the evolution of some figures during training.
Parameters
----------
......@@ -193,7 +202,7 @@ def run(
learning rate scheduler
checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer`
checkpointer
checkpointer implementation
checkpoint_period : int
save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
......@@ -216,45 +225,85 @@ def run(
start_epoch = arguments["epoch"]
max_epoch = arguments["max_epoch"]
if not os.path.exists(output_folder):
logger.debug(f"Creating output directory '{output_folder}'...")
os.makedirs(output_folder)
if device != "cpu":
# asserts we do have a GPU
assert bool(gpu_constants()), (
f"Device set to '{device}', but cannot "
f"find a GPU (maybe nvidia-smi is not installed?)"
)
# Log to file
os.makedirs(output_folder, exist_ok=True)
# Save model summary
summary_path = os.path.join(output_folder, "model_summary.txt")
logger.info(f"Saving model summary at {summary_path}...")
with open(summary_path, "wt") as f:
r, n = summary(model)
logger.info(f"Model has {n} parameters...")
f.write(r)
# write static information to a CSV file
static_logfile_name = os.path.join(output_folder, "constants.csv")
if os.path.exists(static_logfile_name):
backup = static_logfile_name + "~"
if os.path.exists(backup):
os.unlink(backup)
shutil.move(static_logfile_name, backup)
with open(static_logfile_name, "w", newline="") as f:
logdata = cpu_constants()
if device != "cpu":
logdata += gpu_constants()
logdata += (("model_size", n),)
logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
logwriter.writeheader()
logwriter.writerow(dict(k for k in logdata))
# Log continous information to (another) file
logfile_name = os.path.join(output_folder, "trainlog.csv")
if arguments["epoch"] == 0 and os.path.exists(logfile_name):
logger.info(f"Truncating {logfile_name} - training is restarting...")
os.unlink(logfile_name)
backup = logfile_name + "~"
if os.path.exists(backup):
os.unlink(backup)
shutil.move(logfile_name, backup)
logfile_fields = (
"epoch",
"total-time",
"total_time",
"eta",
"average-loss",
"median-loss",
"median-labelled-loss",
"median-unlabelled-loss",
"learning-rate",
"gpu-memory-megabytes",
"average_loss",
"median_loss",
"median_labelled_loss",
"median_unlabelled_loss",
"learning_rate",
)
logfile_fields += tuple([k[0] for k in cpu_log()])
if device != "cpu":
logfile_fields += tuple([k[0] for k in gpu_log()])
with open(logfile_name, "a+", newline="") as logfile:
logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
if arguments["epoch"] == 0:
logwriter.writeheader()
model.train().to(device)
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
model.train().to(device)
# Total training timer
start_training_time = time.time()
for epoch in range(start_epoch, max_epoch):
if not PYTORCH_GE_110: scheduler.step()
for epoch in tqdm(
range(start_epoch, max_epoch),
desc="epoch",
leave=False,
disable=None,
):
if not PYTORCH_GE_110:
scheduler.step()
losses = SmoothedValue(len(data_loader))
labelled_loss = SmoothedValue(len(data_loader))
unlabelled_loss = SmoothedValue(len(data_loader))
......@@ -264,8 +313,9 @@ def run(
# Epoch time
start_epoch_time = time.time()
for samples in tqdm(data_loader, desc="batches", leave=False,
disable=None,):
for samples in tqdm(
data_loader, desc="batch", leave=False, disable=None
):
# data forwarding on the existing network
......@@ -277,12 +327,16 @@ def run(
outputs = model(images)
unlabelled_outputs = model(unlabelled_images)
# guessed unlabelled outputs
unlabelled_ground_truths = guess_labels(unlabelled_images, model)
unlabelled_ground_truths = guess_labels(
unlabelled_images, model
)
# unlabelled_ground_truths = sharpen(unlabelled_ground_truths,0.5)
# images, ground_truths, unlabelled_images, unlabelled_ground_truths = mix_up(0.75, images, ground_truths, unlabelled_images, unlabelled_ground_truths)
# loss evaluation and learning (backward step)
ramp_up_factor = square_rampup(epoch, rampup_length=rampup_length)
ramp_up_factor = square_rampup(
epoch, rampup_length=rampup_length
)
loss, ll, ul = criterion(
outputs,
......@@ -299,7 +353,8 @@ def run(
unlabelled_loss.update(ul)
logger.debug(f"batch loss: {loss.item()}")
if PYTORCH_GE_110: scheduler.step()
if PYTORCH_GE_110:
scheduler.step()
if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save(f"model_{epoch:03d}", **arguments)
......@@ -316,33 +371,24 @@ def run(
logdata = (
("epoch", f"{epoch}"),
(
"total-time",
"total_time",
f"{datetime.timedelta(seconds=int(current_time))}",
),
("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
("average-loss", f"{losses.avg:.6f}"),
("median-loss", f"{losses.median:.6f}"),
("median-labelled-loss", f"{labelled_loss.median:.6f}"),
("median-unlabelled-loss", f"{unlabelled_loss.median:.6f}"),
("learning-rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
(
"gpu-memory-megabytes",
f"{torch.cuda.max_memory_allocated()/(1024.0*1024.0)}"
if torch.cuda.is_available()
else "0.0",
),
)
("average_loss", f"{losses.avg:.6f}"),
("median_loss", f"{losses.median:.6f}"),
("median_labelled_loss", f"{labelled_loss.median:.6f}"),
("median_unlabelled_loss", f"{unlabelled_loss.median:.6f}"),
("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
) + cpu_log()
if device != "cpu":
logdata += gpu_log()
logwriter.writerow(dict(k for k in logdata))
logger.info("|".join([f"{k}: {v}" for (k, v) in logdata]))
logfile.flush()
tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]]))
total_training_time = time.time() - start_training_time
logger.info(
f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)"
)
# plots a version of the CSV trainlog into a PDF
logdf = pandas.read_csv(logfile_name, header=0, names=logfile_fields)
fig = loss_curve(logdf)
figurefile_name = os.path.join(output_folder, "trainlog.pdf")
logger.info(f"Saving {figurefile_name}")
fig.savefig(figurefile_name)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment