Skip to content
Snippets Groups Projects
Commit 373dd226 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Replaced the download_model method to the new one implemented in bob.extension

Updated URLs

Updated URLs

Fixed download method

Fixing import

Created the directory automatically
parent 0ddf5ed3
No related branches found
No related tags found
1 merge request!7Replaced the download_model method to the new one implemented in bob.extension
Pipeline #
...@@ -6,9 +6,10 @@ import numpy ...@@ -6,9 +6,10 @@ import numpy
import tensorflow as tf import tensorflow as tf
import os import os
from bob.extension import rc from bob.extension import rc
from bob.extension.rc_config import _saverc
from . import download_file from . import download_file
import logging import logging
import bob.extension.download
import bob.io.base
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -348,12 +349,21 @@ class DrGanMSUExtractor(object): ...@@ -348,12 +349,21 @@ class DrGanMSUExtractor(object):
# If the path is not, set the default path # If the path is not, set the default path
if model_path is None: if model_path is None:
model_path = self.get_modelpath() model_path = self.get_modelpath()
# If does not exist, download # If does not exist, download
if not os.path.exists(model_path): if not os.path.exists(model_path):
self.download_model() bob.io.base.create_directories_safe(DrGanMSUExtractor.get_modelpath())
zip_file = os.path.join(DrGanMSUExtractor.get_modelpath(),
"DR_GAN_model.zip")
urls = [
# This is a private link at Idiap to save bandwidth.
"http://www.idiap.ch/private/wheels/gitlab/"
"DR_GAN_model.zip",
]
bob.extension.download.download_and_unzip(urls, zip_file)
self.saver = tf.train.Saver() self.saver = tf.train.Saver()
# Reestore either from the last checkpoint or from a particular checkpoint # Reestore either from the last checkpoint or from a particular checkpoint
if os.path.isdir(model_path): if os.path.isdir(model_path):
...@@ -381,46 +391,6 @@ class DrGanMSUExtractor(object): ...@@ -381,46 +391,6 @@ class DrGanMSUExtractor(object):
return "bob.ip.tensorflow_extractor.drgan_modelpath" return "bob.ip.tensorflow_extractor.drgan_modelpath"
@staticmethod
def download_model():
"""
Download and extract the DrGanMSU files in bob/ip/tensorflow_extractor
"""
import zipfile
zip_file = os.path.join(DrGanMSUExtractor.get_modelpath(),
"DR_GAN_model.zip")
urls = [
# This is a private link at Idiap to save bandwidth.
"http://beatubulatest.lab.idiap.ch/private/wheels/gitlab/"
"DR_GAN_model.zip",
]
for url in urls:
try:
logger.info(
"Downloading the DrGanMSU model from "
"{} ...".format(url))
download_file(url, zip_file)
break
except Exception:
logger.warning(
"Could not download from the %s url", url, exc_info=True)
else: # else is for the for loop
if not os.path.isfile(zip_file):
raise RuntimeError("Could not download the zip file.")
# Unzip
logger.info("Unziping in {0}".format(DrGanMSUExtractor.get_modelpath()))
with zipfile.ZipFile(zip_file) as myzip:
myzip.extractall(os.path.dirname(DrGanMSUExtractor.get_modelpath()))
logger.info("Saving the path `{0}` in the ~.bobrc file".format(DrGanMSUExtractor.get_modelpath()))
rc[DrGanMSUExtractor.get_rcvariable()] = DrGanMSUExtractor.get_modelpath()
_saverc(rc)
# delete extra files
os.unlink(zip_file)
def __call__(self, image): def __call__(self, image):
"""__call__(image) -> feature """__call__(image) -> feature
......
...@@ -8,8 +8,8 @@ from bob.ip.color import gray_to_rgb ...@@ -8,8 +8,8 @@ from bob.ip.color import gray_to_rgb
from bob.io.image import to_matplotlib from bob.io.image import to_matplotlib
from . import download_file from . import download_file
from bob.extension import rc from bob.extension import rc
from bob.extension.rc_config import _saverc import bob.extension.download
import bob.io.base
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -98,7 +98,21 @@ class FaceNet(object): ...@@ -98,7 +98,21 @@ class FaceNet(object):
if self.model_path is None: if self.model_path is None:
self.model_path = self.get_modelpath() self.model_path = self.get_modelpath()
if not os.path.exists(self.model_path): if not os.path.exists(self.model_path):
self.download_model() bob.io.base.create_directories_safe(FaceNet.get_modelpath())
zip_file = os.path.join(FaceNet.get_modelpath(),
"20170512-110547.zip")
urls = [
# This is a private link at Idiap to save bandwidth.
"http://www.idiap.ch/private/wheels/gitlab/"
"facenet_model2_20170512-110547.zip",
# this works for everybody
"https://drive.google.com/uc?export=download&id="
"0B5MzpY9kBtDVZ2RpVDYwWmxoSUk",
]
bob.extension.download.download_and_unzip(urls, zip_file)
# code from https://github.com/davidsandberg/facenet # code from https://github.com/davidsandberg/facenet
model_exp = os.path.expanduser(self.model_path) model_exp = os.path.expanduser(self.model_path)
if (os.path.isfile(model_exp)): if (os.path.isfile(model_exp)):
...@@ -158,46 +172,3 @@ class FaceNet(object): ...@@ -158,46 +172,3 @@ class FaceNet(object):
return model_path return model_path
@staticmethod
def download_model():
"""
Download and extract the FaceNet files in bob/ip/tensorflow_extractor
"""
import zipfile
zip_file = os.path.join(FaceNet.get_modelpath(),
"20170512-110547.zip")
urls = [
# This is a private link at Idiap to save bandwidth.
"http://beatubulatest.lab.idiap.ch/private/wheels/gitlab/"
"facenet_model2_20170512-110547.zip",
# this works for everybody
"https://drive.google.com/uc?export=download&id="
"0B5MzpY9kBtDVZ2RpVDYwWmxoSUk",
]
for url in urls:
try:
logger.info(
"Downloading the FaceNet model from "
"{} ...".format(url))
download_file(url, zip_file)
break
except Exception:
logger.warning(
"Could not download from the %s url", url, exc_info=True)
else: # else is for the for loop
if not os.path.isfile(zip_file):
raise RuntimeError("Could not download the zip file.")
# Unzip
logger.info("Unziping in {0}".format(FaceNet.get_modelpath()))
with zipfile.ZipFile(zip_file) as myzip:
myzip.extractall(os.path.dirname(FaceNet.get_modelpath()))
logger.info("Saving the path `{0}` in the ~.bobrc file".format(FaceNet.get_modelpath()))
rc[FaceNet.get_rcvariable()] = FaceNet.get_modelpath()
_saverc(rc)
# delete extra files
os.unlink(zip_file)
...@@ -47,41 +47,9 @@ def test_facenet(): ...@@ -47,41 +47,9 @@ def test_facenet():
assert output.size == 128, output.shape assert output.size == 128, output.shape
def test_drgan(): def test_drgan():
"""
'/remote/idiap.svm/user.active/heusch/work/dev/DR-GAN_code_wmodel/DR_GAN_model/DCGAN.model-590000'
"""
from bob.ip.tensorflow_extractor import DrGanMSUExtractor from bob.ip.tensorflow_extractor import DrGanMSUExtractor
#extractor = DrGanMSUExtractor("/idiap/project/hface/models/cnn/DR_GAN_model/", image_size=[96, 96, 3])
extractor = DrGanMSUExtractor() extractor = DrGanMSUExtractor()
data = numpy.random.rand(3, 96, 96).astype("uint8") data = numpy.random.rand(3, 96, 96).astype("uint8")
output = extractor(data) output = extractor(data)
assert output.size == 320, output.shape assert output.size == 320, output.shape
"""
def test_output_from_meta():
# Loading MNIST model
filename = os.path.join( pkg_resources.resource_filename(__name__, 'data'), "model.ckp.meta")
inputs = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
# Testing the last output
graph = scratch_network(inputs)
extractor = bob.ip.tensorflow_extractor.Extractor(filename, inputs, graph)
data = numpy.random.rand(2, 28, 28, 1).astype("float32")
output = extractor(data)
assert extractor(data).shape == (2, 10)
del extractor
# Testing flatten
inputs = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
graph = scratch_network(inputs, end_point="flatten1")
extractor = bob.ip.tensorflow_extractor.Extractor(filename, inputs, graph)
data = numpy.random.rand(2, 28, 28, 1).astype("float32")
output = extractor(data)
assert extractor(data).shape == (2, 1690)
del extractor
"""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment