Skip to content
Snippets Groups Projects
Commit 7eb7def3 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.ssltrainer] Re-sync with engine.trainer

parent 5899f1b5
No related branches found
No related tags found
1 merge request!12Streamlining
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import csv
import time import time
import datetime import datetime
import torch import torch
...@@ -21,7 +22,7 @@ def sharpen(x, T): ...@@ -21,7 +22,7 @@ def sharpen(x, T):
return temp / temp.sum(dim=1, keepdim=True) return temp / temp.sum(dim=1, keepdim=True)
def mix_up(alpha, input, target, unlabeled_input, unlabled_target): def mix_up(alpha, input, target, unlabelled_input, unlabled_target):
"""Applies mix up as described in [MIXMATCH_19]. """Applies mix up as described in [MIXMATCH_19].
Parameters Parameters
...@@ -32,7 +33,7 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target): ...@@ -32,7 +33,7 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
target : :py:class:`torch.Tensor` target : :py:class:`torch.Tensor`
unlabeled_input : :py:class:`torch.Tensor` unlabelled_input : :py:class:`torch.Tensor`
unlabled_target : :py:class:`torch.Tensor` unlabled_target : :py:class:`torch.Tensor`
...@@ -48,17 +49,17 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target): ...@@ -48,17 +49,17 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
l = np.random.beta(alpha, alpha) # Eq (8) l = np.random.beta(alpha, alpha) # Eq (8)
l = max(l, 1 - l) # Eq (9) l = max(l, 1 - l) # Eq (9)
# Shuffle and concat. Alg. 1 Line: 12 # Shuffle and concat. Alg. 1 Line: 12
w_inputs = torch.cat([input, unlabeled_input], 0) w_inputs = torch.cat([input, unlabelled_input], 0)
w_targets = torch.cat([target, unlabled_target], 0) w_targets = torch.cat([target, unlabled_target], 0)
idx = torch.randperm(w_inputs.size(0)) # get random index idx = torch.randperm(w_inputs.size(0)) # get random index
# Apply MixUp to labeled data and entries from W. Alg. 1 Line: 13 # Apply MixUp to labelled data and entries from W. Alg. 1 Line: 13
input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input) :]] input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input) :]]
target_mixedup = l * target + (1 - l) * w_targets[idx[len(target) :]] target_mixedup = l * target + (1 - l) * w_targets[idx[len(target) :]]
# Apply MixUp to unlabeled data and entries from W. Alg. 1 Line: 14 # Apply MixUp to unlabelled data and entries from W. Alg. 1 Line: 14
unlabeled_input_mixedup = ( unlabelled_input_mixedup = (
l * unlabeled_input + (1 - l) * w_inputs[idx[: len(unlabeled_input)]] l * unlabelled_input + (1 - l) * w_inputs[idx[: len(unlabelled_input)]]
) )
unlabled_target_mixedup = ( unlabled_target_mixedup = (
l * unlabled_target + (1 - l) * w_targets[idx[: len(unlabled_target)]] l * unlabled_target + (1 - l) * w_targets[idx[: len(unlabled_target)]]
...@@ -66,7 +67,7 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target): ...@@ -66,7 +67,7 @@ def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
return ( return (
input_mixedup, input_mixedup,
target_mixedup, target_mixedup,
unlabeled_input_mixedup, unlabelled_input_mixedup,
unlabled_target_mixedup, unlabled_target_mixedup,
) )
...@@ -122,14 +123,14 @@ def linear_rampup(current, rampup_length=16): ...@@ -122,14 +123,14 @@ def linear_rampup(current, rampup_length=16):
return float(current) return float(current)
def guess_labels(unlabeled_images, model): def guess_labels(unlabelled_images, model):
""" """
Calculate the average predictions by 2 augmentations: horizontal and vertical flips Calculate the average predictions by 2 augmentations: horizontal and vertical flips
Parameters Parameters
---------- ----------
unlabeled_images : :py:class:`torch.Tensor` unlabelled_images : :py:class:`torch.Tensor`
``[n,c,h,w]`` ``[n,c,h,w]``
target : :py:class:`torch.Tensor` target : :py:class:`torch.Tensor`
...@@ -142,12 +143,12 @@ def guess_labels(unlabeled_images, model): ...@@ -142,12 +143,12 @@ def guess_labels(unlabeled_images, model):
""" """
with torch.no_grad(): with torch.no_grad():
guess1 = torch.sigmoid(model(unlabeled_images)).unsqueeze(0) guess1 = torch.sigmoid(model(unlabelled_images)).unsqueeze(0)
# Horizontal flip and unsqueeze to work with batches (increase flip dimension by 1) # Horizontal flip and unsqueeze to work with batches (increase flip dimension by 1)
hflip = torch.sigmoid(model(unlabeled_images.flip(2))).unsqueeze(0) hflip = torch.sigmoid(model(unlabelled_images.flip(2))).unsqueeze(0)
guess2 = hflip.flip(3) guess2 = hflip.flip(3)
# Vertical flip and unsqueeze to work with batches (increase flip dimension by 1) # Vertical flip and unsqueeze to work with batches (increase flip dimension by 1)
vflip = torch.sigmoid(model(unlabeled_images.flip(3))).unsqueeze(0) vflip = torch.sigmoid(model(unlabelled_images.flip(3))).unsqueeze(0)
guess3 = vflip.flip(4) guess3 = vflip.flip(4)
# Concat # Concat
concat = torch.cat([guess1, guess2, guess3], 0) concat = torch.cat([guess1, guess2, guess3], 0)
...@@ -169,13 +170,13 @@ def do_ssltrain( ...@@ -169,13 +170,13 @@ def do_ssltrain(
rampup_length, rampup_length,
): ):
""" """
Train model and save to disk. Trains model using semi-supervised learning and saves it to disk.
Parameters Parameters
---------- ----------
model : :py:class:`torch.nn.Module` model : :py:class:`torch.nn.Module`
Network (e.g. DRIU, HED, UNet) Network (e.g. driu, hed, unet)
data_loader : :py:class:`torch.utils.data.DataLoader` data_loader : :py:class:`torch.utils.data.DataLoader`
...@@ -191,13 +192,14 @@ def do_ssltrain( ...@@ -191,13 +192,14 @@ def do_ssltrain(
checkpointer checkpointer
checkpoint_period : int checkpoint_period : int
save a checkpoint every n epochs save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
not save intermediary checkpoints
device : str device : str
device to use ``'cpu'`` or ``'cuda'`` device to use ``'cpu'`` or ``cuda:0``
arguments : dict arguments : dict
start end end epochs start and end epochs
output_folder : str output_folder : str
output path output path
...@@ -206,15 +208,35 @@ def do_ssltrain( ...@@ -206,15 +208,35 @@ def do_ssltrain(
rampup epochs rampup epochs
""" """
logger.info("Start SSL training")
logger.info("Start SSL training")
start_epoch = arguments["epoch"] start_epoch = arguments["epoch"]
max_epoch = arguments["max_epoch"] max_epoch = arguments["max_epoch"]
# Logg to file # Log to file
with open( logfile_name = os.path.join(output_folder, "trainlog.csv")
os.path.join(output_folder, "{}_trainlog.csv".format(model.name)), "a+", 1
) as outfile: if arguments["epoch"] == 0 and os.path.exists(logfile_name):
logger.info(f"Truncating {logfile_name} - training is restarting...")
os.unlink(logfile_name)
logfile_fields = (
"epoch",
"total-time",
"eta",
"average-loss",
"median-loss",
"median-labelled-loss",
"median-unlabelled-loss",
"learning-rate",
"gpu-memory-megabytes",
)
with open(logfile_name, "a+", newline="") as logfile:
logwriter = csv.DictWriter(logfile, fieldnames=logfile_fields)
if arguments["epoch"] == 0:
logwriter.writeheader()
for state in optimizer.state.values(): for state in optimizer.state.values():
for k, v in state.items(): for k, v in state.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
...@@ -223,125 +245,96 @@ def do_ssltrain( ...@@ -223,125 +245,96 @@ def do_ssltrain(
model.train().to(device) model.train().to(device)
# Total training timer # Total training timer
start_training_time = time.time() start_training_time = time.time()
for epoch in range(start_epoch, max_epoch): for epoch in range(start_epoch, max_epoch):
scheduler.step() scheduler.step()
losses = SmoothedValue(len(data_loader)) losses = SmoothedValue(len(data_loader))
labeled_loss = SmoothedValue(len(data_loader)) labelled_loss = SmoothedValue(len(data_loader))
unlabeled_loss = SmoothedValue(len(data_loader)) unlabelled_loss = SmoothedValue(len(data_loader))
epoch = epoch + 1 epoch = epoch + 1
arguments["epoch"] = epoch arguments["epoch"] = epoch
# Epoch time # Epoch time
start_epoch_time = time.time() start_epoch_time = time.time()
for samples in tqdm(data_loader): for samples in tqdm(data_loader, desc="batches", leave=False,
# labeled disable=None,):
# data forwarding on the existing network
# labelled
images = samples[1].to(device) images = samples[1].to(device)
ground_truths = samples[2].to(device) ground_truths = samples[2].to(device)
unlabeled_images = samples[4].to(device) unlabelled_images = samples[4].to(device)
# labeled outputs # labelled outputs
outputs = model(images) outputs = model(images)
unlabeled_outputs = model(unlabeled_images) unlabelled_outputs = model(unlabelled_images)
# guessed unlabeled outputs # guessed unlabelled outputs
unlabeled_ground_truths = guess_labels(unlabeled_images, model) unlabelled_ground_truths = guess_labels(unlabelled_images, model)
# unlabeled_ground_truths = sharpen(unlabeled_ground_truths,0.5) # unlabelled_ground_truths = sharpen(unlabelled_ground_truths,0.5)
# images, ground_truths, unlabeled_images, unlabeled_ground_truths = mix_up(0.75, images, ground_truths, unlabeled_images, unlabeled_ground_truths) # images, ground_truths, unlabelled_images, unlabelled_ground_truths = mix_up(0.75, images, ground_truths, unlabelled_images, unlabelled_ground_truths)
# loss evaluation and learning (backward step)
ramp_up_factor = square_rampup(epoch, rampup_length=rampup_length) ramp_up_factor = square_rampup(epoch, rampup_length=rampup_length)
loss, ll, ul = criterion( loss, ll, ul = criterion(
outputs, outputs,
ground_truths, ground_truths,
unlabeled_outputs, unlabelled_outputs,
unlabeled_ground_truths, unlabelled_ground_truths,
ramp_up_factor, ramp_up_factor,
) )
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
losses.update(loss) losses.update(loss)
labeled_loss.update(ll) labelled_loss.update(ll)
unlabeled_loss.update(ul) unlabelled_loss.update(ul)
logger.debug("batch loss: {}".format(loss.item())) logger.debug(f"batch loss: {loss.item()}")
if epoch % checkpoint_period == 0: if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save("model_{:03d}".format(epoch), **arguments) checkpointer.save(f"model_{epoch:03d}", **arguments)
if epoch == max_epoch: if epoch >= max_epoch:
checkpointer.save("model_final", **arguments) checkpointer.save("model_final", **arguments)
# computes ETA (estimated time-of-arrival; end of training) taking
# into consideration previous epoch performance
epoch_time = time.time() - start_epoch_time epoch_time = time.time() - start_epoch_time
eta_seconds = epoch_time * (max_epoch - epoch) eta_seconds = epoch_time * (max_epoch - epoch)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) current_time = time.time() - start_training_time
outfile.write( logdata = (
("epoch", f"{epoch}"),
( (
"{epoch}, " "total-time",
"{avg_loss:.6f}, " f"{datetime.timedelta(seconds=int(current_time))}",
"{median_loss:.6f}, " ),
"{median_labeled_loss}," ("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
"{median_unlabeled_loss}," ("average-loss", f"{losses.avg:.6f}"),
"{lr:.6f}, " ("median-loss", f"{losses.median:.6f}"),
"{memory:.0f}" ("median-labelled-loss", f"{labelled_loss.median:.6f}"),
"\n" ("median-unlabelled-loss", f"{unlabelled_loss.median:.6f}"),
).format( ("learning-rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
eta=eta_string,
epoch=epoch,
avg_loss=losses.avg,
median_loss=losses.median,
median_labeled_loss=labeled_loss.median,
median_unlabeled_loss=unlabeled_loss.median,
lr=optimizer.param_groups[0]["lr"],
memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
if torch.cuda.is_available()
else 0.0,
)
)
logger.info(
( (
"eta: {eta}, " "gpu-memory-megabytes",
"epoch: {epoch}, " f"{torch.cuda.max_memory_allocated()/(1024.0*1024.0)}"
"avg. loss: {avg_loss:.6f}, "
"median loss: {median_loss:.6f}, "
"labeled loss: {median_labeled_loss}, "
"unlabeled loss: {median_unlabeled_loss}, "
"lr: {lr:.6f}, "
"max mem: {memory:.0f}"
).format(
eta=eta_string,
epoch=epoch,
avg_loss=losses.avg,
median_loss=losses.median,
median_labeled_loss=labeled_loss.median,
median_unlabeled_loss=unlabeled_loss.median,
lr=optimizer.param_groups[0]["lr"],
memory=(torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)
if torch.cuda.is_available() if torch.cuda.is_available()
else 0.0, else "0.0",
) ),
) )
logwriter.writerow(dict(k for k in logdata))
logger.info("|".join([f"{k}: {v}" for (k, v) in logdata]))
logger.info("End of training")
total_training_time = time.time() - start_training_time total_training_time = time.time() - start_training_time
total_time_str = str(datetime.timedelta(seconds=total_training_time))
logger.info( logger.info(
"Total training time: {} ({:.4f} s / epoch)".format( f"Total training time: {datetime.timedelta(seconds=total_training_time)} ({(total_training_time/max_epoch):.4f}s in average per epoch)"
total_time_str, total_training_time / (max_epoch)
)
) )
log_plot_file = os.path.join(output_folder, "{}_trainlog.pdf".format(model.name)) # plots a version of the CSV trainlog into a PDF
logdf = pd.read_csv( logdf = pd.read_csv(logfile_name, header=0, names=logfile_fields)
os.path.join(output_folder, "{}_trainlog.csv".format(model.name)), fig = loss_curve(logdf, title="Loss Evolution")
header=None, figurefile_name = os.path.join(output_folder, "trainlog.pdf")
names=[ logger.info(f"Saving {figurefile_name}")
"avg. loss", fig.savefig(figurefile_name)
"median loss",
"labeled loss",
"unlabeled loss",
"lr",
"max memory",
],
)
fig = loss_curve(logdf, output_folder)
logger.info("saving {}".format(log_plot_file))
fig.savefig(log_plot_file)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment