diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index 8b158c284da769f5f3a65994dc2bfe4fba71d78c..e1ba75a3da01179e48b274d4f6d488c67a4f6bf7 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -815,7 +815,7 @@ class ConcatDataModule(lightning.LightningDataModule): * ``test``: uses only the test dataset * ``predict``: uses only the test dataset """ - pass + super().teardown(stage) def train_dataloader(self) -> DataLoader: """Returns the train data loader."""