Skip to content
Snippets Groups Projects
Commit df63a19e authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[trainer] fixed docstrings

parent bf7663dc
Branches
Tags
No related merge requests found
......@@ -19,26 +19,37 @@ class CNNTrainer(object):
"""
Class to train a CNN
**Parameters**
network: pytorch nn.Module
The network
Attributes
----------
network: :py:class: torch.nn.Module
The network to train
batch_size: int
The size of your minibatch
use_gpu: boolean
If you would like to use the gpu
verbosity_level: int
The level of verbosity output to stdout
"""
def __init__(self, network, batch_size=64, use_gpu=False, verbosity_level=2):
""" Init function
Parameters
----------
network: :py:class: torch.nn.Module
The network to train
batch_size: int
The size of your minibatch
use_gpu: boolean
If you would like to use the gpu
verbosity_level: int
The level of verbosity output to stdout
"""
self.network = network
self.batch_size = batch_size
self.use_gpu = use_gpu
self.criterion = nn.CrossEntropyLoss()
if self.use_gpu:
......@@ -48,56 +59,51 @@ class CNNTrainer(object):
def load_model(self, model_filename):
"""
Loads an existing model
**Parameters**
"""Loads an existing model
Parameters
----------
model_file: str
The filename of the model to load
**Returns**
Returns
-------
start_epoch: int
The epoch to start with
start_iteration: int
The iteration to start with
losses: list
The list of losses wfrom previous training
"""
cp = torch.load(model_filename)
self.network.load_state_dict(cp['state_dict'])
start_epoch = cp['epoch']
start_iter = cp['iteration']
losses = cp['loss']
return start_epoch, start_iter, losses
def save_model(self, output_dir, epoch=0, iteration=0, losses=None):
"""
Save the trained network
**Parameters**
output_dir: str
The directory to write the models to
epoch: int
the current epoch
iteration: int
the current (last) iteration
losses: list
def save_model(self, output_dir, epoch=0, iteration=0, losses=None):
"""Save the trained network
Parameters
----------
output_dir: str
The directory to write the models to
epoch: int
the current epoch
iteration: int
the current (last) iteration
losses: list
The list of losses since the beginning of training
"""
saved_filename = 'model_{}_{}.pth'.format(epoch, iteration)
saved_path = os.path.join(output_dir, saved_filename)
logger.info('Saving model to {}'.format(saved_path))
cp = {'epoch': epoch,
'iteration': iteration,
'loss': losses,
......@@ -111,23 +117,21 @@ class CNNTrainer(object):
def train(self, dataloader, n_epochs=20, learning_rate=0.01, output_dir='out', model=None):
"""
Function that performs the training.
**Parameters**
"""Performs the training.
dataloader: pytorch DataLoader
Parameters
----------
dataloader: :py:class:: torch.utils.data.DataLoader
The dataloader for your data
n_epochs: int
The number of epochs you would like to train for
learning_rate: float
The learning rate for SGD optimizer
output_dir: path
The directory where you would like to save models
"""
# if model exists, load it
if model is not None:
start_epoch, start_iter, losses = self.load_model(model)
......@@ -138,7 +142,6 @@ class CNNTrainer(object):
losses = []
logger.info('Starting training from scratch')
# setup optimizer
optimizer = optim.SGD(self.network.parameters(), learning_rate, momentum = 0.9, weight_decay = 0.0005)
......@@ -151,7 +154,7 @@ class CNNTrainer(object):
start = time.time()
images = data['image']
labels = data['id']
labels = data['label']
batch_size = len(images)
if self.use_gpu:
images = images.cuda()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment