diff --git a/tests/test_mc.py b/tests/test_mc.py index 87de46eaee3637da52394b567cfabf92fd8b4c8b..ec4d85cd01a42d018104e053fb3a20d7c7c62e4f 100644 --- a/tests/test_mc.py +++ b/tests/test_mc.py @@ -15,7 +15,7 @@ def test_protocol_consistency(): datamodule = importlib.import_module( "ptbench.data.montgomery.default" ).datamodule - subset = datamodule.dataset_split.subsets + subset = datamodule.database_split.subsets assert len(subset) == 3 @@ -126,7 +126,7 @@ def test_loading(): assert isinstance(data, torch.Tensor) - assert data.size in ( + assert data.size() in ( (1, 4020, 4892), # portrait (1, 4892, 4020), # landscape (1, 512, 512), # test database @ CI @@ -143,7 +143,7 @@ def test_loading(): datamodule = importlib.import_module( "ptbench.data.montgomery.default" ).datamodule - subset = datamodule.database_split.subsetss + subset = datamodule.database_split.subsets raw_data_loader = datamodule.raw_data_loader # Need to use private function so we can limit the number of samples to use