From ab0f2f297cbcbc982db9c2f6be728916f89aa25d Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Mon, 5 Feb 2024 11:28:13 +0100 Subject: [PATCH] [doc] Add missing "Returns" sections in docstrings --- src/mednet/config/data/tbx11k/datamodule.py | 7 +- .../data/tbx11k/make_splits_from_database.py | 28 +++++- src/mednet/data/augmentations.py | 7 +- src/mednet/data/datamodule.py | 96 ++++++++++++++++--- src/mednet/data/split.py | 1 - src/mednet/data/typing.py | 10 ++ src/mednet/engine/device.py | 13 ++- src/mednet/engine/saliency/generator.py | 4 + src/mednet/utils/rc.py | 7 +- src/mednet/utils/resources.py | 8 +- tests/conftest.py | 8 +- 11 files changed, 170 insertions(+), 19 deletions(-) diff --git a/src/mednet/config/data/tbx11k/datamodule.py b/src/mednet/config/data/tbx11k/datamodule.py index dfb3038b..9593d23d 100644 --- a/src/mednet/config/data/tbx11k/datamodule.py +++ b/src/mednet/config/data/tbx11k/datamodule.py @@ -119,7 +119,12 @@ class BoundingBoxes(collections.abc.Sequence[BoundingBox]): # explained at: # https://pytorch.org/docs/stable/data.html#torch.utils.data.default_collate def _collate_boundingboxes_fn(batch, *, collate_fn_map=None): - """Custom collate_fn() for pytorch dataloaders that ignores BoundingBoxes objects.""" + """Custom collate_fn() for pytorch dataloaders that ignores BoundingBoxes objects. + + Returns + ------- + The given batch. + """ return batch diff --git a/src/mednet/config/data/tbx11k/make_splits_from_database.py b/src/mednet/config/data/tbx11k/make_splits_from_database.py index 4c09c375..48b0b90b 100644 --- a/src/mednet/config/data/tbx11k/make_splits_from_database.py +++ b/src/mednet/config/data/tbx11k/make_splits_from_database.py @@ -59,7 +59,18 @@ from sklearn.model_selection import StratifiedKFold, train_test_split def reorder(data: dict) -> list: - """Reorder data from TBX11K into a sample-based organisation.""" + """Reorder data from TBX11K into a sample-based organisation. + + Parameters + ---------- + data + A dictionary containing the loaded data. + + Returns + ------- + list + The reordered data. + """ categories = {k["id"]: k["name"] for k in data["categories"]} assert len(set(categories.values())) == len( @@ -112,6 +123,11 @@ def normalize_labels(data: list) -> list: bounding boxes with label 0 and no bounding box with label 1 4: sick (but no tb), comes from the imgs/sick subdir, does not have any annotated bounding box. + + Returns + ------- + list + A list of labels per sample. """ def _set_label(s: list) -> int: @@ -213,6 +229,11 @@ def create_v1_default_split(d: dict, seed: int, validation_size: float) -> dict: validation_size The proportion of data when we split the training set to make a train and validation sets. + + Returns + ------- + dict + A dict containing the various v1 splits. """ # filter cases (only interested in labels 0:healthy or 1:active-tb) @@ -251,6 +272,11 @@ def create_v2_default_split(d: dict, seed: int, validation_size) -> dict: 2. The original training set is split into new training and validation sets. The selection of samples is stratified (respects class proportions in Özgür's way - see comments) + + Returns + ------- + dict + A dict containing the various v2 splits. """ # filter cases (only interested in labels 0:healthy or 1:active-tb) diff --git a/src/mednet/data/augmentations.py b/src/mednet/data/augmentations.py index c985b01e..10c4ed38 100644 --- a/src/mednet/data/augmentations.py +++ b/src/mednet/data/augmentations.py @@ -239,7 +239,7 @@ class ElasticDeformation: self.parallel = parallel @property - def parallel(self): + def parallel(self) -> int: """Use multiprocessing for data augmentation. If set to -1 (default), disables multiprocessing. If set to -2, @@ -247,6 +247,11 @@ class ElasticDeformation: batch size and total number of processing cores). Set to 0 to enable as many processes as processing cores available in the system. Set to >= 1 to enable that many processes. + + Returns + ------- + int + The multiprocessing type. """ return self._parallel diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py index 62e07b98..9dd094a2 100644 --- a/src/mednet/data/datamodule.py +++ b/src/mednet/data/datamodule.py @@ -46,7 +46,18 @@ def _sample_size_bytes(s: Sample) -> int: """ def _tensor_size_bytes(t: torch.Tensor) -> int: - """Return a tensor size in bytes.""" + """Return a tensor size in bytes. + + Parameters + ---------- + t + A torch Tensor. + + Returns + ------- + int + The size of the Tensor in bytes. + """ return int(t.element_size() * torch.prod(torch.tensor(t.shape))) size = sys.getsizeof(s[0]) # tensor metadata @@ -100,7 +111,13 @@ class _DelayedLoadingDataset(Dataset): logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") def labels(self) -> list[int | list[int]]: - """Return the integer labels for all samples in the dataset.""" + """Return the integer labels for all samples in the dataset. + + Returns + ------- + list[int | list[int]] + The integer labels for all samples in the dataset. + """ return [self.loader.label(k) for k in self.raw_dataset] def __getitem__(self, key: int) -> Sample: @@ -207,7 +224,13 @@ class _CachedDataset(Dataset): ) def labels(self) -> list[int | list[int]]: - """Return the integer labels for all samples in the dataset.""" + """Return the integer labels for all samples in the dataset. + + Returns + ------- + list[int | list[int]] + The integer labels for all samples in the dataset. + """ return [k[1]["label"] for k in self.data] def __getitem__(self, key: int) -> Sample: @@ -239,7 +262,13 @@ class _ConcatDataset(Dataset): ] def labels(self) -> list[int | list[int]]: - """Return the integer labels for all samples in the dataset.""" + """Return the integer labels for all samples in the dataset. + + Returns + ------- + list[int | list[int]] + The integer labels for all samples in the dataset. + """ return list(itertools.chain(*[k.labels() for k in self._datasets])) def __getitem__(self, key: int) -> Sample: @@ -550,6 +579,11 @@ class ConcatDataModule(lightning.LightningDataModule): - ``parallel`` - Runs mini-batch data loading on as many external processes as set on ``parallel`` + + Returns + ------- + int + The value of self._parallel. """ return self._parallel @@ -588,6 +622,10 @@ class ConcatDataModule(lightning.LightningDataModule): data is cached, it is cached **after** model-transforms are applied, as that is a potential memory saver (e.g., if it contains a resizing operation to smaller images). + + Returns + ------- + A list containing the model tansforms. """ return self._model_transforms @@ -606,7 +644,7 @@ class ConcatDataModule(lightning.LightningDataModule): self._datasets = {} @property - def balance_sampler_by_class(self): + def balance_sampler_by_class(self) -> bool: """Whether to balance samples across labels/datasets. If set, then modifies the random sampler used during training @@ -620,6 +658,11 @@ class ConcatDataModule(lightning.LightningDataModule): samples acording to their ground-truth (labels). If you'd like to have samples balanced per dataset, then implement your own data module inheriting from this one. + + Returns + ------- + bool + True if self._train_sample is set, else False. """ return self._train_sampler is not None @@ -733,7 +776,13 @@ class ConcatDataModule(lightning.LightningDataModule): self._datasets[name] = _ConcatDataset(datasets) def _val_dataset_keys(self) -> list[str]: - """Return list of validation dataset names.""" + """Return list of validation dataset names. + + Returns + ------- + list[str] + The list of validation dataset names. + """ return ["validation"] + [ k for k in self.splits.keys() if k.startswith("monitor-") ] @@ -801,7 +850,12 @@ class ConcatDataModule(lightning.LightningDataModule): super().teardown(stage) def train_dataloader(self) -> DataLoader: - """Return the train data loader.""" + """Return the train data loader. + + Returns + ------- + The train data loader(s). + """ return torch.utils.data.DataLoader( self._datasets["train"], @@ -814,7 +868,12 @@ class ConcatDataModule(lightning.LightningDataModule): ) def unshuffled_train_dataloader(self) -> DataLoader: - """Return the train data loader without shuffling.""" + """Return the train data loader without shuffling. + + Returns + ------- + The train data loader without shuffling. + """ return torch.utils.data.DataLoader( self._datasets["train"], @@ -825,7 +884,12 @@ class ConcatDataModule(lightning.LightningDataModule): ) def val_dataloader(self) -> dict[str, DataLoader]: - """Return the validation data loader(s)""" + """Return the validation data loader(s). + + Returns + ------- + The validation data loader(s). + """ validation_loader_opts = { "batch_size": self._chunk_size, @@ -843,7 +907,12 @@ class ConcatDataModule(lightning.LightningDataModule): } def test_dataloader(self) -> dict[str, DataLoader]: - """Return the test data loader(s)""" + """Return the test data loader(s). + + Returns + ------- + The test data loader(s). + """ return dict( test=torch.utils.data.DataLoader( @@ -857,7 +926,12 @@ class ConcatDataModule(lightning.LightningDataModule): ) def predict_dataloader(self) -> dict[str, DataLoader]: - """Return the prediction data loader(s)""" + """Return the prediction data loader(s). + + Returns + ------- + The prediction data loader(s). + """ return { k: torch.utils.data.DataLoader( diff --git a/src/mednet/data/split.py b/src/mednet/data/split.py index c6abae45..5d70aafd 100644 --- a/src/mednet/data/split.py +++ b/src/mednet/data/split.py @@ -156,7 +156,6 @@ class CSVDatabaseSplit(DatabaseSplit): Returns ------- - datasets : dict A dictionary mapping dataset names to lists of JSON objects. """ diff --git a/src/mednet/data/typing.py b/src/mednet/data/typing.py index 658cad9d..f61b4230 100644 --- a/src/mednet/data/typing.py +++ b/src/mednet/data/typing.py @@ -34,6 +34,16 @@ class RawDataLoader: If you do not override this implementation, then, by default, this method will call :py:meth:`sample` to load the whole sample and extract the label. + + Parameters + ---------- + k + The sample to load. This is implementation-dependent. + + Returns + ------- + int | list[int] + The label corresponding to the specified sample. """ return self.sample(k)[1]["label"] diff --git a/src/mednet/engine/device.py b/src/mednet/engine/device.py index c501f7f2..b008c4cb 100644 --- a/src/mednet/engine/device.py +++ b/src/mednet/engine/device.py @@ -21,7 +21,18 @@ SupportedPytorchDevice: typing.TypeAlias = typing.Literal[ def _split_int_list(s: str) -> list[int]: - """Split a list of integers encoded in a string (e.g. "1,2,3") into a Python list of integers (e.g. ``[1, 2, 3]``).""" + """Split a list of integers encoded in a string (e.g. "1,2,3") into a Python list of integers (e.g. ``[1, 2, 3]``). + + Parameters + ---------- + s + A list of integers encoded in a string. + + Returns + ------- + list[int] + A Python list of integers. + """ return [int(k.strip()) for k in s.split(",")] diff --git a/src/mednet/engine/saliency/generator.py b/src/mednet/engine/saliency/generator.py index 65b4b62d..e99f7887 100644 --- a/src/mednet/engine/saliency/generator.py +++ b/src/mednet/engine/saliency/generator.py @@ -36,6 +36,10 @@ def _create_saliency_map_callable( The target layers to compute CAM for. use_cuda Whether to use cuda or not. + + Returns + ------- + A class activation map (CAM) instance for the given model. """ import pytorch_grad_cam diff --git a/src/mednet/utils/rc.py b/src/mednet/utils/rc.py index b3657e02..fcc1659d 100644 --- a/src/mednet/utils/rc.py +++ b/src/mednet/utils/rc.py @@ -6,5 +6,10 @@ from clapper.rc import UserDefaults def load_rc() -> UserDefaults: - """Return global configuration variables.""" + """Return global configuration variables. + + Returns + ------- + The user defaults read from the user .toml configuration file. + """ return UserDefaults("mednet.toml") diff --git a/src/mednet/utils/resources.py b/src/mednet/utils/resources.py index 184b9d39..a17d3cbd 100644 --- a/src/mednet/utils/resources.py +++ b/src/mednet/utils/resources.py @@ -431,7 +431,13 @@ class _InformationGatherer: self.data[k] = [] def summary(self) -> dict[str, list[int | float]]: - """Return the current data.""" + """Return the current data. + + Returns + ------- + dict[str, list[int | float]] + A dictionary with a list of resources and their corresponding values. + """ if len(next(iter(self.data.values()))) == 0: self.logger.error("CPU/GPU logger was not able to collect any data") return self.data diff --git a/tests/conftest.py b/tests/conftest.py index bacef95e..01223e6b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,13 @@ from mednet.data.typing import DatabaseSplit @pytest.fixture def datadir(request) -> pathlib.Path: - """Return the directory in which the test is sitting.""" + """Return the directory in which the test is sitting. + + Returns + ------- + pathlib.Path + The directory in which the test is sitting. + """ return pathlib.Path(request.module.__file__).parents[0] / "data" -- GitLab