Skip to content
Snippets Groups Projects
Commit e4695064 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine] Set logger on module level

parent e6a4630b
No related branches found
No related tags found
1 merge request!12Streamlining
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import logging
import time import time
import datetime import datetime
import numpy as np import numpy as np
...@@ -17,6 +16,9 @@ from bob.ip.binseg.utils.metric import base_metrics ...@@ -17,6 +16,9 @@ from bob.ip.binseg.utils.metric import base_metrics
from bob.ip.binseg.utils.plot import precision_recall_f1iso_confintval from bob.ip.binseg.utils.plot import precision_recall_f1iso_confintval
from bob.ip.binseg.utils.summary import summary from bob.ip.binseg.utils.summary import summary
import logging
logger = logging.getLogger(__name__)
def batch_metrics(predictions, ground_truths, names, output_folder, logger): def batch_metrics(predictions, ground_truths, names, output_folder, logger):
""" """
...@@ -165,7 +167,7 @@ def do_inference(model, data_loader, device, output_folder=None): ...@@ -165,7 +167,7 @@ def do_inference(model, data_loader, device, output_folder=None):
device to use ``'cpu'`` or ``'cuda'`` device to use ``'cpu'`` or ``'cuda'``
output_folder : str output_folder : str
""" """
logger = logging.getLogger("bob.ip.binseg.engine.inference")
logger.info("Start evaluation") logger.info("Start evaluation")
logger.info("Output folder: {}, Device: {}".format(output_folder, device)) logger.info("Output folder: {}, Device: {}".format(output_folder, device))
results_subfolder = os.path.join(output_folder, "results") results_subfolder = os.path.join(output_folder, "results")
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import logging
import time import time
import datetime import datetime
import numpy as np import numpy as np
...@@ -12,6 +11,9 @@ from tqdm import tqdm ...@@ -12,6 +11,9 @@ from tqdm import tqdm
from bob.ip.binseg.engine.inferencer import save_probability_images from bob.ip.binseg.engine.inferencer import save_probability_images
from bob.ip.binseg.engine.inferencer import save_hdf from bob.ip.binseg.engine.inferencer import save_hdf
import logging
logger = logging.getLogger(__name__)
def do_predict(model, data_loader, device, output_folder=None): def do_predict(model, data_loader, device, output_folder=None):
...@@ -27,7 +29,6 @@ def do_predict(model, data_loader, device, output_folder=None): ...@@ -27,7 +29,6 @@ def do_predict(model, data_loader, device, output_folder=None):
device to use ``'cpu'`` or ``'cuda'`` device to use ``'cpu'`` or ``'cuda'``
output_folder : str output_folder : str
""" """
logger = logging.getLogger("bob.ip.binseg.engine.inference")
logger.info("Start evaluation") logger.info("Start evaluation")
logger.info("Output folder: {}, Device: {}".format(output_folder, device)) logger.info("Output folder: {}, Device: {}".format(output_folder, device))
results_subfolder = os.path.join(output_folder, "results") results_subfolder = os.path.join(output_folder, "results")
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import logging
import time import time
import datetime import datetime
import torch import torch
...@@ -13,6 +12,9 @@ import numpy as np ...@@ -13,6 +12,9 @@ import numpy as np
from bob.ip.binseg.utils.metric import SmoothedValue from bob.ip.binseg.utils.metric import SmoothedValue
from bob.ip.binseg.utils.plot import loss_curve from bob.ip.binseg.utils.plot import loss_curve
import logging
logger = logging.getLogger(__name__)
def sharpen(x, T): def sharpen(x, T):
temp = x ** (1 / T) temp = x ** (1 / T)
...@@ -204,7 +206,6 @@ def do_ssltrain( ...@@ -204,7 +206,6 @@ def do_ssltrain(
rampup epochs rampup epochs
""" """
logger = logging.getLogger("bob.ip.binseg.engine.trainer")
logger.info("Start training") logger.info("Start training")
start_epoch = arguments["epoch"] start_epoch = arguments["epoch"]
max_epoch = arguments["max_epoch"] max_epoch = arguments["max_epoch"]
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os import os
import logging
import time import time
import datetime import datetime
import torch import torch
...@@ -12,6 +11,9 @@ from tqdm import tqdm ...@@ -12,6 +11,9 @@ from tqdm import tqdm
from bob.ip.binseg.utils.metric import SmoothedValue from bob.ip.binseg.utils.metric import SmoothedValue
from bob.ip.binseg.utils.plot import loss_curve from bob.ip.binseg.utils.plot import loss_curve
import logging
logger = logging.getLogger(__name__)
def do_train( def do_train(
model, model,
...@@ -25,12 +27,12 @@ def do_train( ...@@ -25,12 +27,12 @@ def do_train(
arguments, arguments,
output_folder, output_folder,
): ):
""" """
Train model and save to disk. Train model and save to disk.
Parameters Parameters
---------- ----------
model : :py:class:`torch.nn.Module` model : :py:class:`torch.nn.Module`
Network (e.g. DRIU, HED, UNet) Network (e.g. DRIU, HED, UNet)
data_loader : :py:class:`torch.utils.data.DataLoader` data_loader : :py:class:`torch.utils.data.DataLoader`
optimizer : :py:mod:`torch.optim` optimizer : :py:mod:`torch.optim`
...@@ -42,14 +44,13 @@ def do_train( ...@@ -42,14 +44,13 @@ def do_train(
checkpointer checkpointer
checkpoint_period : int checkpoint_period : int
save a checkpoint every n epochs save a checkpoint every n epochs
device : str device : str
device to use ``'cpu'`` or ``'cuda'`` device to use ``'cpu'`` or ``'cuda'``
arguments : dict arguments : dict
start end end epochs start end end epochs
output_folder : str output_folder : str
output path output path
""" """
logger = logging.getLogger("bob.ip.binseg.engine.trainer")
logger.info("Start training") logger.info("Start training")
start_epoch = arguments["epoch"] start_epoch = arguments["epoch"]
max_epoch = arguments["max_epoch"] max_epoch = arguments["max_epoch"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment