Skip to content
Snippets Groups Projects
Commit 507513c1 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.train] Streamline method names

parent dd60fe53
No related branches found
No related tags found
1 merge request!12Streamlining
Pipeline #38730 passed
...@@ -156,7 +156,7 @@ def guess_labels(unlabelled_images, model): ...@@ -156,7 +156,7 @@ def guess_labels(unlabelled_images, model):
return avg_guess return avg_guess
def do_ssltrain( def run(
model, model,
data_loader, data_loader,
optimizer, optimizer,
...@@ -170,7 +170,7 @@ def do_ssltrain( ...@@ -170,7 +170,7 @@ def do_ssltrain(
rampup_length, rampup_length,
): ):
""" """
Trains model using semi-supervised learning and saves it to disk. Fits an FCN model using semi-supervised learning and saves it to disk.
Parameters Parameters
---------- ----------
......
...@@ -16,7 +16,7 @@ import logging ...@@ -16,7 +16,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def do_train( def run(
model, model,
data_loader, data_loader,
optimizer, optimizer,
...@@ -29,7 +29,7 @@ def do_train( ...@@ -29,7 +29,7 @@ def do_train(
output_folder, output_folder,
): ):
""" """
Train models and save it to disk. Fits an FCN model using supervised learning and save it to disk.
This method supports periodic checkpointing and the output of a This method supports periodic checkpointing and the output of a
CSV-formatted log with the evolution of some figures during training. CSV-formatted log with the evolution of some figures during training.
......
...@@ -18,8 +18,6 @@ from bob.extension.scripts.click_helper import ( ...@@ -18,8 +18,6 @@ from bob.extension.scripts.click_helper import (
) )
from ..utils.checkpointer import DetectronCheckpointer from ..utils.checkpointer import DetectronCheckpointer
from ..engine.trainer import do_train
from ..engine.ssltrainer import do_ssltrain
import logging import logging
...@@ -247,7 +245,8 @@ def train( ...@@ -247,7 +245,8 @@ def train(
logger.info("Continuing from epoch {}".format(arguments["epoch"])) logger.info("Continuing from epoch {}".format(arguments["epoch"]))
if not ssl: if not ssl:
do_train( from ..engine.trainer import run
run(
model, model,
data_loader, data_loader,
optimizer, optimizer,
...@@ -261,8 +260,8 @@ def train( ...@@ -261,8 +260,8 @@ def train(
) )
else: else:
from ..engine.ssltrainer import run
do_ssltrain( run(
model, model,
data_loader, data_loader,
optimizer, optimizer,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment