#!/usr/bin/env python # coding=utf-8 """Tests model loading""" import nose.tools from ..models.normalizer import TorchVisionNormalizer from ..models.backbones.vgg import VGG4Segmentation def test_driu(): from ..models.driu import driu, DRIU model = driu(pretrained_backbone=True, progress=True) nose.tools.eq_(len(model), 3) nose.tools.eq_(type(model[0]), TorchVisionNormalizer) nose.tools.eq_(type(model[1]), VGG4Segmentation) #backbone nose.tools.eq_(type(model[2]), DRIU) #head model = driu(pretrained_backbone=False) nose.tools.eq_(len(model), 2) nose.tools.eq_(type(model[0]), VGG4Segmentation) #backbone nose.tools.eq_(type(model[1]), DRIU) #head def test_driu_bn(): from ..models.driu_bn import driu_bn, DRIUBN model = driu_bn(pretrained_backbone=True, progress=True) nose.tools.eq_(len(model), 3) nose.tools.eq_(type(model[0]), TorchVisionNormalizer) nose.tools.eq_(type(model[1]), VGG4Segmentation) #backbone nose.tools.eq_(type(model[2]), DRIUBN) #head model = driu_bn(pretrained_backbone=False) nose.tools.eq_(len(model), 2) nose.tools.eq_(type(model[0]), VGG4Segmentation) #backbone nose.tools.eq_(type(model[1]), DRIUBN) #head def test_driu_od(): from ..models.driu_od import driu_od, DRIUOD model = driu_od(pretrained_backbone=True, progress=True) nose.tools.eq_(len(model), 3) nose.tools.eq_(type(model[0]), TorchVisionNormalizer) nose.tools.eq_(type(model[1]), VGG4Segmentation) #backbone nose.tools.eq_(type(model[2]), DRIUOD) #head model = driu_od(pretrained_backbone=False) nose.tools.eq_(len(model), 2) nose.tools.eq_(type(model[0]), VGG4Segmentation) #backbone nose.tools.eq_(type(model[1]), DRIUOD) #head def test_driu_pix(): from ..models.driu_pix import driu_pix, DRIUPIX model = driu_pix(pretrained_backbone=True, progress=True) nose.tools.eq_(len(model), 3) nose.tools.eq_(type(model[0]), TorchVisionNormalizer) nose.tools.eq_(type(model[1]), VGG4Segmentation) #backbone nose.tools.eq_(type(model[2]), DRIUPIX) #head model = driu_pix(pretrained_backbone=False) nose.tools.eq_(len(model), 2) nose.tools.eq_(type(model[0]), VGG4Segmentation) #backbone nose.tools.eq_(type(model[1]), DRIUPIX) #head def test_unet(): from ..models.unet import unet, UNet model = unet(pretrained_backbone=True, progress=True) nose.tools.eq_(len(model), 3) nose.tools.eq_(type(model[0]), TorchVisionNormalizer) nose.tools.eq_(type(model[1]), VGG4Segmentation) #backbone nose.tools.eq_(type(model[2]), UNet) #head model = unet(pretrained_backbone=False) nose.tools.eq_(len(model), 2) nose.tools.eq_(type(model[0]), VGG4Segmentation) #backbone nose.tools.eq_(type(model[1]), UNet) #head def test_hed(): from ..models.hed import hed, HED model = hed(pretrained_backbone=True, progress=True) nose.tools.eq_(len(model), 3) nose.tools.eq_(type(model[0]), TorchVisionNormalizer) nose.tools.eq_(type(model[1]), VGG4Segmentation) #backbone nose.tools.eq_(type(model[2]), HED) #head model = hed(pretrained_backbone=False) nose.tools.eq_(len(model), 2) nose.tools.eq_(type(model[0]), VGG4Segmentation) #backbone nose.tools.eq_(type(model[1]), HED) #head def test_m2unet(): from ..models.m2unet import m2unet, M2UNet from ..models.backbones.mobilenetv2 import MobileNetV24Segmentation model = m2unet(pretrained_backbone=True, progress=True) nose.tools.eq_(len(model), 3) nose.tools.eq_(type(model[0]), TorchVisionNormalizer) nose.tools.eq_(type(model[1]), MobileNetV24Segmentation) #backbone nose.tools.eq_(type(model[2]), M2UNet) #head model = m2unet(pretrained_backbone=False) nose.tools.eq_(len(model), 2) nose.tools.eq_(type(model[0]), MobileNetV24Segmentation) #backbone nose.tools.eq_(type(model[1]), M2UNet) #head def test_resunet50(): from ..models.resunet import resunet50, ResUNet from ..models.backbones.resnet import ResNet4Segmentation model = resunet50(pretrained_backbone=True, progress=True) nose.tools.eq_(len(model), 3) nose.tools.eq_(type(model[0]), TorchVisionNormalizer) nose.tools.eq_(type(model[1]), ResNet4Segmentation) #backbone nose.tools.eq_(type(model[2]), ResUNet) #head model = resunet50(pretrained_backbone=False) nose.tools.eq_(len(model), 2) nose.tools.eq_(type(model[0]), ResNet4Segmentation) #backbone nose.tools.eq_(type(model[1]), ResUNet) #head print(model)