Commit 55b417d6 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[SampleBatch] Allow other attributes than data

parent aa994f82
Pipeline #45901 failed with stage
in 7 minutes and 33 seconds
...@@ -179,19 +179,20 @@ class SampleBatch(Sequence, _ReprMixin): ...@@ -179,19 +179,20 @@ class SampleBatch(Sequence, _ReprMixin):
sample.data attributes in a memory efficient way. sample.data attributes in a memory efficient way.
""" """
def __init__(self, samples): def __init__(self, samples, sample_attribute="data"):
self.samples = samples self.samples = samples
self.sample_attribute = sample_attribute
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
def __getitem__(self, item): def __getitem__(self, item):
return self.samples[item].data return getattr(self.samples[item], self.sample_attribute)
def __array__(self, dtype=None, *args, **kwargs): def __array__(self, dtype=None, *args, **kwargs):
def _reader(s): def _reader(s):
# adding one more dimension to data so they get stacked sample-wise # adding one more dimension to data so they get stacked sample-wise
return s.data[None, ...] return getattr(s, self.sample_attribute)[None, ...]
arr = vstack_features(_reader, self.samples, dtype=dtype) arr = vstack_features(_reader, self.samples, dtype=dtype)
return np.asarray(arr, dtype, *args, **kwargs) return np.asarray(arr, dtype, *args, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment