From fcc104240637ecdcdcedf4652833985cda2b93f3 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Wed, 17 Apr 2019 16:22:50 +0200
Subject: [PATCH] Fix the euclidean function so that its gradients don't become
 nan. Also moves the bytes_to_human function

---
 bob/learn/tensorflow/utils/util.py | 64 ++++++++++++++++++++++++++----
 1 file changed, 56 insertions(+), 8 deletions(-)

diff --git a/bob/learn/tensorflow/utils/util.py b/bob/learn/tensorflow/utils/util.py
index a2fd7ebe..a9890cdc 100644
--- a/bob/learn/tensorflow/utils/util.py
+++ b/bob/learn/tensorflow/utils/util.py
@@ -6,6 +6,22 @@
 import numpy
 import tensorflow as tf
 from tensorflow.python.client import device_lib
+from tensorflow.python.framework import function
+import tensorflow.keras.backend as K
+
+
+def keras_channels_index():
+    return -3 if K.image_data_format() == 'channels_first' else -1
+
+
+@function.Defun(tf.float32, tf.float32)
+def norm_grad(x, dy):
+    return tf.expand_dims(dy, -1) * (x / (tf.expand_dims(tf.norm(x, ord=2, axis=-1), -1) + 1.0e-19))
+
+
+@function.Defun(tf.float32, grad_func=norm_grad)
+def norm(x):
+    return tf.norm(x, ord=2, axis=-1)
 
 
 def compute_euclidean_distance(x, y):
@@ -14,7 +30,8 @@ def compute_euclidean_distance(x, y):
     """
 
     with tf.name_scope('euclidean_distance'):
-        d = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(x, y)), 1))
+        # d = tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(x, y)), 1))
+        d = norm(tf.subtract(x, y))
         return d
 
 
@@ -167,10 +184,10 @@ def debug_embbeding(image, architecture, embbeding_dim=2, feature_layer="fc3"):
     return embeddings
 
 
-def cdist(A):
+def pdist(A):
     """
     Compute a pairwise euclidean distance in the same fashion
-    as in scipy.spation.distance.cdist
+    as in scipy.spation.distance.pdist
     """
     with tf.variable_scope('Pairwisedistance'):
         ones_1 = tf.reshape(
@@ -199,7 +216,7 @@ def predict_using_tensors(embedding, labels, num=None):
     # sample)
     inf = tf.cast(tf.ones_like(labels), tf.float32) * numpy.inf
 
-    distances = cdist(embedding)
+    distances = pdist(embedding)
     distances = tf.matrix_set_diag(distances, inf)
     indexes = tf.argmin(distances, axis=1)
     return [labels[i] for i in tf.unstack(indexes, num=num)]
@@ -208,7 +225,7 @@ def predict_using_tensors(embedding, labels, num=None):
 def compute_embedding_accuracy_tensors(embedding, labels, num=None):
     """
     Compute the accuracy in a closed-set
-    
+
     **Parameters**
 
     embeddings: `tf.Tensor`
@@ -232,7 +249,7 @@ def compute_embedding_accuracy_tensors(embedding, labels, num=None):
 def compute_embedding_accuracy(embedding, labels):
     """
     Compute the accuracy in a closed-set
-    
+
     **Parameters**
 
     embeddings: :any:`numpy.array`
@@ -242,9 +259,9 @@ def compute_embedding_accuracy(embedding, labels):
       Correspondent labels
     """
 
-    from scipy.spatial.distance import cdist
+    from scipy.spatial.distance import pdist, squareform
 
-    distances = cdist(embedding, embedding)
+    distances = squareform(pdist(embedding))
 
     n_samples = embedding.shape[0]
 
@@ -344,3 +361,34 @@ def to_channels_first(image):
 
 to_skimage = to_matplotlib = to_channels_last
 to_bob = to_channels_first
+
+
+def bytes2human(n, format='%(value).1f %(symbol)s', symbols='customary'):
+    """Convert n bytes into a human readable string based on format.
+    From: https://code.activestate.com/recipes/578019-bytes-to-human-human-to-
+    bytes-converter/
+    Author: Giampaolo Rodola' <g.rodola [AT] gmail [DOT] com>
+    License: MIT
+    symbols can be either "customary", "customary_ext", "iec" or "iec_ext",
+    see: http://goo.gl/kTQMs
+    """
+    SYMBOLS = {
+        'customary': ('B', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y'),
+        'customary_ext': ('byte', 'kilo', 'mega', 'giga', 'tera', 'peta',
+                          'exa', 'zetta', 'iotta'),
+        'iec': ('Bi', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi', 'Yi'),
+        'iec_ext': ('byte', 'kibi', 'mebi', 'gibi', 'tebi', 'pebi', 'exbi',
+                    'zebi', 'yobi'),
+    }
+    n = int(n)
+    if n < 0:
+        raise ValueError("n < 0")
+    symbols = SYMBOLS[symbols]
+    prefix = {}
+    for i, s in enumerate(symbols[1:]):
+        prefix[s] = 1 << (i + 1) * 10
+    for symbol in reversed(symbols[1:]):
+        if n >= prefix[symbol]:
+            value = float(n) / prefix[symbol]
+            return format % locals()
+    return format % dict(symbol=symbols[0], value=n)
-- 
GitLab