Removed aggrgation from sample loaders

parent 98849410
Pipeline #46543 passed with stage
in 8 minutes and 26 seconds
......@@ -127,9 +127,6 @@ class AnnotationsLoader(TransformerMixin, BaseEstimator):
Parameters
----------
csv_to_sample_loader: :any:`CSVToSampleLoader`
Mechanisms that knows how to convert one line to one sample
annotation_directory: str
Path where the annotations are store
......@@ -143,7 +140,6 @@ class AnnotationsLoader(TransformerMixin, BaseEstimator):
def __init__(
self,
csv_to_sample_loader,
annotation_directory=None,
annotation_extension=".json",
annotation_type="json",
......@@ -151,15 +147,13 @@ class AnnotationsLoader(TransformerMixin, BaseEstimator):
self.annotation_directory = annotation_directory
self.annotation_extension = annotation_extension
self.annotation_type = annotation_type
self.csv_to_sample_loader = csv_to_sample_loader
def transform(self, X):
if self.annotation_directory is None:
return None
samples = self.csv_to_sample_loader.transform(X)
annotated_samples = []
for x in samples:
for x in X:
# since the file id is equal to the file name, we can simply use it
annotation_file = os.path.join(
......@@ -179,3 +173,12 @@ class AnnotationsLoader(TransformerMixin, BaseEstimator):
)
return annotated_samples
def fit(self, X, y=None):
return self
def _more_tags(self):
return {
"stateless": True,
"requires_fit": False,
}
......@@ -4,6 +4,7 @@ import bob.io.base
import bob.io.image
import os
import numpy as np
from sklearn.pipeline import make_pipeline
def test_sample_loader():
......@@ -23,19 +24,20 @@ def test_sample_loader():
def test_annotations_loader():
path = pkg_resources.resource_filename(__name__, os.path.join("data", "samples"))
sample_loader = CSVToSampleLoader(
csv_sample_loader = CSVToSampleLoader(
data_loader=bob.io.base.load, dataset_original_directory=path, extension=".pgm"
)
annotation_loader = AnnotationsLoader(
sample_loader,
annotation_directory=path,
annotation_extension=".pos",
annotation_type="eyecenter",
)
sample_loader = make_pipeline(csv_sample_loader, annotation_loader)
f = open(os.path.join(path, "samples.csv"))
samples = annotation_loader.transform(f)
samples = sample_loader.transform(f)
assert len(samples) == 2
assert np.alltrue([s.data.shape == (112, 92) for s in samples])
assert np.alltrue([isinstance(s.annotations, dict) for s in samples])
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment