Skip to content
Snippets Groups Projects
Commit a2c1b973 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[models.driu_od] Fix copy-n-paste error on backbone layer selection

parent 37add490
No related branches found
No related tags found
No related merge requests found
Pipeline #40177 passed
...@@ -24,7 +24,9 @@ class DRIUOD(torch.nn.Module): ...@@ -24,7 +24,9 @@ class DRIUOD(torch.nn.Module):
def __init__(self, in_channels_list=None): def __init__(self, in_channels_list=None):
super(DRIUOD, self).__init__() super(DRIUOD, self).__init__()
in_upsample2, in_upsample_4, in_upsample_8, in_upsample_16 = in_channels_list in_upsample2, in_upsample_4, in_upsample_8, in_upsample_16 = (
in_channels_list
)
self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0) self.upsample2 = UpsampleCropBlock(in_upsample2, 16, 4, 2, 0)
# Upsample layers # Upsample layers
...@@ -83,14 +85,16 @@ def driu_od(pretrained_backbone=True, progress=True): ...@@ -83,14 +85,16 @@ def driu_od(pretrained_backbone=True, progress=True):
""" """
backbone = vgg16_for_segmentation( backbone = vgg16_for_segmentation(
pretrained=pretrained_backbone, progress=progress, pretrained=pretrained_backbone,
return_features=[3, 8, 14, 22], progress=progress,
return_features=[8, 14, 22, 29],
) )
head = DRIUOD([128, 256, 512, 512]) head = DRIUOD([128, 256, 512, 512])
order = [("backbone", backbone), ("head", head)] order = [("backbone", backbone), ("head", head)]
if pretrained_backbone: if pretrained_backbone:
from .normalizer import TorchVisionNormalizer from .normalizer import TorchVisionNormalizer
order = [("normalizer", TorchVisionNormalizer())] + order order = [("normalizer", TorchVisionNormalizer())] + order
model = torch.nn.Sequential(OrderedDict(order)) model = torch.nn.Sequential(OrderedDict(order))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment