Commit 99bd2722 authored by Emmanuel PIGNAT's avatar Emmanuel PIGNAT
Browse files

restoring poglqr

parent a3cdf732
import numpy as np
from utils.utils import lifted_transfer_matrix
from .utils.utils import lifted_transfer_matrix
import pbdlib as pbd
class LQR(object):
def __init__(self, A=None, B=None, nb_dim=2, dt=0.01, horizon=50):
self._horizon = horizon
......@@ -19,9 +20,11 @@ class LQR(object):
self._seq_xi, self._seq_u = None, None
self._S, self._v, self._K, self._Kv, self._ds, self._cs , self._Q = \
self._S, self._v, self._K, self._Kv, self._ds, self._cs, self._Qc = \
None, None, None, None, None, None, None
self._Q, self._z = None, None
@property
def K(self):
assert self._K is not None, "Solve Ricatti before"
......@@ -30,10 +33,38 @@ class LQR(object):
@property
def Q(self):
assert self._Q is not None, "Solve Ricatti before"
return self._Q
@Q.setter
def Q(self, value):
"""
value :
(ndim_xi, ndim_xi) or
((N, ndim_xi, ndim_xi), (nb_timestep, )) or
(nb_timestep, ndim_xi, ndim_xi)
"""
self._Q = value
@property
def z(self):
return self._z
@z.setter
def z(self, value):
"""
value :
(ndim_xi, ) or
((N, ndim_xi, ), (nb_timestep, )) or
(nb_timestep, ndim_xi)
"""
self._z = value
@property
def Qc(self):
assert self._Qc is not None, "Solve Ricatti before"
return self._Qc
@property
def cs(self):
"""
......@@ -158,6 +189,29 @@ class LQR(object):
:param t:
:return:
"""
if self._gmm_xi is None:
z, Q = None, None
if self._z is None:
z = np.zeros(self.A.shape[-1])
elif isinstance(self._z, tuple):
z = self._z[0][self._z[1][t]]
elif isinstance(self._z, np.ndarray):
if self._z.ndim == 1:
z = self._z
elif self._z.ndim == 2:
z = self._z[t]
if isinstance(self._Q, tuple):
Q = self._Q[0][self._Q[1][t]]
elif isinstance(self._Q, np.ndarray):
if self._Q.ndim == 2:
Q = self._Q
elif self._z.ndim == 3:
Q = self._Q[t]
return Q, z
else:
if isinstance(self._gmm_xi, tuple):
gmm, seq = self._gmm_xi
return gmm.lmbda[seq[t]], gmm.mu[seq[t]]
......@@ -166,7 +220,7 @@ class LQR(object):
elif isinstance(self._gmm_xi, pbd.MVN):
return self._gmm_xi.lmbda, self._gmm_xi.mu
else:
raise ValueError, "Not supported gmm_xi"
raise ValueError("Not supported gmm_xi")
def get_R(self, t):
if isinstance(self._gmm_u, pbd.MVN):
......@@ -177,7 +231,7 @@ class LQR(object):
elif isinstance(self._gmm_u, pbd.GMM):
return self._gmm_u.lmbda[t]
else:
raise ValueError, "Not supported gmm_u"
raise ValueError("Not supported gmm_u")
def ricatti(self):
"""
......@@ -191,7 +245,7 @@ class LQR(object):
_v = [None for i in range(self._horizon)]
_K = [None for i in range(self._horizon-1)]
_Kv = [None for i in range(self._horizon-1)]
_Q = [None for i in range(self._horizon-1)]
_Qc = [None for i in range(self._horizon - 1)]
# _S = np.empty((self._horizon, self.xi_dim, self.xi_dim))
# _v = np.empty((self._horizon, self.xi_dim))
# _K = np.empty((self._horizon-1, self.u_dim, self.xi_dim))
......@@ -204,8 +258,8 @@ class LQR(object):
Q, z = self.get_Q_z(t)
R = self.get_R(t)
_Q[t] = np.linalg.inv(R + self.B.T.dot(_S[t+1]).dot(self.B))
_Kv[t] = _Q[t].dot(self.B.T)
_Qc[t] = np.linalg.inv(R + self.B.T.dot(_S[t + 1]).dot(self.B))
_Kv[t] = _Qc[t].dot(self.B.T)
_K[t] = _Kv[t].dot(_S[t+1]).dot(self.A)
AmBK = self.A - self.B.dot(_K[t])
......@@ -213,11 +267,11 @@ class LQR(object):
_S[t] = self.A.T.dot(_S[t+1]).dot(AmBK) + Q
_v[t] = AmBK.T.dot(_v[t+1]) + Q.dot(z)
self._S = np.array(_S)
self._v = np.array(_v)
self._K = np.array(_K)
self._Kv = np.array(_Kv)
self._Q = np.array(_Q)
self._S = _S
self._v = _v
self._K = _K
self._Kv = _Kv
self._Qc = _Qc
self._ds = None
self._cs = None
......@@ -375,7 +429,6 @@ class PoGLQR(LQR):
else:
return self.nb_dim * self.horizon * 2
@property
def mvn_sol_u(self):
"""
......@@ -406,7 +459,6 @@ class PoGLQR(LQR):
return self._seq_u
@property
def mvn_sol_xi(self):
"""
......@@ -483,6 +535,7 @@ class PoGLQR(LQR):
self._s_xi, self._s_u = lifted_transfer_matrix(self.A, self.B,
horizon=self.horizon, dt=self.dt, nb_dim=self.nb_dim)
return self._s_u
@property
def s_xi(self):
if self._s_xi is None:
......@@ -507,5 +560,4 @@ class PoGLQR(LQR):
def horizon(self, value):
self.reset_params()
self._horizon = value
\ No newline at end of file
self._horizon = value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment