Skip to content
Snippets Groups Projects
Commit ab0f2f29 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

[doc] Add missing "Returns" sections in docstrings

parent a5c670f4
No related branches found
No related tags found
1 merge request!15Update documentation
Pipeline #83884 canceled
...@@ -119,7 +119,12 @@ class BoundingBoxes(collections.abc.Sequence[BoundingBox]): ...@@ -119,7 +119,12 @@ class BoundingBoxes(collections.abc.Sequence[BoundingBox]):
# explained at: # explained at:
# https://pytorch.org/docs/stable/data.html#torch.utils.data.default_collate # https://pytorch.org/docs/stable/data.html#torch.utils.data.default_collate
def _collate_boundingboxes_fn(batch, *, collate_fn_map=None): def _collate_boundingboxes_fn(batch, *, collate_fn_map=None):
"""Custom collate_fn() for pytorch dataloaders that ignores BoundingBoxes objects.""" """Custom collate_fn() for pytorch dataloaders that ignores BoundingBoxes objects.
Returns
-------
The given batch.
"""
return batch return batch
......
...@@ -59,7 +59,18 @@ from sklearn.model_selection import StratifiedKFold, train_test_split ...@@ -59,7 +59,18 @@ from sklearn.model_selection import StratifiedKFold, train_test_split
def reorder(data: dict) -> list: def reorder(data: dict) -> list:
"""Reorder data from TBX11K into a sample-based organisation.""" """Reorder data from TBX11K into a sample-based organisation.
Parameters
----------
data
A dictionary containing the loaded data.
Returns
-------
list
The reordered data.
"""
categories = {k["id"]: k["name"] for k in data["categories"]} categories = {k["id"]: k["name"] for k in data["categories"]}
assert len(set(categories.values())) == len( assert len(set(categories.values())) == len(
...@@ -112,6 +123,11 @@ def normalize_labels(data: list) -> list: ...@@ -112,6 +123,11 @@ def normalize_labels(data: list) -> list:
bounding boxes with label 0 and no bounding box with label 1 bounding boxes with label 0 and no bounding box with label 1
4: sick (but no tb), comes from the imgs/sick subdir, does not have any 4: sick (but no tb), comes from the imgs/sick subdir, does not have any
annotated bounding box. annotated bounding box.
Returns
-------
list
A list of labels per sample.
""" """
def _set_label(s: list) -> int: def _set_label(s: list) -> int:
...@@ -213,6 +229,11 @@ def create_v1_default_split(d: dict, seed: int, validation_size: float) -> dict: ...@@ -213,6 +229,11 @@ def create_v1_default_split(d: dict, seed: int, validation_size: float) -> dict:
validation_size validation_size
The proportion of data when we split the training set to make a The proportion of data when we split the training set to make a
train and validation sets. train and validation sets.
Returns
-------
dict
A dict containing the various v1 splits.
""" """
# filter cases (only interested in labels 0:healthy or 1:active-tb) # filter cases (only interested in labels 0:healthy or 1:active-tb)
...@@ -251,6 +272,11 @@ def create_v2_default_split(d: dict, seed: int, validation_size) -> dict: ...@@ -251,6 +272,11 @@ def create_v2_default_split(d: dict, seed: int, validation_size) -> dict:
2. The original training set is split into new training and validation 2. The original training set is split into new training and validation
sets. The selection of samples is stratified (respects class sets. The selection of samples is stratified (respects class
proportions in Özgür's way - see comments) proportions in Özgür's way - see comments)
Returns
-------
dict
A dict containing the various v2 splits.
""" """
# filter cases (only interested in labels 0:healthy or 1:active-tb) # filter cases (only interested in labels 0:healthy or 1:active-tb)
......
...@@ -239,7 +239,7 @@ class ElasticDeformation: ...@@ -239,7 +239,7 @@ class ElasticDeformation:
self.parallel = parallel self.parallel = parallel
@property @property
def parallel(self): def parallel(self) -> int:
"""Use multiprocessing for data augmentation. """Use multiprocessing for data augmentation.
If set to -1 (default), disables multiprocessing. If set to -2, If set to -1 (default), disables multiprocessing. If set to -2,
...@@ -247,6 +247,11 @@ class ElasticDeformation: ...@@ -247,6 +247,11 @@ class ElasticDeformation:
batch size and total number of processing cores). Set to 0 to batch size and total number of processing cores). Set to 0 to
enable as many processes as processing cores available in the enable as many processes as processing cores available in the
system. Set to >= 1 to enable that many processes. system. Set to >= 1 to enable that many processes.
Returns
-------
int
The multiprocessing type.
""" """
return self._parallel return self._parallel
......
...@@ -46,7 +46,18 @@ def _sample_size_bytes(s: Sample) -> int: ...@@ -46,7 +46,18 @@ def _sample_size_bytes(s: Sample) -> int:
""" """
def _tensor_size_bytes(t: torch.Tensor) -> int: def _tensor_size_bytes(t: torch.Tensor) -> int:
"""Return a tensor size in bytes.""" """Return a tensor size in bytes.
Parameters
----------
t
A torch Tensor.
Returns
-------
int
The size of the Tensor in bytes.
"""
return int(t.element_size() * torch.prod(torch.tensor(t.shape))) return int(t.element_size() * torch.prod(torch.tensor(t.shape)))
size = sys.getsizeof(s[0]) # tensor metadata size = sys.getsizeof(s[0]) # tensor metadata
...@@ -100,7 +111,13 @@ class _DelayedLoadingDataset(Dataset): ...@@ -100,7 +111,13 @@ class _DelayedLoadingDataset(Dataset):
logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb") logger.info(f"Estimated sample size: {sample_size_mb:.1f} Mb")
def labels(self) -> list[int | list[int]]: def labels(self) -> list[int | list[int]]:
"""Return the integer labels for all samples in the dataset.""" """Return the integer labels for all samples in the dataset.
Returns
-------
list[int | list[int]]
The integer labels for all samples in the dataset.
"""
return [self.loader.label(k) for k in self.raw_dataset] return [self.loader.label(k) for k in self.raw_dataset]
def __getitem__(self, key: int) -> Sample: def __getitem__(self, key: int) -> Sample:
...@@ -207,7 +224,13 @@ class _CachedDataset(Dataset): ...@@ -207,7 +224,13 @@ class _CachedDataset(Dataset):
) )
def labels(self) -> list[int | list[int]]: def labels(self) -> list[int | list[int]]:
"""Return the integer labels for all samples in the dataset.""" """Return the integer labels for all samples in the dataset.
Returns
-------
list[int | list[int]]
The integer labels for all samples in the dataset.
"""
return [k[1]["label"] for k in self.data] return [k[1]["label"] for k in self.data]
def __getitem__(self, key: int) -> Sample: def __getitem__(self, key: int) -> Sample:
...@@ -239,7 +262,13 @@ class _ConcatDataset(Dataset): ...@@ -239,7 +262,13 @@ class _ConcatDataset(Dataset):
] ]
def labels(self) -> list[int | list[int]]: def labels(self) -> list[int | list[int]]:
"""Return the integer labels for all samples in the dataset.""" """Return the integer labels for all samples in the dataset.
Returns
-------
list[int | list[int]]
The integer labels for all samples in the dataset.
"""
return list(itertools.chain(*[k.labels() for k in self._datasets])) return list(itertools.chain(*[k.labels() for k in self._datasets]))
def __getitem__(self, key: int) -> Sample: def __getitem__(self, key: int) -> Sample:
...@@ -550,6 +579,11 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -550,6 +579,11 @@ class ConcatDataModule(lightning.LightningDataModule):
- ``parallel`` - ``parallel``
- Runs mini-batch data loading on as many external processes as set on - Runs mini-batch data loading on as many external processes as set on
``parallel`` ``parallel``
Returns
-------
int
The value of self._parallel.
""" """
return self._parallel return self._parallel
...@@ -588,6 +622,10 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -588,6 +622,10 @@ class ConcatDataModule(lightning.LightningDataModule):
data is cached, it is cached **after** model-transforms are data is cached, it is cached **after** model-transforms are
applied, as that is a potential memory saver (e.g., if it applied, as that is a potential memory saver (e.g., if it
contains a resizing operation to smaller images). contains a resizing operation to smaller images).
Returns
-------
A list containing the model tansforms.
""" """
return self._model_transforms return self._model_transforms
...@@ -606,7 +644,7 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -606,7 +644,7 @@ class ConcatDataModule(lightning.LightningDataModule):
self._datasets = {} self._datasets = {}
@property @property
def balance_sampler_by_class(self): def balance_sampler_by_class(self) -> bool:
"""Whether to balance samples across labels/datasets. """Whether to balance samples across labels/datasets.
If set, then modifies the random sampler used during training If set, then modifies the random sampler used during training
...@@ -620,6 +658,11 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -620,6 +658,11 @@ class ConcatDataModule(lightning.LightningDataModule):
samples acording to their ground-truth (labels). If you'd like to samples acording to their ground-truth (labels). If you'd like to
have samples balanced per dataset, then implement your own data have samples balanced per dataset, then implement your own data
module inheriting from this one. module inheriting from this one.
Returns
-------
bool
True if self._train_sample is set, else False.
""" """
return self._train_sampler is not None return self._train_sampler is not None
...@@ -733,7 +776,13 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -733,7 +776,13 @@ class ConcatDataModule(lightning.LightningDataModule):
self._datasets[name] = _ConcatDataset(datasets) self._datasets[name] = _ConcatDataset(datasets)
def _val_dataset_keys(self) -> list[str]: def _val_dataset_keys(self) -> list[str]:
"""Return list of validation dataset names.""" """Return list of validation dataset names.
Returns
-------
list[str]
The list of validation dataset names.
"""
return ["validation"] + [ return ["validation"] + [
k for k in self.splits.keys() if k.startswith("monitor-") k for k in self.splits.keys() if k.startswith("monitor-")
] ]
...@@ -801,7 +850,12 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -801,7 +850,12 @@ class ConcatDataModule(lightning.LightningDataModule):
super().teardown(stage) super().teardown(stage)
def train_dataloader(self) -> DataLoader: def train_dataloader(self) -> DataLoader:
"""Return the train data loader.""" """Return the train data loader.
Returns
-------
The train data loader(s).
"""
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
self._datasets["train"], self._datasets["train"],
...@@ -814,7 +868,12 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -814,7 +868,12 @@ class ConcatDataModule(lightning.LightningDataModule):
) )
def unshuffled_train_dataloader(self) -> DataLoader: def unshuffled_train_dataloader(self) -> DataLoader:
"""Return the train data loader without shuffling.""" """Return the train data loader without shuffling.
Returns
-------
The train data loader without shuffling.
"""
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
self._datasets["train"], self._datasets["train"],
...@@ -825,7 +884,12 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -825,7 +884,12 @@ class ConcatDataModule(lightning.LightningDataModule):
) )
def val_dataloader(self) -> dict[str, DataLoader]: def val_dataloader(self) -> dict[str, DataLoader]:
"""Return the validation data loader(s)""" """Return the validation data loader(s).
Returns
-------
The validation data loader(s).
"""
validation_loader_opts = { validation_loader_opts = {
"batch_size": self._chunk_size, "batch_size": self._chunk_size,
...@@ -843,7 +907,12 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -843,7 +907,12 @@ class ConcatDataModule(lightning.LightningDataModule):
} }
def test_dataloader(self) -> dict[str, DataLoader]: def test_dataloader(self) -> dict[str, DataLoader]:
"""Return the test data loader(s)""" """Return the test data loader(s).
Returns
-------
The test data loader(s).
"""
return dict( return dict(
test=torch.utils.data.DataLoader( test=torch.utils.data.DataLoader(
...@@ -857,7 +926,12 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -857,7 +926,12 @@ class ConcatDataModule(lightning.LightningDataModule):
) )
def predict_dataloader(self) -> dict[str, DataLoader]: def predict_dataloader(self) -> dict[str, DataLoader]:
"""Return the prediction data loader(s)""" """Return the prediction data loader(s).
Returns
-------
The prediction data loader(s).
"""
return { return {
k: torch.utils.data.DataLoader( k: torch.utils.data.DataLoader(
......
...@@ -156,7 +156,6 @@ class CSVDatabaseSplit(DatabaseSplit): ...@@ -156,7 +156,6 @@ class CSVDatabaseSplit(DatabaseSplit):
Returns Returns
------- -------
datasets : dict
A dictionary mapping dataset names to lists of JSON objects. A dictionary mapping dataset names to lists of JSON objects.
""" """
......
...@@ -34,6 +34,16 @@ class RawDataLoader: ...@@ -34,6 +34,16 @@ class RawDataLoader:
If you do not override this implementation, then, by default, If you do not override this implementation, then, by default,
this method will call :py:meth:`sample` to load the whole sample this method will call :py:meth:`sample` to load the whole sample
and extract the label. and extract the label.
Parameters
----------
k
The sample to load. This is implementation-dependent.
Returns
-------
int | list[int]
The label corresponding to the specified sample.
""" """
return self.sample(k)[1]["label"] return self.sample(k)[1]["label"]
......
...@@ -21,7 +21,18 @@ SupportedPytorchDevice: typing.TypeAlias = typing.Literal[ ...@@ -21,7 +21,18 @@ SupportedPytorchDevice: typing.TypeAlias = typing.Literal[
def _split_int_list(s: str) -> list[int]: def _split_int_list(s: str) -> list[int]:
"""Split a list of integers encoded in a string (e.g. "1,2,3") into a Python list of integers (e.g. ``[1, 2, 3]``).""" """Split a list of integers encoded in a string (e.g. "1,2,3") into a Python list of integers (e.g. ``[1, 2, 3]``).
Parameters
----------
s
A list of integers encoded in a string.
Returns
-------
list[int]
A Python list of integers.
"""
return [int(k.strip()) for k in s.split(",")] return [int(k.strip()) for k in s.split(",")]
......
...@@ -36,6 +36,10 @@ def _create_saliency_map_callable( ...@@ -36,6 +36,10 @@ def _create_saliency_map_callable(
The target layers to compute CAM for. The target layers to compute CAM for.
use_cuda use_cuda
Whether to use cuda or not. Whether to use cuda or not.
Returns
-------
A class activation map (CAM) instance for the given model.
""" """
import pytorch_grad_cam import pytorch_grad_cam
......
...@@ -6,5 +6,10 @@ from clapper.rc import UserDefaults ...@@ -6,5 +6,10 @@ from clapper.rc import UserDefaults
def load_rc() -> UserDefaults: def load_rc() -> UserDefaults:
"""Return global configuration variables.""" """Return global configuration variables.
Returns
-------
The user defaults read from the user .toml configuration file.
"""
return UserDefaults("mednet.toml") return UserDefaults("mednet.toml")
...@@ -431,7 +431,13 @@ class _InformationGatherer: ...@@ -431,7 +431,13 @@ class _InformationGatherer:
self.data[k] = [] self.data[k] = []
def summary(self) -> dict[str, list[int | float]]: def summary(self) -> dict[str, list[int | float]]:
"""Return the current data.""" """Return the current data.
Returns
-------
dict[str, list[int | float]]
A dictionary with a list of resources and their corresponding values.
"""
if len(next(iter(self.data.values()))) == 0: if len(next(iter(self.data.values()))) == 0:
self.logger.error("CPU/GPU logger was not able to collect any data") self.logger.error("CPU/GPU logger was not able to collect any data")
return self.data return self.data
......
...@@ -17,7 +17,13 @@ from mednet.data.typing import DatabaseSplit ...@@ -17,7 +17,13 @@ from mednet.data.typing import DatabaseSplit
@pytest.fixture @pytest.fixture
def datadir(request) -> pathlib.Path: def datadir(request) -> pathlib.Path:
"""Return the directory in which the test is sitting.""" """Return the directory in which the test is sitting.
Returns
-------
pathlib.Path
The directory in which the test is sitting.
"""
return pathlib.Path(request.module.__file__).parents[0] / "data" return pathlib.Path(request.module.__file__).parents[0] / "data"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment