Skip to content
Snippets Groups Projects
Commit 36c3f4b9 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Add a Gaussian blur filter

parent f481311f
Branches
Tags
1 merge request!79Add keras-based models, add pixel-wise loss, other improvements
Pipeline #36777 failed
from .filter import gaussian_kernel, GaussianFilter
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
"""Says object was actually declared here, an not on the import module.
Parameters:
*args: An iterable of objects to modify
Resolves `Sphinx referencing issues
<https://github.com/sphinx-doc/sphinx/issues/3048>`
"""
for obj in args:
obj.__module__ = __name__
__appropriate__(GaussianFilter)
__all__ = [_ for _ in dir() if not _.startswith("_")]
import tensorflow as tf
def gaussian_kernel(size: int, mean: float, std: float):
"""Makes 2D gaussian Kernel for convolution.
Code adapted from: https://stackoverflow.com/a/52012658/1286165"""
d = tf.distributions.Normal(mean, std)
vals = d.prob(tf.range(start=-size, limit=size + 1, dtype=tf.float32))
gauss_kernel = tf.einsum("i,j->ij", vals, vals)
return gauss_kernel / tf.reduce_sum(gauss_kernel)
class GaussianFilter:
"""A class for blurring images"""
def __init__(self, size=13, mean=0.0, std=3.0, **kwargs):
super().__init__(**kwargs)
self.size = size
self.mean = mean
self.std = std
self.gauss_kernel = gaussian_kernel(size, mean, std)[:, :, None, None]
def __call__(self, image):
shape = tf.shape(image)
image = tf.reshape(image, [-1, shape[-3], shape[-2], shape[-1]])
input_channels = shape[-1]
gauss_kernel = tf.tile(self.gauss_kernel, [1, 1, input_channels, 1])
return tf.nn.depthwise_conv2d(
image,
gauss_kernel,
strides=[1, 1, 1, 1],
padding="SAME",
data_format="NHWC",
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment