sample.py 9.52 KB
Newer Older
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
1
"""Base definition of sample."""
2

3 4
from collections.abc import MutableSequence
from collections.abc import Sequence
5
from typing import Any
6

7
import h5py
8 9
import numpy as np

10
from bob.io.base import vstack_features
11

12
SAMPLE_DATA_ATTRS = ("data", "samples")
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
13

14

15
def _copy_attributes(sample, parent, kwargs, exclude_list=None):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
16
    """Copies attributes from a dictionary to self."""
17
    exclude_list = exclude_list or []
18 19
    if parent is not None:
        for key in parent.__dict__:
20
            if key.startswith("_") or key in SAMPLE_DATA_ATTRS or key in exclude_list:
21
                continue
22

23 24 25
            setattr(sample, key, getattr(parent, key))

    for key, value in kwargs.items():
26
        if key.startswith("_") or key in SAMPLE_DATA_ATTRS or key in exclude_list:
27 28 29
            continue

        setattr(sample, key, value)
30 31


32 33 34 35
class _ReprMixin:
    def __repr__(self):
        return (
            f"{self.__class__.__name__}("
36
            + ", ".join(
37
                f"{k}={v!r}" for k, v in self.__dict__.items() if not k.startswith("_")
38
            )
39 40 41
            + ")"
        )

42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
    def __eq__(self, other):
        sorted_self = {
            k: v for k, v in sorted(self.__dict__.items(), key=lambda item: item[0])
        }
        sorted_other = {
            k: v for k, v in sorted(other.__dict__.items(), key=lambda item: item[0])
        }

        for s, o in zip(sorted_self, sorted_other):
            # Checking keys
            if s != o:
                return False

            # Checking values
            if isinstance(sorted_self[s], np.ndarray) and isinstance(
                sorted_self[o], np.ndarray
            ):
                if not np.allclose(sorted_self[s], sorted_other[o]):
                    return False
            else:
                if sorted_self[s] != sorted_other[o]:
                    return False

        return True

67 68

class Sample(_ReprMixin):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
69 70
    """Representation of sample. A Sample is a simple container that wraps a
    data-point (see :ref:`bob.pipelines.sample`)
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

    Each sample must have the following attributes:

        * attribute ``data``: Contains the data for this sample


    Parameters
    ----------

        data : object
            Object representing the data to initialize this sample with.

        parent : object
            A parent object from which to inherit all other attributes (except
            ``data``)
    """

    def __init__(self, data, parent=None, **kwargs):
        self.data = data
90
        _copy_attributes(self, parent, kwargs)
91 92 93


class DelayedSample(_ReprMixin):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
94
    """Representation of sample that can be loaded via a callable.
95 96 97 98 99 100 101 102

    The optional ``**kwargs`` argument allows you to attach more attributes to
    this sample instance.


    Parameters
    ----------

103
        load
104 105 106
            A python function that can be called parameterlessly, to load the
            sample in question from whatever medium

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
107
        parent : :any:`DelayedSample`, :any:`Sample`, None
108 109 110
            If passed, consider this as a parent of this sample, to copy
            information

111 112 113 114 115
        delayed_attributes : dict or None
            A dictionary of name : load_fn pairs that will be used to create
            attributes of name : load_fn() in this class. Use this to option
            to create more delayed attributes than just ``sample.data``.

116 117 118 119 120
        kwargs : dict
            Further attributes of this sample, to be stored and eventually
            transmitted to transformed versions of the sample
    """

121
    def __init__(self, load, parent=None, delayed_attributes=None, **kwargs):
122
        self.__running_init__ = True
123
        # Merge parent's and param's delayed_attributes
124 125
        parent_attr = getattr(parent, "_delayed_attributes", None)
        self._delayed_attributes = None if parent_attr is None else parent_attr.copy()
126 127
        if self._delayed_attributes is not None and delayed_attributes is not None:
            self._delayed_attributes.update(delayed_attributes)
128
        elif self._delayed_attributes is None:
129
            self._delayed_attributes = delayed_attributes
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
        # Inherit attributes from parent, without calling delayed_attributes
        for key in getattr(parent, "__dict__", []):
            if (
                not key.startswith("_")
                and key not in SAMPLE_DATA_ATTRS
                and (
                    self._delayed_attributes is None
                    or key not in self._delayed_attributes
                )
            ):
                setattr(self, key, getattr(parent, key))
        # Create the delayed attributes, but leave their values as None for now.
        if self._delayed_attributes is not None:
            kwargs.update({k: None for k in self._delayed_attributes})
        # Set attribute from kwargs
        _copy_attributes(self, None, kwargs)
146
        self._load = load
147
        del self.__running_init__
148

149
    def __getattribute__(self, name: str) -> Any:
150 151 152 153
        try:
            delayed_attributes = super().__getattribute__("_delayed_attributes")
        except AttributeError:
            delayed_attributes = None
