Commit 5b6b0aa1 authored by Emmanuel PIGNAT's avatar Emmanuel PIGNAT
Browse files

changing xi_dim, u_dim

parent 6d841d58
......@@ -84,7 +84,7 @@ class GMM(Model):
return gmm
def concatenate_gaussian(self, q, get_mvn=True):
def concatenate_gaussian(self, q, get_mvn=True, reg=None):
"""
Get a concatenated-block-diagonal replication of the GMM with sequence of state
given by q.
......@@ -95,15 +95,29 @@ class GMM(Model):
:return:
"""
if not get_mvn:
return np.concatenate([self.mu[i] for i in q]), block_diag(*[self.sigma[i] for i in q])
if reg is None:
if not get_mvn:
return np.concatenate([self.mu[i] for i in q]), block_diag(*[self.sigma[i] for i in q])
else:
mvn = MVN()
mvn.mu = np.concatenate([self.mu[i] for i in q])
mvn._sigma = block_diag(*[self.sigma[i] for i in q])
mvn._lmbda = block_diag(*[self.lmbda[i] for i in q])
return mvn
else:
mvn = MVN()
mvn.mu = np.concatenate([self.mu[i] for i in q])
mvn._sigma = block_diag(*[self.sigma[i] for i in q])
mvn._lmbda = block_diag(*[self.lmbda[i] for i in q])
if not get_mvn:
return np.concatenate([self.mu[i] for i in q]), block_diag(
*[self.sigma[i] + reg for i in q])
else:
mvn = MVN()
mvn.mu = np.concatenate([self.mu[i] for i in q])
mvn._sigma = block_diag(*[self.sigma[i] + reg for i in q])
mvn._lmbda = block_diag(*[np.linalg.inv(self.sigma[i] + reg) for i in q])
return mvn
return mvn
def compute_resp(self, demo=None, dep=None, table=None, marginal=None, ):
sample_size = demo.shape[0]
......@@ -169,7 +183,7 @@ class GMM(Model):
self.priors = np.ones(self.nb_states) / self.nb_states
def em(self, data, reg=1e-8, maxiter=100, minstepsize=1e-5, diag=False, reg_finish=False,
kmeans_init=False, random_init=True, dep_mask=None):
kmeans_init=False, random_init=True, dep_mask=None, no_init=False):
"""
:param data: [np.array([nb_timesteps, nb_dim])]
......@@ -199,13 +213,13 @@ class GMM(Model):
nb_samples = data.shape[0]
if random_init:
self.init_params_random(data)
elif kmeans_init:
self.init_params_kmeans(data)
else:
self.init_params_scikit(data)
if not no_init:
if random_init:
self.init_params_random(data)
elif kmeans_init:
self.init_params_kmeans(data)
else:
self.init_params_scikit(data)
data = data.T
......
......@@ -16,7 +16,7 @@ class PoGLQR(object):
self.B = B
self.nb_dim = nb_dim
self.dt = dt
self._s_xi, self._s_u = None, None
self._x0 = None
......@@ -44,7 +44,7 @@ class PoGLQR(object):
self._B = value
@property
def u_dim(self):
def mvn_u_dim(self):
"""
Number of dimension of input sequence lifted form
:return:
......@@ -55,7 +55,7 @@ class PoGLQR(object):
return self.nb_dim * self.horizon
@property
def xi_dim(self):
def mvn_xi_dim(self):
"""
Number of dimension of state sequence lifted form
......@@ -66,6 +66,29 @@ class PoGLQR(object):
else:
return self.nb_dim * self.horizon * 2
@property
def u_dim(self):
"""
Number of dimension of input
:return:
"""
if self.B is not None:
return self.B.shape[1]
else:
return self.nb_dim
@property
def xi_dim(self):
"""
Number of dimension of state
:return:
"""
if self.A is not None:
return self.A.shape[0]
else:
return self.nb_di * 2
@property
def mvn_sol_u(self):
"""
......@@ -85,14 +108,14 @@ class PoGLQR(object):
@property
def seq_xi(self):
if self._seq_xi is None:
self._seq_xi = self.mvn_sol_xi.mu.reshape(self.horizon, self.nb_dim * 2)
self._seq_xi = self.mvn_sol_xi.mu.reshape(self.horizon, self.xi_dim)
return self._seq_xi
@property
def seq_u(self):
if self._seq_u is None:
self._seq_u = self.mvn_sol_u.mu.reshape(self.horizon, self.nb_dim)
self._seq_u = self.mvn_sol_u.mu.reshape(self.horizon, self.u_dim)
return self._seq_u
......@@ -153,7 +176,7 @@ class PoGLQR(object):
self._mvn_u = value
else:
self._mvn_u = pbd.MVN(
mu=np.zeros(self.u_dim), lmbda=10 ** value * np.eye(self.u_dim))
mu=np.zeros(self.mvn_u_dim), lmbda=10 ** value * np.eye(self.mvn_u_dim))
@property
def x0(self):
......@@ -179,7 +202,7 @@ class PoGLQR(object):
def k(self):
# return self.mvn_sol_u.sigma.dot(self.s_u.T.dot(self.mvn_xi.lmbda)).dot(self.s_xi).reshape(
return self.mvn_sol_u.sigma.dot(self.s_u.T.dot(self.mvn_xi.lmbda)).dot(self.s_xi).reshape(
(self.horizon, self.u_dim/self.horizon, self.xi_dim/self.horizon))
(self.horizon, self.mvn_u_dim/self.horizon, self.mvn_xi_dim/self.horizon))
@property
def s_xi(self):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment