diff --git a/api_test.py b/api_test.py index d59f571d515ff73ade4b442922e2291b46485ae3..940fff96b543bc9520a5f9e4dc400352a64f346e 100644 --- a/api_test.py +++ b/api_test.py @@ -2,6 +2,7 @@ from bob.io.stream import Stream, StreamFile, StreamFilter, stream_filter #, Str import numpy as np import bob.io.base + # create data sets num_frame = 10 @@ -40,7 +41,31 @@ assert(stream_b.timestamps == None) assert(stream_a.camera == None) assert(stream_b.camera == None) -print(stream_a.image_points) +# load full datasets +ld_a = stream_a.load() +ld_b = stream_b.load() + +assert(ld_a.shape == data_a.shape) +assert(ld_b.shape == data_b.shape) +assert(np.array_equal(ld_a, data_a)) +assert(np.array_equal(ld_b, data_b)) + +# try some indices slices cases and compare data + +tests = [ ( None, None, None), + ( None, 3, None), + ( 5, None, None), + ( None, None, 3), + ( 1, 10, 3), + ( 9, 0, -3), + ( -5, -1, None), + ( 10, 0, -3)] + +for t in tests: + s = slice(t[0], t[1], t[2]) + dd = data_a[s] + sd = stream_a.load(s) + assert(np.array_equal(sd, dd)) ########### diff --git a/bob/io/stream/stream.py b/bob/io/stream/stream.py index 14fc163d56501d758751c37adedea3e232b9452f..12f23e2020bccbb09b0dafaa6716b8f2f0b89dbb 100644 --- a/bob/io/stream/stream.py +++ b/bob/io/stream/stream.py @@ -128,7 +128,7 @@ class Stream: raise Exception("not yet implemented") # load one or several frames - def load(self, index): + def load(self, index=None): indices = self.get_indices(index) # return buffered data OR load from file OR process data if self.__loaded == indices and self.__data is not None: @@ -141,20 +141,50 @@ class Stream: self.__loaded = indices return self.__data - # get list of indices - # TODO check all cases + # get list of frame indices def get_indices(self, index): + # None index is equivalent to [:] i.e. slice(None, None, None) + if index is None: + index = slice(None, None, None) + # frame index transform to list if isinstance(index, int): indices = [index] + # slice transform to list elif isinstance(index, slice): - if index.step == None: - indices = list(range(index.start, index.stop)) + # start value: handle None and negative + if index.start is not None: + if index.start < 0: + start = self.shape[0] + index.start + else: + start = index.start + # boundary case + if start >= self.shape[0]: + start = self.shape[0] - 1 + else: + start = 0 + # stop value: handle None and negative + if index.stop is not None: + if index.stop < 0: + stop = self.shape[0] + index.stop + else: + stop = index.stop + # boundary case + if stop >= self.shape[0]: + stop = self.shape[0] - 1 + else: + stop = self.shape[0] + # step value: handle None + if index.step is not None: + step = index.step else: - indices = list(range(index.start, index.stop, index.step)) + step = 1 + # generate list + indices = list(range(start, stop, step)) + # pass lists thru elif isinstance(index, list): indices = index else: - raise Exception("index can only be int, slice, tuple or list") + raise Exception("index can only be None, int, slice or list") return indices # filters diff --git a/bob/io/stream/stream_file.py b/bob/io/stream/stream_file.py index 6e4d2e5dc0707b697a8a53595402caf6d86e6846..4b9c1d41a89f33d1af0b4e8ea162a055540409f6 100644 --- a/bob/io/stream/stream_file.py +++ b/bob/io/stream/stream_file.py @@ -51,7 +51,7 @@ class StreamFile: else: # return a generic config if no config is present # TODO: make formal - data_config = { 'array_format' : None, + data_config = { 'array_format' : {}, 'rotation' : None, 'camera' : None, 'path' : stream_name}