Skip to content
Snippets Groups Projects
Commit a504c37c authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models] Fix setup of normalizer

parent c648dfd6
No related branches found
No related tags found
1 merge request!64Add object detection
Pipeline #91753 passed
......@@ -89,7 +89,7 @@ class Alexnet(Model):
if self.pretrained:
from ..normalizer import make_imagenet_normalizer
self.normalizer = make_imagenet_normalizer()
Model.normalizer.fset(self, make_imagenet_normalizer()) # type: ignore[attr-defined]
logger.info(f"Loading pretrained `{self.name}` model weights")
weights = models.AlexNet_Weights.DEFAULT
......
......@@ -91,7 +91,7 @@ class Densenet(Model):
if self.pretrained:
from ..normalizer import make_imagenet_normalizer
self.normalizer = make_imagenet_normalizer()
Model.normalizer.fset(self, make_imagenet_normalizer()) # type: ignore[attr-defined]
logger.info(f"Loading pretrained `{self.name}` model weights")
weights = models.DenseNet121_Weights.DEFAULT
......
......@@ -88,7 +88,7 @@ class FasterRCNN(Model):
if pretrained:
from ..normalizer import make_cocov1_normalizer
self.normalizer = make_cocov1_normalizer()
Model.normalizer.fset(self, make_cocov1_normalizer()) # type: ignore[attr-defined]
logger.info(f"Loading pretrained `{self.name}` model weights")
match self.variant:
......
......@@ -131,7 +131,7 @@ class DRIU(Model):
if pretrained:
from ..normalizer import make_imagenet_normalizer
self.normalizer = make_imagenet_normalizer()
Model.normalizer.fset(self, make_imagenet_normalizer()) # type: ignore[attr-defined]
self.backbone = vgg16_for_segmentation(
pretrained=pretrained,
......
......@@ -135,7 +135,8 @@ class DRIUBN(Model):
if pretrained:
from ..normalizer import make_imagenet_normalizer
self.normalizer = make_imagenet_normalizer()
assert Model.normalizer.fset is not None # type: ignore[attr-defined]
Model.normalizer.fset(self, make_imagenet_normalizer()) # type: ignore[attr-defined]
self.backbone = vgg16_for_segmentation(
pretrained=pretrained,
......
......@@ -116,7 +116,7 @@ class DRIUOD(Model):
if pretrained:
from ..normalizer import make_imagenet_normalizer
self.normalizer = make_imagenet_normalizer()
Model.normalizer.fset(self, make_imagenet_normalizer()) # type: ignore[attr-defined]
self.backbone = vgg16_for_segmentation(
pretrained=pretrained,
......
......@@ -120,7 +120,7 @@ class DRIUPix(Model):
if pretrained:
from ..normalizer import make_imagenet_normalizer
self.normalizer = make_imagenet_normalizer()
Model.normalizer.fset(self, make_imagenet_normalizer()) # type: ignore[attr-defined]
self.backbone = vgg16_for_segmentation(
pretrained=pretrained,
......
......@@ -135,7 +135,7 @@ class HED(Model):
if pretrained:
from ..normalizer import make_imagenet_normalizer
self.normalizer = make_imagenet_normalizer()
Model.normalizer.fset(self, make_imagenet_normalizer()) # type: ignore[attr-defined]
self.backbone = vgg16_for_segmentation(
pretrained=pretrained, return_features=[3, 8, 14, 22, 29]
......
......@@ -183,7 +183,7 @@ class M2Unet(Model):
if pretrained:
from ..normalizer import make_imagenet_normalizer
self.normalizer = make_imagenet_normalizer()
Model.normalizer.fset(self, make_imagenet_normalizer()) # type: ignore[attr-defined]
self.backbone = mobilenet_v2_for_segmentation(
pretrained=pretrained,
......@@ -196,31 +196,3 @@ class M2Unet(Model):
x = self.normalizer(x)
x = self.backbone(x)
return self.head(x)
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""Initialize the normalizer for the current model.
If ``pretrained = True`` the normalizer is set to ImageNet weights, else, a new
set of weights is calculated for Z-normalization based on the input dataloader
data.
Parameters
----------
dataloader
A torch Dataloader from which to compute the mean and std.
Will not be used if the model is pretrained.
"""
if self.pretrained:
from ...utils.string import rewrap
from ..normalizer import make_imagenet_normalizer
logger.warning(
rewrap(
f"""ImageNet pre-trained `{self.name}` model - NOT computing z-norm
factors from train dataloader. Using preset factors from
torchvision."""
)
)
self.normalizer = make_imagenet_normalizer()
else:
super().set_normalizer(dataloader)
......@@ -123,7 +123,7 @@ class Unet(Model):
if pretrained:
from ..normalizer import make_imagenet_normalizer
self.normalizer = make_imagenet_normalizer()
Model.normalizer.fset(self, make_imagenet_normalizer()) # type: ignore[attr-defined]
self.backbone = vgg16_for_segmentation(
pretrained=pretrained,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment