From 7e55b4be45777847d94650007c601d0ff08ce6cb Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 9 May 2023 13:23:15 +0200 Subject: [PATCH] Replaced imports of pytorch_lightning py lightning.pytorch --- doc/conf.py | 2 +- src/ptbench/engine/callbacks.py | 4 ++-- src/ptbench/engine/predictor.py | 2 +- src/ptbench/engine/trainer.py | 8 ++++---- src/ptbench/models/alexnet.py | 4 ++-- src/ptbench/models/densenet.py | 2 +- src/ptbench/models/densenet_rs.py | 2 +- src/ptbench/models/logistic_regression.py | 2 +- src/ptbench/models/pasa.py | 2 +- src/ptbench/models/signs_to_tb.py | 2 +- src/ptbench/scripts/train.py | 2 +- 11 files changed, 16 insertions(+), 16 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 38c97b27..49b7ceac 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -122,7 +122,7 @@ auto_intersphinx_packages = [ "psutil", "torch", "torchvision", - "pytorch-lightning", + "lightning", ("clapper", "latest"), ("python", "3"), ] diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index b1d86c8f..d6e35e84 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -3,8 +3,8 @@ import time import numpy -from pytorch_lightning import Callback -from pytorch_lightning.callbacks import BasePredictionWriter +from lightning.pytorch import Callback +from lightning.pytorch.callbacks import BasePredictionWriter # This ensures CSVLogger logs training and evaluation metrics on the same line diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py index 8afc8b85..33dc18be 100644 --- a/src/ptbench/engine/predictor.py +++ b/src/ptbench/engine/predictor.py @@ -5,7 +5,7 @@ import logging import os -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from ..utils.accelerator import AcceleratorProcessor from .callbacks import PredictionsWriter diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index b8532e85..76aaaf76 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -7,10 +7,10 @@ import logging import os import shutil -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger -from pytorch_lightning.utilities.model_summary import ModelSummary +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger +from lightning.pytorch.utilities.model_summary import ModelSummary from ..utils.accelerator import AcceleratorProcessor from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index 74c07b71..073013cd 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.nn as nn import torchvision.models as models @@ -10,7 +10,7 @@ import torchvision.models as models from .normalizer import TorchVisionNormalizer -class Alexnet(pl.LightningModule): +class Alexnet(pl.core.LightningModule): """Alexnet module. Note: only usable with a normalized dataset diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 31abf44d..27c3393d 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.nn as nn import torchvision.models as models diff --git a/src/ptbench/models/densenet_rs.py b/src/ptbench/models/densenet_rs.py index 557d34a9..16f4eefb 100644 --- a/src/ptbench/models/densenet_rs.py +++ b/src/ptbench/models/densenet_rs.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.nn as nn import torchvision.models as models diff --git a/src/ptbench/models/logistic_regression.py b/src/ptbench/models/logistic_regression.py index d53f8df0..c6df54bc 100644 --- a/src/ptbench/models/logistic_regression.py +++ b/src/ptbench/models/logistic_regression.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.nn as nn diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 76aab5a4..3d4a7641 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.nn as nn import torch.nn.functional as F diff --git a/src/ptbench/models/signs_to_tb.py b/src/ptbench/models/signs_to_tb.py index f88707e9..47337727 100644 --- a/src/ptbench/models/signs_to_tb.py +++ b/src/ptbench/models/signs_to_tb.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import pytorch_lightning as pl +import lightning.pytorch as pl import torch diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index 6d117c5f..50705299 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -6,7 +6,7 @@ import click from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.logging import setup -from pytorch_lightning import seed_everything +from lightning.pytorch import seed_everything from ..utils.checkpointer import get_checkpoint -- GitLab