Skip to content
Snippets Groups Projects
Commit 5899f1b5 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.trainer] Improve docs; Use f-strings where possible

parent 39ed1c24
No related branches found
No related tags found
1 merge request!12Streamlining
...@@ -13,7 +13,6 @@ from bob.ip.binseg.utils.metric import SmoothedValue ...@@ -13,7 +13,6 @@ from bob.ip.binseg.utils.metric import SmoothedValue
from bob.ip.binseg.utils.plot import loss_curve from bob.ip.binseg.utils.plot import loss_curve
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -63,7 +62,7 @@ def do_train( ...@@ -63,7 +62,7 @@ def do_train(
device to use ``'cpu'`` or ``cuda:0`` device to use ``'cpu'`` or ``cuda:0``
arguments : dict arguments : dict
start end end epochs start and end epochs
output_folder : str output_folder : str
output path output path
...@@ -133,10 +132,10 @@ def do_train( ...@@ -133,10 +132,10 @@ def do_train(
optimizer.step() optimizer.step()
losses.update(loss) losses.update(loss)
logger.debug("batch loss: {}".format(loss.item())) logger.debug(f"batch loss: {loss.item()}")
if checkpoint_period and (epoch % checkpoint_period == 0): if checkpoint_period and (epoch % checkpoint_period == 0):
checkpointer.save("model_{:03d}".format(epoch), **arguments) checkpointer.save(f"model_{epoch:03d}", **arguments)
if epoch >= max_epoch: if epoch >= max_epoch:
checkpointer.save("model_final", **arguments) checkpointer.save("model_final", **arguments)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment