vgg2-short.py 1.33 KB
Newer Older
Tiago de Freitas Pereira's avatar
Update  
Tiago de Freitas Pereira committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from bob.bio.face.pytorch.datasets import VGG2TorchDataset

# https://pytorch.org/docs/stable/data.html
from torch.utils.data import DataLoader
from bob.extension import rc

import torch
from functools import partial

import torchvision.transforms as transforms

from bob.bio.face.pytorch.preprocessing import get_standard_data_augmentation


if locals().get("BATCH_SIZE") is None:
    BATCH_SIZE = 128


PROTOCOL = "vgg2-short"
DATABASE_PATH = rc.get("bob.bio.face.vgg2-crops.directory")
DATABASE_EXTENSION = ".jpg"
import logging

logger = logging.getLogger(__name__)

logger.info(f"Loading protocol {PROTOCOL} from {DATABASE_PATH}")


transform = get_standard_data_augmentation()

train_dataset = VGG2TorchDataset(
    protocol=PROTOCOL,
    database_path=DATABASE_PATH,
    database_extension=DATABASE_EXTENSION,
    transform=transform,
)

# validation_dataset = VGG2TorchDataset(
#    protocol=protocol,
#    database_path=database_path,
#    database_extension=DATABASE_EXTENSION,
#    transform=transform,
#    train=False,
# )

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    num_workers=2,
)

# validation_dataloader = DataLoader(
#    validation_dataset,
#    batch_size=batch_size,
#    shuffle=False,
#    pin_memory=False,
#    num_workers=1,
# )

validation_dataset = None