Commit 5d4cb785 authored by Emmanuel PIGNAT's avatar Emmanuel PIGNAT
Browse files

adding files

parent 366cdec5
......@@ -2,7 +2,7 @@ import numpy as np
from .model import *
from .functions import multi_variate_normal
from scipy.linalg import block_diag
from scipy.misc import logsumexp
from termcolor import colored
from .mvn import MVN
......@@ -328,7 +328,6 @@ class GMM(Model):
self.init_params_scikit(data, 'full')
if only_scikit: return
data = data.T
LL = np.zeros(nb_max_steps)
for it in range(nb_max_steps):
......@@ -338,7 +337,7 @@ class GMM(Model):
L_log = np.zeros((self.nb_states, nb_samples))
for i in range(self.nb_states):
L_log[i, :] = np.log(self.priors[i]) + multi_variate_normal(data.T, self.mu[i],
L_log[i, :] = np.log(self.priors[i]) + multi_variate_normal(data, self.mu[i],
self.sigma[i],
log=True)
......@@ -347,12 +346,13 @@ class GMM(Model):
GAMMA2 = GAMMA / np.sum(GAMMA, axis=1)[:, np.newaxis]
# M-step
self.mu = np.einsum('ac,ic->ai', GAMMA2,
data) # a states, c sample, i dim
dx = data[None, :] - self.mu[:, :, None] # nb_dim, nb_states, nb_samples
# self.mu = np.einsum('ac,ci->ai', GAMMA2,
# data) # a states, c sample, i dim
#
self.mu = np.dot(GAMMA2, data)
dx = data[None] - self.mu[:, None] # states, samples, dim
self.sigma = np.einsum('acj,aic->aij', np.einsum('aic,ac->aci', dx, GAMMA2),
self.sigma = np.einsum('acj,aci->aij', np.einsum('aci,ac->aci', dx, GAMMA2),
dx) # a states, c sample, i-j dim
self.sigma += self.reg
......@@ -373,9 +373,9 @@ class GMM(Model):
if it > nb_min_steps:
if LL[it] - LL[it - 1] < max_diff_ll:
if reg_finish is not False:
self.sigma = np.einsum(
'acj,aic->aij', np.einsum('aic,ac->aci', dx, GAMMA2),
dx) + reg_finish
self.sigma = np.einsum('acj,aci->aij',
np.einsum('aci,ac->aci', dx, GAMMA2),
dx) + reg_finish
if verbose:
print(
......@@ -496,3 +496,6 @@ class GMM(Model):
return -0.5 * np.einsum(eins_idx[0], dx, np.einsum(eins_idx[1], lmbda_, dx)) \
- mu.shape[1] / 2. * np.log(2 * np.pi) - np.sum(
np.log(sigma_chol_.diagonal(axis1=1, axis2=2)), axis=1)
def log_prob(self, x):
return logsumexp(self.mvn_pdf(x) + np.log(self.priors)[None], axis=1)
......@@ -3,6 +3,7 @@ from .utils.gaussian_utils import gaussian_moment_matching
from .plot import plot_gmm
class Model(object):
"""
Basis class for Gaussian mixture model (GMM), Hidden Markov Model (HMM), Hidden semi-Markov
......@@ -302,6 +303,10 @@ class Model(object):
mu_est, sigma_est = (np.asarray(mu_est), np.asarray(sigma_est))
if return_gmm:
if sample_size == 1:
from .gmm import GMM
return GMM(priors=h[:, 0], mu=mu_est[:, 0], sigma=sigma_est,
nb_dim=mu_est.shape[-1], nb_states=mu_est.shape[0])
return h, mu_est, sigma_est
# return np.mean(mu_est, axis=0)
else:
......
......@@ -275,6 +275,19 @@ class MVN(object):
return prod
def alpha_divergence(self, other, alpha=0.5):
"https://mast.queensu.ca/~communications/Papers/gil-msc11.pdf"
lmbda = np.linalg.inv(alpha * other.sigma + (1. - alpha) * self.sigma)
r = 0.5 * (self.mu - other.mu).T.dot(lmbda).dot(self.mu - other.mu) -\
1./(2 * alpha * (alpha - 1.)) * (
-np.linalg.slogdet(lmbda)[1]\
- (1.-alpha) * np.linalg.slogdet(self.sigma)[1]\
- alpha * np.linalg.slogdet(other.sigma)[1])
return r
def sample(self, size=None):
return np.random.multivariate_normal(self.mu, self.sigma, size=size)
......
......@@ -77,7 +77,7 @@ def plot_distpatch(ax, x, mean, var, color=[1, 0, 0], num_std=2, alpha=0.5, line
ax.plot(x, mean, linewidth=linewidth, color=color, alpha=linealpha) # Mean
def plot_spherical_gmm(Mu, Sigma, dim=None, tp=None, color='r',
alpha=255, swap=False):
alpha=255, swap=False, ax=None, label=None):
"""
:param Mu:
......@@ -168,9 +168,14 @@ def plot_spherical_gmm(Mu, Sigma, dim=None, tp=None, color='r',
else:
c = col
plt.plot(points_int[0, :], points_int[1, :], lw=1, alpha=1, color=col)
plt.plot(points_ext[0, :], points_ext[1, :], lw=1, alpha=1, color=col)
plt.plot(points[0, :], points[1, :], lw=1, alpha=1, color='k', ls='--')
if ax is None:
p, a = (plt, plt.axes())
else:
p, a = (ax, ax)
_label_std = None if label is None else label + ' std'
p.plot(points_int[0, :], points_int[1, :], lw=1, alpha=1, color=col, ls='--')
p.plot(points_ext[0, :], points_ext[1, :], lw=1, alpha=1, color=col, ls='--')
p.plot(points[0, :], points[1, :], lw=1, alpha=1, color=col, label=label)
# plt.fill_between(points_ext[0], points_int[1], points_ext[1])
......
......@@ -157,7 +157,7 @@ class LQR(object):
@gmm_u.setter
def gmm_u(self, value):
"""
:param value [float] or [pbd.MVN] or [pbd.GMM] or [(pbd.GMM, list)]
:param value [float (std of u)] or [pbd.MVN] or [pbd.GMM] or [(pbd.GMM, list)]
"""
# resetting solution
self._mvn_sol_xi = None
......@@ -167,7 +167,7 @@ class LQR(object):
if isinstance(value, float):
self._gmm_u = pbd.MVN(
mu=np.zeros(self.u_dim), lmbda=10 ** value * np.eye(self.u_dim))
mu=np.zeros(self.u_dim), lmbda=value ** -2 * np.eye(self.u_dim))
else:
self._gmm_u = value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment