From c9311f9c7dc6d4289683e766ee542df022929c22 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Thu, 6 Jun 2019 17:30:45 +0200
Subject: [PATCH] Move utils to the utils folder

---
 bob/learn/tensorflow/__init__.py            |  4 --
 bob/learn/tensorflow/estimators/__init__.py | 47 +------------------
 bob/learn/tensorflow/loss/__init__.py       | 16 +++++--
 bob/learn/tensorflow/network/utils.py       | 52 +--------------------
 4 files changed, 15 insertions(+), 104 deletions(-)

diff --git a/bob/learn/tensorflow/__init__.py b/bob/learn/tensorflow/__init__.py
index 48cde48a..b1734864 100644
--- a/bob/learn/tensorflow/__init__.py
+++ b/bob/learn/tensorflow/__init__.py
@@ -1,7 +1,3 @@
-import logging
-logging.getLogger("tensorflow").setLevel(logging.WARNING)
-
-
 def get_config():
     """
     Returns a string containing the configuration information.
diff --git a/bob/learn/tensorflow/estimators/__init__.py b/bob/learn/tensorflow/estimators/__init__.py
index 7f144bf2..10f500c5 100644
--- a/bob/learn/tensorflow/estimators/__init__.py
+++ b/bob/learn/tensorflow/estimators/__init__.py
@@ -4,52 +4,7 @@
 
 import tensorflow as tf
 
-
-def check_features(features):
-    if "data" not in features or "key" not in features:
-        raise ValueError(
-            "The input function needs to contain a dictionary with the keys `data` and `key` "
-        )
-    return True
-
-
-def get_trainable_variables(extra_checkpoint, mode=tf.estimator.ModeKeys.TRAIN):
-    """
-    Given the extra_checkpoint dictionary provided to the estimator,
-    extract the content of "trainable_variables".
-
-    If trainable_variables is not provided, all end points are trainable by
-    default.
-    If trainable_variables==[], all end points are NOT trainable.
-    If trainable_variables contains some end_points, ONLY these endpoints will
-    be trainable.
-
-    Attributes
-    ----------
-
-    extra_checkpoint: dict
-      The extra_checkpoint dictionary provided to the estimator
-
-    mode:
-        The estimator mode. TRAIN, EVAL, and PREDICT. If not TRAIN, None is
-        returned.
-
-    Returns
-    -------
-
-    Returns `None` if **trainable_variables** is not in extra_checkpoint;
-    otherwise returns the content of extra_checkpoint .
-    """
-    if mode != tf.estimator.ModeKeys.TRAIN:
-        return None
-
-    # If you don't set anything, everything is trainable
-    if extra_checkpoint is None or "trainable_variables" not in extra_checkpoint:
-        return None
-
-    return extra_checkpoint["trainable_variables"]
-
-
+from ..utils import get_trainable_variables, check_features
 from .utils import MovingAverageOptimizer, learning_rate_decay_fn
 from .Logits import Logits, LogitsCenterLoss
 from .Siamese import Siamese
diff --git a/bob/learn/tensorflow/loss/__init__.py b/bob/learn/tensorflow/loss/__init__.py
index f35a0ceb..eab22bb4 100644
--- a/bob/learn/tensorflow/loss/__init__.py
+++ b/bob/learn/tensorflow/loss/__init__.py
@@ -3,6 +3,7 @@ from .ContrastiveLoss import contrastive_loss
 from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss
 from .StyleLoss import linear_gram_style_loss, content_loss, denoising_loss
 from .vat import VATLoss
+from .pixel_wise import PixelWise
 from .utils import *
 
 
@@ -22,7 +23,14 @@ def __appropriate__(*args):
         obj.__module__ = __name__
 
 
-__appropriate__(mean_cross_entropy_loss, mean_cross_entropy_center_loss,
-                contrastive_loss, triplet_loss, triplet_average_loss,
-                triplet_fisher_loss)
-__all__ = [_ for _ in dir() if not _.startswith('_')]
+__appropriate__(
+    mean_cross_entropy_loss,
+    mean_cross_entropy_center_loss,
+    contrastive_loss,
+    triplet_loss,
+    triplet_average_loss,
+    triplet_fisher_loss,
+    VATLoss,
+    PixelWise,
+)
+__all__ = [_ for _ in dir() if not _.startswith("_")]
diff --git a/bob/learn/tensorflow/network/utils.py b/bob/learn/tensorflow/network/utils.py
index 780a1682..0c5c855f 100644
--- a/bob/learn/tensorflow/network/utils.py
+++ b/bob/learn/tensorflow/network/utils.py
@@ -2,53 +2,5 @@
 # vim: set fileencoding=utf-8 :
 # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
 
-import tensorflow as tf
-import tensorflow.contrib.slim as slim
-
-
-def append_logits(graph,
-                  n_classes,
-                  reuse=False,
-                  l2_regularizer=5e-05,
-                  weights_std=0.1, trainable_variables=None,
-                  name='Logits'):
-    trainable = is_trainable(name, trainable_variables)
-    return slim.fully_connected(
-        graph,
-        n_classes,
-        activation_fn=None,
-        weights_initializer=tf.truncated_normal_initializer(
-            stddev=weights_std),
-        weights_regularizer=slim.l2_regularizer(l2_regularizer),
-        scope=name,
-        reuse=reuse,
-        trainable=trainable,
-    )
-
-
-def is_trainable(name, trainable_variables, mode=tf.estimator.ModeKeys.TRAIN):
-    """
-    Check if a variable is trainable or not
-
-    Parameters
-    ----------
-
-    name: str
-       Layer name
-
-    trainable_variables: list
-       List containing the variables or scopes to be trained.
-       If None, the variable/scope is trained
-    """
-
-    # if mode is not training, so we shutdown
-    if mode != tf.estimator.ModeKeys.TRAIN:
-        return False
-
-    # If None, we train by default
-    if trainable_variables is None:
-        return True
-
-    # Here is my choice to shutdown the whole scope
-    return name in trainable_variables
-
+# functions were moved
+from ..utils.network import append_logits, is_trainable
-- 
GitLab