Disk.py 2.64 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST 

import numpy
import bob.io.base
import bob.io.image
import bob.ip.base
import bob.core
from .Base import Base

logger = bob.core.log.setup("bob.learn.tensorflow")


class Disk(Base):
    def __init__(self, data, labels,
                 input_shape,
                 input_dtype="float64",
                 scale=True,
21 22
                 batch_size=1,
                 seed=10):
23
        """
24 25
         This datashuffler deal with databases that are stored in the disk.
         The data is loaded on the fly,.
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47

         **Parameters**
         data:
         labels:
         input_shape: Shape of the input. `input_shape != data.shape`, the data will be reshaped
         input_dtype="float64":
         scale=True:
         batch_size=1:
        """

        if isinstance(data, list):
            data = numpy.array(data)

        if isinstance(labels, list):
            labels = numpy.array(labels)

        super(Disk, self).__init__(
            data=data,
            labels=labels,
            input_shape=input_shape,
            input_dtype=input_dtype,
            scale=scale,
48 49
            batch_size=batch_size,
            seed=seed
50
        )
51 52
        # Seting the seed
        numpy.random.seed(seed)
53 54 55 56

        # TODO: very bad solution to deal with bob.shape images an tf shape images
        self.bob_shape = tuple([input_shape[2]] + list(input_shape[0:2]))

57
    def load_from_file(self, file_name):
58 59 60 61 62 63 64 65 66 67 68
        d = bob.io.base.load(file_name)
        if d.shape[0] != 3 and self.input_shape[2] != 3: # GRAY SCALE IMAGE
            data = numpy.zeros(shape=(d.shape[0], d.shape[1], 1))
            data[:, :, 0] = d
            data = self.rescale(data)
        else:
            d = self.rescale(d)
            data = self.bob2skimage(d)

        # Checking NaN
        if numpy.sum(numpy.isnan(data)) > 0:
69
            logger.warning("######### Sample {0} has noise #########".format(file_name))
70 71 72 73 74 75 76 77 78 79 80 81 82

        return data

    def get_batch(self):

        # Shuffling samples
        indexes = numpy.array(range(self.data.shape[0]))
        numpy.random.shuffle(indexes)

        selected_data = numpy.zeros(shape=self.shape)
        for i in range(self.batch_size):

            file_name = self.data[indexes[i]]
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
83
            data = self.load_from_file(file_name)
84 85 86 87 88 89 90 91

            selected_data[i, ...] = data
            if self.scale:
                selected_data[i, ...] *= self.scale_value

        selected_labels = self.labels[indexes[0:self.batch_size]]

        return [selected_data.astype("float32"), selected_labels.astype("int64")]