diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index 98c17e1b56c7e0732870f36d005dbf395446da74..6d378268a2c5aed5e54b34ea88e86c69afe4763b 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -23,7 +23,6 @@ import logging logger = logging.getLogger(__name__) - def sharpen(x, T): temp = x ** (1 / T) return temp / temp.sum(dim=1, keepdim=True) @@ -300,7 +299,7 @@ def run( if arguments["epoch"] == 0: logwriter.writeheader() - model.train() #set training mode + model.train() # set training mode model.to(device) # set/cast parameters to device for state in optimizer.state.values(): @@ -336,9 +335,15 @@ def run( # data forwarding on the existing network # labelled - images = samples[1].to(device) - ground_truths = samples[2].to(device) - unlabelled_images = samples[4].to(device) + images = samples[1].to( + device=device, non_blocking=torch.cuda.is_available() + ) + ground_truths = samples[2].to( + device=device, non_blocking=torch.cuda.is_available() + ) + unlabelled_images = samples[4].to( + device=device, non_blocking=torch.cuda.is_available() + ) # labelled outputs outputs = model(images) unlabelled_outputs = model(unlabelled_images) @@ -382,9 +387,18 @@ def run( ): # labelled - images = samples[1].to(device) - ground_truths = samples[2].to(device) - unlabelled_images = samples[4].to(device) + images = samples[1].to( + device=device, + non_blocking=torch.cuda.is_available(), + ) + ground_truths = samples[2].to( + device=device, + non_blocking=torch.cuda.is_available(), + ) + unlabelled_images = samples[4].to( + device=device, + non_blocking=torch.cuda.is_available(), + ) # labelled outputs outputs = model(images) unlabelled_outputs = model(unlabelled_images) diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index 4c4e558e66bdfdba99d066ca9060d7d4a41678df..e0299cc1f1a29ccbb1fccc8c45146dd15eeae1e3 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -179,7 +179,7 @@ def run( if arguments["epoch"] == 0: logwriter.writeheader() - model.train() #set training mode + model.train() # set training mode model.to(device) # set/cast parameters to device for state in optimizer.state.values(): @@ -211,11 +211,17 @@ def run( ): # data forwarding on the existing network - images = samples[1].to(device) - ground_truths = samples[2].to(device) + images = samples[1].to( + device=device, non_blocking=torch.cuda.is_available() + ) + ground_truths = samples[2].to( + device=device, non_blocking=torch.cuda.is_available() + ) masks = None if len(samples) == 4: - masks = samples[-1].to(device) + masks = samples[-1].to( + device=device, non_blocking=torch.cuda.is_available() + ) outputs = model(images) @@ -242,11 +248,20 @@ def run( valid_loader, desc="valid", leave=False, disable=None ): # data forwarding on the existing network - images = samples[1].to(device) - ground_truths = samples[2].to(device) + images = samples[1].to( + device=device, + non_blocking=torch.cuda.is_available(), + ) + ground_truths = samples[2].to( + device=device, + non_blocking=torch.cuda.is_available(), + ) masks = None if len(samples) == 4: - masks = samples[-1].to(device) + masks = samples[-1].to( + device=device, + non_blocking=torch.cuda.is_available(), + ) outputs = model(images)