Skip to content
Snippets Groups Projects
Commit a181ea80 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Added regular head

parent cc839af2
Branches
Tags
1 merge request!50WIP: Lightning
Pipeline #53117 failed
......@@ -5,7 +5,7 @@ from torch import nn
class Lenet5(Module):
def __init__(self, num_features=30):
def __init__(self, num_features=84):
super(Lenet5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.relu1 = nn.ReLU()
......@@ -15,12 +15,12 @@ class Lenet5(Module):
self.pool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(256, 120)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120, 84)
self.fc2 = nn.Linear(120, num_features)
self.relu4 = nn.ReLU()
# Tiago: this was a slight adaptation I made
# self.fc3 = nn.Linear(84, 10)
self.fc3 = nn.Linear(84, num_features)
self.relu5 = nn.ReLU()
# self.fc3 = nn.Linear(84, num_features)
# self.relu5 = nn.ReLU()
def forward(self, x):
......@@ -35,6 +35,6 @@ class Lenet5(Module):
y = self.relu3(y)
y = self.fc2(y)
y = self.relu4(y)
y = self.fc3(y)
y = self.relu5(y)
# y = self.fc3(y)
# y = self.relu5(y)
return y
from .arcface import ArcFace
from .regular import Regular
from torch.nn import Module, Linear
class Regular(Module):
"""
Implement a regular head used for softmax layers
"""
def __init__(self, feat_dim, num_class):
super(Regular, self).__init__()
self.fc = Linear(feat_dim, num_class, bias=False)
def forward(self, feats, labels):
return self.fc(feats)
......@@ -9,30 +9,31 @@ import os
from torch.utils.data import DataLoader
from bob.learn.pytorch.trainers import BackboneHeadModel
from bob.learn.pytorch.architectures.lenet import Lenet5
from bob.learn.pytorch.head import ArcFace
from bob.learn.pytorch.head import ArcFace, Regular
from functools import partial
import pytorch_lightning as pl
import torch
from torch.utils.data import IterableDataset
from torch.utils.data import Dataset
# import torchvision.transforms as transforms
class MnistDictionaryDataset(IterableDataset):
class MnistDictionaryDataset(Dataset):
def __init__(self, fashion_mnist_dataset):
super(MnistDictionaryDataset, self).__init__()
self.fashion_mnist_dataset = fashion_mnist_dataset
def __iter__(self):
def __len__(self):
return self.fashion_mnist_dataset.data.shape[0]
for d, l in zip(
self.fashion_mnist_dataset.data, self.fashion_mnist_dataset.targets
):
yield {
"data": torch.unsqueeze(d / 255.0, axis=0),
"label": l,
}
def __getitem__(self, idx):
return {
"data": torch.unsqueeze(
self.fashion_mnist_dataset.data[idx] / 255.0, axis=0
),
"label": self.fashion_mnist_dataset.targets[idx],
}
def convert_dataset(dataset):
......@@ -50,6 +51,9 @@ def test_boring_model():
torchvision.datasets.FashionMNIST(root_path, download=True, train=True)
),
batch_size=128,
shuffle=True,
persistent_workers=True,
num_workers=2,
)
validation_dataloader = DataLoader(
convert_dataset(
......@@ -59,7 +63,8 @@ def test_boring_model():
)
backbone = Lenet5()
head = ArcFace(feat_dim=30, num_class=10)
# head = ArcFace(feat_dim=30, num_class=10)
head = Regular(feat_dim=84, num_class=10)
optimizer = partial(torch.optim.SGD, lr=0.1, momentum=0.9)
# Preparing lightining model
......@@ -70,16 +75,17 @@ def test_boring_model():
optimizer_fn=optimizer,
)
# using this code to learn too
# TODO: using this code to learn too
# so, be nice with my comments
trainer = pl.Trainer(
# callbacks=..... # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#callbacks
# logger=logger,...
max_epochs=10,
max_epochs=4,
gpus=-1 if torch.cuda.is_available() else None,
# resume_from_checkpoint=resume_from_checkpoint, #https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#resume-from-checkpoint
# debug flags
limit_train_batches=10, # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#limit-train-batches
limit_val_batches=1,
# limit_train_batches=10, # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#limit-train-batches
# limit_val_batches=1,
amp_level="00", # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#amp-level
)
......@@ -89,4 +95,6 @@ def test_boring_model():
val_dataloaders=validation_dataloader,
)
pass
## Assert the accuracy
assert trainer.validate()[0]["validation/accuracy"] > 0.5
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment