"""
1D ergodic control formulated as Spectral Multiscale Coverage (SMC) objective,
with a spatial distribution described as a mixture of Gaussians.

Copyright (c) 2023 Idiap Research Institute <https://www.idiap.ch>
Written by Philip Abbet <philip.abbet@idiap.ch> and
Sylvain Calinon <https://calinon.ch>

This file is part of RCFS <https://rcfs.ch>
License: GPL-3.0-only
"""

import numpy as np
import matplotlib.pyplot as plt


# Parameters
# ===============================
nbData = 500  # Number of datapoints
nbFct = 10  # Number of basis functions
nbGaussian = 2  # Number of Gaussians to represent the spatial distribution
dt = 1e-2  # Time step
xlim = [0, 1] # Domain limit for each dimension (considered to be 1 for each dimension in this implementation)
L = (xlim[1] - xlim[0]) * 2  # Size of [-xlim(2),xlim(2)]
om = 2 * np.pi / L # omega
u_max = 3e-0  # Maximum speed allowed
u_norm_reg = 1e-3 # Regularizer to avoid numerical issues when speed is close to zero

# Initial point
x0 = 0.1

# Desired spatial distribution represented as a mixture of Gaussians (GMM)
# gaussian centers
Mu = np.array([
    0.7,
    0.5,
])

# Gaussian covariances
Sigma = np.ones((1, 1, nbGaussian)) * 0.01
Sigma[:,:,1] *= 0.5

# Mixing coefficients
Priors = np.ones(nbGaussian) / nbGaussian


# Compute Fourier series coefficients w_hat of desired spatial distribution
# ===============================
rg = np.arange(nbFct, dtype=float).reshape((nbFct, 1))
kk = rg * om
Lambda = (rg**2 + 1) ** -1 # Weighting vector

# Explicit description of w_hat by exploiting the Fourier transform
# properties of Gaussians (optimized version by exploiting symmetries)
w_hat = np.zeros((nbFct, 1))
for j in range(nbGaussian):
    w_hat = w_hat + Priors[j] * np.cos(kk * Mu[j]) * np.exp(-.5 * kk**2 * Sigma[:,:,j])

w_hat = w_hat / L


# Fourier basis functions (for a discretized map)
# ===============================
nbRes = 200
xm = np.linspace(xlim[0], xlim[1], nbRes).reshape((1, nbRes))  # Spatial range for 1D
phim = np.cos(kk @ xm) * 2  # Fourier basis functions
phim[1:,:] = phim[1:,:] * 2

# Desired spatial distribution
g = phim.T @ w_hat


# Ergodic control
# ===============================
x = x0  # Initial position

wt = np.zeros((nbFct, 1))
r_x = np.zeros((nbData))

for t in range(nbData):
    r_x[t] = x

    # Fourier basis functions and derivatives for each dimension
    # (only cosine part on [0,L/2] is computed since the signal
    # is even and real by construction)
    phi = np.cos(x * kk) / L

    # Gradient of basis functions
    dphi = -np.sin(x * kk) * kk / L

    # w are the Fourier series coefficients along trajectory
    wt = wt + phi 
    w = wt / (t+1)

    # Controller with constrained velocity norm
    u = -dphi.T @ (Lambda * (w - w_hat))
    u = u * u_max / (np.linalg.norm(u) + u_norm_reg)  # Velocity command

    # Ensure that we don't go out of limits
    next_x = x + u * dt
    if (next_x < xlim[0]) or (next_x > xlim[1]):
        u = -u

    # Update of position
    x = x + (u * dt)

    # Log data
    r_x[t] = x

r_g = phim.T @ w

# Plot
# ===============================
#def gdf(x, mu, sigma):
#    return 1. / (np.sqrt(2. * np.pi) * sigma) * np.exp(-np.power((x - mu) / sigma, 2.) / 2)

fig, ax = plt.subplots(4, 1, figsize=(16, 12), gridspec_kw={'height_ratios': [3, 3, 1, 1]})
plt.subplots_adjust(hspace=0.4)

#xx = xm.reshape((nbRes))
#for j in range(nbGaussian):
#    yy = gdf(xx, Mu[j], Sigma[0,0,j] * 4)
#    ax[0].plot(xx, yy, color="red")
#    ax[0].fill_between(xx, yy, alpha=0.2, color="red")
ax[0].plot(xm.T, g, lw=4, c=[1.0, .4, .4])
ax[0].plot(xm.T, r_g, c="black")
#ax[0].fill_between(xm.T, g, alpha=0.2, color="red")
ax[0].legend(['Desired','Reproduced'], fontsize=10,loc='upper left')
ax[0].set_xlim(xlim[0], xlim[1])
ax[0].title.set_text('Distributions')
ax[0].set_yticks([0])

ax[1].plot(r_x[:], np.arange(nbData), linestyle="-", c="black")
ax[1].plot(r_x[-1], [nbData], marker=".", c="black", markersize=10)
ax[1].set_xlim(xlim[0], xlim[1])
ax[1].set_ylim(0, nbData)
ax[1].set_yticks([0, nbData])
ax[1].set_yticklabels(['$t-T$','$t$'])
ax[1].title.set_text('Trajectory')
#ax[1].set_ylabel('$t$')

ax[2].set_title(r"Desired Fourier coefficients $\hat{w}$")
ax[2].imshow(np.reshape(w_hat / nbData, [1, nbFct]), cmap="gray_r")
msh = np.array([[0.0,0.0,nbFct,nbFct,0.0], [0.0,1.0,1.0,0.0,0.0]]) - 0.5
ax[2].plot(msh[0,:],msh[1,:], linestyle="-", lw=4, c=[0.8, 0, 0])
ax[2].set_yticks([])

ax[3].set_title(r"Fourier coefficients $w$")
ax[3].imshow(np.reshape(wt / nbData, [1, nbFct]), cmap="gray_r")
ax[3].plot(msh[0,:],msh[1,:], linestyle="-", lw=4, c=[0, 0, 0])
ax[3].set_yticks([])

plt.show()