From 8a7292a5554d1d136ef3167d6fa681b046ae8cac Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 24 May 2023 13:41:13 +0200
Subject: [PATCH] Moved DataModule to its own script

DataModule is generic and not tied to a specific dataset.
---
 src/ptbench/configs/datasets/shenzhen/default.py |  2 --
 src/ptbench/data/{shenzhen => }/datamodule.py    |  2 +-
 src/ptbench/scripts/train.py                     | 10 ++--------
 3 files changed, 3 insertions(+), 11 deletions(-)
 rename src/ptbench/data/{shenzhen => }/datamodule.py (98%)

diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py
index b2d3ab1d..635c9ed9 100644
--- a/src/ptbench/configs/datasets/shenzhen/default.py
+++ b/src/ptbench/configs/datasets/shenzhen/default.py
@@ -10,8 +10,6 @@
 * See :py:mod:`ptbench.data.shenzhen` for dataset details
 """
 
-from ....data.shenzhen.datamodule import ShenzhenDataModule
 from . import _maker
 
 dataset = _maker("default")
-datamodule = ShenzhenDataModule
diff --git a/src/ptbench/data/shenzhen/datamodule.py b/src/ptbench/data/datamodule.py
similarity index 98%
rename from src/ptbench/data/shenzhen/datamodule.py
rename to src/ptbench/data/datamodule.py
index 60de2efb..efbcfaf9 100644
--- a/src/ptbench/data/shenzhen/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -13,7 +13,7 @@ from ptbench.configs.datasets import get_samples_weights
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
-class ShenzhenDataModule(pl.LightningDataModule):
+class DataModule(pl.LightningDataModule):
     def __init__(
         self,
         dataset,
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index eb6910cd..fe5a2b85 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -8,6 +8,7 @@ from clapper.click import ConfigCommand, ResourceOption, verbosity_option
 from clapper.logging import setup
 from lightning.pytorch import seed_everything
 
+from ..data.datamodule import DataModule
 from ..utils.checkpointer import get_checkpoint
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
@@ -43,12 +44,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     required=True,
     cls=ResourceOption,
 )
-@click.option(
-    "--datamodule",
-    help="A torch.nn.Module instance implementing the network to be trained",
-    required=True,
-    cls=ResourceOption,
-)
 @click.option(
     "--dataset",
     "-d",
@@ -238,7 +233,6 @@ def train(
     drop_incomplete_batch,
     criterion,
     criterion_valid,
-    datamodule,
     dataset,
     checkpoint_period,
     accelerator,
@@ -296,7 +290,7 @@ def train(
     else:
         batch_chunk_size = batch_size // batch_chunk_count
 
-    datamodule = datamodule(
+    datamodule = DataModule(
         dataset,
         train_batch_size=batch_chunk_size,
         multiproc_kwargs=multiproc_kwargs,
-- 
GitLab