Skip to content
Snippets Groups Projects

Replaced the download_model method to the new one implemented in bob.extension

Merged Tiago de Freitas Pereira requested to merge issue-50 into master
Files
4
@@ -6,9 +6,9 @@ import numpy
import tensorflow as tf
import os
from bob.extension import rc
from bob.extension.rc_config import _saverc
from . import download_file
import logging
import bob.extension.download
import bob.io.base
logger = logging.getLogger(__name__)
@@ -348,12 +348,21 @@ class DrGanMSUExtractor(object):
# If the path is not, set the default path
if model_path is None:
model_path = self.get_modelpath()
model_path = self.get_modelpath()
# If does not exist, download
if not os.path.exists(model_path):
self.download_model()
bob.io.base.create_directories_safe(DrGanMSUExtractor.get_modelpath())
zip_file = os.path.join(DrGanMSUExtractor.get_modelpath(),
"DR_GAN_model.zip")
urls = [
# This is a private link at Idiap to save bandwidth.
"http://beatubulatest.lab.idiap.ch/private/wheels/gitlab/"
"DR_GAN_model.zip",
]
bob.extension.download.download_and_unzip(urls, zip_file)
self.saver = tf.train.Saver()
# Reestore either from the last checkpoint or from a particular checkpoint
if os.path.isdir(model_path):
@@ -381,46 +390,6 @@ class DrGanMSUExtractor(object):
return "bob.ip.tensorflow_extractor.drgan_modelpath"
@staticmethod
def download_model():
"""
Download and extract the DrGanMSU files in bob/ip/tensorflow_extractor
"""
import zipfile
zip_file = os.path.join(DrGanMSUExtractor.get_modelpath(),
"DR_GAN_model.zip")
urls = [
# This is a private link at Idiap to save bandwidth.
"http://beatubulatest.lab.idiap.ch/private/wheels/gitlab/"
"DR_GAN_model.zip",
]
for url in urls:
try:
logger.info(
"Downloading the DrGanMSU model from "
"{} ...".format(url))
download_file(url, zip_file)
break
except Exception:
logger.warning(
"Could not download from the %s url", url, exc_info=True)
else: # else is for the for loop
if not os.path.isfile(zip_file):
raise RuntimeError("Could not download the zip file.")
# Unzip
logger.info("Unziping in {0}".format(DrGanMSUExtractor.get_modelpath()))
with zipfile.ZipFile(zip_file) as myzip:
myzip.extractall(os.path.dirname(DrGanMSUExtractor.get_modelpath()))
logger.info("Saving the path `{0}` in the ~.bobrc file".format(DrGanMSUExtractor.get_modelpath()))
rc[DrGanMSUExtractor.get_rcvariable()] = DrGanMSUExtractor.get_modelpath()
_saverc(rc)
# delete extra files
os.unlink(zip_file)
def __call__(self, image):
"""__call__(image) -> feature
Loading