From d70f90aec0c9303c5660ff1e66757bd09f8ee142 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Wed, 13 May 2020 11:37:06 +0200
Subject: [PATCH] [engine] More comments on model.train() and model.to() usage

---
 bob/ip/binseg/engine/predictor.py  | 7 ++++---
 bob/ip/binseg/engine/ssltrainer.py | 4 +++-
 bob/ip/binseg/engine/trainer.py    | 4 +++-
 3 files changed, 10 insertions(+), 5 deletions(-)

diff --git a/bob/ip/binseg/engine/predictor.py b/bob/ip/binseg/engine/predictor.py
index 4e4640f0..de78e69f 100644
--- a/bob/ip/binseg/engine/predictor.py
+++ b/bob/ip/binseg/engine/predictor.py
@@ -125,9 +125,10 @@ def run(model, data_loader, device, output_folder, overlayed_folder):
     os.makedirs(output_folder, exist_ok=True)
 
     logger.info(f"Device: {device}")
-    model.eval().to(device)
-    # Sigmoid for predictions
-    sigmoid = torch.nn.Sigmoid()
+
+    model.eval()  # set evaluation mode
+    model.to(device)  # set/cast parameters to device
+    sigmoid = torch.nn.Sigmoid()  # use sigmoid for predictions
 
     # Setup timers
     start_total_time = time.time()
diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index 40566f21..98c17e1b 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -300,7 +300,9 @@ def run(
         if arguments["epoch"] == 0:
             logwriter.writeheader()
 
-        model.train().to(device)
+        model.train()  #set training mode
+
+        model.to(device)  # set/cast parameters to device
         for state in optimizer.state.values():
             for k, v in state.items():
                 if isinstance(v, torch.Tensor):
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index ed34fbe0..4c4e558e 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -179,7 +179,9 @@ def run(
         if arguments["epoch"] == 0:
             logwriter.writeheader()
 
-        model.train().to(device)
+        model.train()  #set training mode
+
+        model.to(device)  # set/cast parameters to device
         for state in optimizer.state.values():
             for k, v in state.items():
                 if isinstance(v, torch.Tensor):
-- 
GitLab