Commit 7c6e5e9e authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixing test cases

parent 52081a00
Pipeline #59264 failed with stages
in 19 minutes and 23 seconds
......@@ -22,7 +22,7 @@ def run_mine(is_mine_f):
x_dim = X.shape[1]
z_dim = Z.shape[1]
t = T(x_dim, z_dim)
t = T(x_dim, z_dim, device=torch.device("cpu"))
model = Mine(T=t, is_mine_f=is_mine_f)
X = torch.tensor(X.astype("float32"))
......
......@@ -136,23 +136,21 @@ def test_mine():
# )
train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
backbone = iresnet34(
pretrained="/idiap/temp/tpereira/bob/data/pytorch/iresnet-91a5de61/iresnet34-5b0d0e90.pth"
)
identity_backbone = iresnet34()
demographic_backbone = iresnet34()
# list(dataloader.dataset.labels.values())
#####################
## IDENTITY
num_class = len(list(train_dataloader.dataset.labels.values()))
num_class = len(train_dataloader.dataset.labels)
identity_head = ArcFace(
feat_dim=backbone.features.num_features, num_class=num_class
feat_dim=identity_backbone.features.num_features, num_class=num_class
)
######################
## DEMOGRAPHIC
num_class = len(list(train_dataloader.dataset.demographic_keys.values()))
demographic_head = DemographicRegularHead(
feat_dim=backbone.features.num_features, num_class=num_class
feat_dim=demographic_backbone.features.num_features, num_class=num_class
)
################
......@@ -161,7 +159,8 @@ def test_mine():
# Preparing lightining model
model = MINEModel(
backbone=backbone,
identity_backbone=identity_backbone,
demographic_backbone=demographic_backbone,
identity_head=identity_head,
demographic_head=demographic_head,
loss_fn=torch.nn.CrossEntropyLoss(),
......@@ -175,13 +174,13 @@ def test_mine():
trainer = pl.Trainer(
# callbacks=..... # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#callbacks
# logger=logger,...
max_epochs=4,
gpus=-1 if torch.cuda.is_available() else None,
max_epochs=1,
gpus=None,
# resume_from_checkpoint=resume_from_checkpoint, #https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#resume-from-checkpoint
# debug flags
# limit_train_batches=10, # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#limit-train-batches
limit_train_batches=2, # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#limit-train-batches
# limit_val_batches=1,
amp_level="00", # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#amp-level
# amp_level="00", # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#amp-level
)
trainer.fit(
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment