diff --git a/bob/learn/pytorch/trainers/tflog.py b/bob/learn/pytorch/trainers/tflog.py index d872817ec2465ade5c78b0bd94b29acdb12c6143..18e08cf42c8ddf3b4f49eca6c5853d54126df6c2 100644 --- a/bob/learn/pytorch/trainers/tflog.py +++ b/bob/learn/pytorch/trainers/tflog.py @@ -1,7 +1,7 @@ # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 import tensorflow as tf import numpy as np -import scipy.misc +import matplotlib.pyplot as plt try: from StringIO import StringIO # Python 2.7 except ImportError: @@ -29,7 +29,7 @@ class Logger(object): s = StringIO() except: s = BytesIO() - scipy.misc.toimage(img).save(s, format="png") + plt.imsave(s, img, format='png') # Create an Image object img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),