From d70f90aec0c9303c5660ff1e66757bd09f8ee142 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Wed, 13 May 2020 11:37:06 +0200 Subject: [PATCH] [engine] More comments on model.train() and model.to() usage --- bob/ip/binseg/engine/predictor.py | 7 ++++--- bob/ip/binseg/engine/ssltrainer.py | 4 +++- bob/ip/binseg/engine/trainer.py | 4 +++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/bob/ip/binseg/engine/predictor.py b/bob/ip/binseg/engine/predictor.py index 4e4640f0..de78e69f 100644 --- a/bob/ip/binseg/engine/predictor.py +++ b/bob/ip/binseg/engine/predictor.py @@ -125,9 +125,10 @@ def run(model, data_loader, device, output_folder, overlayed_folder): os.makedirs(output_folder, exist_ok=True) logger.info(f"Device: {device}") - model.eval().to(device) - # Sigmoid for predictions - sigmoid = torch.nn.Sigmoid() + + model.eval() # set evaluation mode + model.to(device) # set/cast parameters to device + sigmoid = torch.nn.Sigmoid() # use sigmoid for predictions # Setup timers start_total_time = time.time() diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 40566f21..98c17e1b 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -300,7 +300,9 @@ def run( if arguments["epoch"] == 0: logwriter.writeheader() - model.train().to(device) + model.train() #set training mode + + model.to(device) # set/cast parameters to device for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index ed34fbe0..4c4e558e 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -179,7 +179,9 @@ def run( if arguments["epoch"] == 0: logwriter.writeheader() - model.train().to(device) + model.train() #set training mode + + model.to(device) # set/cast parameters to device for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): -- GitLab