Skip to content
Snippets Groups Projects
Commit fb856cdf authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[architectures] added CASIA Net

parent 49195707
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import make_conv_layers
CASIA_CONFIG = [32, 64, 'M', 64, 128, 'M', 96, 192, 'M', 128, 256, 'M', 160, 320]
class CASIANet(nn.Module):
def __init__(self, num_cls, drop_rate=0.5):
super(CASIANet, self).__init__()
self.num_classes = num_cls
self.drop_rate = float(drop_rate)
self.conv = make_conv_layers(CASIA_CONFIG)
self.avgpool = nn.AvgPool2d(8)
self.classifier = nn.Linear(320, self.num_classes)
def forward(self, x):
x = self.conv(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = F.dropout(x, p = self.drop_rate, training=self.training)
out = self.classifier(x)
return out, x # x for feature
...@@ -18,6 +18,7 @@ from .DRGANOriginal import DRGANOriginal_discriminator ...@@ -18,6 +18,7 @@ from .DRGANOriginal import DRGANOriginal_discriminator
from .DCGAN import weights_init from .DCGAN import weights_init
from .CNN8 import CNN8 from .CNN8 import CNN8
from .CASIANet import CASIANet
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
__all__ = [_ for _ in dir() if not _.startswith('_')] __all__ = [_ for _ in dir() if not _.startswith('_')]
......
...@@ -128,180 +128,14 @@ def main(user_input=None): ...@@ -128,180 +128,14 @@ def main(user_input=None):
if network == 'CNN8': if network == 'CNN8':
from bob.learn.pytorch.architectures import CNN8 from bob.learn.pytorch.architectures import CNN8
net = CNN8(number_of_classes, dropout) net = CNN8(number_of_classes, dropout)
elif network == 'CASIANet':
from bob.learn.pytorch.architectures import CASIANet
net = CASIANet(number_of_classes, dropout)
print(net) print(net)
#print model
#if args['--model']:
# cp = torch.load(args['--model'])
# net.load_state_dict(cp['state_dict'])
# start_epoch = cp['epoch']
# epoch_ = start_epoch
# start_iter = cp['iteration']
# losses = cp['loss']
# ===== # =====
# TRAIN # TRAIN
# ===== # =====
trainer = CNNTrainer(net, batch_size=batch_size, use_gpu=use_gpu, verbosity_level=verbosity_level) trainer = CNNTrainer(net, batch_size=batch_size, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer.train(dataloader, n_epochs=epochs, learning_rate=learning_rate, output_dir=output_dir, model=model) trainer.train(dataloader, n_epochs=epochs, learning_rate=learning_rate, output_dir=output_dir, model=model)
#elif modname=='CASIA_NET':
# net = networks.CASIA_NET(int(args['--num_classes']), args['--drop'])
#elif modname=='ResNet26':
# net = networks.ResNet26(int(args['--num_classes']), args['--drop'])
#def exit_handler():
# if 'args' in globals():
# global losses, epoch_, iteration_, net
# print 'Ctrl-C! Safely exit after saving your model...'
# save_network(args['<save_dir>'], net, args['--net'], losses, epoch_, iteration_)
#
#def get_time_str():
# return time.strftime("%Y-%m-%d, %H:%M:%S ", time.localtime((time.time()) ))
#
#def print_info(msg):
# print get_time_str(), msg
# sys.stdout.flush()
#
#def setup(modname):
# global net
# if modname=='CNN8':
# net = networks.CNN8(int(args['--num_classes']), args['--drop'])
# elif modname=='CASIA_NET':
# net = networks.CASIA_NET(int(args['--num_classes']), args['--drop'])
# elif modname=='ResNet26':
# net = networks.ResNet26(int(args['--num_classes']), args['--drop'])
#
#
#def save_network(save_dir, network,network_label, trlosses = None, epoch=0, iteration = 0):
# save_filename = '{}_{}_{}.pth'.format(network_label,epoch, iteration)
# save_path = os.path.join(save_dir, save_filename)
# print_info('saving model to {}'.format(save_path))
# cp = {'epoch': epoch, 'iteration': iteration,
# 'loss': trlosses, 'state_dict': network.cpu().state_dict()}
# torch.save(cp, save_path)
# print_info('Successful!')
# if torch.cuda.is_available(): #back to gpu
# network.cuda()
#
#atexit.register(exit_handler)
#
#args = docopt.docopt(__doc__)
#if args['--visdom']:
# import visdom
# vis = visdom.Visdom()
#
#
#
#global losses, epoch_, iteration_, net
#if __name__ == '__main__':
# print args
# setup(args['--net']) # setup the net to train
# net.train()
#
# save_dir = args['<save_dir>']
# if not os.path.exists(save_dir):
# os.makedirs(save_dir)
#
# transform = transforms.Compose([
# transforms.RandomHorizontalFlip(),
# transforms.Scale((128,128)), # fix scale for all model
# transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# dset = folder.ImageFolder(args['<data_dir>'], transform, filelist=args['--trainlist'])
# loader = data.DataLoader(dset,
# batch_size=int(args['--batch_size']), shuffle=True,
# num_workers=int(args['--num_workers']), pin_memory=args['--pm'])
# iters_per_epoch = len(loader)
# print_info('# training images = %d' % len(dset))
# print_info('# classes = %d' % len(dset.classes))
# print_info('iters_per_epoch = %d' % iters_per_epoch)
#
# if args['--gpu']: net.cuda()
# start_epoch = 0
# start_iter = 0
# iteration_ = 0
# epoch_ = 0
# losses = []
# if args['--resume']:
# cp = torch.load(args['--resume'])
# net.load_state_dict(cp['state_dict'])
# start_epoch = cp['epoch']
# epoch_ = start_epoch
# start_iter = cp['iteration']
# losses = cp['loss']
#
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(net.parameters(), float(args['--lr']) , momentum = 0.9, weight_decay = 0.0005)
# # normally you watch visdom, make early stop, set smaller lr and resume
# # fisrt to finish the last end epoch if necessary
# if start_iter>0:
# if args['--visdom']:
# win = vis.line(X=np.arange(len(losses))*int(args['--plot_freq_iter']),Y=np.array(losses)) # see http://localhost:8097/
# for i, (input, target) in enumerate(loader):
# iteration_ = i
# if i>=start_iter:
# if args['--gpu']:
# input_var = torch.autograd.Variable(input.cuda())
# target_var = torch.autograd.Variable(target.cuda())
# else:
# input_var = torch.autograd.Variable(input)
# target_var = torch.autograd.Variable(target)
# output, _ = net(input_var)
# loss = criterion(output, target_var)
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
#
# if i % int(args['--plot_freq_iter']) == 0:
# print_info('Epoch: {}/{}, iter: {}/{}, loss: {}'.format(start_epoch, args['--epochs'], i, iters_per_epoch, loss.data[0]))
# losses.append(loss.data[0])
# if args['--visdom']:
# vis.line(X=np.arange(len(losses))*int(args['--plot_freq_iter']), Y=np.array(losses),
# opts={
# 'title':' loss over time',
# 'xlabel': 'iteration',
# 'ylabel': 'loss'},
# win=win)
# if i % int(args['--save_freq_iter']) == 0: #i!=0
# save_network(save_dir, net, args['--net'], losses, start_epoch, i)
# epoch_ += 1
# # now normal epochs
# for epoch in range(start_epoch+1, int(args['--epochs'])):
# epoch_ = epoch # for exit writing
# for i, (input, target) in enumerate(loader):
# iteration_ = i
# if args['--gpu']:
# input_var = torch.autograd.Variable(input.cuda())
# target_var = torch.autograd.Variable(target.cuda())
# else:
# input_var = torch.autograd.Variable(input)
# target_var = torch.autograd.Variable(target)
# output, _ = net(input_var)
# loss = criterion(output, target_var)
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
#
# if i % int(args['--plot_freq_iter']) == 0:
# print_info('Epoch: {}/{}, iter: {}/{}, loss: {}'.format(epoch, args['--epochs'], i, iters_per_epoch, loss.data[0]))
# losses.append(loss.data[0])
# if args['--visdom']:
# if i==0 and start_iter==0:
# win = vis.line(X=np.arange(len(losses))*int(args['--plot_freq_iter']),Y=np.array(losses)) # see http://localhost:8097/
# else:
# vis.line(X=np.arange(len(losses))*int(args['--plot_freq_iter']), Y=np.array(losses),
# opts={
# 'title':' loss over time',
# 'xlabel': 'iteration',
# 'ylabel': 'loss'},
# win=win)
# if i>0 and i % int(args['--save_freq_iter']) == 0: #i!=0
# save_network(save_dir, net, args['--net'], losses, epoch, i)
#
# if epoch>0 and epoch % int(args['--save_freq_epoch']) == 0:
# save_network(save_dir, net, args['--net'], losses, epoch, 0)
#
# print 'Done'
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment