Skip to content
Snippets Groups Projects
Commit bb1ddfbb authored by Anjith GEORGE's avatar Anjith GEORGE
Browse files

Some mods in docstrings and arguments

parent c4d7b4d1
Branches
Tags
1 merge request!18MCCNN trainer
......@@ -58,7 +58,7 @@ bob_hldi_instance_train = BatlPadDatabase(
#==============================================================================
# Initialize the torch dataset, subselect channels from the pretrained files if needed.
SELECTED_CHANNELS = [0,1,2] # selects only color, depth and infrared
SELECTED_CHANNELS = [0,1,2,3] # selects color, depth, infrared and thermal
img_transform_train = transforms.Compose([ChannelSelect(selected_channels = SELECTED_CHANNELS),transforms.ToPILImage(),transforms.RandomHorizontalFlip(),transforms.ToTensor()])# Add p=0.5 later
......@@ -74,7 +74,9 @@ dataset = DataFolder(data_folder=data_folder_train,
#==============================================================================
# Load the architecture
NUM_CHANNELS = 3
NUM_CHANNELS = 4
assert(len(SELECTED_CHANNELS)==NUM_CHANNELS)
network=MCCNN(num_channels = NUM_CHANNELS)
......@@ -87,7 +89,7 @@ learning_rate=0.0001
seed = 3
output_dir = 'training_mccn'
use_gpu = False
adapted_layers = 'conv1-ffc'
adapted_layers = 'conv1-group1-block1-ffc'
adapt_reference_channel = False
verbose = 2
......
......@@ -23,13 +23,16 @@ class MCCNNExtractor(Extractor):
"""
def __init__(self, num_channels=4, transforms = transforms.Compose([transforms.ToTensor()]), model_file=None):
def __init__(self, num_channels_used=4, transforms = transforms.Compose([transforms.ToTensor()]), model_file=None):
""" Init method
Parameters
----------
num_channels: int
The number of channels present in the input
num_channels_used: int
The number of channels to be used by the network. This could be
different from the number of channels present in the input image. For instance,
when used together with 'ChannelSelect' transform. The value of `num_channels_used`
should be the number of channels eventually used by the network (i.e., output of transform).
model_file: str
The path of the trained PAD network to load
transforms: :py:mod:`torchvision.transforms`
......@@ -41,22 +44,26 @@ class MCCNNExtractor(Extractor):
# model
self.transforms = transforms
self.network = MCCNN(num_channels=num_channels)
self.network = MCCNN(num_channels=num_channels_used)
logger.debug('Initiliazed model with lightCNN weights')
#self.network=self.network.to(device)
if model_file is None:
# do nothing (used mainly for unit testing)
logger.info("No pretrained file provided")
logger.debug("No pretrained file provided")
pass
else:
# With the new training
logger.debug('Starting to load the pretrained PAD model')
cp = torch.load(model_file)
if 'state_dict' in cp:
self.network.load_state_dict(cp['state_dict'])
logger.info('Loaded the pretrained PAD model')
logger.debug('Loaded the pretrained PAD model')
self.network.eval()
......@@ -69,6 +76,9 @@ class MCCNNExtractor(Extractor):
----------
image : 3D :py:class:`numpy.ndarray` (floats)
The multi-channel image to extract the score from. Its size must be num_channelsx128x128;
Note: the value of `num_channels` is the number of channels as obtained from the preprocessed
data. The actual number of channels used may vary, for instance
if ChannelSelect transform is used, the number of channels used would change.
Returns
-------
......
......@@ -323,7 +323,7 @@ def test_extractors():
# MCCNN
from ..extractor.image import MCCNNExtractor
extractor = MCCNNExtractor(num_channels=4)
extractor = MCCNNExtractor(num_channels_used=4)
# this architecture expects num_channelsx128x128 Multi channel images
data = numpy.random.rand(4, 128, 128).astype("float32")
output = extractor(data)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment