transformers.py 853 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from bob.bio.face.embeddings.pytorch import PyTorchModel, iresnet_template


class RunnableTransformer(PyTorchModel):
    """
    ArcFace model (RESNET 100) from Insightface ported to pytorch
    """

    def __init__(
        self,
        runnable_pytorch_model,
        preprocessor=lambda x: (x - 127.5) / 128.0,
        memory_demanding=False,
        device=None,
        **kwargs,
    ):

        super(RunnableTransformer, self).__init__(
            checkpoint_path="",
            config="",
            memory_demanding=memory_demanding,
            preprocessor=preprocessor,
            device=device,
            **kwargs,
        )
        self.runnable_pytorch_model = runnable_pytorch_model

    def _load_model(self):

        self.model = self.runnable_pytorch_model()

        self.model.eval()
        self.place_model_on_device()