Skip to content
Snippets Groups Projects
Commit 507513c1 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.train] Streamline method names

parent dd60fe53
No related branches found
No related tags found
1 merge request!12Streamlining
Pipeline #38730 passed
......@@ -156,7 +156,7 @@ def guess_labels(unlabelled_images, model):
return avg_guess
def do_ssltrain(
def run(
model,
data_loader,
optimizer,
......@@ -170,7 +170,7 @@ def do_ssltrain(
rampup_length,
):
"""
Trains model using semi-supervised learning and saves it to disk.
Fits an FCN model using semi-supervised learning and saves it to disk.
Parameters
----------
......
......@@ -16,7 +16,7 @@ import logging
logger = logging.getLogger(__name__)
def do_train(
def run(
model,
data_loader,
optimizer,
......@@ -29,7 +29,7 @@ def do_train(
output_folder,
):
"""
Train models and save it to disk.
Fits an FCN model using supervised learning and save it to disk.
This method supports periodic checkpointing and the output of a
CSV-formatted log with the evolution of some figures during training.
......
......@@ -18,8 +18,6 @@ from bob.extension.scripts.click_helper import (
)
from ..utils.checkpointer import DetectronCheckpointer
from ..engine.trainer import do_train
from ..engine.ssltrainer import do_ssltrain
import logging
......@@ -247,7 +245,8 @@ def train(
logger.info("Continuing from epoch {}".format(arguments["epoch"]))
if not ssl:
do_train(
from ..engine.trainer import run
run(
model,
data_loader,
optimizer,
......@@ -261,8 +260,8 @@ def train(
)
else:
do_ssltrain(
from ..engine.ssltrainer import run
run(
model,
data_loader,
optimizer,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment