diff --git a/bob/ip/binseg/engine/predictor.py b/bob/ip/binseg/engine/predictor.py index 4e4640f05b82bef935fd6a70db03c82ba9e1e98f..de78e69f5ba071d017babd5f7b3654244e64bb80 100644 --- a/bob/ip/binseg/engine/predictor.py +++ b/bob/ip/binseg/engine/predictor.py @@ -125,9 +125,10 @@ def run(model, data_loader, device, output_folder, overlayed_folder): os.makedirs(output_folder, exist_ok=True) logger.info(f"Device: {device}") - model.eval().to(device) - # Sigmoid for predictions - sigmoid = torch.nn.Sigmoid() + + model.eval() # set evaluation mode + model.to(device) # set/cast parameters to device + sigmoid = torch.nn.Sigmoid() # use sigmoid for predictions # Setup timers start_total_time = time.time() diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 40566f21aad29a1ed588ff67938dd7cd506d5cac..98c17e1b56c7e0732870f36d005dbf395446da74 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -300,7 +300,9 @@ def run( if arguments["epoch"] == 0: logwriter.writeheader() - model.train().to(device) + model.train() #set training mode + + model.to(device) # set/cast parameters to device for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index ed34fbe0226c749c3e94ad59b57e23dfeece8b5c..4c4e558e66bdfdba99d066ca9060d7d4a41678df 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -179,7 +179,9 @@ def run( if arguments["epoch"] == 0: logwriter.writeheader() - model.train().to(device) + model.train() #set training mode + + model.to(device) # set/cast parameters to device for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor):