diff --git a/train.py b/train.py index 72f56ef549551f6efae70b902a792481f6c9066a..40bba9246c76200cf684d9043a50901b67ba32d1 100644 --- a/train.py +++ b/train.py @@ -13,9 +13,9 @@ parser = argparse.ArgumentParser(description='Train face reconstruction network parser.add_argument('--path_eg3d_repo', metavar='<path_eg3d_repo>', type= str, default='./eg3d', help='./eg3d') parser.add_argument('--path_eg3d_checkpoint', metavar='<path_eg3d_checkpoint>', type= str, default='./ffhq512-128.pkl', - help='./ffhq512-128.pkl`') + help='./ffhq512-128.pkl') parser.add_argument('--path_ffhq_dataset', metavar='<path_ffhq_dataset>', type= str, default='./Flickr-Faces-HQ/images1024x1024', - help='FFHQ directory`') + help='FFHQ directory') parser.add_argument('--FR_system', metavar='<FR_system>', type= str, default='ArcFace', help='ArcFace/ElasticFace (FR system from whose database the templates are leaked)') parser.add_argument('--FR_loss', metavar='<FR_loss>', type= str, default='ArcFace', @@ -26,7 +26,7 @@ args = parser.parse_args() import os,sys sys.path.append(os.getcwd()) # import src -sys.path.append(args.path_eg3d_repo) # import eg3d files +sys.path.append(f"{args.path_eg3d_repo}/eg3d") # import eg3d files from camera_utils import LookAtPoseSampler, FOV_to_intrinsics import pickle @@ -79,8 +79,8 @@ z_dim_EG3D = 512 from src.Dataset import MyDataset from torch.utils.data import DataLoader -training_dataset = MyDataset(FR_system= args.FR_system, train=True, device=device) -testing_dataset = MyDataset(FR_system= args.FR_system, train=False, device=device) +training_dataset = MyDataset(dataset_dir=args.path_ffhq_dataset, FR_system= args.FR_system, train=True, device=device) +testing_dataset = MyDataset(dataset_dir=args.path_ffhq_dataset, FR_system= args.FR_system, train=False, device=device) train_dataloader = training_dataset test_dataloader = DataLoader(testing_dataset, batch_size=18, shuffle=False)