From c9311f9c7dc6d4289683e766ee542df022929c22 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Thu, 6 Jun 2019 17:30:45 +0200 Subject: [PATCH] Move utils to the utils folder --- bob/learn/tensorflow/__init__.py | 4 -- bob/learn/tensorflow/estimators/__init__.py | 47 +------------------ bob/learn/tensorflow/loss/__init__.py | 16 +++++-- bob/learn/tensorflow/network/utils.py | 52 +-------------------- 4 files changed, 15 insertions(+), 104 deletions(-) diff --git a/bob/learn/tensorflow/__init__.py b/bob/learn/tensorflow/__init__.py index 48cde48a..b1734864 100644 --- a/bob/learn/tensorflow/__init__.py +++ b/bob/learn/tensorflow/__init__.py @@ -1,7 +1,3 @@ -import logging -logging.getLogger("tensorflow").setLevel(logging.WARNING) - - def get_config(): """ Returns a string containing the configuration information. diff --git a/bob/learn/tensorflow/estimators/__init__.py b/bob/learn/tensorflow/estimators/__init__.py index 7f144bf2..10f500c5 100644 --- a/bob/learn/tensorflow/estimators/__init__.py +++ b/bob/learn/tensorflow/estimators/__init__.py @@ -4,52 +4,7 @@ import tensorflow as tf - -def check_features(features): - if "data" not in features or "key" not in features: - raise ValueError( - "The input function needs to contain a dictionary with the keys `data` and `key` " - ) - return True - - -def get_trainable_variables(extra_checkpoint, mode=tf.estimator.ModeKeys.TRAIN): - """ - Given the extra_checkpoint dictionary provided to the estimator, - extract the content of "trainable_variables". - - If trainable_variables is not provided, all end points are trainable by - default. - If trainable_variables==[], all end points are NOT trainable. - If trainable_variables contains some end_points, ONLY these endpoints will - be trainable. - - Attributes - ---------- - - extra_checkpoint: dict - The extra_checkpoint dictionary provided to the estimator - - mode: - The estimator mode. TRAIN, EVAL, and PREDICT. If not TRAIN, None is - returned. - - Returns - ------- - - Returns `None` if **trainable_variables** is not in extra_checkpoint; - otherwise returns the content of extra_checkpoint . - """ - if mode != tf.estimator.ModeKeys.TRAIN: - return None - - # If you don't set anything, everything is trainable - if extra_checkpoint is None or "trainable_variables" not in extra_checkpoint: - return None - - return extra_checkpoint["trainable_variables"] - - +from ..utils import get_trainable_variables, check_features from .utils import MovingAverageOptimizer, learning_rate_decay_fn from .Logits import Logits, LogitsCenterLoss from .Siamese import Siamese diff --git a/bob/learn/tensorflow/loss/__init__.py b/bob/learn/tensorflow/loss/__init__.py index f35a0ceb..eab22bb4 100644 --- a/bob/learn/tensorflow/loss/__init__.py +++ b/bob/learn/tensorflow/loss/__init__.py @@ -3,6 +3,7 @@ from .ContrastiveLoss import contrastive_loss from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss from .StyleLoss import linear_gram_style_loss, content_loss, denoising_loss from .vat import VATLoss +from .pixel_wise import PixelWise from .utils import * @@ -22,7 +23,14 @@ def __appropriate__(*args): obj.__module__ = __name__ -__appropriate__(mean_cross_entropy_loss, mean_cross_entropy_center_loss, - contrastive_loss, triplet_loss, triplet_average_loss, - triplet_fisher_loss) -__all__ = [_ for _ in dir() if not _.startswith('_')] +__appropriate__( + mean_cross_entropy_loss, + mean_cross_entropy_center_loss, + contrastive_loss, + triplet_loss, + triplet_average_loss, + triplet_fisher_loss, + VATLoss, + PixelWise, +) +__all__ = [_ for _ in dir() if not _.startswith("_")] diff --git a/bob/learn/tensorflow/network/utils.py b/bob/learn/tensorflow/network/utils.py index 780a1682..0c5c855f 100644 --- a/bob/learn/tensorflow/network/utils.py +++ b/bob/learn/tensorflow/network/utils.py @@ -2,53 +2,5 @@ # vim: set fileencoding=utf-8 : # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> -import tensorflow as tf -import tensorflow.contrib.slim as slim - - -def append_logits(graph, - n_classes, - reuse=False, - l2_regularizer=5e-05, - weights_std=0.1, trainable_variables=None, - name='Logits'): - trainable = is_trainable(name, trainable_variables) - return slim.fully_connected( - graph, - n_classes, - activation_fn=None, - weights_initializer=tf.truncated_normal_initializer( - stddev=weights_std), - weights_regularizer=slim.l2_regularizer(l2_regularizer), - scope=name, - reuse=reuse, - trainable=trainable, - ) - - -def is_trainable(name, trainable_variables, mode=tf.estimator.ModeKeys.TRAIN): - """ - Check if a variable is trainable or not - - Parameters - ---------- - - name: str - Layer name - - trainable_variables: list - List containing the variables or scopes to be trained. - If None, the variable/scope is trained - """ - - # if mode is not training, so we shutdown - if mode != tf.estimator.ModeKeys.TRAIN: - return False - - # If None, we train by default - if trainable_variables is None: - return True - - # Here is my choice to shutdown the whole scope - return name in trainable_variables - +# functions were moved +from ..utils.network import append_logits, is_trainable -- GitLab