Skip to content
Snippets Groups Projects
Commit 620192f5 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[extractor] fixed the docstrings

parent a57c4814
Branches
Tags
No related merge requests found
......@@ -9,8 +9,30 @@ from bob.learn.pytorch.architectures import CNN8
from bob.bio.base.extractor import Extractor
class CNN8Extractor(Extractor):
""" The class implementing the feature extraction of CASIA-Net embeddings.
Attributes
----------
network: :py:class:`torch.nn.Module`
The network architecture
to_tensor: :py:mod:`torchvision.transforms`
The transform from numpy.array to torch.Tensor
norm: :py:mod:`torchvision.transforms`
The transform to normalize the input
"""
def __init__(self, model_file=None, num_classes=10575):
""" Init method
Parameters
----------
model_file: str
The path of the trained network to load
drop_rate: float
The number of classes.
"""
Extractor.__init__(self, skip_extractor_training=True)
......@@ -30,20 +52,20 @@ class CNN8Extractor(Extractor):
self.norm = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def __call__(self, image):
"""__call__(image) -> feature
Extract features from an image
**Parameters:**
""" Extract features from an image
Parameters
----------
image : 3D :py:class:`numpy.ndarray` (floats)
The image to extract the features from. Its size must be 3x128x128
**Returns:**
Returns
-------
feature : 2D :py:class:`numpy.ndarray` (floats)
The extracted features as a 1d array of size 512
The extracted features as a 1d array of size 320
"""
input_image = numpy.rollaxis(numpy.rollaxis(image, 2),2)
input_image = self.to_tensor(input_image)
input_image = self.norm(input_image)
......
......@@ -15,9 +15,9 @@ class CasiaNetExtractor(Extractor):
----------
network: :py:class:`torch.nn.Module`
The network architecture
to_tensor: :py:class:`torchvision.transforms`
to_tensor: :py:mod:`torchvision.transforms`
The transform from numpy.array to torch.Tensor
norm: :py:class:`torchvision.transforms`
norm: :py:mod:`torchvision.transforms`
The transform to normalize the input
"""
......@@ -61,10 +61,11 @@ class CasiaNetExtractor(Extractor):
Returns
-------
feature : 2D :py:class:`numpy.ndarray` (floats)
The extracted features as a 1d array of size 320
"""
input_image = numpy.rollaxis(numpy.rollaxis(image, 2),2)
input_image = self.to_tensor(input_image)
input_image = self.norm(input_image)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment