Skip to content
Snippets Groups Projects
Commit 687f8b24 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[engine.*trainer] Set non-blocking operation for CPU->GPU data transfers to...

[engine.*trainer] Set non-blocking operation for CPU->GPU data transfers to make communication asynchronous
parent d70f90ae
No related branches found
No related tags found
No related merge requests found
Pipeline #39843 failed
...@@ -23,7 +23,6 @@ import logging ...@@ -23,7 +23,6 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def sharpen(x, T): def sharpen(x, T):
temp = x ** (1 / T) temp = x ** (1 / T)
return temp / temp.sum(dim=1, keepdim=True) return temp / temp.sum(dim=1, keepdim=True)
...@@ -300,7 +299,7 @@ def run( ...@@ -300,7 +299,7 @@ def run(
if arguments["epoch"] == 0: if arguments["epoch"] == 0:
logwriter.writeheader() logwriter.writeheader()
model.train() #set training mode model.train() # set training mode
model.to(device) # set/cast parameters to device model.to(device) # set/cast parameters to device
for state in optimizer.state.values(): for state in optimizer.state.values():
...@@ -336,9 +335,15 @@ def run( ...@@ -336,9 +335,15 @@ def run(
# data forwarding on the existing network # data forwarding on the existing network
# labelled # labelled
images = samples[1].to(device) images = samples[1].to(
ground_truths = samples[2].to(device) device=device, non_blocking=torch.cuda.is_available()
unlabelled_images = samples[4].to(device) )
ground_truths = samples[2].to(
device=device, non_blocking=torch.cuda.is_available()
)
unlabelled_images = samples[4].to(
device=device, non_blocking=torch.cuda.is_available()
)
# labelled outputs # labelled outputs
outputs = model(images) outputs = model(images)
unlabelled_outputs = model(unlabelled_images) unlabelled_outputs = model(unlabelled_images)
...@@ -382,9 +387,18 @@ def run( ...@@ -382,9 +387,18 @@ def run(
): ):
# labelled # labelled
images = samples[1].to(device) images = samples[1].to(
ground_truths = samples[2].to(device) device=device,
unlabelled_images = samples[4].to(device) non_blocking=torch.cuda.is_available(),
)
ground_truths = samples[2].to(
device=device,
non_blocking=torch.cuda.is_available(),
)
unlabelled_images = samples[4].to(
device=device,
non_blocking=torch.cuda.is_available(),
)
# labelled outputs # labelled outputs
outputs = model(images) outputs = model(images)
unlabelled_outputs = model(unlabelled_images) unlabelled_outputs = model(unlabelled_images)
......
...@@ -179,7 +179,7 @@ def run( ...@@ -179,7 +179,7 @@ def run(
if arguments["epoch"] == 0: if arguments["epoch"] == 0:
logwriter.writeheader() logwriter.writeheader()
model.train() #set training mode model.train() # set training mode
model.to(device) # set/cast parameters to device model.to(device) # set/cast parameters to device
for state in optimizer.state.values(): for state in optimizer.state.values():
...@@ -211,11 +211,17 @@ def run( ...@@ -211,11 +211,17 @@ def run(
): ):
# data forwarding on the existing network # data forwarding on the existing network
images = samples[1].to(device) images = samples[1].to(
ground_truths = samples[2].to(device) device=device, non_blocking=torch.cuda.is_available()
)
ground_truths = samples[2].to(
device=device, non_blocking=torch.cuda.is_available()
)
masks = None masks = None
if len(samples) == 4: if len(samples) == 4:
masks = samples[-1].to(device) masks = samples[-1].to(
device=device, non_blocking=torch.cuda.is_available()
)
outputs = model(images) outputs = model(images)
...@@ -242,11 +248,20 @@ def run( ...@@ -242,11 +248,20 @@ def run(
valid_loader, desc="valid", leave=False, disable=None valid_loader, desc="valid", leave=False, disable=None
): ):
# data forwarding on the existing network # data forwarding on the existing network
images = samples[1].to(device) images = samples[1].to(
ground_truths = samples[2].to(device) device=device,
non_blocking=torch.cuda.is_available(),
)
ground_truths = samples[2].to(
device=device,
non_blocking=torch.cuda.is_available(),
)
masks = None masks = None
if len(samples) == 4: if len(samples) == 4:
masks = samples[-1].to(device) masks = samples[-1].to(
device=device,
non_blocking=torch.cuda.is_available(),
)
outputs = model(images) outputs = model(images)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment