Commit 386a0418 authored by Emmanuel PIGNAT's avatar Emmanuel PIGNAT
Browse files

adding plot MVN

saving temporary variables in mtmm for conditioning and marginal model
saving log_normalization
parent ddb8177c
import numpy as np
from scipy.interpolate import interp1d
from scipy.special import gamma, gammaln
colvec = lambda x: np.array(x).reshape(-1, 1)
rowvec = lambda x: np.array(x).reshape(1, -1)
......@@ -262,7 +263,7 @@ def multi_variate_t(x, nu, mu, sigma=None, log=True, gmm=False, lmbda=None):
:param log: bool
:return:
"""
from scipy.special import gamma, gammaln
if not gmm:
if type(sigma) is float:
sigma = np.array(sigma, ndmin=2)
......
......@@ -27,6 +27,8 @@ class Model(object):
self._has_finish_state = False
self._has_init_state = False
self._log_normalization = None
@property
def has_finish_state(self):
return self._has_finish_state
......@@ -148,6 +150,7 @@ class Model(object):
self._lmbda = None
self._sigma_chol = None
self._sigma = value
self._log_normalization = None
@property
def lmbda(self):
......@@ -166,6 +169,7 @@ class Model(object):
self._sigma = None # reset sigma
self._sigma_chol = None
self._lmbda = value
self._log_normalization = None
def get_dep_mask(self, deps):
mask = np.eye(self.nb_dim)
......
......@@ -2,6 +2,7 @@ import numpy as np
from .gmm import GMM, MVN
from functions import multi_variate_normal, multi_variate_t
from utils import gaussian_moment_matching
from scipy.special import gamma, gammaln
class MTMM(GMM):
"""
......@@ -29,6 +30,15 @@ class MTMM(GMM):
else:
raise NotImplementedError
def marginal_model(self, dims):
mtmm = MTMM(nb_dim=dims.stop - dims.start, nb_states=self.nb_states)
mtmm.priors = self.priors
mtmm.mu = self.mu[:, dims]
mtmm.sigma = self.sigma[:, dims, dims]
mtmm.nu = self.nu
return mtmm
@property
def k(self):
return self._k
......@@ -51,7 +61,6 @@ class MTMM(GMM):
# compute responsabilities
mu_in, sigma_in = self.get_marginal(dim_in)
h = np.zeros((self.nb_states, sample_size))
for i in range(self.nb_states):
h[i, :] = multi_variate_t(data_in[None], self.nu[i],
......@@ -94,8 +103,24 @@ class MTMM(GMM):
return gmm_out
# @profile
def log_prob_components(self, x):
dx = self.mu[:, None] - x[None] # [nb_states, nb_samples, nb_dim]
s = np.sum(np.einsum('kij,kai->kaj', self.lmbda, dx) * dx, axis=2) # [nb_states, nb_samples]
log_norm = self.log_normalization[:, None]
return log_norm + (-(self.nu + self.nb_dim) / 2)[:, None] * np.log(1 + s/ self.nu[:, None])
def condition(self, data_in, dim_in, dim_out, h=None, return_gmm=False, reg_in=1e-20):
@property
def log_normalization(self):
if self._log_normalization is None:
self._log_normalization = gammaln((self.nu + self.nb_dim) / 2) + 0.5 * np.linalg.slogdet(self.lmbda)[1] - \
gammaln(self.nu / 2) - self.nb_dim / 2. * (np.log(self.nu) + np.log(np.pi))
return self._log_normalization
# @profile
def condition(self, data_in, dim_in, dim_out, h=None, return_gmm=False, reg_in=1e-20,
concat=True, return_linear=False, tmp=False):
"""
[1] M. Hofert, 'On the Multivariate t Distribution,' R J., vol. 5, pp. 129-136, 2013.
......@@ -112,56 +137,103 @@ class MTMM(GMM):
:return:
"""
if data_in.ndim == 1:
data_in = data_in[None]
was_not_batch = True
else:
was_not_batch = False
sample_size = data_in.shape[0]
if tmp and hasattr(self, '_tmp_slices') and not self._tmp_slices == (dim_in, dim_out):
del self._tmp_inv_sigma_out_in, self._tmp_inv_sigma_in_in, self._tmp_slices, self._tmp_marginal_model
# compute marginal probabilities of states given observation p(k|x_in)
mu_in, sigma_in = self.get_marginal(dim_in)
if h is None:
h = np.zeros((self.nb_states, sample_size))
for i in range(self.nb_states):
h[i, :] = multi_variate_t(data_in, self.nu[i],
mu_in[i],
sigma_in[i])
if tmp and hasattr(self, '_tmp_marginal_model'):
marginal_model = self._tmp_marginal_model
else:
marginal_model = self.marginal_model(dim_in)
if tmp:
self._tmp_marginal_model = marginal_model
if h is None:
h = marginal_model.log_prob_components(data_in)
h += np.log(self.priors)[:, None]
h = np.exp(h).T
h /= np.sum(h, axis=1, keepdims=True)
h = h.T
#[nb_samples, nb_states]
self._h = h # storing value
mu_out, sigma_out = self.get_marginal(dim_out) # get marginal distribution of x_out
mu_est, sigma_est = ([], [])
# get conditional distribution of x_out given x_in for each states p(x_out|x_in, k)
inv_sigma_in_in, inv_sigma_out_in = ([], [])
_, sigma_in_out = self.get_marginal(dim_in, dim_out)
for i in range(self.nb_states):
inv_sigma_in_in += [np.linalg.inv(sigma_in[i] + reg_in * np.eye(sigma_in.shape[-1]))]
inv_sigma_out_in += [sigma_in_out[i].T.dot(inv_sigma_in_in[-1])]
dx = data_in - mu_in[i]
mu_est += [mu_out[i] + np.einsum('ij,aj->ai',
inv_sigma_out_in[-1], dx)]
if not concat: # faster when more datapointsS
mu_est, sigma_est = ([], [])
inv_sigma_in_in, inv_sigma_out_in = ([], [])
s = np.sum(np.einsum('ai,ij->aj', dx, inv_sigma_in_in[-1]) * dx, axis=1)
a = (self.nu[i] + s)/(self.nu[i] + mu_in.shape[1])
for i in range(self.nb_states):
inv_sigma_in_in += [np.linalg.inv(sigma_in[i] + reg_in * np.eye(sigma_in.shape[-1]))]
inv_sigma_out_in += [sigma_in_out[i].T.dot(inv_sigma_in_in[-1])]
dx = data_in - mu_in[i]
mu_est += [mu_out[i] + np.einsum('ij,aj->ai',
inv_sigma_out_in[-1], dx)]
sigma_est += [a[:, None, None] *
(sigma_out[i] - inv_sigma_out_in[-1].dot(sigma_in_out[i]))[None]]
s = np.sum(np.einsum('ai,ij->aj', dx, inv_sigma_in_in[-1]) * dx, axis=1)
a = (self.nu[i] + s)/(self.nu[i] + mu_in.shape[1])
mu_est, sigma_est = (np.asarray(mu_est), np.asarray(sigma_est))
sigma_est += [a[:, None, None] *
(sigma_out[i] - inv_sigma_out_in[-1].dot(sigma_in_out[i]))[None]]
mu_est, sigma_est = (np.asarray(mu_est), np.asarray(sigma_est))
else:
# test if slices change and reset
if tmp and hasattr(self, '_tmp_inv_sigma_in_in'):
inv_sigma_in_in = self._tmp_inv_sigma_in_in
inv_sigma_out_in = self._tmp_inv_sigma_out_in
else:
inv_sigma_in_in = np.linalg.inv(sigma_in + reg_in * np.eye(sigma_in.shape[-1])[None])
inv_sigma_out_in = np.einsum('aji,ajk->aik', sigma_in_out, inv_sigma_in_in)
if tmp and not hasattr(self, '_tmp_inv_sigma_in_in'):
self._tmp_inv_sigma_in_in = inv_sigma_in_in
self._tmp_inv_sigma_out_in = inv_sigma_out_in
self._tmp_slices = (dim_in, dim_out)
# [nb_states, nb_sample, nb_dim]
dx = data_in[None] - mu_in[:, None]
mu_est = mu_out[:, None] + np.einsum('aij,abj->abi', inv_sigma_out_in, dx)
s = np.sum(np.einsum('kij,kai->kaj',inv_sigma_in_in, dx) * dx, axis=2)
a = (self.nu[:, None] + s) / (self.nu[:, None] + mu_in.shape[1])
sigma_est = a[:, :, None, None] * (sigma_out - np.einsum(
'aij,ajk->aik', inv_sigma_out_in, sigma_in_out))[:, None]
nu = self.nu + mu_in.shape[1]
# the conditional distribution is now a still a mixture
if return_gmm:
return mu_est, sigma_est * nu/(nu-2.)
elif return_linear:
As = inv_sigma_out_in
bs = mu_out - np.matmul(inv_sigma_out_in, mu_in[:, :, None])[:, :, 0]
A = np.einsum('ak,kij->aij', h, As)
b = np.einsum('ak,ki->ai', h, bs)
if was_not_batch:
return A[0], b[0], gaussian_moment_matching(mu_est, sigma_est * (nu/(nu-2.))[:, None, None, None], h)[1][0]
else:
return A, b, gaussian_moment_matching(mu_est, sigma_est * (nu/(nu-2.))[:, None, None, None], h)[1]
else:
# apply moment matching to get a single MVN for each datapoint
return gaussian_moment_matching(mu_est, sigma_est * (nu/(nu-2.))[:, None, None, None], h.T)
return gaussian_moment_matching(mu_est, sigma_est * (nu/(nu-2.))[:, None, None, None], h)
def get_pred_post_uncertainty(self, data_in, dim_in, dim_out):
"""
......@@ -298,68 +370,25 @@ class VBayesianGMM(MTMM):
self._sk_model.fit(data)
states = np.where(self._sk_model.weights_ > -5e-2)[0]
self.nb_states = states.shape[0]
self.nb_states = self._sk_model.weights_.shape[0]
# see [1] K. P. Murphy, 'Conjugate Bayesian analysis of the Gaussian distribution,' vol. 0, no. 7, 2007. par 9.4
# or [1] E. Fox, 'Bayesian nonparametric learning of complex dynamical phenomena,' 2009, p
# m.covariances_ = W_k_^-1/m.degrees_of_freedom_
m = self._sk_model
self.priors = np.copy(m.weights_[states])
self.mu = np.copy(m.means_[states])
self.k = np.copy(m.mean_precision_[states])
self.priors = np.copy(m.weights_)
self.mu = np.copy(m.means_)
self.k = np.copy(m.mean_precision_)
self.nu = np.copy(m.degrees_of_freedom_[states]) - self.nb_dim + 1
self.nu = np.copy(m.degrees_of_freedom_) - self.nb_dim + 1
w_k = np.linalg.inv(m.covariances_ * m.degrees_of_freedom_[:, None, None])
l_k = ((m.degrees_of_freedom_[:, None, None] + 1 - self.nb_dim) * m.mean_precision_[:, None, None])/ \
(1. + m.mean_precision_[:, None, None]) * w_k
self.sigma = np.copy(np.linalg.inv(l_k))[states]
# self.sigma = np.copy(self._sk_model.covariances_[states]) * (
# self.k[:, None, None] + 1) * self.nu[:, None, None] \
# / (self.k[:, None, None] * (self.nu[:, None, None] - self.nb_dim + 1))
# add new state, base measure TODO make not heuristic
if dp:
self.priors = np.concatenate([self.priors, 0.02 * np.ones((1,))], 0)
self.priors /= np.sum(self.priors)
self.mu = np.concatenate([self.mu, np.zeros((1, self.nb_dim))], axis=0)
if cov is None:
cov = mean_scale ** 2 * np.eye(self.nb_dim)
self.sigma = np.concatenate([self.sigma, cov[None]], axis=0)
self.k = np.concatenate([self.k, np.ones((1, ))], axis=0)
_nu_p = self._sk_model.degrees_of_freedom_prior_
self.nu = np.concatenate([self.nu, _nu_p * np.ones((1, ))], axis=0)
self.nb_states = states.shape[0] + 1
# add new state, base measure TODO make not heuristic
if dp:
self.priors = np.concatenate([self.priors, 0.02 * np.ones((1,))], 0)
self.priors /= np.sum(self.priors)
self.mu = np.concatenate([self.mu, np.zeros((1, self.nb_dim))], axis=0)
if cov is None:
cov = mean_scale ** 2 * np.eye(self.nb_dim)
self.sigma = np.concatenate([self.sigma, cov[None]], axis=0)
self.k = np.concatenate([self.k, np.ones((1, ))], axis=0)
self.nu = np.concatenate([self.nu, np.ones((1, ))], axis=0)
self.nb_states = states.shape[0] + 1
self.sigma = np.copy(np.linalg.inv(l_k))
def condition(self, *args, **kwargs):
"""
......@@ -378,9 +407,10 @@ class VBayesianGMM(MTMM):
:return:
"""
if not kwargs.get('samples', False):
kwargs.pop('return_samples', True)
return MTMM.condition(self, *args, **kwargs)
kwargs.pop('samples')
return_samples = kwargs.pop('return_samples', False)
mus, sigmas = [], []
for _gmm in self.posterior_samples:
......@@ -394,7 +424,11 @@ class VBayesianGMM(MTMM):
sigma = np.mean(sigmas, axis=0) + \
np.einsum('aki,akj->kij', dmu, dmu) / len(self.posterior_samples)
return mu, sigma
if return_samples:
return mu, sigma, mus
else:
return mu, sigma
class VMBayesianGMM(VBayesianGMM):
def __init__(self, n, sk_parameters, *args, **kwargs):
......
......@@ -91,6 +91,9 @@ class MVN(object):
self._eta = None
def plot(self, *args, **kwargs):
pbd.plot_gaussian(self.mu, self.sigma, *args, **kwargs)
@property
def muT(self):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment