diff --git a/src/mednet/config/data/tbx11k/datamodule.py b/src/mednet/config/data/tbx11k/datamodule.py index dfb3038b24c6097e37edd230f43de29468415f51..9593d23d6de35d5ab874d8aaaecaf0c1cf61262c 100644 --- a/src/mednet/config/data/tbx11k/datamodule.py +++ b/src/mednet/config/data/tbx11k/datamodule.py @@ -119,7 +119,12 @@ class BoundingBoxes(collections.abc.Sequence[BoundingBox]): # explained at: # https://pytorch.org/docs/stable/data.html#torch.utils.data.default_collate def _collate_boundingboxes_fn(batch, *, collate_fn_map=None): - """Custom collate_fn() for pytorch dataloaders that ignores BoundingBoxes objects.""" + """Custom collate_fn() for pytorch dataloaders that ignores BoundingBoxes objects. + + Returns + ------- + The given batch. + """ return batch diff --git a/src/mednet/config/data/tbx11k/make_splits_from_database.py b/src/mednet/config/data/tbx11k/make_splits_from_database.py index 4c09c375da3e385cd2845351db64d3d8e24b7975..48b0b90bccb96e8c0c81126f47684197222ecd75 100644 --- a/src/mednet/config/data/tbx11k/make_splits_from_database.py +++ b/src/mednet/config/data/tbx11k/make_splits_from_database.py @@ -59,7 +59,18 @@ from sklearn.model_selection import StratifiedKFold, train_test_split def reorder(data: dict) -> list: - """Reorder data from TBX11K into a sample-based organisation.""" + """Reorder data from TBX11K into a sample-based organisation. + + Parameters + ---------- + data + A dictionary containing the loaded data. + + Returns + ------- + list + The reordered data. + """ categories = {k["id"]: k["name"] for k in data["categories"]} assert len(set(categories.values())) == len( @@ -112,6 +123,11 @@ def normalize_labels(data: list) -> list: bounding boxes with label 0 and no bounding box with label 1 4: sick (but no tb), comes from the imgs/sick subdir, does not have any annotated bounding box. + + Returns + ------- + list + A list of labels per sample. """ def _set_label(s: list) -> int: @@ -213,6 +229,11 @@ def create_v1_default_split(d: dict, seed: int, validation_size: float) -> dict: validation_size The proportion of data when we split the training set to make a train and validation sets. + + Returns + ------- + dict + A dict containing the various v1 splits. """ # filter cases (only interested in labels 0:healthy or 1:active-tb) @@ -251,6 +272,11 @@ def create_v2_default_split(d: dict, seed: int, validation_size) -> dict: 2. The original training set is split into new training and validation sets. The selection of samples is stratified (respects class proportions in Özgür's way - see comments) + + Returns + ------- + dict + A dict containing the various v2 splits. """ # filter cases (only interested in labels 0:healthy or 1:active-tb) diff --git a/src/mednet/data/augmentations.py b/src/mednet/data/augmentations.py index c985b01e39ea5c85946dc3c8a5868e0888862ba8..10c4ed384685ca93bc771f1e880facf940b18e6e 100644 --- a/src/mednet/data/augmentations.py +++ b/src/mednet/data/augmentations.py @@ -239,7 +239,7 @@ class ElasticDeformation: self.parallel = parallel @property - def parallel(self): + def parallel(self) -> int: """Use multiprocessing for data augmentation. If set to -1 (default), disables multiprocessing. If set to -2, @@ -247,6 +247,11 @@ class ElasticDeformation: batch size and total number of processing cores). Set to 0 to enable as many processes as processing cores available in the system. Set to >= 1 to enable that many processes. + + Returns + ------- + int + The multiprocessing type. """ return self._parallel diff --git a/src/mednet/data/datamodule.py b/src/mednet/data/datamodule.py index 62e07b9871dbc02e18b2e80c5ec91b281b707473..9dd094a27a38f0e547d5a2e66e5a177514903c06 100644 --- a/src/mednet/data/datamodule.py +++ b/src/mednet/data/datamodule.py @@ -46,7 +46,18 @@ def _sample_size_bytes(s: Sample) -> int: """ def _tensor_size_bytes(t: torch.Tensor) -> int: - """Return a tensor size in bytes.""" + """Return a tensor size in bytes. + + Parameters + ---------- + t + A torch Tensor. + + Returns + ------- + int + The size of the Tensor in bytes. + """ return int(t.element_size() * torch.prod(torch.tensor(t.shape))) size = sys.getsizeof(s[0]) # tensor metadata @@ -100,7 +111,13 @@ class _DelayedLoadingDataset(Dataset): logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") def labels(self) -> list[int | list[int]]: - """Return the integer labels for all samples in the dataset.""" + """Return the integer labels for all samples in the dataset. + + Returns + ------- + list[int | list[int]] + The integer labels for all samples in the dataset. + """ return [self.loader.label(k) for k in self.raw_dataset] def __getitem__(self, key: int) -> Sample: @@ -207,7 +224,13 @@ class _CachedDataset(Dataset): ) def labels(self) -> list[int | list[int]]: - """Return the integer labels for all samples in the dataset.""" + """Return the integer labels for all samples in the dataset. + + Returns + ------- + list[int | list[int]] + The integer labels for all samples in the dataset. + """ return [k[1]["label"] for k in self.data] def __getitem__(self, key: int) -> Sample: @@ -239,7 +262,13 @@ class _ConcatDataset(Dataset): ] def labels(self) -> list[int | list[int]]: - """Return the integer labels for all samples in the dataset.""" + """Return the integer labels for all samples in the dataset. + + Returns + ------- + list[int | list[int]] + The integer labels for all samples in the dataset. + """ return list(itertools.chain(*[k.labels() for k in self._datasets])) def __getitem__(self, key: int) -> Sample: @@ -550,6 +579,11 @@ class ConcatDataModule(lightning.LightningDataModule): - ``parallel`` - Runs mini-batch data loading on as many external processes as set on ``parallel`` + + Returns + ------- + int + The value of self._parallel. """ return self._parallel @@ -588,6 +622,10 @@ class ConcatDataModule(lightning.LightningDataModule): data is cached, it is cached **after** model-transforms are applied, as that is a potential memory saver (e.g., if it contains a resizing operation to smaller images). + + Returns + ------- + A list containing the model tansforms. """ return self._model_transforms @@ -606,7 +644,7 @@ class ConcatDataModule(lightning.LightningDataModule): self._datasets = {} @property - def balance_sampler_by_class(self): + def balance_sampler_by_class(self) -> bool: """Whether to balance samples across labels/datasets. If set, then modifies the random sampler used during training @@ -620,6 +658,11 @@ class ConcatDataModule(lightning.LightningDataModule): samples acording to their ground-truth (labels). If you'd like to have samples balanced per dataset, then implement your own data module inheriting from this one. + + Returns + ------- + bool + True if self._train_sample is set, else False. """ return self._train_sampler is not None @@ -733,7 +776,13 @@ class ConcatDataModule(lightning.LightningDataModule): self._datasets[name] = _ConcatDataset(datasets) def _val_dataset_keys(self) -> list[str]: - """Return list of validation dataset names.""" + """Return list of validation dataset names. + + Returns + ------- + list[str] + The list of validation dataset names. + """ return ["validation"] + [ k for k in self.splits.keys() if k.startswith("monitor-") ] @@ -801,7 +850,12 @@ class ConcatDataModule(lightning.LightningDataModule): super().teardown(stage) def train_dataloader(self) -> DataLoader: - """Return the train data loader.""" + """Return the train data loader. + + Returns + ------- + The train data loader(s). + """ return torch.utils.data.DataLoader( self._datasets["train"], @@ -814,7 +868,12 @@ class ConcatDataModule(lightning.LightningDataModule): ) def unshuffled_train_dataloader(self) -> DataLoader: - """Return the train data loader without shuffling.""" + """Return the train data loader without shuffling. + + Returns + ------- + The train data loader without shuffling. + """ return torch.utils.data.DataLoader( self._datasets["train"], @@ -825,7 +884,12 @@ class ConcatDataModule(lightning.LightningDataModule): ) def val_dataloader(self) -> dict[str, DataLoader]: - """Return the validation data loader(s)""" + """Return the validation data loader(s). + + Returns + ------- + The validation data loader(s). + """ validation_loader_opts = { "batch_size": self._chunk_size, @@ -843,7 +907,12 @@ class ConcatDataModule(lightning.LightningDataModule): } def test_dataloader(self) -> dict[str, DataLoader]: - """Return the test data loader(s)""" + """Return the test data loader(s). + + Returns + ------- + The test data loader(s). + """ return dict( test=torch.utils.data.DataLoader( @@ -857,7 +926,12 @@ class ConcatDataModule(lightning.LightningDataModule): ) def predict_dataloader(self) -> dict[str, DataLoader]: - """Return the prediction data loader(s)""" + """Return the prediction data loader(s). + + Returns + ------- + The prediction data loader(s). + """ return { k: torch.utils.data.DataLoader( diff --git a/src/mednet/data/split.py b/src/mednet/data/split.py index c6abae451f86be71ee47e54bbdedbe3a382620de..5d70aafd551b18fca856f846d91818442bb63c03 100644 --- a/src/mednet/data/split.py +++ b/src/mednet/data/split.py @@ -156,7 +156,6 @@ class CSVDatabaseSplit(DatabaseSplit): Returns ------- - datasets : dict A dictionary mapping dataset names to lists of JSON objects. """ diff --git a/src/mednet/data/typing.py b/src/mednet/data/typing.py index 658cad9d603f8a9e067c5de0089d811648b5d279..f61b423059f165f22fd0f860843f2573fa6be772 100644 --- a/src/mednet/data/typing.py +++ b/src/mednet/data/typing.py @@ -34,6 +34,16 @@ class RawDataLoader: If you do not override this implementation, then, by default, this method will call :py:meth:`sample` to load the whole sample and extract the label. + + Parameters + ---------- + k + The sample to load. This is implementation-dependent. + + Returns + ------- + int | list[int] + The label corresponding to the specified sample. """ return self.sample(k)[1]["label"] diff --git a/src/mednet/engine/device.py b/src/mednet/engine/device.py index c501f7f2029780239f6169e867b547eccefea77e..b008c4cbe952c33bf4ace294426e1362ff15b205 100644 --- a/src/mednet/engine/device.py +++ b/src/mednet/engine/device.py @@ -21,7 +21,18 @@ SupportedPytorchDevice: typing.TypeAlias = typing.Literal[ def _split_int_list(s: str) -> list[int]: - """Split a list of integers encoded in a string (e.g. "1,2,3") into a Python list of integers (e.g. ``[1, 2, 3]``).""" + """Split a list of integers encoded in a string (e.g. "1,2,3") into a Python list of integers (e.g. ``[1, 2, 3]``). + + Parameters + ---------- + s + A list of integers encoded in a string. + + Returns + ------- + list[int] + A Python list of integers. + """ return [int(k.strip()) for k in s.split(",")] diff --git a/src/mednet/engine/saliency/generator.py b/src/mednet/engine/saliency/generator.py index 65b4b62d35c7e815242560b9d5f7040f38d1064b..e99f78879f6049b6520c1525ef69b2db4f89143e 100644 --- a/src/mednet/engine/saliency/generator.py +++ b/src/mednet/engine/saliency/generator.py @@ -36,6 +36,10 @@ def _create_saliency_map_callable( The target layers to compute CAM for. use_cuda Whether to use cuda or not. + + Returns + ------- + A class activation map (CAM) instance for the given model. """ import pytorch_grad_cam diff --git a/src/mednet/utils/rc.py b/src/mednet/utils/rc.py index b3657e0292014135cc9eb15a0b29a07dc3c03ab0..fcc1659d1f5304c44e0561a7e1e70c6bd905db7a 100644 --- a/src/mednet/utils/rc.py +++ b/src/mednet/utils/rc.py @@ -6,5 +6,10 @@ from clapper.rc import UserDefaults def load_rc() -> UserDefaults: - """Return global configuration variables.""" + """Return global configuration variables. + + Returns + ------- + The user defaults read from the user .toml configuration file. + """ return UserDefaults("mednet.toml") diff --git a/src/mednet/utils/resources.py b/src/mednet/utils/resources.py index 184b9d39a81d664e00dc689cf3757d1b2fd02d47..a17d3cbd2750e01650e3d7ddc89bb5f379df0157 100644 --- a/src/mednet/utils/resources.py +++ b/src/mednet/utils/resources.py @@ -431,7 +431,13 @@ class _InformationGatherer: self.data[k] = [] def summary(self) -> dict[str, list[int | float]]: - """Return the current data.""" + """Return the current data. + + Returns + ------- + dict[str, list[int | float]] + A dictionary with a list of resources and their corresponding values. + """ if len(next(iter(self.data.values()))) == 0: self.logger.error("CPU/GPU logger was not able to collect any data") return self.data diff --git a/tests/conftest.py b/tests/conftest.py index bacef95ee8c7193eb684fd71fc2cbe46354cc76f..01223e6b8251f5e9fb9a1f2bb2bb55e8983cf2f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,13 @@ from mednet.data.typing import DatabaseSplit @pytest.fixture def datadir(request) -> pathlib.Path: - """Return the directory in which the test is sitting.""" + """Return the directory in which the test is sitting. + + Returns + ------- + pathlib.Path + The directory in which the test is sitting. + """ return pathlib.Path(request.module.__file__).parents[0] / "data"