Skip to content
Snippets Groups Projects
Commit 373dd226 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Replaced the download_model method to the new one implemented in bob.extension

Updated URLs

Updated URLs

Fixed download method

Fixing import

Created the directory automatically
parent 0ddf5ed3
Branches
Tags
1 merge request!7Replaced the download_model method to the new one implemented in bob.extension
Pipeline #
......@@ -6,9 +6,10 @@ import numpy
import tensorflow as tf
import os
from bob.extension import rc
from bob.extension.rc_config import _saverc
from . import download_file
import logging
import bob.extension.download
import bob.io.base
logger = logging.getLogger(__name__)
......@@ -348,12 +349,21 @@ class DrGanMSUExtractor(object):
# If the path is not, set the default path
if model_path is None:
model_path = self.get_modelpath()
model_path = self.get_modelpath()
# If does not exist, download
if not os.path.exists(model_path):
self.download_model()
bob.io.base.create_directories_safe(DrGanMSUExtractor.get_modelpath())
zip_file = os.path.join(DrGanMSUExtractor.get_modelpath(),
"DR_GAN_model.zip")
urls = [
# This is a private link at Idiap to save bandwidth.
"http://www.idiap.ch/private/wheels/gitlab/"
"DR_GAN_model.zip",
]
bob.extension.download.download_and_unzip(urls, zip_file)
self.saver = tf.train.Saver()
# Reestore either from the last checkpoint or from a particular checkpoint
if os.path.isdir(model_path):
......@@ -381,46 +391,6 @@ class DrGanMSUExtractor(object):
return "bob.ip.tensorflow_extractor.drgan_modelpath"
@staticmethod
def download_model():
"""
Download and extract the DrGanMSU files in bob/ip/tensorflow_extractor
"""
import zipfile
zip_file = os.path.join(DrGanMSUExtractor.get_modelpath(),
"DR_GAN_model.zip")
urls = [
# This is a private link at Idiap to save bandwidth.
"http://beatubulatest.lab.idiap.ch/private/wheels/gitlab/"
"DR_GAN_model.zip",
]
for url in urls:
try:
logger.info(
"Downloading the DrGanMSU model from "
"{} ...".format(url))
download_file(url, zip_file)
break
except Exception:
logger.warning(
"Could not download from the %s url", url, exc_info=True)
else: # else is for the for loop
if not os.path.isfile(zip_file):
raise RuntimeError("Could not download the zip file.")
# Unzip
logger.info("Unziping in {0}".format(DrGanMSUExtractor.get_modelpath()))
with zipfile.ZipFile(zip_file) as myzip:
myzip.extractall(os.path.dirname(DrGanMSUExtractor.get_modelpath()))
logger.info("Saving the path `{0}` in the ~.bobrc file".format(DrGanMSUExtractor.get_modelpath()))
rc[DrGanMSUExtractor.get_rcvariable()] = DrGanMSUExtractor.get_modelpath()
_saverc(rc)
# delete extra files
os.unlink(zip_file)
def __call__(self, image):
"""__call__(image) -> feature
......
......@@ -8,8 +8,8 @@ from bob.ip.color import gray_to_rgb
from bob.io.image import to_matplotlib
from . import download_file
from bob.extension import rc
from bob.extension.rc_config import _saverc
import bob.extension.download
import bob.io.base
logger = logging.getLogger(__name__)
......@@ -98,7 +98,21 @@ class FaceNet(object):
if self.model_path is None:
self.model_path = self.get_modelpath()
if not os.path.exists(self.model_path):
self.download_model()
bob.io.base.create_directories_safe(FaceNet.get_modelpath())
zip_file = os.path.join(FaceNet.get_modelpath(),
"20170512-110547.zip")
urls = [
# This is a private link at Idiap to save bandwidth.
"http://www.idiap.ch/private/wheels/gitlab/"
"facenet_model2_20170512-110547.zip",
# this works for everybody
"https://drive.google.com/uc?export=download&id="
"0B5MzpY9kBtDVZ2RpVDYwWmxoSUk",
]
bob.extension.download.download_and_unzip(urls, zip_file)
# code from https://github.com/davidsandberg/facenet
model_exp = os.path.expanduser(self.model_path)
if (os.path.isfile(model_exp)):
......@@ -158,46 +172,3 @@ class FaceNet(object):
return model_path
@staticmethod
def download_model():
"""
Download and extract the FaceNet files in bob/ip/tensorflow_extractor
"""
import zipfile
zip_file = os.path.join(FaceNet.get_modelpath(),
"20170512-110547.zip")
urls = [
# This is a private link at Idiap to save bandwidth.
"http://beatubulatest.lab.idiap.ch/private/wheels/gitlab/"
"facenet_model2_20170512-110547.zip",
# this works for everybody
"https://drive.google.com/uc?export=download&id="
"0B5MzpY9kBtDVZ2RpVDYwWmxoSUk",
]
for url in urls:
try:
logger.info(
"Downloading the FaceNet model from "
"{} ...".format(url))
download_file(url, zip_file)
break
except Exception:
logger.warning(
"Could not download from the %s url", url, exc_info=True)
else: # else is for the for loop
if not os.path.isfile(zip_file):
raise RuntimeError("Could not download the zip file.")
# Unzip
logger.info("Unziping in {0}".format(FaceNet.get_modelpath()))
with zipfile.ZipFile(zip_file) as myzip:
myzip.extractall(os.path.dirname(FaceNet.get_modelpath()))
logger.info("Saving the path `{0}` in the ~.bobrc file".format(FaceNet.get_modelpath()))
rc[FaceNet.get_rcvariable()] = FaceNet.get_modelpath()
_saverc(rc)
# delete extra files
os.unlink(zip_file)
......@@ -47,41 +47,9 @@ def test_facenet():
assert output.size == 128, output.shape
def test_drgan():
"""
'/remote/idiap.svm/user.active/heusch/work/dev/DR-GAN_code_wmodel/DR_GAN_model/DCGAN.model-590000'
"""
from bob.ip.tensorflow_extractor import DrGanMSUExtractor
#extractor = DrGanMSUExtractor("/idiap/project/hface/models/cnn/DR_GAN_model/", image_size=[96, 96, 3])
extractor = DrGanMSUExtractor()
data = numpy.random.rand(3, 96, 96).astype("uint8")
output = extractor(data)
assert output.size == 320, output.shape
"""
def test_output_from_meta():
# Loading MNIST model
filename = os.path.join( pkg_resources.resource_filename(__name__, 'data'), "model.ckp.meta")
inputs = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
# Testing the last output
graph = scratch_network(inputs)
extractor = bob.ip.tensorflow_extractor.Extractor(filename, inputs, graph)
data = numpy.random.rand(2, 28, 28, 1).astype("float32")
output = extractor(data)
assert extractor(data).shape == (2, 10)
del extractor
# Testing flatten
inputs = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
graph = scratch_network(inputs, end_point="flatten1")
extractor = bob.ip.tensorflow_extractor.Extractor(filename, inputs, graph)
data = numpy.random.rand(2, 28, 28, 1).astype("float32")
output = extractor(data)
assert extractor(data).shape == (2, 1690)
del extractor
"""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment