From d67d354fea727ac2782337691a654375d4fa094c Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Thu, 5 Apr 2018 16:50:48 +0200
Subject: [PATCH] Renamed the bobrc variables and update the docs

Renamed the bobrc variables and update the docs
---
 bob/ip/tensorflow_extractor/DrGanMSU.py | 26 +++++++++++++++++++++----
 bob/ip/tensorflow_extractor/FaceNet.py  | 25 ++++++++++++++++++++----
 doc/guide.rst                           | 15 ++++++++++++++
 3 files changed, 58 insertions(+), 8 deletions(-)

diff --git a/bob/ip/tensorflow_extractor/DrGanMSU.py b/bob/ip/tensorflow_extractor/DrGanMSU.py
index b856ca2..298de65 100644
--- a/bob/ip/tensorflow_extractor/DrGanMSU.py
+++ b/bob/ip/tensorflow_extractor/DrGanMSU.py
@@ -6,6 +6,7 @@ import numpy
 import tensorflow as tf
 import os
 from bob.extension import rc
+from bob.extension.rc_config import _saverc
 from . import download_file
 import logging
 logger = logging.getLogger(__name__)
@@ -331,7 +332,7 @@ class DrGanMSUExtractor(object):
       
     """
 
-    def __init__(self, model_path=rc["drgan_modelpath"], image_size=[96, 96, 3]):
+    def __init__(self, model_path=rc["bob.ip.tensorflow_extractor.drgan_modelpath"], image_size=[96, 96, 3]):
 
         self.image_size = image_size
         self.session = tf.Session()
@@ -363,9 +364,22 @@ class DrGanMSUExtractor(object):
 
     @staticmethod
     def get_modelpath():
-        import pkg_resources
-        return pkg_resources.resource_filename(__name__,
-                                               'data/DR_GAN_model')
+        
+        # Priority to the RC path
+        model_path = rc[DrGanMSUExtractor.get_rcvariable()]
+
+        if model_path is None:
+            import pkg_resources
+            model_path = pkg_resources.resource_filename(__name__,
+                                                 'data/DR_GAN_model')
+
+        return model_path
+
+
+    @staticmethod
+    def get_rcvariable():
+        return "bob.ip.tensorflow_extractor.drgan_modelpath"
+
 
     @staticmethod
     def download_model():
@@ -400,6 +414,10 @@ class DrGanMSUExtractor(object):
         with zipfile.ZipFile(zip_file) as myzip:
             myzip.extractall(os.path.dirname(DrGanMSUExtractor.get_modelpath()))
 
+        logger.info("Saving the path `{0}` in the ~.bobrc file".format(DrGanMSUExtractor.get_modelpath()))
+        rc[DrGanMSUExtractor.get_rcvariable()] = DrGanMSUExtractor.get_modelpath()
+        _saverc(rc)
+
         # delete extra files
         os.unlink(zip_file)
 
diff --git a/bob/ip/tensorflow_extractor/FaceNet.py b/bob/ip/tensorflow_extractor/FaceNet.py
index 6021136..4b2a8e1 100644
--- a/bob/ip/tensorflow_extractor/FaceNet.py
+++ b/bob/ip/tensorflow_extractor/FaceNet.py
@@ -8,6 +8,7 @@ from bob.ip.color import gray_to_rgb
 from bob.io.image import to_matplotlib
 from . import download_file
 from bob.extension import rc
+from bob.extension.rc_config import _saverc
 
 
 logger = logging.getLogger(__name__)
@@ -74,7 +75,7 @@ class FaceNet(object):
     """
 
     def __init__(self,
-                 model_path=rc["facenet_modelpath"],
+                 model_path=rc["bob.ip.tensorflow_extractor.facenet_modelpath"],
                  image_size=160,
                  **kwargs):
         super(FaceNet, self).__init__()
@@ -140,11 +141,23 @@ class FaceNet(object):
     def __del__(self):
         tf.reset_default_graph()
 
+    @staticmethod
+    def get_rcvariable():
+        return "bob.ip.tensorflow_extractor.facenet_modelpath"
+
     @staticmethod
     def get_modelpath():
-        import pkg_resources
-        return pkg_resources.resource_filename(__name__,
-                                               'data/FaceNet/20170512-110547')
+        
+        # Priority to the RC path
+        model_path = rc[FaceNet.get_rcvariable()]
+
+        if model_path is None:
+            import pkg_resources
+            model_path = pkg_resources.resource_filename(__name__,
+                                                         'data/FaceNet/20170512-110547')
+
+        return model_path
+
 
     @staticmethod
     def download_model():
@@ -182,5 +195,9 @@ class FaceNet(object):
         with zipfile.ZipFile(zip_file) as myzip:
             myzip.extractall(os.path.dirname(FaceNet.get_modelpath()))
 
+        logger.info("Saving the path `{0}` in the ~.bobrc file".format(FaceNet.get_modelpath()))
+        rc[FaceNet.get_rcvariable()] = FaceNet.get_modelpath()
+        _saverc(rc)
+
         # delete extra files
         os.unlink(zip_file)
diff --git a/doc/guide.rst b/doc/guide.rst
index aaadfac..529a2d8 100644
--- a/doc/guide.rst
+++ b/doc/guide.rst
@@ -55,6 +55,14 @@ Facenet Model
 :ref:`bob.bio.base <bob.bio.base>` wrapper Facenet model.
 Check `here for more info <py_api.html#bob.ip.tensorflow_extractor.FaceNet>`_
 
+.. note::
+
+   The models will automatically download to the data folder of this package and save it in 
+   ``[env-path]./bob/ip/tensorflow_extractor/data/FaceNet``.
+   If you want want set another path for this model do::
+   
+   $ bob config set bob.ip.tensorflow_extractor.facenet_modelpath /path/to/mydatabase
+
 
 
 DRGan from L.Tran @ MSU:
@@ -63,6 +71,13 @@ DRGan from L.Tran @ MSU:
 :ref:`bob.bio.base <bob.bio.base>` wrapper to the DRGan model trained by L.Tran @ MSU.
 Check `here <py_api.html#bob.ip.tensorflow_extractor.DrGanMSUExtractor>`_ for more info
 
+.. note::
+
+   The models will automatically download to the data folder of this package and save it in 
+   ``[env-path]./bob/ip/tensorflow_extractor/data/DR_GAN_model``.
+   If you want want set another path for this model do::
+   
+   $ bob config set bob.ip.tensorflow_extractor.drgan_modelpath /path/to/mydatabase
 
 
 
-- 
GitLab