From fcc104240637ecdcdcedf4652833985cda2b93f3 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Wed, 17 Apr 2019 16:22:50 +0200 Subject: [PATCH] Fix the euclidean function so that its gradients don't become nan. Also moves the bytes_to_human function --- bob/learn/tensorflow/utils/util.py | 64 ++++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 8 deletions(-) diff --git a/bob/learn/tensorflow/utils/util.py b/bob/learn/tensorflow/utils/util.py index a2fd7ebe..a9890cdc 100644 --- a/bob/learn/tensorflow/utils/util.py +++ b/bob/learn/tensorflow/utils/util.py @@ -6,6 +6,22 @@ import numpy import tensorflow as tf from tensorflow.python.client import device_lib +from tensorflow.python.framework import function +import tensorflow.keras.backend as K + + +def keras_channels_index(): + return -3 if K.image_data_format() == 'channels_first' else -1 + + +@function.Defun(tf.float32, tf.float32) +def norm_grad(x, dy): + return tf.expand_dims(dy, -1) * (x / (tf.expand_dims(tf.norm(x, ord=2, axis=-1), -1) + 1.0e-19)) + + +@function.Defun(tf.float32, grad_func=norm_grad) +def norm(x): + return tf.norm(x, ord=2, axis=-1) def compute_euclidean_distance(x, y): @@ -14,7 +30,8 @@ def compute_euclidean_distance(x, y): """ with tf.name_scope('euclidean_distance'): - d = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(x, y)), 1)) + # d = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(x, y)), 1)) + d = norm(tf.subtract(x, y)) return d @@ -167,10 +184,10 @@ def debug_embbeding(image, architecture, embbeding_dim=2, feature_layer="fc3"): return embeddings -def cdist(A): +def pdist(A): """ Compute a pairwise euclidean distance in the same fashion - as in scipy.spation.distance.cdist + as in scipy.spation.distance.pdist """ with tf.variable_scope('Pairwisedistance'): ones_1 = tf.reshape( @@ -199,7 +216,7 @@ def predict_using_tensors(embedding, labels, num=None): # sample) inf = tf.cast(tf.ones_like(labels), tf.float32) * numpy.inf - distances = cdist(embedding) + distances = pdist(embedding) distances = tf.matrix_set_diag(distances, inf) indexes = tf.argmin(distances, axis=1) return [labels[i] for i in tf.unstack(indexes, num=num)] @@ -208,7 +225,7 @@ def predict_using_tensors(embedding, labels, num=None): def compute_embedding_accuracy_tensors(embedding, labels, num=None): """ Compute the accuracy in a closed-set - + **Parameters** embeddings: `tf.Tensor` @@ -232,7 +249,7 @@ def compute_embedding_accuracy_tensors(embedding, labels, num=None): def compute_embedding_accuracy(embedding, labels): """ Compute the accuracy in a closed-set - + **Parameters** embeddings: :any:`numpy.array` @@ -242,9 +259,9 @@ def compute_embedding_accuracy(embedding, labels): Correspondent labels """ - from scipy.spatial.distance import cdist + from scipy.spatial.distance import pdist, squareform - distances = cdist(embedding, embedding) + distances = squareform(pdist(embedding)) n_samples = embedding.shape[0] @@ -344,3 +361,34 @@ def to_channels_first(image): to_skimage = to_matplotlib = to_channels_last to_bob = to_channels_first + + +def bytes2human(n, format='%(value).1f %(symbol)s', symbols='customary'): + """Convert n bytes into a human readable string based on format. + From: https://code.activestate.com/recipes/578019-bytes-to-human-human-to- + bytes-converter/ + Author: Giampaolo Rodola' <g.rodola [AT] gmail [DOT] com> + License: MIT + symbols can be either "customary", "customary_ext", "iec" or "iec_ext", + see: http://goo.gl/kTQMs + """ + SYMBOLS = { + 'customary': ('B', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y'), + 'customary_ext': ('byte', 'kilo', 'mega', 'giga', 'tera', 'peta', + 'exa', 'zetta', 'iotta'), + 'iec': ('Bi', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi', 'Yi'), + 'iec_ext': ('byte', 'kibi', 'mebi', 'gibi', 'tebi', 'pebi', 'exbi', + 'zebi', 'yobi'), + } + n = int(n) + if n < 0: + raise ValueError("n < 0") + symbols = SYMBOLS[symbols] + prefix = {} + for i, s in enumerate(symbols[1:]): + prefix[s] = 1 << (i + 1) * 10 + for symbol in reversed(symbols[1:]): + if n >= prefix[symbol]: + value = float(n) / prefix[symbol] + return format % locals() + return format % dict(symbol=symbols[0], value=n) -- GitLab