Commit b9dfd56d authored by Emmanuel PIGNAT's avatar Emmanuel PIGNAT
Browse files

adding GMMLQR

adding contour on/off
various additions to MTMM
initialize GMM with parameters
parent 86da3fea
......@@ -8,7 +8,7 @@ from .model import Model
from .mvn import *
from .plot import *
from .pylqr import *
from .poglqr import PoGLQR, LQR
from .poglqr import PoGLQR, LQR, GMMLQR
from .mtmm import MTMM, VBayesianGMM, VMBayesianGMM
try:
......
......@@ -8,11 +8,20 @@ from mvn import MVN
class GMM(Model):
def __init__(self, nb_states=1, nb_dim=None, init_zeros=False):
def __init__(self, nb_states=1, nb_dim=None, init_zeros=False, mu=None, lmbda=None, sigma=None, priors=None):
if mu is not None:
nb_states = mu.shape[0]
nb_dim = mu.shape[-1]
Model.__init__(self, nb_states, nb_dim)
# flag to indicate that publishing was not init
self.publish_init = False
self._mu = mu
self._lmbda = lmbda
self._sigma = sigma
self._priors = priors
if init_zeros:
self.init_zeros()
......
......@@ -2,7 +2,7 @@ import numpy as np
from .gmm import GMM, MVN
from functions import multi_variate_normal, multi_variate_t
from utils import gaussian_moment_matching
from scipy.special import gamma, gammaln
from scipy.special import gamma, gammaln, logsumexp
class MTMM(GMM):
"""
......@@ -10,9 +10,9 @@ class MTMM(GMM):
"""
def __init__(self, *args, **kwargs):
self._nu = kwargs.pop('nu', None)
GMM.__init__(self, *args, **kwargs)
self._nu = None
self._k = None
def __add__(self, other):
......@@ -39,6 +39,10 @@ class MTMM(GMM):
return mtmm
def get_matching_gmm(self):
return GMM(mu=self.mu, sigma=self.sigma * (self.nu/(self.nu-2.))[:, None, None],
priors=self.priors)
@property
def k(self):
return self._k
......@@ -103,6 +107,9 @@ class MTMM(GMM):
return gmm_out
def log_prob(self, x):
return logsumexp(self.log_prob_components(x) + np.log(self.priors)[:, None], axis=0)
def log_prob_components(self, x):
dx = self.mu[:, None] - x[None] # [nb_states, nb_samples, nb_dim]
......@@ -373,6 +380,11 @@ class VBayesianGMM(MTMM):
self._posterior_samples += [_gmm]
def get_used_states(self):
keep = self.nu - 1. > self.nu_prior
return MTMM(mu=self.mu[keep], lmbda=self.lmbda[keep],
sigma=self.sigma[keep], nu=self.nu[keep], priors=self.priors[keep])
def posterior(self, data, mean_scale=10., cov=None, dp=True):
self.nb_dim = data.shape[1]
......@@ -392,6 +404,7 @@ class VBayesianGMM(MTMM):
self.nu = np.copy(m.degrees_of_freedom_) - self.nb_dim + 1
self.nu_prior = m.degrees_of_freedom_prior
w_k = np.linalg.inv(m.covariances_ * m.degrees_of_freedom_[:, None, None])
l_k = ((m.degrees_of_freedom_[:, None, None] + 1 - self.nb_dim) * m.mean_precision_[:, None, None])/ \
......
......@@ -269,7 +269,7 @@ def plot_linear_system(K, b=None, name=None, nb_sub=10, ax0=None, xlim=[-1, 1],
return [strm]
def plot_function_map(f, nb_sub=10, ax=None, xlim=[-1, 1], ylim=[-1, 1], opp=False, exp=False):
def plot_function_map(f, nb_sub=10, ax=None, xlim=[-1, 1], ylim=[-1, 1], opp=False, exp=False, vmin=None, vmax=None, contour=True):
"""
:param f: [function]
......@@ -296,12 +296,18 @@ def plot_function_map(f, nb_sub=10, ax=None, xlim=[-1, 1], ylim=[-1, 1], opp=Fal
if ax is None:
ax = plt
CS = ax.contour(xx, yy, z, cmap='viridis')
ax.clabel(CS, inline=1, fontsize=10)
if contour:
try:
CS = ax.contour(xx, yy, z, cmap='viridis')
ax.clabel(CS, inline=1, fontsize=10)
except:
pass
if opp: z = -z
if exp: z = np.exp(z)
ax.imshow(z, interpolation='bilinear', origin='lower', extent=xlim + ylim,
alpha=0.5, cmap='viridis')
alpha=0.5, cmap='viridis', vmin=vmin, vmax=vmax)
return np.min(z), np.max(z)
def plot_mixture_linear_system(model, mode='glob', nb_sub=20, gmm=True, min_alpha=0.,
cmap=plt.cm.jet, A=None,b=None, gmr=False, return_strm=False,
......
......@@ -240,6 +240,57 @@ class LQR(object):
else:
return np.array(xis), np.array(us)
class GMMLQR(LQR):
"""
LQR with a GMM cost on the state, approximation to be checked
"""
def __init__(self, *args, **kwargs):
self._full_gmm_xi = None
LQR.__init__(self, *args, **kwargs)
@property
def full_gmm_xi(self):
"""
Distribution of state
:return:
"""
return self._full_gmm_xi
@full_gmm_xi.setter
def full_gmm_xi(self, value):
"""
:param value [pbd.GMM] or [(pbd.GMM, list)]
"""
self._full_gmm_xi = value
def ricatti(self, x0, n_best=None):
costs = []
if isinstance(self._full_gmm_xi, pbd.MTMM):
full_gmm = self.full_gmm_xi.get_matching_gmm()
else:
full_gmm = self.full_gmm_xi
if n_best is not None:
log_prob_components = self.full_gmm_xi.log_prob_components(x0)
a = np.sort(log_prob_components, axis=0)[-n_best - 1][0]
for i in range(self.full_gmm_xi.nb_states):
if n_best is not None and log_prob_components[i] <a:
costs += [-np.inf]
else:
self.gmm_xi = full_gmm, [i for j in range(self.horizon)]
LQR.ricatti(self)
xis, us = self.get_seq(x0)
costs += [np.sum(self.gmm_u.log_prob(us) + self.full_gmm_xi.log_prob(xis))]
max_lqr = np.argmax(costs)
self.gmm_xi = full_gmm, [max_lqr for j in range(self.horizon)]
LQR.ricatti(self)
class PoGLQR(LQR):
"""
Implementation of LQR with Product of Gaussian as described in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment