diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py
index b2d3ab1d48a9bd0d3a66e4a66c50fcc2300be2c9..635c9ed9b7ed23a255ba9f5e5f79828de17b7f17 100644
--- a/src/ptbench/configs/datasets/shenzhen/default.py
+++ b/src/ptbench/configs/datasets/shenzhen/default.py
@@ -10,8 +10,6 @@
 * See :py:mod:`ptbench.data.shenzhen` for dataset details
 """
 
-from ....data.shenzhen.datamodule import ShenzhenDataModule
 from . import _maker
 
 dataset = _maker("default")
-datamodule = ShenzhenDataModule
diff --git a/src/ptbench/data/shenzhen/datamodule.py b/src/ptbench/data/datamodule.py
similarity index 98%
rename from src/ptbench/data/shenzhen/datamodule.py
rename to src/ptbench/data/datamodule.py
index 60de2efb1fa1d79be5a01054578c3718904f1d2b..efbcfaf9e930025a5c0583221d2b8626eb033e3e 100644
--- a/src/ptbench/data/shenzhen/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -13,7 +13,7 @@ from ptbench.configs.datasets import get_samples_weights
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
-class ShenzhenDataModule(pl.LightningDataModule):
+class DataModule(pl.LightningDataModule):
     def __init__(
         self,
         dataset,
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index eb6910cdce997df94155e602591508b99ec512dd..fe5a2b851806b108ad63e48edd9f4f0809536ebc 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -8,6 +8,7 @@ from clapper.click import ConfigCommand, ResourceOption, verbosity_option
 from clapper.logging import setup
 from lightning.pytorch import seed_everything
 
+from ..data.datamodule import DataModule
 from ..utils.checkpointer import get_checkpoint
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@@ -43,12 +44,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     required=True,
     cls=ResourceOption,
 )
-@click.option(
-    "--datamodule",
-    help="A torch.nn.Module instance implementing the network to be trained",
-    required=True,
-    cls=ResourceOption,
-)
 @click.option(
     "--dataset",
     "-d",
@@ -238,7 +233,6 @@ def train(
     drop_incomplete_batch,
     criterion,
     criterion_valid,
-    datamodule,
     dataset,
     checkpoint_period,
     accelerator,
@@ -296,7 +290,7 @@ def train(
     else:
         batch_chunk_size = batch_size // batch_chunk_count
 
-    datamodule = datamodule(
+    datamodule = DataModule(
         dataset,
         train_batch_size=batch_chunk_size,
         multiproc_kwargs=multiproc_kwargs,