diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py index b2d3ab1d48a9bd0d3a66e4a66c50fcc2300be2c9..635c9ed9b7ed23a255ba9f5e5f79828de17b7f17 100644 --- a/src/ptbench/configs/datasets/shenzhen/default.py +++ b/src/ptbench/configs/datasets/shenzhen/default.py @@ -10,8 +10,6 @@ * See :py:mod:`ptbench.data.shenzhen` for dataset details """ -from ....data.shenzhen.datamodule import ShenzhenDataModule from . import _maker dataset = _maker("default") -datamodule = ShenzhenDataModule diff --git a/src/ptbench/data/shenzhen/datamodule.py b/src/ptbench/data/datamodule.py similarity index 98% rename from src/ptbench/data/shenzhen/datamodule.py rename to src/ptbench/data/datamodule.py index 60de2efb1fa1d79be5a01054578c3718904f1d2b..efbcfaf9e930025a5c0583221d2b8626eb033e3e 100644 --- a/src/ptbench/data/shenzhen/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -13,7 +13,7 @@ from ptbench.configs.datasets import get_samples_weights logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") -class ShenzhenDataModule(pl.LightningDataModule): +class DataModule(pl.LightningDataModule): def __init__( self, dataset, diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index eb6910cdce997df94155e602591508b99ec512dd..fe5a2b851806b108ad63e48edd9f4f0809536ebc 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -8,6 +8,7 @@ from clapper.click import ConfigCommand, ResourceOption, verbosity_option from clapper.logging import setup from lightning.pytorch import seed_everything +from ..data.datamodule import DataModule from ..utils.checkpointer import get_checkpoint logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -43,12 +44,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") required=True, cls=ResourceOption, ) -@click.option( - "--datamodule", - help="A torch.nn.Module instance implementing the network to be trained", - required=True, - cls=ResourceOption, -) @click.option( "--dataset", "-d", @@ -238,7 +233,6 @@ def train( drop_incomplete_batch, criterion, criterion_valid, - datamodule, dataset, checkpoint_period, accelerator, @@ -296,7 +290,7 @@ def train( else: batch_chunk_size = batch_size // batch_chunk_count - datamodule = datamodule( + datamodule = DataModule( dataset, train_batch_size=batch_chunk_size, multiproc_kwargs=multiproc_kwargs,