Skip to content
Snippets Groups Projects
Commit ed0ecb85 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Move utils to the utils folder

parent 0e2d6351
No related branches found
No related tags found
No related merge requests found
import logging
logging.getLogger("tensorflow").setLevel(logging.WARNING)
def get_config(): def get_config():
""" """
Returns a string containing the configuration information. Returns a string containing the configuration information.
......
...@@ -4,52 +4,7 @@ ...@@ -4,52 +4,7 @@
import tensorflow as tf import tensorflow as tf
from ..utils import get_trainable_variables, check_features
def check_features(features):
if "data" not in features or "key" not in features:
raise ValueError(
"The input function needs to contain a dictionary with the keys `data` and `key` "
)
return True
def get_trainable_variables(extra_checkpoint, mode=tf.estimator.ModeKeys.TRAIN):
"""
Given the extra_checkpoint dictionary provided to the estimator,
extract the content of "trainable_variables".
If trainable_variables is not provided, all end points are trainable by
default.
If trainable_variables==[], all end points are NOT trainable.
If trainable_variables contains some end_points, ONLY these endpoints will
be trainable.
Attributes
----------
extra_checkpoint: dict
The extra_checkpoint dictionary provided to the estimator
mode:
The estimator mode. TRAIN, EVAL, and PREDICT. If not TRAIN, None is
returned.
Returns
-------
Returns `None` if **trainable_variables** is not in extra_checkpoint;
otherwise returns the content of extra_checkpoint .
"""
if mode != tf.estimator.ModeKeys.TRAIN:
return None
# If you don't set anything, everything is trainable
if extra_checkpoint is None or "trainable_variables" not in extra_checkpoint:
return None
return extra_checkpoint["trainable_variables"]
from .utils import MovingAverageOptimizer, learning_rate_decay_fn from .utils import MovingAverageOptimizer, learning_rate_decay_fn
from .Logits import Logits, LogitsCenterLoss from .Logits import Logits, LogitsCenterLoss
from .Siamese import Siamese from .Siamese import Siamese
......
...@@ -3,6 +3,7 @@ from .ContrastiveLoss import contrastive_loss ...@@ -3,6 +3,7 @@ from .ContrastiveLoss import contrastive_loss
from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss
from .StyleLoss import linear_gram_style_loss, content_loss, denoising_loss from .StyleLoss import linear_gram_style_loss, content_loss, denoising_loss
from .vat import VATLoss from .vat import VATLoss
from .pixel_wise import PixelWise
from .utils import * from .utils import *
...@@ -22,7 +23,14 @@ def __appropriate__(*args): ...@@ -22,7 +23,14 @@ def __appropriate__(*args):
obj.__module__ = __name__ obj.__module__ = __name__
__appropriate__(mean_cross_entropy_loss, mean_cross_entropy_center_loss, __appropriate__(
contrastive_loss, triplet_loss, triplet_average_loss, mean_cross_entropy_loss,
triplet_fisher_loss) mean_cross_entropy_center_loss,
__all__ = [_ for _ in dir() if not _.startswith('_')] contrastive_loss,
triplet_loss,
triplet_average_loss,
triplet_fisher_loss,
VATLoss,
PixelWise,
)
__all__ = [_ for _ in dir() if not _.startswith("_")]
...@@ -2,53 +2,5 @@ ...@@ -2,53 +2,5 @@
# vim: set fileencoding=utf-8 : # vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import tensorflow as tf # functions were moved
import tensorflow.contrib.slim as slim from ..utils.network import append_logits, is_trainable
def append_logits(graph,
n_classes,
reuse=False,
l2_regularizer=5e-05,
weights_std=0.1, trainable_variables=None,
name='Logits'):
trainable = is_trainable(name, trainable_variables)
return slim.fully_connected(
graph,
n_classes,
activation_fn=None,
weights_initializer=tf.truncated_normal_initializer(
stddev=weights_std),
weights_regularizer=slim.l2_regularizer(l2_regularizer),
scope=name,
reuse=reuse,
trainable=trainable,
)
def is_trainable(name, trainable_variables, mode=tf.estimator.ModeKeys.TRAIN):
"""
Check if a variable is trainable or not
Parameters
----------
name: str
Layer name
trainable_variables: list
List containing the variables or scopes to be trained.
If None, the variable/scope is trained
"""
# if mode is not training, so we shutdown
if mode != tf.estimator.ModeKeys.TRAIN:
return False
# If None, we train by default
if trainable_variables is None:
return True
# Here is my choice to shutdown the whole scope
return name in trainable_variables
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment