diff --git a/bob/bio/base/test/utils.py b/bob/bio/base/test/utils.py index 7d757b036e28af9bf66ecab2d09f38e2ddc0c29d..8a87ed1e2eb3beda858d669aa131d18242921d81 100644 --- a/bob/bio/base/test/utils.py +++ b/bob/bio/base/test/utils.py @@ -8,7 +8,7 @@ import os import sys import functools from nose.plugins.skip import SkipTest - +from bob.extension.download import download_and_unzip # based on: http://stackoverflow.com/questions/6796492/temporarily-redirect-stdout-stderr class Quiet(object): @@ -91,31 +91,20 @@ def atnt_database_directory(): if os.path.exists(atnt_default_directory): return atnt_default_directory - import sys, tempfile - if sys.version_info[0] <= 2: - import urllib2 as urllib - else: - import urllib.request as urllib +# TODO: THIS SHOULD BE A CLASS METHOD OF bob.db.atnt database + source_url = ['http://bobconda.lab.idiap.ch/public/data/bob/att_faces.zip', + 'http://www.idiap.ch/software/bob/data/bob/att_faces.zip'] + import tempfile atnt_downloaded_directory = tempfile.mkdtemp(prefix='atnt_db_') - db_url = "http://www.cl.cam.ac.uk/Research/DTG/attarchive/pub/data/att_faces.zip" - logger.warn("Downloading the AT&T database from '%s' to '%s' ...", db_url, atnt_downloaded_directory) + logger.warn("Downloading the AT&T database from '%s' to '%s' ...", source_url, atnt_downloaded_directory) logger.warn("To avoid this, please download the database manually, extract the data and set the ATNT_DATABASE_DIRECTORY environment variable to this directory.") # to avoid re-downloading in parallel test execution os.environ['ATNT_DATABASE_DIRECTORY'] = atnt_downloaded_directory - # download - url = urllib.urlopen(db_url) - local_zip_file = os.path.join(atnt_downloaded_directory, 'att_faces.zip') - dfile = open(local_zip_file, 'wb') - dfile.write(url.read()) - dfile.close() - - # unzip - import zipfile - zip = zipfile.ZipFile(local_zip_file) - zip.extractall(atnt_downloaded_directory) - os.remove(local_zip_file) + if not os.path.exists(atnt_downloaded_directory): + os.mkdir(atnt_downloaded_directory) + download_and_unzip(source_url, os.path.join(atnt_downloaded_directory, "att_faces.zip")) return atnt_downloaded_directory