Commit 55b417d6 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[SampleBatch] Allow other attributes than data

parent aa994f82
Pipeline #45901 failed with stage
in 7 minutes and 33 seconds
......@@ -179,19 +179,20 @@ class SampleBatch(Sequence, _ReprMixin):
sample.data attributes in a memory efficient way.
"""
def __init__(self, samples):
def __init__(self, samples, sample_attribute="data"):
self.samples = samples
self.sample_attribute = sample_attribute
def __len__(self):
return len(self.samples)
def __getitem__(self, item):
return self.samples[item].data
return getattr(self.samples[item], self.sample_attribute)
def __array__(self, dtype=None, *args, **kwargs):
def _reader(s):
# adding one more dimension to data so they get stacked sample-wise
return s.data[None, ...]
return getattr(s, self.sample_attribute)[None, ...]
arr = vstack_features(_reader, self.samples, dtype=dtype)
return np.asarray(arr, dtype, *args, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment