diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 27c0089f4658171f5f0b87c5c8be6bba604f5521..d7310ed057de296eadaffd6833891dd296290257 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -5,10 +5,12 @@ import os import csv import time import datetime +import distutils.version + +import numpy +import pandas import torch -import pandas as pd from tqdm import tqdm -import numpy as np from bob.ip.binseg.utils.metric import SmoothedValue from bob.ip.binseg.utils.plot import loss_curve @@ -48,7 +50,7 @@ def mix_up(alpha, input, target, unlabelled_input, unlabled_target): """ # TODO: with torch.no_grad(): - l = np.random.beta(alpha, alpha) # Eq (8) + l = numpy.random.beta(alpha, alpha) # Eq (8) l = max(l, 1 - l) # Eq (9) # Shuffle and concat. Alg. 1 Line: 12 w_inputs = torch.cat([input, unlabelled_input], 0) @@ -96,7 +98,7 @@ def square_rampup(current, rampup_length=16): if rampup_length == 0: return 1.0 else: - current = np.clip((current / float(rampup_length)) ** 2, 0.0, 1.0) + current = numpy.clip((current / float(rampup_length)) ** 2, 0.0, 1.0) return float(current) @@ -121,7 +123,7 @@ def linear_rampup(current, rampup_length=16): if rampup_length == 0: return 1.0 else: - current = np.clip(current / rampup_length, 0.0, 1.0) + current = numpy.clip(current / rampup_length, 0.0, 1.0) return float(current) @@ -340,7 +342,7 @@ def run( ) # plots a version of the CSV trainlog into a PDF - logdf = pd.read_csv(logfile_name, header=0, names=logfile_fields) + logdf = pandas.read_csv(logfile_name, header=0, names=logfile_fields) fig = loss_curve(logdf) figurefile_name = os.path.join(output_folder, "trainlog.pdf") logger.info(f"Saving {figurefile_name}")