Skip to content
Snippets Groups Projects
Commit acbd24f8 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[*] Use package import when possible

parent fd907257
No related branches found
No related tags found
No related merge requests found
......@@ -12,8 +12,8 @@ import pandas
import torch
from tqdm import tqdm
from bob.ip.binseg.utils.metric import SmoothedValue
from bob.ip.binseg.utils.plot import loss_curve
from ..utils.metric import SmoothedValue
from ..utils.plot import loss_curve
import logging
logger = logging.getLogger(__name__)
......
......@@ -4,8 +4,8 @@
import torch
import torch.nn
from collections import OrderedDict
from bob.ip.binseg.modeling.backbones.vgg import vgg16
from bob.ip.binseg.modeling.make_layers import (
from .backbones.vgg import vgg16
from .make_layers import (
conv_with_kaiming_uniform,
convtrans_with_kaiming_uniform,
UpsampleCropBlock,
......
......@@ -4,8 +4,8 @@
import torch
import torch.nn
from collections import OrderedDict
from bob.ip.binseg.modeling.backbones.vgg import vgg16_bn
from bob.ip.binseg.modeling.make_layers import (
from .backbones.vgg import vgg16_bn
from .make_layers import (
conv_with_kaiming_uniform,
convtrans_with_kaiming_uniform,
UpsampleCropBlock,
......
......@@ -4,8 +4,8 @@
import torch
import torch.nn
from collections import OrderedDict
from bob.ip.binseg.modeling.backbones.vgg import vgg16
from bob.ip.binseg.modeling.make_layers import (
from .backbones.vgg import vgg16
from .make_layers import (
conv_with_kaiming_uniform,
convtrans_with_kaiming_uniform,
UpsampleCropBlock,
......
......@@ -4,8 +4,8 @@
import torch
import torch.nn
from collections import OrderedDict
from bob.ip.binseg.modeling.backbones.vgg import vgg16
from bob.ip.binseg.modeling.make_layers import (
from .backbones.vgg import vgg16
from .make_layers import (
conv_with_kaiming_uniform,
convtrans_with_kaiming_uniform,
UpsampleCropBlock,
......
......@@ -4,8 +4,8 @@
import torch
import torch.nn
from collections import OrderedDict
from bob.ip.binseg.modeling.backbones.vgg import vgg16
from bob.ip.binseg.modeling.make_layers import (
from .backbones.vgg import vgg16
from .make_layers import (
conv_with_kaiming_uniform,
convtrans_with_kaiming_uniform,
UpsampleCropBlock,
......
......@@ -6,7 +6,7 @@
from collections import OrderedDict
import torch
import torch.nn
from bob.ip.binseg.modeling.backbones.mobilenetv2 import MobileNetV2, InvertedResidual
from .backbones.mobilenetv2 import MobileNetV2, InvertedResidual
class DecoderBlock(torch.nn.Module):
......
......@@ -3,13 +3,13 @@
import torch.nn as nn
from collections import OrderedDict
from bob.ip.binseg.modeling.make_layers import (
from .make_layers import (
conv_with_kaiming_uniform,
convtrans_with_kaiming_uniform,
PixelShuffle_ICNR,
UnetBlock,
)
from bob.ip.binseg.modeling.backbones.resnet import resnet50
from .backbones.resnet import resnet50
class ResUNet(nn.Module):
......
......@@ -3,13 +3,13 @@
import torch.nn as nn
from collections import OrderedDict
from bob.ip.binseg.modeling.make_layers import (
from .make_layers import (
conv_with_kaiming_uniform,
convtrans_with_kaiming_uniform,
PixelShuffle_ICNR,
UnetBlock,
)
from bob.ip.binseg.modeling.backbones.vgg import vgg16
from .backbones.vgg import vgg16
class UNet(nn.Module):
......
......@@ -9,7 +9,7 @@ import torch
from torch import nn
import os
from bob.ip.binseg.utils.checkpointer import Checkpointer
from ..utils.checkpointer import Checkpointer
class TestCheckpointer(unittest.TestCase):
......@@ -81,7 +81,3 @@ class TestCheckpointer(unittest.TestCase):
self.assertFalse(id(trained_p) == id(loaded_p))
# same content
self.assertTrue(trained_p.equal(loaded_p))
if __name__ == "__main__":
unittest.main()
......@@ -3,12 +3,13 @@
import os
import unittest
from bob.ip.binseg.modeling.driu import build_driu
from bob.ip.binseg.modeling.driuod import build_driuod
from bob.ip.binseg.modeling.hed import build_hed
from bob.ip.binseg.modeling.unet import build_unet
from bob.ip.binseg.modeling.resunet import build_res50unet
from bob.ip.binseg.utils.summary import summary
from ..modeling.driu import build_driu
from ..modeling.driuod import build_driuod
from ..modeling.hed import build_hed
from ..modeling.unet import build_unet
from ..modeling.resunet import build_res50unet
from ..utils.summary import summary
class Tester(unittest.TestCase):
......@@ -45,7 +46,3 @@ class Tester(unittest.TestCase):
s, param = summary(model)
self.assertIsInstance(s, str)
self.assertIsInstance(param, int)
if __name__ == "__main__":
unittest.main()
......@@ -5,8 +5,9 @@
import torch
import os
from bob.ip.binseg.utils.model_serialization import load_state_dict
from bob.ip.binseg.utils.model_zoo import cache_url
from .model_serialization import load_state_dict
from .model_zoo import cache_url
import logging
logger = logging.getLogger(__name__)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment