From 20364461659ff8c8a5f76f12ca0bdd93df71872f Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.anjos@idiap.ch>
Date: Tue, 7 Apr 2020 12:18:57 +0200
Subject: [PATCH] [engine.trainer] Move creation of output_folder into engine;
 Call train argument output-path -> output-folder to make it more explicit

---
 bob/ip/binseg/engine/ssltrainer.py |  4 ++++
 bob/ip/binseg/engine/trainer.py    |  4 ++++
 bob/ip/binseg/script/train.py      | 16 ++++++----------
 doc/training.rst                   |  1 +
 4 files changed, 15 insertions(+), 10 deletions(-)

diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index 3ba5e46a..4d1d1c2c 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -213,6 +213,10 @@ def run(
     start_epoch = arguments["epoch"]
     max_epoch = arguments["max_epoch"]
 
+    if not os.path.exists(output_folder):
+        logger.debug(f"Creating output directory '{output_folder}'...")
+        os.makedirs(output_folder)
+
     # Log to file
     logfile_name = os.path.join(output_folder, "trainlog.csv")
 
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index 501105a9..0d575bc9 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -72,6 +72,10 @@ def run(
     start_epoch = arguments["epoch"]
     max_epoch = arguments["max_epoch"]
 
+    if not os.path.exists(output_folder):
+        logger.debug(f"Creating output directory '{output_folder}'...")
+        os.makedirs(output_folder)
+
     # Log to file
     logfile_name = os.path.join(output_folder, "trainlog.csv")
 
diff --git a/bob/ip/binseg/script/train.py b/bob/ip/binseg/script/train.py
index ac1ba769..3302e9ea 100644
--- a/bob/ip/binseg/script/train.py
+++ b/bob/ip/binseg/script/train.py
@@ -1,8 +1,7 @@
 #!/usr/bin/env python
-# vim: set fileencoding=utf-8 :
+# coding=utf-8
 
 import os
-import pkg_resources
 
 import click
 from click_plugins import with_plugins
@@ -20,7 +19,6 @@ from bob.extension.scripts.click_helper import (
 from ..utils.checkpointer import DetectronCheckpointer
 
 import logging
-
 logger = logging.getLogger(__name__)
 
 
@@ -50,7 +48,7 @@ logger = logging.getLogger(__name__)
 """,
 )
 @click.option(
-    "--output-path",
+    "--output-folder",
     "-o",
     help="Path where to store the generated model (created if does not exist)",
     required=True,
@@ -193,7 +191,7 @@ def train(
     model,
     optimizer,
     scheduler,
-    output_path,
+    output_folder,
     epochs,
     pretrained_backbone,
     batch_size,
@@ -217,8 +215,6 @@ def train(
     abruptly.
     """
 
-    if not os.path.exists(output_path):
-        os.makedirs(output_path)
     torch.manual_seed(seed)
 
     # PyTorch dataloader
@@ -232,7 +228,7 @@ def train(
 
     # Checkpointer
     checkpointer = DetectronCheckpointer(
-        model, optimizer, scheduler, save_dir=output_path, save_to_disk=True
+        model, optimizer, scheduler, save_dir=output_folder, save_to_disk=True
     )
 
     arguments = {}
@@ -256,7 +252,7 @@ def train(
             checkpoint_period,
             device,
             arguments,
-            output_path,
+            output_folder,
         )
 
     else:
@@ -271,6 +267,6 @@ def train(
             checkpoint_period,
             device,
             arguments,
-            output_path,
+            output_folder,
             rampup,
         )
diff --git a/doc/training.rst b/doc/training.rst
index 7096bb8a..65cf6669 100644
--- a/doc/training.rst
+++ b/doc/training.rst
@@ -33,6 +33,7 @@ card, for supervised training of baselines.  Use it like this:
 
    # change <model> and <dataset> by one of items bellow
    $ bob binseg train -vv <model> <dataset> --batch-size=<see-table> --device="cuda:0"
+   # check results in the "results" folder
 
 .. list-table::
 
-- 
GitLab