From 60d7070fdd41c15853c09bc1c697082039cba29e Mon Sep 17 00:00:00 2001
From: Anjith George <ageorge@idiap.ch>
Date: Wed, 3 Jun 2020 20:46:25 +0200
Subject: [PATCH] WIP: logging with summary writer

---
 bob/learn/pytorch/trainers/GenericTrainer.py | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/bob/learn/pytorch/trainers/GenericTrainer.py b/bob/learn/pytorch/trainers/GenericTrainer.py
index 89cadfe..639339e 100644
--- a/bob/learn/pytorch/trainers/GenericTrainer.py
+++ b/bob/learn/pytorch/trainers/GenericTrainer.py
@@ -10,6 +10,7 @@ import torch.nn as nn
 from torch.autograd import Variable
 from .tflog import Logger
 
+from torch.utils.tensorboard import SummaryWriter
 import bob.core
 logger = bob.core.log.setup("bob.learn.pytorch")
 
@@ -67,7 +68,7 @@ class GenericTrainer(object):
 
         bob.core.log.set_verbosity_level(logger, verbosity_level)
 
-        self.tf_logger = Logger(tf_logdir)
+        self.tf_logger = SummaryWriter(log_dir=tf_logdir)
 
         # Setting the gradients to true for the layers which needs to be adapted
 
@@ -252,17 +253,17 @@ class GenericTrainer(object):
             # scalar logs
 
             for tag, value in info.items():
-                self.tf_logger.scalar_summary(tag, value, epoch+1)
+                self.tf_logger.add_scalar(tag=tag, scalar_value=value, global_step=epoch+1)
 
             # Log values and gradients of the parameters (histogram summary)
 
             for tag, value in self.network.named_parameters():
                 tag = tag.replace('.', '/')
                 try:
-                    self.tf_logger.histo_summary(
-                        tag, value.data.cpu().numpy(), epoch+1)
-                    self.tf_logger.histo_summary(
-                        tag+'/grad', value.grad.data.cpu().numpy(), epoch+1)
+                    self.tf_logger.add_histogram(
+                        tag=tag, values=value.data.cpu().numpy(), global_step=epoch+1)
+                    self.tf_logger.add_histogram(
+                        tag=tag+'/grad', values=value.grad.data.cpu().numpy(), global_step=epoch+1)
                 except:
                     pass
 
-- 
GitLab