Commit d10d3d24 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Batching demographic datasets

parent d02339d3
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
from torch.utils.data import Dataset
from bob.bio.face.database import MEDSDatabase, MorphDatabase
class DemoraphicTorchDataset(Dataset):
def __init__(
self, bob_dataset,
):
self.bob_dataset = bob_dataset
self.bucket = [s for sset in self.bob_dataset.zprobes() for s in sset]
self.bucket += [s for sset in self.bob_dataset.treferences() for s in sset]
# Defining keys and labels
keys = [sset.subject_id for sset in self.bob_dataset.zprobes()] + [
sset.subject_id for sset in self.bob_dataset.treferences()
]
self.labels = dict(zip(keys, range(len(keys))))
self.metadata_keys = self.load_demographics()
def __len__(self):
return len(self.bucket)
def __getitem__(self, idx):
sample = self.bucket[idx]
image = sample.data
label = self.labels[sample.subject_id]
demography = self.get_demographics(sample)
return {"data": image, "label": label, "demography": demography}
class MedsTorchDataset(DemoraphicTorchDataset):
def __init__(
self, protocol, database_path, database_extension=".h5",
):
bob_dataset = MEDSDatabase(
protocol=protocol,
dataset_original_directory=database_path,
dataset_original_extension=database_extension,
)
super().__init__(bob_dataset)
def load_demographics(self):
target_metadata = "rac"
metadata_keys = set(
[getattr(sset, target_metadata) for sset in self.bob_dataset.zprobes()]
+ [
getattr(sset, target_metadata)
for sset in self.bob_dataset.treferences()
]
)
metadata_keys = dict(zip(metadata_keys, range(len(metadata_keys))))
return metadata_keys
def get_demographics(self, sample):
demographic_key = getattr(sample, "rac")
return self.metadata_keys[demographic_key]
class MorphTorchDataset(DemoraphicTorchDataset):
def __init__(
self, protocol, database_path, database_extension=".h5",
):
bob_dataset = MorphDatabase(
protocol=protocol,
dataset_original_directory=database_path,
dataset_original_extension=database_extension,
)
super().__init__(bob_dataset)
def load_demographics(self):
target_metadata = "rac"
metadata_keys = set(
[f"{sset.rac}-{sset.sex}" for sset in self.bob_dataset.zprobes()]
+ [f"{sset.rac}-{sset.sex}" for sset in self.bob_dataset.treferences()]
)
metadata_keys = dict(zip(metadata_keys, range(len(metadata_keys))))
return metadata_keys
def get_demographics(self, sample):
demographic_key = f"{sample.rac}-{sample.sex}"
return self.metadata_keys[demographic_key]
from .facecrop import facecrop_pipeline
from bob.pipelines import wrap
from sklearn.pipeline import make_pipeline
def facecrop_pipeline(database, preprocessor, output_dir, dask_client):
transform_extra_arguments = (("annotations", "annotations"),)
pipeline = make_pipeline(
wrap(
["sample"],
preprocessor,
transform_extra_arguments=transform_extra_arguments,
)
)
pipeline = wrap(["checkpoint", "dask"], pipeline, features_dir=output_dir,)
# pipeline = make_pipeline(
# wrap(["sample", "checkpoint", "dask"], preprocessor, features_dir=output_dir,)
# )
background_model_samples = database.background_model_samples()
pipeline.transform(background_model_samples).compute(scheduler=dask_client)
if hasattr(database, "zprobes"):
pipeline.transform(database.zprobes()).compute(scheduler=dask_client)
if hasattr(database, "treferences"):
pipeline.transform(database.treferences()).compute(scheduler=dask_client)
pass
from bob.bio.demographics.preprocessor import facecrop_pipeline
from bob.bio.face.preprocessor import FaceCrop
def face_crop():
"""
from bob.bio.face.database import MEDSDatabase
output_dir = "/idiap/temp/tpereira/3.FaceCrops/meds/"
protocol = "verification_fold1"
database = MEDSDatabase(protocol=protocol)
"""
from bob.bio.face.database import MorphDatabase
output_dir = "/idiap/temp/tpereira/3.FaceCrops/morph/"
protocol = "verification_fold1"
database = MorphDatabase(protocol=protocol)
cropped_image_size = (112, 112)
cropped_positions = {
"leye": (55, 72),
"reye": (55, 40),
}
color_channel = "rgb"
preprocessor = FaceCrop(
cropped_image_size=cropped_image_size,
cropped_positions=cropped_positions,
color_channel=color_channel,
)
# dask_client = "single-threaded"
from dask.distributed import Client
from bob.pipelines.distributed.sge import SGEMultipleQueuesCluster
cluster = SGEMultipleQueuesCluster(min_jobs=1)
dask_client = Client(cluster)
facecrop_pipeline(database, preprocessor, output_dir, dask_client)
face_crop()
from bob.bio.demographics.datasets import MedsTorchDataset, MorphTorchDataset
import time
# https://pytorch.org/docs/stable/data.html
from torch.utils.data import DataLoader
import pytest
from bob.extension import rc
import os
@pytest.mark.skipif(
rc.get("bob.bio.demographics.directory") is None,
reason="Demographics features directory not available. Please do `bob config set bob.bio.demographics.directory [PATH]` to set the base features path.",
)
def test_meds():
database_path = os.path.join(
rc.get("bob.bio.demographics.directory"), "meds", "samplewrapper"
)
dataset = MedsTorchDataset(
protocol="verification_fold1", database_path=database_path,
)
dataloader = DataLoader(
dataset, batch_size=64, shuffle=True, pin_memory=True, num_workers=2
)
batch = next(iter(dataloader))
batch["data"].shape == (64, 3, 112, 112)
@pytest.mark.skipif(
rc.get("bob.bio.demographics.directory") is None,
reason="Demographics features directory not available. Please do `bob config set bob.bio.demographics.directory [PATH]` to set the base features path.",
)
def test_morph():
database_path = os.path.join(
rc.get("bob.bio.demographics.directory"), "morph", "samplewrapper"
)
dataset = MorphTorchDataset(
protocol="verification_fold1", database_path=database_path,
)
# dataloader = DataLoader(
# dataset, batch_size=64, shuffle=True, pin_memory=True, num_workers=2
# )
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
batch = next(iter(dataloader))
batch["data"].shape == (64, 3, 112, 112)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment