Skip to content
Snippets Groups Projects
Commit c9311f9c authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Move utils to the utils folder

parent 2f7fec82
No related branches found
No related tags found
No related merge requests found
Pipeline #30762 failed
import logging
logging.getLogger("tensorflow").setLevel(logging.WARNING)
def get_config():
"""
Returns a string containing the configuration information.
......
......@@ -4,52 +4,7 @@
import tensorflow as tf
def check_features(features):
if "data" not in features or "key" not in features:
raise ValueError(
"The input function needs to contain a dictionary with the keys `data` and `key` "
)
return True
def get_trainable_variables(extra_checkpoint, mode=tf.estimator.ModeKeys.TRAIN):
"""
Given the extra_checkpoint dictionary provided to the estimator,
extract the content of "trainable_variables".
If trainable_variables is not provided, all end points are trainable by
default.
If trainable_variables==[], all end points are NOT trainable.
If trainable_variables contains some end_points, ONLY these endpoints will
be trainable.
Attributes
----------
extra_checkpoint: dict
The extra_checkpoint dictionary provided to the estimator
mode:
The estimator mode. TRAIN, EVAL, and PREDICT. If not TRAIN, None is
returned.
Returns
-------
Returns `None` if **trainable_variables** is not in extra_checkpoint;
otherwise returns the content of extra_checkpoint .
"""
if mode != tf.estimator.ModeKeys.TRAIN:
return None
# If you don't set anything, everything is trainable
if extra_checkpoint is None or "trainable_variables" not in extra_checkpoint:
return None
return extra_checkpoint["trainable_variables"]
from ..utils import get_trainable_variables, check_features
from .utils import MovingAverageOptimizer, learning_rate_decay_fn
from .Logits import Logits, LogitsCenterLoss
from .Siamese import Siamese
......
......@@ -3,6 +3,7 @@ from .ContrastiveLoss import contrastive_loss
from .TripletLoss import triplet_loss, triplet_average_loss, triplet_fisher_loss
from .StyleLoss import linear_gram_style_loss, content_loss, denoising_loss
from .vat import VATLoss
from .pixel_wise import PixelWise
from .utils import *
......@@ -22,7 +23,14 @@ def __appropriate__(*args):
obj.__module__ = __name__
__appropriate__(mean_cross_entropy_loss, mean_cross_entropy_center_loss,
contrastive_loss, triplet_loss, triplet_average_loss,
triplet_fisher_loss)
__all__ = [_ for _ in dir() if not _.startswith('_')]
__appropriate__(
mean_cross_entropy_loss,
mean_cross_entropy_center_loss,
contrastive_loss,
triplet_loss,
triplet_average_loss,
triplet_fisher_loss,
VATLoss,
PixelWise,
)
__all__ = [_ for _ in dir() if not _.startswith("_")]
......@@ -2,53 +2,5 @@
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
import tensorflow as tf
import tensorflow.contrib.slim as slim
def append_logits(graph,
n_classes,
reuse=False,
l2_regularizer=5e-05,
weights_std=0.1, trainable_variables=None,
name='Logits'):
trainable = is_trainable(name, trainable_variables)
return slim.fully_connected(
graph,
n_classes,
activation_fn=None,
weights_initializer=tf.truncated_normal_initializer(
stddev=weights_std),
weights_regularizer=slim.l2_regularizer(l2_regularizer),
scope=name,
reuse=reuse,
trainable=trainable,
)
def is_trainable(name, trainable_variables, mode=tf.estimator.ModeKeys.TRAIN):
"""
Check if a variable is trainable or not
Parameters
----------
name: str
Layer name
trainable_variables: list
List containing the variables or scopes to be trained.
If None, the variable/scope is trained
"""
# if mode is not training, so we shutdown
if mode != tf.estimator.ModeKeys.TRAIN:
return False
# If None, we train by default
if trainable_variables is None:
return True
# Here is my choice to shutdown the whole scope
return name in trainable_variables
# functions were moved
from ..utils.network import append_logits, is_trainable
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment