balance.py 3.16 KB
Newer Older
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
1
from bob.bio.face.pytorch.lightning import BackboneHeadModel
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
2
3
4
5
6
7
8
9
10
11
from torch.nn import Module, Linear
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import numpy as np
import torch.nn as nn
import copy
import math


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
12
13
14
15
16
17
18
19
20
21
22
def switch(model, flag):

    model.train(flag)
    # model.requires_grad = flag
    for p in model.parameters():
        p.requires_grad = flag

    return model



Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
23
24
25
26
27
28
29
class SimpleBalanceModel(BackboneHeadModel):
    """
    Trainer that trains using a balanced dataset
    """

    def __init__(
        self,
Tiago de Freitas Pereira's avatar
Update    
Tiago de Freitas Pereira committed
30
31
32
        backbone=None,
        identity_head=None,
        loss_fn=None,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
33
        backbone_checkpoint_file=None,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
34
        head_epochs=2,
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
35
36
37
38
39
40
41
42
        **kwargs,
    ):
        # super(pl.LightningModule, self).__init__(**kwargs)

        pl.LightningModule.__init__(self, **kwargs)
        self.backbone = backbone
        self.identity_head = identity_head
        self.loss_fn = loss_fn
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
43
        self.head_epochs = head_epochs
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
44
45


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
46
        self.backbone_checkpoint_file = backbone_checkpoint_file
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
47
48
        # self.last_op = None

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        # Control the networks that will be updated
        self.head_switch = False
        self.backbone_switch = False

        # Important: This property activates manual optimization.
        self.automatic_optimization = False


    def define_step(self):
        """
        Step 0: Trains the head only

        Step 1: Trains the backbone only

        """

        if self.current_epoch < self.head_epochs:
            return 0

        return 1

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
70
71
    def training_step(self, batch, batch_idx):

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
72
73
74
        head_opt, backbone_opt = self.optimizers()


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
75
76
77
78
        data = batch["data"]
        label = batch["label"]
        # demography = batch["demography"]

Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

        step = self.define_step()
        self.log("train/multitask_step", step)

        if step == 0:
            ## First we learn the head
            if not self.head_switch:
                self.head_switch = True
                self.backbone = switch(self.backbone, False)
        else:
            if not self.backbone_switch:
                self.backbone_switch = True
                self.backbone = switch(self.backbone, True)

        # Embedding
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
94
95
96
97
98
99
        embedding = self.backbone(data)

        # Identity loss
        logits_identiy = self.identity_head(embedding, label)
        loss_identity = self.loss_fn(logits_identiy, label)
        self.log("train/loss_identity", loss_identity)
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
100
101
102
103
104
105
106
107
108
109
        
        
        # Updating the head and bacbone
        head_opt.zero_grad()
        backbone_opt.zero_grad()        
        self.manual_backward(loss_identity)
        head_opt.step()
        backbone_opt.step()


Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

        acc = (
            sum(
                np.argmax(logits_identiy.cpu().detach().numpy(), axis=1)
                == label.cpu().detach().numpy()
            )
            / label.shape[0]
        )
        self.log("train/acc_identity", acc)

        self.log("train/total_loss", loss_identity)

        return loss_identity

    def configure_optimizers(self):
Tiago de Freitas Pereira's avatar
Tiago de Freitas Pereira committed
125
126
127
128
129
        head_opt = torch.optim.Adam(self.identity_head.parameters(), lr=0.1)
        backbone_opt = torch.optim.Adam(self.backbone.parameters(), lr=0.001)
        return head_opt, backbone_opt