Commit 7d39c378 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira

Merge branch 'issue-50' into 'master'

Replaced the download_model method to the new one implemented in bob.extension

Closes #5

See merge request !7
parents 0ddf5ed3 73c665b3
Pipeline #19942 passed with stages
in 26 minutes and 21 seconds
......@@ -6,9 +6,9 @@ import numpy
import tensorflow as tf
import os
from bob.extension import rc
from bob.extension.rc_config import _saverc
from . import download_file
import logging
import bob.extension.download
import bob.io.base
logger = logging.getLogger(__name__)
......@@ -348,12 +348,21 @@ class DrGanMSUExtractor(object):
# If the path is not, set the default path
if model_path is None:
model_path = self.get_modelpath()
model_path = self.get_modelpath()
# If does not exist, download
if not os.path.exists(model_path):
self.download_model()
bob.io.base.create_directories_safe(DrGanMSUExtractor.get_modelpath())
zip_file = os.path.join(DrGanMSUExtractor.get_modelpath(),
"DR_GAN_model.zip")
urls = [
# This is a private link at Idiap to save bandwidth.
"http://beatubulatest.lab.idiap.ch/private/wheels/gitlab/"
"DR_GAN_model.zip",
]
bob.extension.download.download_and_unzip(urls, zip_file)
self.saver = tf.train.Saver()
# Reestore either from the last checkpoint or from a particular checkpoint
if os.path.isdir(model_path):
......@@ -381,46 +390,6 @@ class DrGanMSUExtractor(object):
return "bob.ip.tensorflow_extractor.drgan_modelpath"
@staticmethod
def download_model():
"""
Download and extract the DrGanMSU files in bob/ip/tensorflow_extractor
"""
import zipfile
zip_file = os.path.join(DrGanMSUExtractor.get_modelpath(),
"DR_GAN_model.zip")
urls = [
# This is a private link at Idiap to save bandwidth.
"http://beatubulatest.lab.idiap.ch/private/wheels/gitlab/"
"DR_GAN_model.zip",
]
for url in urls:
try:
logger.info(
"Downloading the DrGanMSU model from "
"{} ...".format(url))
download_file(url, zip_file)
break
except Exception:
logger.warning(
"Could not download from the %s url", url, exc_info=True)
else: # else is for the for loop
if not os.path.isfile(zip_file):
raise RuntimeError("Could not download the zip file.")
# Unzip
logger.info("Unziping in {0}".format(DrGanMSUExtractor.get_modelpath()))
with zipfile.ZipFile(zip_file) as myzip:
myzip.extractall(os.path.dirname(DrGanMSUExtractor.get_modelpath()))
logger.info("Saving the path `{0}` in the ~.bobrc file".format(DrGanMSUExtractor.get_modelpath()))
rc[DrGanMSUExtractor.get_rcvariable()] = DrGanMSUExtractor.get_modelpath()
_saverc(rc)
# delete extra files
os.unlink(zip_file)
def __call__(self, image):
"""__call__(image) -> feature
......
......@@ -6,10 +6,9 @@ import numpy
import tensorflow as tf
from bob.ip.color import gray_to_rgb
from bob.io.image import to_matplotlib
from . import download_file
from bob.extension import rc
from bob.extension.rc_config import _saverc
import bob.extension.download
import bob.io.base
logger = logging.getLogger(__name__)
......@@ -77,7 +76,7 @@ class FaceNet(object):
def __init__(self,
model_path=rc["bob.ip.tensorflow_extractor.facenet_modelpath"],
image_size=160,
**kwargs):
**kwargs):
super(FaceNet, self).__init__()
self.model_path = model_path
self.image_size = image_size
......@@ -98,7 +97,19 @@ class FaceNet(object):
if self.model_path is None:
self.model_path = self.get_modelpath()
if not os.path.exists(self.model_path):
self.download_model()
bob.io.base.create_directories_safe(FaceNet.get_modelpath())
zip_file = os.path.join(FaceNet.get_modelpath(),
"20170512-110547.zip")
urls = [
# This is a private link at Idiap to save bandwidth.
"http://beatubulatest.lab.idiap.ch/private/wheels/gitlab/"
"facenet_model2_20170512-110547.zip",
# this works for everybody
"https://drive.google.com/uc?export=download&id="
"0B5MzpY9kBtDVZ2RpVDYwWmxoSUk",
]
bob.extension.download.download_and_unzip(urls, zip_file)
# code from https://github.com/davidsandberg/facenet
model_exp = os.path.expanduser(self.model_path)
if (os.path.isfile(model_exp)):
......@@ -143,10 +154,20 @@ class FaceNet(object):
@staticmethod
def get_rcvariable():
"""
Variable name used in the Bob Global Configuration System
https://www.idiap.ch/software/bob/docs/bob/bob.extension/stable/rc.html#global-configuration-system
"""
return "bob.ip.tensorflow_extractor.facenet_modelpath"
@staticmethod
def get_modelpath():
"""
Get default model path.
First we try the to search this path via Global Configuration System.
If we can not find it, we set the path in the directory `<project>/data`
"""
# Priority to the RC path
model_path = rc[FaceNet.get_rcvariable()]
......@@ -158,46 +179,3 @@ class FaceNet(object):
return model_path
@staticmethod
def download_model():
"""
Download and extract the FaceNet files in bob/ip/tensorflow_extractor
"""
import zipfile
zip_file = os.path.join(FaceNet.get_modelpath(),
"20170512-110547.zip")
urls = [
# This is a private link at Idiap to save bandwidth.
"http://beatubulatest.lab.idiap.ch/private/wheels/gitlab/"
"facenet_model2_20170512-110547.zip",
# this works for everybody
"https://drive.google.com/uc?export=download&id="
"0B5MzpY9kBtDVZ2RpVDYwWmxoSUk",
]
for url in urls:
try:
logger.info(
"Downloading the FaceNet model from "
"{} ...".format(url))
download_file(url, zip_file)
break
except Exception:
logger.warning(
"Could not download from the %s url", url, exc_info=True)
else: # else is for the for loop
if not os.path.isfile(zip_file):
raise RuntimeError("Could not download the zip file.")
# Unzip
logger.info("Unziping in {0}".format(FaceNet.get_modelpath()))
with zipfile.ZipFile(zip_file) as myzip:
myzip.extractall(os.path.dirname(FaceNet.get_modelpath()))
logger.info("Saving the path `{0}` in the ~.bobrc file".format(FaceNet.get_modelpath()))
rc[FaceNet.get_rcvariable()] = FaceNet.get_modelpath()
_saverc(rc)
# delete extra files
os.unlink(zip_file)
......@@ -26,37 +26,6 @@ def scratch_network(inputs, end_point="fc1", reuse=False):
return end_points[end_point]
def download_file(url, out_file):
"""Downloads a file from a given url
Parameters
----------
url : str
The url to download form.
out_file : str
Where to save the file.
"""
from bob.io.base import create_directories_safe
import os
create_directories_safe(os.path.dirname(out_file))
import sys
if sys.version_info[0] < 3:
# python2 technique for downloading a file
from urllib2 import urlopen
with open(out_file, 'wb') as f:
response = urlopen(url)
f.write(response.read())
else:
# python3 technique for downloading a file
from urllib.request import urlopen
from shutil import copyfileobj
with urlopen(url) as response:
with open(out_file, 'wb') as f:
copyfileobj(response, f)
def get_config():
"""Returns a string containing the configuration information.
"""
......
......@@ -47,41 +47,9 @@ def test_facenet():
assert output.size == 128, output.shape
def test_drgan():
"""
'/remote/idiap.svm/user.active/heusch/work/dev/DR-GAN_code_wmodel/DR_GAN_model/DCGAN.model-590000'
"""
from bob.ip.tensorflow_extractor import DrGanMSUExtractor
#extractor = DrGanMSUExtractor("/idiap/project/hface/models/cnn/DR_GAN_model/", image_size=[96, 96, 3])
extractor = DrGanMSUExtractor()
data = numpy.random.rand(3, 96, 96).astype("uint8")
output = extractor(data)
assert output.size == 320, output.shape
"""
def test_output_from_meta():
# Loading MNIST model
filename = os.path.join( pkg_resources.resource_filename(__name__, 'data'), "model.ckp.meta")
inputs = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
# Testing the last output
graph = scratch_network(inputs)
extractor = bob.ip.tensorflow_extractor.Extractor(filename, inputs, graph)
data = numpy.random.rand(2, 28, 28, 1).astype("float32")
output = extractor(data)
assert extractor(data).shape == (2, 10)
del extractor
# Testing flatten
inputs = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
graph = scratch_network(inputs, end_point="flatten1")
extractor = bob.ip.tensorflow_extractor.Extractor(filename, inputs, graph)
data = numpy.random.rand(2, 28, 28, 1).astype("float32")
output = extractor(data)
assert extractor(data).shape == (2, 1690)
del extractor
"""
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment