From de12a4aa6b0f7959e1a551ff928a6631f87a35b9 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Mon, 20 Apr 2020 19:48:40 +0200 Subject: [PATCH] [engine.ssltrainer] Fix doc generation --- bob/ip/binseg/engine/ssltrainer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 27c0089f..d7310ed0 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -5,10 +5,12 @@ import os import csv import time import datetime +import distutils.version + +import numpy +import pandas import torch -import pandas as pd from tqdm import tqdm -import numpy as np from bob.ip.binseg.utils.metric import SmoothedValue from bob.ip.binseg.utils.plot import loss_curve @@ -48,7 +50,7 @@ def mix_up(alpha, input, target, unlabelled_input, unlabled_target): """ # TODO: with torch.no_grad(): - l = np.random.beta(alpha, alpha) # Eq (8) + l = numpy.random.beta(alpha, alpha) # Eq (8) l = max(l, 1 - l) # Eq (9) # Shuffle and concat. Alg. 1 Line: 12 w_inputs = torch.cat([input, unlabelled_input], 0) @@ -96,7 +98,7 @@ def square_rampup(current, rampup_length=16): if rampup_length == 0: return 1.0 else: - current = np.clip((current / float(rampup_length)) ** 2, 0.0, 1.0) + current = numpy.clip((current / float(rampup_length)) ** 2, 0.0, 1.0) return float(current) @@ -121,7 +123,7 @@ def linear_rampup(current, rampup_length=16): if rampup_length == 0: return 1.0 else: - current = np.clip(current / rampup_length, 0.0, 1.0) + current = numpy.clip(current / rampup_length, 0.0, 1.0) return float(current) @@ -340,7 +342,7 @@ def run( ) # plots a version of the CSV trainlog into a PDF - logdf = pd.read_csv(logfile_name, header=0, names=logfile_fields) + logdf = pandas.read_csv(logfile_name, header=0, names=logfile_fields) fig = loss_curve(logdf) figurefile_name = os.path.join(output_folder, "trainlog.pdf") logger.info(f"Saving {figurefile_name}") -- GitLab