Commit 1d621fbb authored by André Anjos's avatar André Anjos 💬
Browse files

[python3] Minor adjustments for py2/3 compatibility

parent e57d041e
......@@ -50,7 +50,7 @@ class Database():
def __del__(self):
try:
if self.m_tmp_dir:
if self.m_tmp_dir:
shutil.rmtree(self.m_tmp_dir) # delete directory
except OSError as e:
if e.errno != 2: # code 2 - no such file or directory
......@@ -66,34 +66,48 @@ class Database():
return True
def __create_tmp_dir_and_download__(self):
import urllib2, tempfile, os
import tempfile, os, sys
tmp_directory = tempfile.mkdtemp(prefix='mnist_db')
print("Downloading the mnist database from http://yann.lecun.com/exdb/mnist/ ...")
for f in self.m_mnist_filenames:
url = urllib2.urlopen('http://yann.lecun.com/exdb/mnist/'+f)
dfile = open(os.path.join(tmp_directory, f), 'w')
dfile.write(url.read())
dfile.close()
tmp_file = os.path.join(tmp_directory, f)
url = 'http://yann.lecun.com/exdb/mnist/'+f
if sys.version_info[0] < 3:
# python2 technique for downloading a file
from urllib2 import urlopen
with urlopen(url) as response, open(tmp_file, 'wb') as out_file:
dfile.write(response.read())
else:
# python3 technique for downloading a file
from urllib.request import urlopen
from shutil import copyfileobj
with urlopen(url) as response, open(tmp_file, 'wb') as out_file:
copyfileobj(response, out_file)
return tmp_directory
def __read_labels__(self, fname):
"""Reads the labels from the original MNIST label binary file"""
import struct, gzip, numpy
f = gzip.GzipFile(fname, 'rb')
# reads 2 big-ending integers
magic_nr, n_examples = struct.unpack(">II", f.read(8))
# reads the rest, using an uint8 dataformat (endian-less)
# reads 2 big-ending integers
magic_nr, n_examples = struct.unpack(">II", f.read(8))
# reads the rest, using an uint8 dataformat (endian-less)
labels = numpy.fromstring(f.read(), dtype='uint8')
return labels
def __read_images__(self, fname):
"""Reads the images from the original MNIST label binary file"""
import struct, gzip, numpy
f = gzip.GzipFile(fname, 'rb')
# reads 4 big-ending integers
magic_nr, n_examples, rows, cols = struct.unpack(">IIII", f.read(16))
shape = (n_examples, rows*cols)
# reads the rest, using an uint8 dataformat (endian-less)
# reads 4 big-ending integers
magic_nr, n_examples, rows, cols = struct.unpack(">IIII", f.read(16))
shape = (n_examples, rows*cols)
# reads the rest, using an uint8 dataformat (endian-less)
images = numpy.fromstring(f.read(), dtype='uint8').reshape(shape)
return images
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment