Commit 6aabd230 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Move utils to the utils folder

parent 22eabcd1
import logging
logging.getLogger("tensorflow").setLevel(logging.WARNING)
def get_config():
"""
Returns a string containing the configuration information.
......
......@@ -4,52 +4,7 @@
import tensorflow as tf
def check_features(features):
if "data" not in features or "key" not in features:
raise ValueError(
"The input function needs to contain a dictionary with the keys `data` and `key` "
)
return True
def get_trainable_variables(extra_checkpoint, mode=tf.estimator.ModeKeys.TRAIN):
"""
Given the extra_checkpoint dictionary provided to the estimator,
extract the content of "trainable_variables".
If trainable_variables is not provided, all end points are trainable by
default.
If trainable_variables==[], all end points are NOT trainable.
If trainable_variables contains some end_points, ONLY these endpoints will
be trainable.
Attributes
----------
extra_checkpoint: dict
The extra_checkpoint dictionary provided to the estimator
mode:
The estimator mode. TRAIN, EVAL, and PREDICT. If not TRAIN, None is
returned.
Returns
-------
Returns `None` if **trainable_variables** is not in extra_checkpoint;
otherwise returns the content of extra_checkpoint .
"""
if mode != tf.estimator.ModeKeys.TRAIN:
return None
# If you don't set anything, everything is trainable
if extra_checkpoint is None or "trainable_variables" not in extra_checkpoint:
return None
return extra_checkpoint["trainable_variables"]
from ..utils import get_trainable_variables, check_features
from .utils import MovingAverageOptimizer, learning_rate_decay_fn
from .Logits import Logits, LogitsCenterLoss
from .Siamese import Siamese
......
......@@ -3,6 +3,7 @@ from .ContrastiveLoss import contrastive_loss
from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss
from .StyleLoss import linear_gram_style_loss, content_loss, denoising_loss
from .vat import VATLoss
from .pixel_wise import PixelWise
from .utils import *
......@@ -22,7 +23,14 @@ def __appropriate__(*args):
obj.__module__ = __name__
__appropriate__(mean_cross_entropy_loss, mean_cross_entropy_center_loss,
contrastive_loss, triplet_loss, triplet_average_loss,
triplet_fisher_loss)
__all__ = [_ for _ in dir() if not _.startswith('_')]
__appropriate__(
mean_cross_entropy_loss,
mean_cross_entropy_center_loss,
contrastive_loss,
triplet_loss,
triplet_average_loss,
triplet_fisher_loss,
VATLoss,
PixelWise,
)
__all__ = [_ for _ in dir() if not _.startswith("_")]
......@@ -2,53 +2,5 @@
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import tensorflow as tf
import tensorflow.contrib.slim as slim
def append_logits(graph,
n_classes,
reuse=False,
l2_regularizer=5e-05,
weights_std=0.1, trainable_variables=None,
name='Logits'):
trainable = is_trainable(name, trainable_variables)
return slim.fully_connected(
graph,
n_classes,
activation_fn=None,
weights_initializer=tf.truncated_normal_initializer(
stddev=weights_std),
weights_regularizer=slim.l2_regularizer(l2_regularizer),
scope=name,
reuse=reuse,
trainable=trainable,
)
def is_trainable(name, trainable_variables, mode=tf.estimator.ModeKeys.TRAIN):
"""
Check if a variable is trainable or not
Parameters
----------
name: str
Layer name
trainable_variables: list
List containing the variables or scopes to be trained.
If None, the variable/scope is trained
"""
# if mode is not training, so we shutdown
if mode != tf.estimator.ModeKeys.TRAIN:
return False
# If None, we train by default
if trainable_variables is None:
return True
# Here is my choice to shutdown the whole scope
return name in trainable_variables
# functions were moved
from ..utils.network import append_logits, is_trainable
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment