From 687f8b244c41cc3c1add9810da3a9922af9326a2 Mon Sep 17 00:00:00 2001 From: Andre Anjos <andre.dos.anjos@gmail.com> Date: Wed, 13 May 2020 13:21:05 +0200 Subject: [PATCH] [engine.*trainer] Set non-blocking operation for CPU->GPU data transfers to make communication asynchronous --- bob/ip/binseg/engine/ssltrainer.py | 30 ++++++++++++++++++++++-------- bob/ip/binseg/engine/trainer.py | 29 ++++++++++++++++++++++------- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 98c17e1b..6d378268 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -23,7 +23,6 @@ import logging logger = logging.getLogger(__name__) - def sharpen(x, T): temp = x ** (1 / T) return temp / temp.sum(dim=1, keepdim=True) @@ -300,7 +299,7 @@ def run( if arguments["epoch"] == 0: logwriter.writeheader() - model.train() #set training mode + model.train() # set training mode model.to(device) # set/cast parameters to device for state in optimizer.state.values(): @@ -336,9 +335,15 @@ def run( # data forwarding on the existing network # labelled - images = samples[1].to(device) - ground_truths = samples[2].to(device) - unlabelled_images = samples[4].to(device) + images = samples[1].to( + device=device, non_blocking=torch.cuda.is_available() + ) + ground_truths = samples[2].to( + device=device, non_blocking=torch.cuda.is_available() + ) + unlabelled_images = samples[4].to( + device=device, non_blocking=torch.cuda.is_available() + ) # labelled outputs outputs = model(images) unlabelled_outputs = model(unlabelled_images) @@ -382,9 +387,18 @@ def run( ): # labelled - images = samples[1].to(device) - ground_truths = samples[2].to(device) - unlabelled_images = samples[4].to(device) + images = samples[1].to( + device=device, + non_blocking=torch.cuda.is_available(), + ) + ground_truths = samples[2].to( + device=device, + non_blocking=torch.cuda.is_available(), + ) + unlabelled_images = samples[4].to( + device=device, + non_blocking=torch.cuda.is_available(), + ) # labelled outputs outputs = model(images) unlabelled_outputs = model(unlabelled_images) diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index 4c4e558e..e0299cc1 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -179,7 +179,7 @@ def run( if arguments["epoch"] == 0: logwriter.writeheader() - model.train() #set training mode + model.train() # set training mode model.to(device) # set/cast parameters to device for state in optimizer.state.values(): @@ -211,11 +211,17 @@ def run( ): # data forwarding on the existing network - images = samples[1].to(device) - ground_truths = samples[2].to(device) + images = samples[1].to( + device=device, non_blocking=torch.cuda.is_available() + ) + ground_truths = samples[2].to( + device=device, non_blocking=torch.cuda.is_available() + ) masks = None if len(samples) == 4: - masks = samples[-1].to(device) + masks = samples[-1].to( + device=device, non_blocking=torch.cuda.is_available() + ) outputs = model(images) @@ -242,11 +248,20 @@ def run( valid_loader, desc="valid", leave=False, disable=None ): # data forwarding on the existing network - images = samples[1].to(device) - ground_truths = samples[2].to(device) + images = samples[1].to( + device=device, + non_blocking=torch.cuda.is_available(), + ) + ground_truths = samples[2].to( + device=device, + non_blocking=torch.cuda.is_available(), + ) masks = None if len(samples) == 4: - masks = samples[-1].to(device) + masks = samples[-1].to( + device=device, + non_blocking=torch.cuda.is_available(), + ) outputs = model(images) -- GitLab