From 20364461659ff8c8a5f76f12ca0bdd93df71872f Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.anjos@idiap.ch> Date: Tue, 7 Apr 2020 12:18:57 +0200 Subject: [PATCH] [engine.trainer] Move creation of output_folder into engine; Call train argument output-path -> output-folder to make it more explicit --- bob/ip/binseg/engine/ssltrainer.py | 4 ++++ bob/ip/binseg/engine/trainer.py | 4 ++++ bob/ip/binseg/script/train.py | 16 ++++++---------- doc/training.rst | 1 + 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 3ba5e46a..4d1d1c2c 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -213,6 +213,10 @@ def run( start_epoch = arguments["epoch"] max_epoch = arguments["max_epoch"] + if not os.path.exists(output_folder): + logger.debug(f"Creating output directory '{output_folder}'...") + os.makedirs(output_folder) + # Log to file logfile_name = os.path.join(output_folder, "trainlog.csv") diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index 501105a9..0d575bc9 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -72,6 +72,10 @@ def run( start_epoch = arguments["epoch"] max_epoch = arguments["max_epoch"] + if not os.path.exists(output_folder): + logger.debug(f"Creating output directory '{output_folder}'...") + os.makedirs(output_folder) + # Log to file logfile_name = os.path.join(output_folder, "trainlog.csv") diff --git a/bob/ip/binseg/script/train.py b/bob/ip/binseg/script/train.py index ac1ba769..3302e9ea 100644 --- a/bob/ip/binseg/script/train.py +++ b/bob/ip/binseg/script/train.py @@ -1,8 +1,7 @@ #!/usr/bin/env python -# vim: set fileencoding=utf-8 : +# coding=utf-8 import os -import pkg_resources import click from click_plugins import with_plugins @@ -20,7 +19,6 @@ from bob.extension.scripts.click_helper import ( from ..utils.checkpointer import DetectronCheckpointer import logging - logger = logging.getLogger(__name__) @@ -50,7 +48,7 @@ logger = logging.getLogger(__name__) """, ) @click.option( - "--output-path", + "--output-folder", "-o", help="Path where to store the generated model (created if does not exist)", required=True, @@ -193,7 +191,7 @@ def train( model, optimizer, scheduler, - output_path, + output_folder, epochs, pretrained_backbone, batch_size, @@ -217,8 +215,6 @@ def train( abruptly. """ - if not os.path.exists(output_path): - os.makedirs(output_path) torch.manual_seed(seed) # PyTorch dataloader @@ -232,7 +228,7 @@ def train( # Checkpointer checkpointer = DetectronCheckpointer( - model, optimizer, scheduler, save_dir=output_path, save_to_disk=True + model, optimizer, scheduler, save_dir=output_folder, save_to_disk=True ) arguments = {} @@ -256,7 +252,7 @@ def train( checkpoint_period, device, arguments, - output_path, + output_folder, ) else: @@ -271,6 +267,6 @@ def train( checkpoint_period, device, arguments, - output_path, + output_folder, rampup, ) diff --git a/doc/training.rst b/doc/training.rst index 7096bb8a..65cf6669 100644 --- a/doc/training.rst +++ b/doc/training.rst @@ -33,6 +33,7 @@ card, for supervised training of baselines. Use it like this: # change <model> and <dataset> by one of items bellow $ bob binseg train -vv <model> <dataset> --batch-size=<see-table> --device="cuda:0" + # check results in the "results" folder .. list-table:: -- GitLab