Skip to content
Snippets Groups Projects
Commit d70f90ae authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine] More comments on model.train() and model.to() usage

parent aaab33de
No related branches found
No related tags found
No related merge requests found
Pipeline #39842 failed
......@@ -125,9 +125,10 @@ def run(model, data_loader, device, output_folder, overlayed_folder):
os.makedirs(output_folder, exist_ok=True)
logger.info(f"Device: {device}")
model.eval().to(device)
# Sigmoid for predictions
sigmoid = torch.nn.Sigmoid()
model.eval() # set evaluation mode
model.to(device) # set/cast parameters to device
sigmoid = torch.nn.Sigmoid() # use sigmoid for predictions
# Setup timers
start_total_time = time.time()
......
......@@ -300,7 +300,9 @@ def run(
if arguments["epoch"] == 0:
logwriter.writeheader()
model.train().to(device)
model.train() #set training mode
model.to(device) # set/cast parameters to device
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
......
......@@ -179,7 +179,9 @@ def run(
if arguments["epoch"] == 0:
logwriter.writeheader()
model.train().to(device)
model.train() #set training mode
model.to(device) # set/cast parameters to device
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment