'''
Inverse kinematics with visualization of manipulability

Copyright (c) 2024 Idiap Research Institute <https://www.idiap.ch/>
Written by Sylvain Calinon <https://calinon.ch>

This file is part of RCFS <https://robotics-codes-from-scratch.github.io/>
License: GPLv3
'''

import numpy as np
import random
import matplotlib
import matplotlib.pyplot as plt
import math as m
import scipy
import scipy.spatial

# Logarithmic map for R^2 x S^1 manifold
def logmap(f, f0):
	diff = np.zeros(3)
	diff[:2] = f[:2] - f0[:2] # Position residual
	diff[2] = np.imag(np.log(np.exp(f0[-1]*1j).conj().T * np.exp(f[-1]*1j).T)).conj() # Orientation residual
	return diff
	
# Forward kinematics for end-effector (in robot coordinate system)
def fkin(x, param):
	L = np.tril(np.ones([param.nbVarX, param.nbVarX]))
	f = np.stack([
		param.l @ np.cos(L @ x),
		param.l @ np.sin(L @ x),
		np.mod(np.sum(x,0)+np.pi, 2*np.pi) - np.pi
	]) # f1,f2,f3, where f3 is the orientation (single Euler angle for planar robot)
	return f

# Forward kinematics for all joints (in robot coordinate system)
def fkin0(x, param): 
	L = np.tril(np.ones([param.nbVarX, param.nbVarX]))
	f = np.vstack([
		L @ np.diag(param.l) @ np.cos(L @ x),
		L @ np.diag(param.l) @ np.sin(L @ x)
	])
	f = np.hstack([np.zeros([2,1]), f])
	return f

# Jacobian with analytical computation (for single time step)
def Jkin(x, param):
	L = np.tril(np.ones([param.nbVarX, param.nbVarX]))
	J = np.vstack([
		-np.sin(L @ x).T @ np.diag(param.l) @ L,
		 np.cos(L @ x).T @ np.diag(param.l) @ L,
		 np.ones([1,param.nbVarX])
	])
	return J

## Parameters
# ===============================

param = lambda: None # Lazy way to define an empty class in python
param.dt = 1e-2 # Time step length
param.nbData = 50 # Number of datapoints
param.nbVarX = 3 # State space dimension (x1,x2,x3)
param.nbVarU = 3 # Control space dimension (dx1,dx2,dx3)
param.nbVarF = 3 # Objective function dimension (position and orientation of the end-effector)
param.l = [2, 2, 1] # Robot links lengths

fig, ax = plt.subplots()


fh = np.array([3, 1, -np.pi/2]) # Desired target for the end-effector (position and orientation)
x = -np.ones(param.nbVarX) * np.pi / param.nbVarX # Initial robot pose
x[0] = x[0] + np.pi 

## Inverse kinematics (IK)
# ===============================

ax.scatter(fh[0], fh[1], color='r', marker='.', s=10**2) #Plot target
for t in range(param.nbData):
	f = fkin(x, param) # Forward kinematics (for end-effector)
	J = Jkin(x, param) # Jacobian (for end-effector)
#	x += np.linalg.pinv(J) @ (fh - f) * 10 * param.dt # Update state 
	x += np.linalg.pinv(J) @ logmap(fh, f) * 10 * param.dt # Update state
	
	f_rob = fkin0(x, param) # Forward kinematics (for all articulations, including end-effector)
	ax.plot(f_rob[0,:], f_rob[1,:], color=str(1-t/param.nbData), linewidth=2) # Plot robot
	


### MANIPULABILITY ###
	
J = J[:2,:]
center = np.array([f[0], f[1]]) # end-effector position
length, width, height = 1.8,1.5,1 # max joint velocities
size = np.array([length, width, height])
refell = 130 * np.identity(2) # reference ellipsoid

# Choice of the Jacobian matrix         
J1 = False
J2 = False
J3 = False
J4 = False
diffJac = [J1, J2, J3, J4]


# 1. Robot manipulator
if J1 == True:
        theta = 5*m.pi/6
        U = np.array([[m.cos(theta), -m.sin(theta)],[m.sin(theta), m.cos(theta)]])
        J = U.T @ J
        print(J)

# 2. Bounded joint-space
if J2 == True:
        jminlim = -np.ones(param.nbVarX)
        jmaxlim = np.ones(param.nbVarX)   
        J = np.diag(1 - np.heaviside(x - jminlim,0)*np.heaviside(jmaxlim - x, 0))[:2,:]
        print(J)
        

# 3. Bounded task-space
if J3 == True:
        tminlim = -np.ones(2)
        tmaxlim = np.ones(2)
        J = np.diag(1 - np.heaviside(f[:2] - tminlim,0)*np.heaviside(tmaxlim - f[:2], 0)) @ J   
        print(J)

# 4. Object boundaries
if J4 == True:
        theta = m.pi/4
        U = np.array([[m.cos(theta), -m.sin(theta)],[m.sin(theta), m.cos(theta)]])
        tminlim = -np.ones(2)
        tmaxlim = 2*np.ones(2)
        J = np.diag(1 - np.heaviside(U.T@(f[:2] - fh[:2]) - tminlim,0)*np.heaviside(tmaxlim - (U.T@(f[:2] - fh[:2])), 0)) @ J
        print(J)
        
        

# Boundaries in joint-velocity space

# 1. Rectangular cuboid
showedges = False    # Shows the mapping of the cube's edges

# 2. Ellipse
ellBound = True

# 3. Superellipsoid
superBound = False
superVolume = False # Returns the fraction of the rectangular cuboid's volume covered by the superellipsoid


# 1. Rectangular cuboid
cube = np.zeros((2 ** param.nbVarX, param.nbVarX))
vertex = np.zeros(param.nbVarX)
# These two loops store the numbers 0 to 7 in binary (which can be seen as the coordinates of a cube)
for count1 in range(2 ** param.nbVarX):
        for count2 in range(len(bin(count1)) - 2):
                vertex[-count2-1] = int(str(bin(count1)[-count2-1]))
        cube[count1] = vertex

# Rescaling so that the center of the cube is located at the origin
cube = cube * 2 - 1


for i in range(len(size)):
        cube[:,i] = cube[:,i] * size[i]

# Computation of the manipulability polytope
polytope = np.zeros((2 ** param.nbVarX,2))
for count in range(2 ** param.nbVarX):
        polytope[count] = J @ cube[count] + center
                
xpoints = polytope[:,0]
ypoints = polytope[:,1]
polytope = np.array([xpoints, ypoints]).T

if not any(diffJac) == True:
        hull = scipy.spatial.ConvexHull(polytope)

        # vertices of the covex hull (might come in handy)
        vertices = np.zeros((len(hull.vertices),2))
        for i in range(len(hull.vertices)):
                vertices[i] = polytope[hull.vertices[i]]
        cube_norms = np.linalg.norm(vertices, axis = 1)
                
        for simplex in hull.simplices:
                plt.plot(polytope[simplex, 0], polytope[simplex, 1], 'b--')
        

def norm(vec, coeff, exp):
        terms = abs(vec/coeff)**exp
        norm = sum(terms) ** (1/exp)
        return norm

def sample(npoints, coeff, exp):
        vecs = np.random.rand(npoints, param.nbVarX) * 2 - 1
        vecs *= size
        
        for count in range(len(vecs)):
                vecs[count] = vecs[count] / norm(vecs[count], coeff, exp)
        return vecs

# 2. Ellipsoid
if ellBound == True:
        num_iter = 1000
        # coeff = np.array([1,1,1]) # these are the dimensions of the superellipsoid in joint-velocity space
        coeff = size # if one wants the superellipsoid to be contained in the cuboid
        exp = 2

        ell_jvlim = sample(num_iter, coeff, 2)
        
        ell_tvlim = np.zeros((num_iter,2))

        for count in range(len(ell_jvlim)):
                ell_tvlim[count] = J @ ell_jvlim[count] + center

        ell_x, ell_y = ell_tvlim.T

        A = np.diag(coeff ** 2)
        Q = J @ A @ J.T
        eigenvals, eigenvecs = np.linalg.eig(Q)
        
        # Sort Eigenvalues and EigenVectors
        idx = eigenvals.argsort()[::-1]   
        eigenvals = eigenvals[idx]
        eigenvecs = eigenvecs[idx]

        # the sqrt of the eigenvalues give the length of the semi-axes
        print(f"Ellipsoid eigenvalues: {eigenvals}")
        vec1, vec2 = eigenvecs.T
        vec1 = vec1 * m.sqrt(eigenvals[0]) + center
        vec2 = vec2 * m.sqrt(eigenvals[1]) + center


        if not any(diffJac) == True:
                polytope = np.array([ell_x, ell_y]).T
                hull = scipy.spatial.ConvexHull(polytope)

                        # vertices of the covex hull (might come in handy)
                        #vertex = np.zeros((len(hull.vertices),2))
                        #for i in range(len(hull.vertices)):
                        #        vertex[i] = polytope[hull.vertices[i]]
                        #ax.plot(vertex[:,0], vertex[:,1], "gv")
               
                for simplex in hull.simplices:

                    plt.plot(polytope[simplex, 0], polytope[simplex, 1], 'r--')



# 3. Superellipsoid (rigorously this is not the most general form of a superellipsoid)
if superBound == True:
        num_iter = 1000         
        # coeff = np.array([1,1,1]) # these are the dimensions of the superellipsoid in joint-velocity space
        coeff = size # if one wants the superellipsoid to be contained in the cuboid
        exp = 4 # exp = 2 for an ellipse, exp = 4 for squircle, exp --> infty for rectangular cuboid

        if superVolume == True:
                vol = scipy.special.gamma(1/exp + 1)**param.nbVarX/scipy.special.gamma(param.nbVarX/exp + 1)
                print(f"fraction of the rectangular cuboid's volume: {vol}")

        jvlim = sample(num_iter, coeff, exp)
        
        tvlim = np.zeros((num_iter,2))

        for count in range(len(jvlim)):
                tvlim[count] = J @ jvlim[count] + center
                

        xpoints, ypoints = tvlim.T

        # Idea: approximate whatever shape I get with an ellipsoid, so that the reasoning on the eigenvalues apply!
        # Note: it does not just give the same ellipsoid as if exp = 2

        tvmax = tvlim[np.argmax(np.linalg.norm(tvlim-center, axis = 1))]
        cov_mat = np.cov(tvlim.T)
        eigenvals, eigenvecs = np.linalg.eig(cov_mat)
        idx = eigenvals.argsort()[::-1]   
        eigenvals = eigenvals[idx]
        eigenvecs = eigenvecs[:,idx]
        eigenvecs = eigenvecs * np.sqrt(eigenvals)
        ratio = np.linalg.norm(tvmax-center)/m.sqrt(eigenvals[0])
        eigenvecs *= ratio

        vec1, vec2 = eigenvecs.T[0],eigenvecs.T[1]

        '''
        # Manipulability matrix
        Q = eigenvecs @ np.array([[eigenvals[0],0],[0,eigenvals[1]]]) @ np.linalg.inv(eigenvecs)

        # Riemannian distance
        A = np.linalg.inv(scipy.linalg.sqrtm(refell)) @ Q @ np.linalg.inv(scipy.linalg.sqrtm(refell))
        d = np.linalg.norm(scipy.linalg.logm(A))
        print(f"Riemannian distance: {d}")
        '''
        
        print(f"Superellipsoid eigenvalues: {(ratio * np.sqrt(eigenvals))**2}")

        phi = np.linspace(0, 2*m.pi,200)
        x = np.zeros((len(phi),2))

        for i in range(len(phi)):
                x[i] = center + vec1 * m.cos(phi[i]) + vec2 * m.sin(phi[i])

        super_norms = np.linalg.norm(x,axis = 1)
        
        if not any(diffJac) == True:
                ax.plot(x[:,0], x[:,1], "g1", label = "superellipsoid")
                vec1 += center
                vec2 += center
                ax.plot([center[0], tvmax[0]], [center[1],tvmax[1]]) 
                ax.plot([center[0], vec1[0]], [center[1],vec1[1]])
                ax.plot([center[0], vec2[0]], [center[1],vec2[1]])


