Skip to content
Snippets Groups Projects

Improve trainer function

Merged Driss KHALIL requested to merge ImproveTrainerFunction into master
1 file
+ 388
121
Compare changes
  • Side-by-side
  • Inline
+ 388
121
@@ -52,6 +52,365 @@ def torch_evaluation(model):
model.train()
def check_gpu(device):
"""
Check the device type and the availability of GPU.
Parameters
----------
device : :py:class:`torch.device`
device to use
"""
if device.type == "cuda":
# asserts we do have a GPU
assert bool(
gpu_constants()
), f"Device set to '{device}', but nvidia-smi is not installed"
def save_model_summary(output_folder, model):
"""
Save a little summary of the model in a txt file.
Parameters
----------
output_folder : str
output path
model : :py:class:`torch.nn.Module`
Network (e.g. driu, hed, unet)
Returns
-------
r : str
The model summary in a text format.
n : int
The number of parameters of the model.
"""
summary_path = os.path.join(output_folder, "model_summary.txt")
logger.info(f"Saving model summary at {summary_path}...")
with open(summary_path, "wt") as f:
r, n = summary(model)
logger.info(f"Model has {n} parameters...")
f.write(r)
return r, n
def static_information_to_csv(static_logfile_name, device, n):
"""
Save the static information in a csv file.
Parameters
----------
static_logfile_name : str
The static file name which is a join between the output folder and "constant.csv"
"""
if os.path.exists(static_logfile_name):
backup = static_logfile_name + "~"
if os.path.exists(backup):
os.unlink(backup)
shutil.move(static_logfile_name, backup)
with open(static_logfile_name, "w", newline="") as f:
logdata = cpu_constants()
if device.type == "cuda":
logdata += gpu_constants()
logdata += (("model_size", n),)
logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
logwriter.writeheader()
logwriter.writerow(dict(k for k in logdata))
def check_exist_logfile(logfile_name, arguments):
"""
Check existance of logfile (trainlog.csv),
If the logfile exist the and the epochs number are still 0, The logfile will be replaced.
Parameters
----------
logfile_name : str
The logfile_name which is a join between the output_folder and trainlog.csv
arguments : dict
start and end epochs
"""
if arguments["epoch"] == 0 and os.path.exists(logfile_name):
backup = logfile_name + "~"
if os.path.exists(backup):
os.unlink(backup)
shutil.move(logfile_name, backup)
def create_logfile_fields(valid_loader, device):
"""
Creation of the logfile fields that will appear in the logfile.
Parameters
----------
valid_loader : :py:class:`torch.utils.data.DataLoader`
To be used to validate the model and enable automatic checkpointing.
If set to ``None``, then do not validate it.
device : :py:class:`torch.device`
device to use
Returns
-------
logfile_fields: tuple
The fields that will appear in trainlog.csv
"""
logfile_fields = (
"epoch",
"total_time",
"eta",
"average_loss",
"median_loss",
"learning_rate",
)
if valid_loader is not None:
logfile_fields += ("validation_average_loss", "validation_median_loss")
logfile_fields += tuple([k[0] for k in cpu_log()])
if device.type == "cuda":
logfile_fields += tuple([k[0] for k in gpu_log()])
return logfile_fields
def train_sample_process(samples, model, optimizer, losses, device, criterion):
"""
Processing the training inputs (Images, ground truth, masks) and apply the backprogration to update the training losses.
Parameters
----------
samples : list
model : :py:class:`torch.nn.Module`
Network (e.g. driu, hed, unet)
optimizer : :py:mod:`torch.optim`
losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue`
device : :py:class:`torch.device`
device to use
criterion : :py:class:`torch.nn.modules.loss._Loss`
loss function
Returns
-------
losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue`
optimizer : :py:mod:`torch.optim`
"""
images = samples[1].to(
device=device, non_blocking=torch.cuda.is_available()
)
ground_truths = samples[2].to(
device=device, non_blocking=torch.cuda.is_available()
)
masks = (
torch.ones_like(ground_truths)
if len(samples) < 4
else samples[3].to(
device=device, non_blocking=torch.cuda.is_available()
)
)
outputs = model(images)
loss = criterion(outputs, ground_truths, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.update(loss)
logger.debug(f"batch loss: {loss.item()}")
return losses, optimizer
def valid_sample_process(samples, model, valid_losses, device, criterion):
"""
Processing the validation inputs (Images, ground truth, masks) and update validation losses.
Parameters
----------
samples : list
model : :py:class:`torch.nn.Module`
Network (e.g. driu, hed, unet)
optimizer : :py:mod:`torch.optim`
valid_losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue`
device : :py:class:`torch.device`
device to use
criterion : :py:class:`torch.nn.modules.loss._Loss`
loss function
Returns
-------
valid_losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue`
"""
images = samples[1].to(
device=device,
non_blocking=torch.cuda.is_available(),
)
ground_truths = samples[2].to(
device=device,
non_blocking=torch.cuda.is_available(),
)
masks = (
torch.ones_like(ground_truths)
if len(samples) < 4
else samples[3].to(
device=device,
non_blocking=torch.cuda.is_available(),
)
)
outputs = model(images)
loss = criterion(outputs, ground_truths, masks)
valid_losses.update(loss)
return valid_losses
def checkpointer_process(
checkpointer,
checkpoint_period,
valid_losses,
lowest_validation_loss,
arguments,
epoch,
max_epoch,
):
"""
Process the checkpointer, save the final model and keep track of the best model.
Parameters
----------
checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.Checkpointer`
checkpointer implementation
checkpoint_period : int
save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
not save intermediary checkpoints
valid_losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue`
lowest_validation_loss : float
Keep track of the best (lowest) validation loss
arguments : dict
start and end epochs
max_epoch : int
end_potch
"""
if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save(f"model_{epoch:03d}", **arguments)
if valid_losses is not None and valid_losses.avg < lowest_validation_loss:
lowest_validation_loss = valid_losses.avg
logger.info(
f"Found new low on validation set:" f" {lowest_validation_loss:.6f}"
)
checkpointer.save("model_lowest_valid_loss", **arguments)
if epoch >= max_epoch:
checkpointer.save("model_final", **arguments)
def write_log_info(
epoch,
current_time,
eta_seconds,
losses,
valid_losses,
optimizer,
logwriter,
logfile,
device,
):
"""
Write log info in trainlog.csv
Parameters
----------
epoch : int
Current epoch
current_time : float
Current training time
eta_seconds : float
estimated time-of-arrival taking into consideration previous epoch performance
losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue`
valid_losses : :py:class:`bob.ip.binseg.utils.measure.SmoothedValue`
optimizer : :py:mod:`torch.optim`
logwriter : csv.DictWriter
Dictionary writer that give the ability to write on the trainlog.csv
logfile: io.TextIOWrapper
device : :py:class:`torch.device`
device to use
"""
logdata = (
("epoch", f"{epoch}"),
(
"total_time",
f"{datetime.timedelta(seconds=int(current_time))}",
),
("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
("average_loss", f"{losses.avg:.6f}"),
("median_loss", f"{losses.median:.6f}"),
("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
)
if valid_losses is not None:
logdata += (
("validation_average_loss", f"{valid_losses.avg:.6f}"),
("validation_median_loss", f"{valid_losses.median:.6f}"),
)
logdata += cpu_log()
if device.type == "cuda":
logdata += gpu_log()
logwriter.writerow(dict(k for k in logdata))
logfile.flush()
tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]]))
def run(
model,
data_loader,
@@ -113,60 +472,24 @@ def run(
start_epoch = arguments["epoch"]
max_epoch = arguments["max_epoch"]
if device.type == "cuda":
# asserts we do have a GPU
assert bool(
gpu_constants()
), f"Device set to '{device}', but nvidia-smi is not installed"
check_gpu(device)
os.makedirs(output_folder, exist_ok=True)
# Save model summary
summary_path = os.path.join(output_folder, "model_summary.txt")
logger.info(f"Saving model summary at {summary_path}...")
with open(summary_path, "wt") as f:
r, n = summary(model)
logger.info(f"Model has {n} parameters...")
f.write(r)
r, n = save_model_summary(output_folder, model)
# write static information to a CSV file
static_logfile_name = os.path.join(output_folder, "constants.csv")
if os.path.exists(static_logfile_name):
backup = static_logfile_name + "~"
if os.path.exists(backup):
os.unlink(backup)
shutil.move(static_logfile_name, backup)
with open(static_logfile_name, "w", newline="") as f:
logdata = cpu_constants()
if device.type == "cuda":
logdata += gpu_constants()
logdata += (("model_size", n),)
logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
logwriter.writeheader()
logwriter.writerow(dict(k for k in logdata))
static_information_to_csv(static_logfile_name, device, n)
# Log continous information to (another) file
logfile_name = os.path.join(output_folder, "trainlog.csv")
if arguments["epoch"] == 0 and os.path.exists(logfile_name):
backup = logfile_name + "~"
if os.path.exists(backup):
os.unlink(backup)
shutil.move(logfile_name, backup)
check_exist_logfile(logfile_name, arguments)
logfile_fields = (
"epoch",
"total_time",
"eta",
"average_loss",
"median_loss",
"learning_rate",
)
if valid_loader is not None:
logfile_fields += ("validation_average_loss", "validation_median_loss")
logfile_fields += tuple([k[0] for k in cpu_log()])
if device.type == "cuda":
logfile_fields += tuple([k[0] for k in gpu_log()])
logfile_fields = create_logfile_fields(valid_loader, device)
# the lowest validation loss obtained so far - this value is updated only
# if a validation set is available
@@ -208,31 +531,11 @@ def run(
for samples in tqdm(
data_loader, desc="batch", leave=False, disable=None
):
# data forwarding on the existing network
images = samples[1].to(
device=device, non_blocking=torch.cuda.is_available()
)
ground_truths = samples[2].to(
device=device, non_blocking=torch.cuda.is_available()
)
masks = (
torch.ones_like(ground_truths)
if len(samples) < 4
else samples[3].to(
device=device, non_blocking=torch.cuda.is_available()
)
losses, optimizer = train_sample_process(
samples, model, optimizer, losses, device, criterion
)
outputs = model(images)
loss = criterion(outputs, ground_truths, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.update(loss)
logger.debug(f"batch loss: {loss.item()}")
if PYTORCH_GE_110:
scheduler.step()
@@ -247,43 +550,19 @@ def run(
valid_loader, desc="valid", leave=False, disable=None
):
# data forwarding on the existing network
images = samples[1].to(
device=device,
non_blocking=torch.cuda.is_available(),
)
ground_truths = samples[2].to(
device=device,
non_blocking=torch.cuda.is_available(),
)
masks = (
torch.ones_like(ground_truths)
if len(samples) < 4
else samples[3].to(
device=device,
non_blocking=torch.cuda.is_available(),
)
valid_losses = valid_sample_process(
samples, model, valid_losses, device, criterion
)
outputs = model(images)
loss = criterion(outputs, ground_truths, masks)
valid_losses.update(loss)
if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save(f"model_{epoch:03d}", **arguments)
if (
valid_losses is not None
and valid_losses.avg < lowest_validation_loss
):
lowest_validation_loss = valid_losses.avg
logger.info(
f"Found new low on validation set:"
f" {lowest_validation_loss:.6f}"
)
checkpointer.save("model_lowest_valid_loss", **arguments)
if epoch >= max_epoch:
checkpointer.save("model_final", **arguments)
checkpointer_process(
checkpointer,
checkpoint_period,
valid_losses,
lowest_validation_loss,
arguments,
epoch,
max_epoch,
)
# computes ETA (estimated time-of-arrival; end of training) taking
# into consideration previous epoch performance
@@ -291,29 +570,17 @@ def run(
eta_seconds = epoch_time * (max_epoch - epoch)
current_time = time.time() - start_training_time
logdata = (
("epoch", f"{epoch}"),
(
"total_time",
f"{datetime.timedelta(seconds=int(current_time))}",
),
("eta", f"{datetime.timedelta(seconds=int(eta_seconds))}"),
("average_loss", f"{losses.avg:.6f}"),
("median_loss", f"{losses.median:.6f}"),
("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"),
write_log_info(
epoch,
current_time,
eta_seconds,
losses,
valid_losses,
optimizer,
logwriter,
logfile,
device,
)
if valid_losses is not None:
logdata += (
("validation_average_loss", f"{valid_losses.avg:.6f}"),
("validation_median_loss", f"{valid_losses.median:.6f}"),
)
logdata += cpu_log()
if device.type == "cuda":
logdata += gpu_log()
logwriter.writerow(dict(k for k in logdata))
logfile.flush()
tqdm.write("|".join([f"{k}: {v}" for (k, v) in logdata[:4]]))
total_training_time = time.time() - start_training_time
logger.info(
Loading