From 507513c14a129e2bca1cf41e383d8b2503ec3f92 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.anjos@idiap.ch>
Date: Sun, 5 Apr 2020 15:51:47 +0200
Subject: [PATCH] [engine.train] Streamline method names

---
 bob/ip/binseg/engine/ssltrainer.py | 4 ++--
 bob/ip/binseg/engine/trainer.py    | 4 ++--
 bob/ip/binseg/script/train.py      | 9 ++++-----
 3 files changed, 8 insertions(+), 9 deletions(-)

diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index 542d3162..3ba5e46a 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -156,7 +156,7 @@ def guess_labels(unlabelled_images, model):
         return avg_guess
 
 
-def do_ssltrain(
+def run(
     model,
     data_loader,
     optimizer,
@@ -170,7 +170,7 @@ def do_ssltrain(
     rampup_length,
 ):
     """
-    Trains model using semi-supervised learning and saves it to disk.
+    Fits an FCN model using semi-supervised learning and saves it to disk.
 
     Parameters
     ----------
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index 28ffeaa2..501105a9 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -16,7 +16,7 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-def do_train(
+def run(
     model,
     data_loader,
     optimizer,
@@ -29,7 +29,7 @@ def do_train(
     output_folder,
 ):
     """
-    Train models and save it to disk.
+    Fits an FCN model using supervised learning and save it to disk.
 
     This method supports periodic checkpointing and the output of a
     CSV-formatted log with the evolution of some figures during training.
diff --git a/bob/ip/binseg/script/train.py b/bob/ip/binseg/script/train.py
index 3156b2e6..ac1ba769 100644
--- a/bob/ip/binseg/script/train.py
+++ b/bob/ip/binseg/script/train.py
@@ -18,8 +18,6 @@ from bob.extension.scripts.click_helper import (
 )
 
 from ..utils.checkpointer import DetectronCheckpointer
-from ..engine.trainer import do_train
-from ..engine.ssltrainer import do_ssltrain
 
 import logging
 
@@ -247,7 +245,8 @@ def train(
     logger.info("Continuing from epoch {}".format(arguments["epoch"]))
 
     if not ssl:
-        do_train(
+        from ..engine.trainer import run
+        run(
             model,
             data_loader,
             optimizer,
@@ -261,8 +260,8 @@ def train(
         )
 
     else:
-
-        do_ssltrain(
+        from ..engine.ssltrainer import run
+        run(
             model,
             data_loader,
             optimizer,
-- 
GitLab