# Plots
showhull = True # to show the convex hull of the superellipsoid
showpoints = False # to show the image of all the sampled points


if showhull == True and not any(diffJac) == True:
        polytope = np.array([xpoints, ypoints]).T
        hull = scipy.spatial.ConvexHull(polytope)

                # vertices of the covex hull (might come in handy)
                #vertex = np.zeros((len(hull.vertices),2))
                #for i in range(len(hull.vertices)):
                #        vertex[i] = polytope[hull.vertices[i]]
                #ax.plot(vertex[:,0], vertex[:,1], "gv")
       
        for simplex in hull.simplices:

            plt.plot(polytope[simplex, 0], polytope[simplex, 1], 'g--')
                
if showpoints == True:
        ax.plot(xpoints, ypoints, "kx")

#fig = plt.figure()
#ax2 = fig.add_subplot(projection='3d')

#ax2.scatter(cube[:,0], cube[:,1], cube[:,2], c = "blue", label = "rectangular cuboid")

#if ellBound == True:
#        ax2.scatter(ell_jvlim[:,0], ell_jvlim[:,1], ell_jvlim[:,2], c = "red", label = "ellipsoid")
#    
#if superBound == True:
#        ax2.scatter(jvlim[:,0], jvlim[:,1], jvlim[:,2], c = "green", label = "superellipsoid")

#legend = ax2.legend(loc='upper right')


#if showedges == True:
#        num_points = 50
#        jvlim, tvlim = np.zeros((num_points,3)), np.zeros((num_points,2))
#        
#        edges = np.vstack((np.unique(cube[:,:2], axis = 0), np.unique(cube[:,1:3], axis = 0), np.unique(cube[:,0:3:2], axis = 0)))
#        for edge in edges[:4]:
#                for count in range(num_points):
#                        z = (random.random() * 2 - 1) * height
#                        jvlim[count] = np.array([edge[0],edge[1],z])
#                        tvlim[count] = J @ jvlim[count] + center
#                ax2.scatter(jvlim[:,0], jvlim[:,1], jvlim[:,2], c = "blue")
#                ax.plot(tvlim[:,0], tvlim[:,1], "bx")

#        for edge in edges[4:8]:
#                for count in range(num_points):
#                        x = (random.random() * 2 - 1) * length
#                        jvlim[count] = np.array([x,edge[0],edge[1]])
#                        tvlim[count] = J @ jvlim[count] + center
#                ax2.scatter(jvlim[:,0], jvlim[:,1], jvlim[:,2], c = "red")
#                ax.plot(tvlim[:,0], tvlim[:,1], "rx")
#                        
#        for edge in edges[8:]:
#                for count in range(num_points):
#                        y = (random.random() * 2 - 1) * width
#                        jvlim[count] = np.array([edge[0],y,edge[1]])
#                        tvlim[count] = J @ jvlim[count] + center
#                ax2.scatter(jvlim[:,0], jvlim[:,1], jvlim[:,2], c = "green")
#                ax.plot(tvlim[:,0], tvlim[:,1], "gx")
#        


#ax.axis('off')
ax.axis('equal')
#ax2.axis('equal')
#plt.title(f"Length: {length}, width: {width}, height: {height}, p = {exp}")

plt.show()