diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 542d316209b64cfbc6f367a7e0ebf0f26ad2d6be..3ba5e46acd2fb8028e2e6f3aa1e9619428a1d23e 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -156,7 +156,7 @@ def guess_labels(unlabelled_images, model): return avg_guess -def do_ssltrain( +def run( model, data_loader, optimizer, @@ -170,7 +170,7 @@ def do_ssltrain( rampup_length, ): """ - Trains model using semi-supervised learning and saves it to disk. + Fits an FCN model using semi-supervised learning and saves it to disk. Parameters ---------- diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index 28ffeaa2dce1ebf721da384053c6b2ec23a15a8b..501105a9036f35f522bf52bf3dc3b1650813e557 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -16,7 +16,7 @@ import logging logger = logging.getLogger(__name__) -def do_train( +def run( model, data_loader, optimizer, @@ -29,7 +29,7 @@ def do_train( output_folder, ): """ - Train models and save it to disk. + Fits an FCN model using supervised learning and save it to disk. This method supports periodic checkpointing and the output of a CSV-formatted log with the evolution of some figures during training. diff --git a/bob/ip/binseg/script/train.py b/bob/ip/binseg/script/train.py index 3156b2e6d484bf94ca846056953ffcd4d9e9751b..ac1ba76987d0e6a2cf6d85161ea91d18d734ad81 100644 --- a/bob/ip/binseg/script/train.py +++ b/bob/ip/binseg/script/train.py @@ -18,8 +18,6 @@ from bob.extension.scripts.click_helper import ( ) from ..utils.checkpointer import DetectronCheckpointer -from ..engine.trainer import do_train -from ..engine.ssltrainer import do_ssltrain import logging @@ -247,7 +245,8 @@ def train( logger.info("Continuing from epoch {}".format(arguments["epoch"])) if not ssl: - do_train( + from ..engine.trainer import run + run( model, data_loader, optimizer, @@ -261,8 +260,8 @@ def train( ) else: - - do_ssltrain( + from ..engine.ssltrainer import run + run( model, data_loader, optimizer,