diff --git a/bob/learn/tensorflow/image/__init__.py b/bob/learn/tensorflow/image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a8962844d659e6b1e13f6eeeb5e1d11dc698e8d --- /dev/null +++ b/bob/learn/tensorflow/image/__init__.py @@ -0,0 +1,19 @@ +from .filter import gaussian_kernel, GaussianFilter + +# gets sphinx autodoc done right - don't remove it +def __appropriate__(*args): + """Says object was actually declared here, an not on the import module. + + Parameters: + + *args: An iterable of objects to modify + + Resolves `Sphinx referencing issues + <https://github.com/sphinx-doc/sphinx/issues/3048>` + """ + for obj in args: + obj.__module__ = __name__ + + +__appropriate__(GaussianFilter) +__all__ = [_ for _ in dir() if not _.startswith("_")] diff --git a/bob/learn/tensorflow/image/filter.py b/bob/learn/tensorflow/image/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..3ac149db3113cf5166d46fe5c4ca80ed11052c2c --- /dev/null +++ b/bob/learn/tensorflow/image/filter.py @@ -0,0 +1,38 @@ +import tensorflow as tf + + +def gaussian_kernel(size: int, mean: float, std: float): + """Makes 2D gaussian Kernel for convolution. + Code adapted from: https://stackoverflow.com/a/52012658/1286165""" + + d = tf.distributions.Normal(mean, std) + + vals = d.prob(tf.range(start=-size, limit=size + 1, dtype=tf.float32)) + + gauss_kernel = tf.einsum("i,j->ij", vals, vals) + + return gauss_kernel / tf.reduce_sum(gauss_kernel) + + +class GaussianFilter: + """A class for blurring images""" + + def __init__(self, size=13, mean=0.0, std=3.0, **kwargs): + super().__init__(**kwargs) + self.size = size + self.mean = mean + self.std = std + self.gauss_kernel = gaussian_kernel(size, mean, std)[:, :, None, None] + + def __call__(self, image): + shape = tf.shape(image) + image = tf.reshape(image, [-1, shape[-3], shape[-2], shape[-1]]) + input_channels = shape[-1] + gauss_kernel = tf.tile(self.gauss_kernel, [1, 1, input_channels, 1]) + return tf.nn.depthwise_conv2d( + image, + gauss_kernel, + strides=[1, 1, 1, 1], + padding="SAME", + data_format="NHWC", + )