From 507513c14a129e2bca1cf41e383d8b2503ec3f92 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.anjos@idiap.ch> Date: Sun, 5 Apr 2020 15:51:47 +0200 Subject: [PATCH] [engine.train] Streamline method names --- bob/ip/binseg/engine/ssltrainer.py | 4 ++-- bob/ip/binseg/engine/trainer.py | 4 ++-- bob/ip/binseg/script/train.py | 9 ++++----- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 542d3162..3ba5e46a 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -156,7 +156,7 @@ def guess_labels(unlabelled_images, model): return avg_guess -def do_ssltrain( +def run( model, data_loader, optimizer, @@ -170,7 +170,7 @@ def do_ssltrain( rampup_length, ): """ - Trains model using semi-supervised learning and saves it to disk. + Fits an FCN model using semi-supervised learning and saves it to disk. Parameters ---------- diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index 28ffeaa2..501105a9 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -16,7 +16,7 @@ import logging logger = logging.getLogger(__name__) -def do_train( +def run( model, data_loader, optimizer, @@ -29,7 +29,7 @@ def do_train( output_folder, ): """ - Train models and save it to disk. + Fits an FCN model using supervised learning and save it to disk. This method supports periodic checkpointing and the output of a CSV-formatted log with the evolution of some figures during training. diff --git a/bob/ip/binseg/script/train.py b/bob/ip/binseg/script/train.py index 3156b2e6..ac1ba769 100644 --- a/bob/ip/binseg/script/train.py +++ b/bob/ip/binseg/script/train.py @@ -18,8 +18,6 @@ from bob.extension.scripts.click_helper import ( ) from ..utils.checkpointer import DetectronCheckpointer -from ..engine.trainer import do_train -from ..engine.ssltrainer import do_ssltrain import logging @@ -247,7 +245,8 @@ def train( logger.info("Continuing from epoch {}".format(arguments["epoch"])) if not ssl: - do_train( + from ..engine.trainer import run + run( model, data_loader, optimizer, @@ -261,8 +260,8 @@ def train( ) else: - - do_ssltrain( + from ..engine.ssltrainer import run + run( model, data_loader, optimizer, -- GitLab