154 155 156
        if delayed_attributes is None or name not in delayed_attributes:
            return super().__getattribute__(name)
        return delayed_attributes[name]()
157

158 159
    def __setattr__(self, name: str, value: Any) -> None:
        if name != "delayed_attributes" and "__running_init__" not in self.__dict__:
160
            delayed_attributes = getattr(self, "_delayed_attributes", None)
161 162 163 164 165 166
            # if setting an attribute which was delayed, remove it from delayed_attributes
            if delayed_attributes is not None and name in delayed_attributes:
                del delayed_attributes[name]

        super().__setattr__(name, value)

167 168
    @property
    def data(self):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
169
        """Loads the data from the disk file."""
170
        return self._load()
171 172


173
class SampleSet(MutableSequence, _ReprMixin):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
174
    """A set of samples with extra attributes"""
175 176

    def __init__(self, samples, parent=None, **kwargs):
177
        self.samples = samples
178 179 180 181 182 183
        _copy_attributes(
            self,
            parent,
            kwargs,
            exclude_list=getattr(parent, "_delayed_attributes", None),
        )
184 185 186

    def __len__(self):
        return len(self.samples)
187

188 189
    def __getitem__(self, item):
        return self.samples.__getitem__(item)
190

191 192
    def __setitem__(self, key, item):
        return self.samples.__setitem__(key, item)
193

194 195
    def __delitem__(self, item):
        return self.samples.__delitem__(item)
196

197 198 199
    def insert(self, index, item):
        # if not item in self.samples:
        self.samples.insert(index, item)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
200 201


202 203 204 205
class DelayedSampleSet(SampleSet):
    """A set of samples with extra attributes"""

    def __init__(self, load, parent=None, **kwargs):
206
        self._load = load
207 208 209 210 211 212
        _copy_attributes(
            self,
            parent,
            kwargs,
            exclude_list=getattr(parent, "_delayed_attributes", None),
        )
213 214 215

    @property
    def samples(self):
216
        return self._load()
217 218


219 220 221 222 223 224
class DelayedSampleSetCached(DelayedSampleSet):
    """A cached version of DelayedSampleSet"""

    def __init__(self, load, parent=None, **kwargs):
        super().__init__(load, parent=parent, kwargs=kwargs)
        self._data = None
225 226 227 228 229 230
        _copy_attributes(
            self,
            parent,
            kwargs,
            exclude_list=getattr(parent, "_delayed_attributes", None),
        )
231 232 233 234 235 236 237 238

    @property
    def samples(self):
        if self._data is None:
            self._data = self._load()
        return self._data


Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
239 240 241 242 243 244 245
class SampleBatch(Sequence, _ReprMixin):
    """A batch of samples that looks like [s.data for s in samples]

    However, when you call np.array(SampleBatch), it will construct a numpy array from
    sample.data attributes in a memory efficient way.
    """

246
    def __init__(self, samples, sample_attribute="data"):
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
247
        self.samples = samples
248
        self.sample_attribute = sample_attribute
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
249 250 251 252 253

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, item):
254
        return getattr(self.samples[item], self.sample_attribute)
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
255 256 257 258

    def __array__(self, dtype=None, *args, **kwargs):
        def _reader(s):
            # adding one more dimension to data so they get stacked sample-wise
259
            return getattr(s, self.sample_attribute)[None, ...]
260

Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
261 262
        arr = vstack_features(_reader, self.samples, dtype=dtype)
        return np.asarray(arr, dtype, *args, **kwargs)
263 264 265 266 267 268


def sample_to_hdf5(sample, hdf5):
    """
    Saves the content of sample to hdf5 file

269 270
    Parameters
    ----------
271 272 273 274

        sample: :any:`Sample` or :any:`DelayedSample` or :any:`list`
            Sample to be saved

275
        hdf5: `h5py.File`
276 277 278 279 280 281 282 283
            Pointer to a HDF5 file for writing
    """
    if isinstance(sample, list):
        for i, s in enumerate(sample):
            group = hdf5.create_group(str(i))
            sample_to_hdf5(s, group)
    else:
        for s in sample.__dict__:
284
            hdf5[s] = getattr(sample, s)
285 286 287 288 289 290


def hdf5_to_sample(hdf5):
    """
    Reads the content of a HDF5File and returns a :any:`Sample`

291 292
    Parameters
    ----------
293

294
        hdf5: `h5py.File`
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
            Pointer to a HDF5 file for reading
    """

    # Checking if it has groups
    has_groups = np.sum([isinstance(hdf5[k], h5py.Group) for k in hdf5.keys()]) > 0

    if has_groups:
        # If has groups, returns a list of Samples
        samples = []
        for k in hdf5.keys():
            group = hdf5[k]
            samples.append(hdf5_to_sample(group))
        return samples
    else:
        # If hasn't groups, returns a sample
        sample = Sample(None)
        for k in hdf5.keys():
312
            setattr(sample, k, hdf5[k].value)
313 314

        return sample