Commit ff0a629f authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Merge branch 'rcvariables' into 'master'

Renamed the bobrc variables and update the docs

Closes #2

See merge request !6
parents 81aae6a2 d67d354f
Pipeline #18978 passed with stages
in 38 minutes and 3 seconds
...@@ -6,6 +6,7 @@ import numpy ...@@ -6,6 +6,7 @@ import numpy
import tensorflow as tf import tensorflow as tf
import os import os
from bob.extension import rc from bob.extension import rc
from bob.extension.rc_config import _saverc
from . import download_file from . import download_file
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -331,7 +332,7 @@ class DrGanMSUExtractor(object): ...@@ -331,7 +332,7 @@ class DrGanMSUExtractor(object):
""" """
def __init__(self, model_path=rc["drgan_modelpath"], image_size=[96, 96, 3]): def __init__(self, model_path=rc["bob.ip.tensorflow_extractor.drgan_modelpath"], image_size=[96, 96, 3]):
self.image_size = image_size self.image_size = image_size
self.session = tf.Session() self.session = tf.Session()
...@@ -363,10 +364,23 @@ class DrGanMSUExtractor(object): ...@@ -363,10 +364,23 @@ class DrGanMSUExtractor(object):
@staticmethod @staticmethod
def get_modelpath(): def get_modelpath():
# Priority to the RC path
model_path = rc[DrGanMSUExtractor.get_rcvariable()]
if model_path is None:
import pkg_resources import pkg_resources
return pkg_resources.resource_filename(__name__, model_path = pkg_resources.resource_filename(__name__,
'data/DR_GAN_model') 'data/DR_GAN_model')
return model_path
@staticmethod
def get_rcvariable():
return "bob.ip.tensorflow_extractor.drgan_modelpath"
@staticmethod @staticmethod
def download_model(): def download_model():
""" """
...@@ -400,6 +414,10 @@ class DrGanMSUExtractor(object): ...@@ -400,6 +414,10 @@ class DrGanMSUExtractor(object):
with zipfile.ZipFile(zip_file) as myzip: with zipfile.ZipFile(zip_file) as myzip:
myzip.extractall(os.path.dirname(DrGanMSUExtractor.get_modelpath())) myzip.extractall(os.path.dirname(DrGanMSUExtractor.get_modelpath()))
logger.info("Saving the path `{0}` in the ~.bobrc file".format(DrGanMSUExtractor.get_modelpath()))
rc[DrGanMSUExtractor.get_rcvariable()] = DrGanMSUExtractor.get_modelpath()
_saverc(rc)
# delete extra files # delete extra files
os.unlink(zip_file) os.unlink(zip_file)
......
...@@ -8,6 +8,7 @@ from bob.ip.color import gray_to_rgb ...@@ -8,6 +8,7 @@ from bob.ip.color import gray_to_rgb
from bob.io.image import to_matplotlib from bob.io.image import to_matplotlib
from . import download_file from . import download_file
from bob.extension import rc from bob.extension import rc
from bob.extension.rc_config import _saverc
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -74,7 +75,7 @@ class FaceNet(object): ...@@ -74,7 +75,7 @@ class FaceNet(object):
""" """
def __init__(self, def __init__(self,
model_path=rc["facenet_modelpath"], model_path=rc["bob.ip.tensorflow_extractor.facenet_modelpath"],
image_size=160, image_size=160,
**kwargs): **kwargs):
super(FaceNet, self).__init__() super(FaceNet, self).__init__()
...@@ -140,12 +141,24 @@ class FaceNet(object): ...@@ -140,12 +141,24 @@ class FaceNet(object):
def __del__(self): def __del__(self):
tf.reset_default_graph() tf.reset_default_graph()
@staticmethod
def get_rcvariable():
return "bob.ip.tensorflow_extractor.facenet_modelpath"
@staticmethod @staticmethod
def get_modelpath(): def get_modelpath():
# Priority to the RC path
model_path = rc[FaceNet.get_rcvariable()]
if model_path is None:
import pkg_resources import pkg_resources
return pkg_resources.resource_filename(__name__, model_path = pkg_resources.resource_filename(__name__,
'data/FaceNet/20170512-110547') 'data/FaceNet/20170512-110547')
return model_path
@staticmethod @staticmethod
def download_model(): def download_model():
""" """
...@@ -182,5 +195,9 @@ class FaceNet(object): ...@@ -182,5 +195,9 @@ class FaceNet(object):
with zipfile.ZipFile(zip_file) as myzip: with zipfile.ZipFile(zip_file) as myzip:
myzip.extractall(os.path.dirname(FaceNet.get_modelpath())) myzip.extractall(os.path.dirname(FaceNet.get_modelpath()))
logger.info("Saving the path `{0}` in the ~.bobrc file".format(FaceNet.get_modelpath()))
rc[FaceNet.get_rcvariable()] = FaceNet.get_modelpath()
_saverc(rc)
# delete extra files # delete extra files
os.unlink(zip_file) os.unlink(zip_file)
...@@ -55,6 +55,14 @@ Facenet Model ...@@ -55,6 +55,14 @@ Facenet Model
:ref:`bob.bio.base <bob.bio.base>` wrapper Facenet model. :ref:`bob.bio.base <bob.bio.base>` wrapper Facenet model.
Check `here for more info <py_api.html#bob.ip.tensorflow_extractor.FaceNet>`_ Check `here for more info <py_api.html#bob.ip.tensorflow_extractor.FaceNet>`_
.. note::
The models will automatically download to the data folder of this package and save it in
``[env-path]./bob/ip/tensorflow_extractor/data/FaceNet``.
If you want want set another path for this model do::
$ bob config set bob.ip.tensorflow_extractor.facenet_modelpath /path/to/mydatabase
DRGan from L.Tran @ MSU: DRGan from L.Tran @ MSU:
...@@ -63,6 +71,13 @@ DRGan from L.Tran @ MSU: ...@@ -63,6 +71,13 @@ DRGan from L.Tran @ MSU:
:ref:`bob.bio.base <bob.bio.base>` wrapper to the DRGan model trained by L.Tran @ MSU. :ref:`bob.bio.base <bob.bio.base>` wrapper to the DRGan model trained by L.Tran @ MSU.
Check `here <py_api.html#bob.ip.tensorflow_extractor.DrGanMSUExtractor>`_ for more info Check `here <py_api.html#bob.ip.tensorflow_extractor.DrGanMSUExtractor>`_ for more info
.. note::
The models will automatically download to the data folder of this package and save it in
``[env-path]./bob/ip/tensorflow_extractor/data/DR_GAN_model``.
If you want want set another path for this model do::
$ bob config set bob.ip.tensorflow_extractor.drgan_modelpath /path/to/mydatabase
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment