Skip to content
Snippets Groups Projects
Commit 07aa13df authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[config] add config files for both CNNs and GANs

parent c75f9631
Branches
Tags
No related merge requests found
Pipeline #26287 passed
### DATA ###
from bob.learn.pytorch.datasets import CasiaWebFaceDataset
import torchvision.transforms as transforms
from bob.learn.pytorch.datasets import RollChannels
from bob.learn.pytorch.datasets import ToTensor
from bob.learn.pytorch.datasets import Normalize
dataset = CasiaWebFaceDataset(root_dir='/idiap/project/fargo/xpeng_prepro/CASIA-Webface-crop-128/',
transform=transforms.Compose([
RollChannels(), # bob to skimage:
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
)
### NETWORK ###
from bob.learn.pytorch.architectures import CASIANet
number_of_classes = 10575
dropout = 0.5
network = CASIANet(number_of_classes, dropout)
### DATA ###
from bob.learn.pytorch.datasets import CasiaWebFaceDataset
import torchvision.transforms as transforms
from bob.learn.pytorch.datasets import RollChannels
from bob.learn.pytorch.datasets import ToTensor
from bob.learn.pytorch.datasets import Normalize
dataset = CasiaWebFaceDataset(root_dir='/idiap/project/fargo/xpeng_prepro/CASIA-Webface-crop-128/',
transform=transforms.Compose([
RollChannels(), # bob to skimage:
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
)
### NETWORK ###
from bob.learn.pytorch.architectures import CNN8
number_of_classes = 10575
dropout = 0.5
network = CNN8(number_of_classes, dropout)
### NETWORK ###
from bob.learn.pytorch.architectures import ConditionalGAN_generator
from bob.learn.pytorch.architectures import ConditionalGAN_discriminator
from bob.learn.pytorch.architectures import weights_init
noise_dim = 100
conditional_dim = 13
channels = 3
ngpu = 1
generator = ConditionalGAN_generator(noise_dim, conditional_dim)
generator.apply(weights_init)
discriminator = ConditionalGAN_discriminator(conditional_dim)
discriminator.apply(weights_init)
### DATA ###
from bob.learn.pytorch.datasets.multipie import MultiPIEDataset
from bob.learn.pytorch.datasets import RollChannels
from bob.learn.pytorch.datasets import ToTensor
from bob.learn.pytorch.datasets import Normalize
import torchvision.transforms as transforms
print("loading data ....")
dataset = MultiPIEDataset(root_dir='/idiap/temp/heusch/data/multipie-cropped-64x64',
frontal_only=False,
transform=transforms.Compose([
RollChannels(), # bob to skimage:
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
)
print("done")
### DATA ###
from bob.learn.pytorch.datasets.multipie import MultiPIEDataset
from bob.learn.pytorch.datasets import RollChannels
from bob.learn.pytorch.datasets import ToTensor
from bob.learn.pytorch.datasets import Normalize
import torchvision.transforms as transforms
dataset = MultiPIEDataset(root_dir='/idiap/temp/heusch/data/multipie-cropped-64x64',
frontal_only=True,
transform=transforms.Compose([
RollChannels(), # bob to skimage:
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
)
### NETWORK ###
from bob.learn.pytorch.architectures import DCGAN_generator
from bob.learn.pytorch.architectures import DCGAN_discriminator
from bob.learn.pytorch.architectures import weights_init
ngpu = 1
generator = DCGAN_generator(ngpu)
generator.apply(weights_init)
discriminator = DCGAN_discriminator(ngpu)
discriminator.apply(weights_init)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment