diff --git a/bob/learn/tensorflow/__init__.py b/bob/learn/tensorflow/__init__.py
index 48cde48af85134ab3034ee1a75833c0647eec32c..b17348644c602f4d16633e7c392b19a37c59086a 100644
--- a/bob/learn/tensorflow/__init__.py
+++ b/bob/learn/tensorflow/__init__.py
@@ -1,7 +1,3 @@
-import logging
-logging.getLogger("tensorflow").setLevel(logging.WARNING)
-
-
def get_config():
"""
Returns a string containing the configuration information.
diff --git a/bob/learn/tensorflow/estimators/__init__.py b/bob/learn/tensorflow/estimators/__init__.py
index 7f144bf271cb20c8198ab1b67e5a25e9157c52a0..10f500c531d9c10e19204272d5dd1a9fb27355a8 100644
--- a/bob/learn/tensorflow/estimators/__init__.py
+++ b/bob/learn/tensorflow/estimators/__init__.py
@@ -4,52 +4,7 @@
import tensorflow as tf
-
-def check_features(features):
- if "data" not in features or "key" not in features:
- raise ValueError(
- "The input function needs to contain a dictionary with the keys `data` and `key` "
- )
- return True
-
-
-def get_trainable_variables(extra_checkpoint, mode=tf.estimator.ModeKeys.TRAIN):
- """
- Given the extra_checkpoint dictionary provided to the estimator,
- extract the content of "trainable_variables".
-
- If trainable_variables is not provided, all end points are trainable by
- default.
- If trainable_variables==[], all end points are NOT trainable.
- If trainable_variables contains some end_points, ONLY these endpoints will
- be trainable.
-
- Attributes
- ----------
-
- extra_checkpoint: dict
- The extra_checkpoint dictionary provided to the estimator
-
- mode:
- The estimator mode. TRAIN, EVAL, and PREDICT. If not TRAIN, None is
- returned.
-
- Returns
- -------
-
- Returns `None` if **trainable_variables** is not in extra_checkpoint;
- otherwise returns the content of extra_checkpoint .
- """
- if mode != tf.estimator.ModeKeys.TRAIN:
- return None
-
- # If you don't set anything, everything is trainable
- if extra_checkpoint is None or "trainable_variables" not in extra_checkpoint:
- return None
-
- return extra_checkpoint["trainable_variables"]
-
-
+from ..utils import get_trainable_variables, check_features
from .utils import MovingAverageOptimizer, learning_rate_decay_fn
from .Logits import Logits, LogitsCenterLoss
from .Siamese import Siamese
diff --git a/bob/learn/tensorflow/loss/__init__.py b/bob/learn/tensorflow/loss/__init__.py
index f35a0cebf7afb894c8ef5671b4517827f06ab9df..eab22bb48fbeaccf3b9fa3f734fd0839f76281f6 100644
--- a/bob/learn/tensorflow/loss/__init__.py
+++ b/bob/learn/tensorflow/loss/__init__.py
@@ -3,6 +3,7 @@ from .ContrastiveLoss import contrastive_loss
from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss
from .StyleLoss import linear_gram_style_loss, content_loss, denoising_loss
from .vat import VATLoss
+from .pixel_wise import PixelWise
from .utils import *
@@ -22,7 +23,14 @@ def __appropriate__(*args):
obj.__module__ = __name__
-__appropriate__(mean_cross_entropy_loss, mean_cross_entropy_center_loss,
- contrastive_loss, triplet_loss, triplet_average_loss,
- triplet_fisher_loss)
-__all__ = [_ for _ in dir() if not _.startswith('_')]
+__appropriate__(
+ mean_cross_entropy_loss,
+ mean_cross_entropy_center_loss,
+ contrastive_loss,
+ triplet_loss,
+ triplet_average_loss,
+ triplet_fisher_loss,
+ VATLoss,
+ PixelWise,
+)
+__all__ = [_ for _ in dir() if not _.startswith("_")]
diff --git a/bob/learn/tensorflow/network/utils.py b/bob/learn/tensorflow/network/utils.py
index 780a16820170af128a0ca45b5832db79b28c8a6d..0c5c855f023be38dcc3b4d130f1117cb0ccfa307 100644
--- a/bob/learn/tensorflow/network/utils.py
+++ b/bob/learn/tensorflow/network/utils.py
@@ -2,53 +2,5 @@
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-import tensorflow as tf
-import tensorflow.contrib.slim as slim
-
-
-def append_logits(graph,
- n_classes,
- reuse=False,
- l2_regularizer=5e-05,
- weights_std=0.1, trainable_variables=None,
- name='Logits'):
- trainable = is_trainable(name, trainable_variables)
- return slim.fully_connected(
- graph,
- n_classes,
- activation_fn=None,
- weights_initializer=tf.truncated_normal_initializer(
- stddev=weights_std),
- weights_regularizer=slim.l2_regularizer(l2_regularizer),
- scope=name,
- reuse=reuse,
- trainable=trainable,
- )
-
-
-def is_trainable(name, trainable_variables, mode=tf.estimator.ModeKeys.TRAIN):
- """
- Check if a variable is trainable or not
-
- Parameters
- ----------
-
- name: str
- Layer name
-
- trainable_variables: list
- List containing the variables or scopes to be trained.
- If None, the variable/scope is trained
- """
-
- # if mode is not training, so we shutdown
- if mode != tf.estimator.ModeKeys.TRAIN:
- return False
-
- # If None, we train by default
- if trainable_variables is None:
- return True
-
- # Here is my choice to shutdown the whole scope
- return name in trainable_variables
-
+# functions were moved
+from ..utils.network import append_logits, is_trainable