Skip to content
Snippets Groups Projects
Commit bc98b557 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed pytorch normalization bit

parent bcf473ed
No related branches found
No related tags found
No related merge requests found
Pipeline #52288 failed
......@@ -160,7 +160,9 @@ class IResnet34(PyTorchModel):
ArcFace model (RESNET 34) from Insightface ported to pytorch
"""
def __init__(self, preprocessor=lambda x: (x - 127.5) / 128.0, memory_demanding=False):
def __init__(
self, preprocessor=lambda x: (x - 127.5) / 128.0, memory_demanding=False
):
urls = [
"https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz",
......@@ -174,7 +176,10 @@ class IResnet34(PyTorchModel):
checkpoint_path = os.path.join(path, "iresnet34-5b0d0e90.pth")
super(IResnet34, self).__init__(
checkpoint_path, config, memory_demanding=memory_demanding
checkpoint_path,
config,
memory_demanding=memory_demanding,
preprocessor=preprocessor,
)
def _load_model(self):
......@@ -188,7 +193,9 @@ class IResnet50(PyTorchModel):
ArcFace model (RESNET 50) from Insightface ported to pytorch
"""
def __init__(self, preprocessor=lambda x: (x - 127.5) / 128.0, memory_demanding=False):
def __init__(
self, preprocessor=lambda x: (x - 127.5) / 128.0, memory_demanding=False
):
filename = _get_iresnet_file()
......@@ -197,7 +204,10 @@ class IResnet50(PyTorchModel):
checkpoint_path = os.path.join(path, "iresnet50-7f187506.pth")
super(IResnet50, self).__init__(
checkpoint_path, config, memory_demanding=memory_demanding
checkpoint_path,
config,
memory_demanding=memory_demanding,
preprocessor=preprocessor,
)
def _load_model(self):
......@@ -206,12 +216,14 @@ class IResnet50(PyTorchModel):
self.model = model
class IResnet100(PyTorchModel):
class IResnet100(PyTorchModel):
"""
ArcFace model (RESNET 100) from Insightface ported to pytorch
"""
def __init__(self, preprocessor=lambda x: (x - 127.5) / 128.0, memory_demanding=False):
def __init__(
self, preprocessor=lambda x: (x - 127.5) / 128.0, memory_demanding=False
):
filename = _get_iresnet_file()
......@@ -220,7 +232,10 @@ class IResnet100(PyTorchModel):
checkpoint_path = os.path.join(path, "iresnet100-73e07ba7.pth")
super(IResnet100, self).__init__(
checkpoint_path, config, memory_demanding=memory_demanding
checkpoint_path,
config,
memory_demanding=memory_demanding,
preprocessor=preprocessor,
)
def _load_model(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment