diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index 00f9318212e999c742f0c87380f6b45cf1c61a5a..81b191dd8f7dde26984623f306669c1ffa1ed957 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -45,7 +45,7 @@ def run( ---------- model : :py:class:`torch.nn.Module` - Network (e.g. DRIU, HED, UNet) + Network (e.g. driu, hed, unet) data_loader : :py:class:`torch.utils.data.DataLoader` @@ -142,6 +142,7 @@ def run( for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) + # Total training timer start_training_time = time.time() @@ -209,7 +210,7 @@ def run( ("median_loss", f"{losses.median:.6f}"), ("learning_rate", f"{optimizer.param_groups[0]['lr']:.6f}"), ) + cpu_log() - if device != 'cpu': + if device != "cpu": logdata += gpu_log() logwriter.writerow(dict(k for k in logdata))