diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index c3aadc6d0466777fb0adcb62eb7d20887cf36b20..c567b9b3e8bac6b3fa52e0a95e62d55f9627e748 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -13,7 +13,8 @@ import torch.utils.data import torchvision.models as models import torchvision.transforms -from ..data.typing import Checkpoint, DataLoader, TransformSequence +from ..data.typing import DataLoader, TransformSequence +from .typing import Checkpoint logger = logging.getLogger(__name__) diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 021f6ce2c6f5cfb3ad3819144a442744577d5eaa..ea2cab0047736d4bd08fdc75974b1e212325abbd 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -13,7 +13,8 @@ import torch.utils.data import torchvision.models as models import torchvision.transforms -from ..data.typing import Checkpoint, DataLoader, TransformSequence +from ..data.typing import DataLoader, TransformSequence +from .typing import Checkpoint logger = logging.getLogger(__name__) diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 7f0cf57dde2baf79cef18021543d6257a98e35a3..aaa5a2b0dabc72d09b20ace356a32ce211619f8c 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -13,7 +13,8 @@ import torch.optim.optimizer import torch.utils.data import torchvision.transforms -from ..data.typing import Checkpoint, DataLoader, TransformSequence +from ..data.typing import DataLoader, TransformSequence +from .typing import Checkpoint logger = logging.getLogger(__name__)