vgg2-full.py 1.32 KB
Newer Older
Tiago de Freitas Pereira's avatar
Update  
Tiago de Freitas Pereira committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from bob.bio.face.pytorch.datasets import VGG2TorchDataset

# https://pytorch.org/docs/stable/data.html
from torch.utils.data import DataLoader
from bob.extension import rc

import torch
from functools import partial

import torchvision.transforms as transforms

from bob.bio.face.pytorch.preprocessing import get_standard_data_augmentation


BATCH_SIZE = 128
PROTOCOL = "vgg2-full"
DATABASE_PATH = rc.get("bob.bio.face.vgg2-crops.directory")
DATABASE_EXTENSION = ".jpg"

import logging

logger = logging.getLogger(__name__)

logger.info(f"Loading protocol {PROTOCOL} from {DATABASE_PATH}")


transform = get_standard_data_augmentation()

train_dataset = VGG2TorchDataset(
    protocol=PROTOCOL,
    database_path=DATABASE_PATH,
    database_extension=DATABASE_EXTENSION,
    transform=transform,
)

# validation_dataset = VGG2TorchDataset(
# protocol=PROTOCOL,
# database_path=DATABASE_PATH,
# database_extension=DATABASE_EXTENSION,
# transform=transform,
# train=False,
# )

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    num_workers=4,
)

validation_dataset = None

# For some reason we have an issue with
#validation_dataloader = None
# validation_dataloader = DataLoader(
# validation_dataset,
# batch_size=BATCH_SIZE,
# shuffle=False,
# pin_memory=False,
# num_workers=1,
# )