diff --git a/bob/ip/tensorflow_extractor/DrGanMSU.py b/bob/ip/tensorflow_extractor/DrGanMSU.py index b856ca2598a5f3ccf35c092bf3eb1ee3fed540db..298de65d9e9c110337157361f0fba588df72253d 100644 --- a/bob/ip/tensorflow_extractor/DrGanMSU.py +++ b/bob/ip/tensorflow_extractor/DrGanMSU.py @@ -6,6 +6,7 @@ import numpy import tensorflow as tf import os from bob.extension import rc +from bob.extension.rc_config import _saverc from . import download_file import logging logger = logging.getLogger(__name__) @@ -331,7 +332,7 @@ class DrGanMSUExtractor(object): """ - def __init__(self, model_path=rc["drgan_modelpath"], image_size=[96, 96, 3]): + def __init__(self, model_path=rc["bob.ip.tensorflow_extractor.drgan_modelpath"], image_size=[96, 96, 3]): self.image_size = image_size self.session = tf.Session() @@ -363,9 +364,22 @@ class DrGanMSUExtractor(object): @staticmethod def get_modelpath(): - import pkg_resources - return pkg_resources.resource_filename(__name__, - 'data/DR_GAN_model') + + # Priority to the RC path + model_path = rc[DrGanMSUExtractor.get_rcvariable()] + + if model_path is None: + import pkg_resources + model_path = pkg_resources.resource_filename(__name__, + 'data/DR_GAN_model') + + return model_path + + + @staticmethod + def get_rcvariable(): + return "bob.ip.tensorflow_extractor.drgan_modelpath" + @staticmethod def download_model(): @@ -400,6 +414,10 @@ class DrGanMSUExtractor(object): with zipfile.ZipFile(zip_file) as myzip: myzip.extractall(os.path.dirname(DrGanMSUExtractor.get_modelpath())) + logger.info("Saving the path `{0}` in the ~.bobrc file".format(DrGanMSUExtractor.get_modelpath())) + rc[DrGanMSUExtractor.get_rcvariable()] = DrGanMSUExtractor.get_modelpath() + _saverc(rc) + # delete extra files os.unlink(zip_file) diff --git a/bob/ip/tensorflow_extractor/FaceNet.py b/bob/ip/tensorflow_extractor/FaceNet.py index 60211362e67a801cea8da04881988e201d0e7960..4b2a8e16402ec644c07596051a4150d337f8b32d 100644 --- a/bob/ip/tensorflow_extractor/FaceNet.py +++ b/bob/ip/tensorflow_extractor/FaceNet.py @@ -8,6 +8,7 @@ from bob.ip.color import gray_to_rgb from bob.io.image import to_matplotlib from . import download_file from bob.extension import rc +from bob.extension.rc_config import _saverc logger = logging.getLogger(__name__) @@ -74,7 +75,7 @@ class FaceNet(object): """ def __init__(self, - model_path=rc["facenet_modelpath"], + model_path=rc["bob.ip.tensorflow_extractor.facenet_modelpath"], image_size=160, **kwargs): super(FaceNet, self).__init__() @@ -140,11 +141,23 @@ class FaceNet(object): def __del__(self): tf.reset_default_graph() + @staticmethod + def get_rcvariable(): + return "bob.ip.tensorflow_extractor.facenet_modelpath" + @staticmethod def get_modelpath(): - import pkg_resources - return pkg_resources.resource_filename(__name__, - 'data/FaceNet/20170512-110547') + + # Priority to the RC path + model_path = rc[FaceNet.get_rcvariable()] + + if model_path is None: + import pkg_resources + model_path = pkg_resources.resource_filename(__name__, + 'data/FaceNet/20170512-110547') + + return model_path + @staticmethod def download_model(): @@ -182,5 +195,9 @@ class FaceNet(object): with zipfile.ZipFile(zip_file) as myzip: myzip.extractall(os.path.dirname(FaceNet.get_modelpath())) + logger.info("Saving the path `{0}` in the ~.bobrc file".format(FaceNet.get_modelpath())) + rc[FaceNet.get_rcvariable()] = FaceNet.get_modelpath() + _saverc(rc) + # delete extra files os.unlink(zip_file) diff --git a/doc/guide.rst b/doc/guide.rst index aaadfac8fefeec856a20ae011f57e7a113e66f56..529a2d8dc4e9059c2fc2e2927dfd7402dbb19192 100644 --- a/doc/guide.rst +++ b/doc/guide.rst @@ -55,6 +55,14 @@ Facenet Model :ref:`bob.bio.base <bob.bio.base>` wrapper Facenet model. Check `here for more info <py_api.html#bob.ip.tensorflow_extractor.FaceNet>`_ +.. note:: + + The models will automatically download to the data folder of this package and save it in + ``[env-path]./bob/ip/tensorflow_extractor/data/FaceNet``. + If you want want set another path for this model do:: + + $ bob config set bob.ip.tensorflow_extractor.facenet_modelpath /path/to/mydatabase + DRGan from L.Tran @ MSU: @@ -63,6 +71,13 @@ DRGan from L.Tran @ MSU: :ref:`bob.bio.base <bob.bio.base>` wrapper to the DRGan model trained by L.Tran @ MSU. Check `here <py_api.html#bob.ip.tensorflow_extractor.DrGanMSUExtractor>`_ for more info +.. note:: + + The models will automatically download to the data folder of this package and save it in + ``[env-path]./bob/ip/tensorflow_extractor/data/DR_GAN_model``. + If you want want set another path for this model do:: + + $ bob config set bob.ip.tensorflow_extractor.drgan_modelpath /path/to/mydatabase