From de12a4aa6b0f7959e1a551ff928a6631f87a35b9 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Mon, 20 Apr 2020 19:48:40 +0200
Subject: [PATCH] [engine.ssltrainer] Fix doc generation

---
 bob/ip/binseg/engine/ssltrainer.py | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index 27c0089f..d7310ed0 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -5,10 +5,12 @@ import os
 import csv
 import time
 import datetime
+import distutils.version
+
+import numpy
+import pandas
 import torch
-import pandas as pd
 from tqdm import tqdm
-import numpy as np
 
 from bob.ip.binseg.utils.metric import SmoothedValue
 from bob.ip.binseg.utils.plot import loss_curve
@@ -48,7 +50,7 @@ def mix_up(alpha, input, target, unlabelled_input, unlabled_target):
     """
     # TODO:
     with torch.no_grad():
-        l = np.random.beta(alpha, alpha)  # Eq (8)
+        l = numpy.random.beta(alpha, alpha)  # Eq (8)
         l = max(l, 1 - l)  # Eq (9)
         # Shuffle and concat. Alg. 1 Line: 12
         w_inputs = torch.cat([input, unlabelled_input], 0)
@@ -96,7 +98,7 @@ def square_rampup(current, rampup_length=16):
     if rampup_length == 0:
         return 1.0
     else:
-        current = np.clip((current / float(rampup_length)) ** 2, 0.0, 1.0)
+        current = numpy.clip((current / float(rampup_length)) ** 2, 0.0, 1.0)
     return float(current)
 
 
@@ -121,7 +123,7 @@ def linear_rampup(current, rampup_length=16):
     if rampup_length == 0:
         return 1.0
     else:
-        current = np.clip(current / rampup_length, 0.0, 1.0)
+        current = numpy.clip(current / rampup_length, 0.0, 1.0)
     return float(current)
 
 
@@ -340,7 +342,7 @@ def run(
         )
 
     # plots a version of the CSV trainlog into a PDF
-    logdf = pd.read_csv(logfile_name, header=0, names=logfile_fields)
+    logdf = pandas.read_csv(logfile_name, header=0, names=logfile_fields)
     fig = loss_curve(logdf)
     figurefile_name = os.path.join(output_folder, "trainlog.pdf")
     logger.info(f"Saving {figurefile_name}")
-- 
GitLab