Commit ffe507f8 authored by Hakan GIRGIN's avatar Hakan GIRGIN
Browse files

Code cleaning and compatibility with Python3

parent bc8aff8e
......@@ -5,8 +5,8 @@ from scipy.special import gamma, gammaln
colvec = lambda x: np.array(x).reshape(-1, 1)
rowvec = lambda x: np.array(x).reshape(1, -1)
realmin = np.finfo(np.float32).tiny
realmax = np.finfo(np.float32).max
realmin = np.finfo(np.float64).tiny
realmax = np.finfo(np.float64).max
def limit_gains(gains, gain_limit):
"""
......
This diff is collapsed.
......@@ -48,7 +48,6 @@ class HMM(GMM):
def Trans(self, value):
self.trans = value
def make_finish_state(self, demos, dep_mask=None):
self.has_finish_state = True
self.nb_states += 1
......@@ -75,7 +74,7 @@ class HMM(GMM):
self.priors = np.concatenate([self.priors, np.zeros(1)], axis=0)
pass
def viterbi(self, demo, reg=False):
def viterbi(self, demo, reg=True):
"""
Compute most likely sequence of state given observations
......@@ -284,7 +283,6 @@ class HMM(GMM):
self.Trans += self_trans * np.eye(self.nb_states)
self.init_priors = np.ones(self.nb_states)/ self.nb_states
def gmm_init(self, data, **kwargs):
if isinstance(data, list):
data = np.concatenate(data, axis=0)
......@@ -449,7 +447,6 @@ class HMM(GMM):
return True
print("EM did not converge")
return False
......@@ -466,13 +463,17 @@ class HMM(GMM):
return ll
def condition(self, data_in, dim_in, dim_out, h=None, gmm=False, return_gmm=False):
if gmm:
return super(HMM, self).condition(data_in, dim_in, dim_out, return_gmm=return_gmm)
def condition(self, data_in, dim_in, dim_out, h=None, return_gmm=False):
if return_gmm:
return super().condition(data_in, dim_in, dim_out, return_gmm=return_gmm)
else:
a, _, _, _, _ = self.compute_messages(data_in, marginal=dim_in)
if dim_in == slice(0, 1):
dim_in_msg = []
else:
dim_in_msg = dim_in
a, _, _, _, _ = self.compute_messages(data_in, marginal=dim_in_msg)
return super(HMM, self).condition(data_in, dim_in, dim_out, h=a)
return super().condition(data_in, dim_in, dim_out, h=a)
"""
To ensure compatibility
......
......@@ -4,6 +4,7 @@ from .hmm import *
from .functions import *
from .model import *
class OnlineForwardVariable():
def __init__(self):
self.nbD = None
......@@ -37,7 +38,7 @@ class HSMM(HMM):
@mu_d.setter
def mu_d(self, value):
self._mu_d= value
self._mu_d = value
@property
def sigma_d(self):
......@@ -65,7 +66,6 @@ class HSMM(HMM):
# self.Trans_Pd = self.Trans - np.diag(np.diag(self.Trans)) + realmin
# self.Trans_Pd /= colvec(np.sum(self.Trans_Pd, axis=1))
# init duration components
self.Mu_Pd = np.zeros(self.nb_states)
self.Sigma_Pd = np.zeros(self.nb_states)
......@@ -73,7 +73,7 @@ class HSMM(HMM):
# reconstruct sequence of states from all demonstrations
state_seq = []
trans_list = np.zeros((self.nb_states, self.nb_states))# create a table to count the transition
trans_list = np.zeros((self.nb_states, self.nb_states)) # create a table to count the transition
s = demos if demos is not None else sequ
# reformat transition matrix by counting the transition
......@@ -84,14 +84,13 @@ class HSMM(HMM):
state_seq_tmp = d.tolist()
prev_state = 0
for i, state in enumerate(state_seq_tmp):
if i == 0: # first state of sequence :
if i == 0: # first state of sequence :
pass
elif i == len(state_seq_tmp)-1 and last: # last state of sequence
elif i == len(state_seq_tmp) - 1 and last: # last state of sequence
trans_list[state][state] += 1.0
elif state != prev_state: # transition detected
elif state != prev_state: # transition detected
trans_list[prev_state][state] += 1.0
prev_state = state
......@@ -101,11 +100,11 @@ class HSMM(HMM):
self.Trans_Pd = trans_list
# make sum to one
for i in range(self.nb_states):
sum = np.sum(self.Trans_Pd[i,:])
sum = np.sum(self.Trans_Pd[i, :])
if sum > realmin:
self.Trans_Pd[i,:] /= sum
self.Trans_Pd[i, :] /= sum
#print state_seq
# print state_seq
# list of duration
stateDuration = [[] for i in range(self.nb_states)]
......@@ -113,8 +112,8 @@ class HSMM(HMM):
currState = state_seq[0]
cnt = 1
for i,state in enumerate(state_seq):
if i == len(state_seq)-1: # last state of sequence
for i, state in enumerate(state_seq):
if i == len(state_seq) - 1: # last state of sequence
stateDuration[currState] += [cnt]
elif state == currState:
cnt += 1
......@@ -123,10 +122,10 @@ class HSMM(HMM):
cnt = 1
currState = state
#print stateDuration
# print stateDuration
for i in range(self.nb_states):
self.Mu_Pd[i] = np.mean(stateDuration[i])
if len(stateDuration[i])>1:
if len(stateDuration[i]) > 1:
self.Sigma_Pd[i] = np.std(stateDuration[i]) + dur_reg
else:
self.Sigma_Pd[i] = dur_reg
......@@ -148,7 +147,6 @@ class HSMM(HMM):
return alpha, beta, gamma, zeta, c
def forward_variable_ts(self, n_step, p0=None):
"""
Compute forward variables without any observation of the sequence.
......@@ -158,7 +156,7 @@ class HSMM(HMM):
:return:
"""
nbD = np.round(4* n_step/self.nb_states)
nbD = np.round(4 * n_step // self.nb_states)
self.Pd = np.zeros((self.nb_states, nbD))
# Precomputation of duration probabilities
......@@ -177,7 +175,6 @@ class HSMM(HMM):
h /= np.sum(h, axis=0)
return h
def _fwd_init_ts(self, nbD, p0=None):
"""
Initiatize forward variable computation based only on duration (no observation)
......@@ -189,7 +186,7 @@ class HSMM(HMM):
else:
ALPHA = np.tile(p0, [nbD, 1]).T * self.Pd
S = np.dot(self.Trans_Pd.T, ALPHA[:, [0]]) # use [idx] to keep the dimension
S = np.dot(self.Trans_Pd.T, ALPHA[:, [0]]) # use [idx] to keep the dimension
return ALPHA, S, np.sum(ALPHA, axis=1)
......@@ -198,14 +195,13 @@ class HSMM(HMM):
Step of forward variable computation based only on duration (no observation)
:return:
"""
ALPHA = np.concatenate((S[:, [-1]] * self.Pd[:, 0:nbD-1] + ALPHA[:, 1:nbD],
S[:, [-1]] * self.Pd[:, [nbD-1]]), axis=1)
ALPHA = np.concatenate((S[:, [-1]] * self.Pd[:, 0:nbD - 1] + ALPHA[:, 1:nbD],
S[:, [-1]] * self.Pd[:, [nbD - 1]]), axis=1)
S = np.concatenate((S, np.dot(self.Trans_Pd.T, ALPHA[:, [0]])), axis=1)
return ALPHA, S, np.sum(ALPHA, axis=1)
def forward_variable(self, n_step=None, demo=None, marginal=None, dep=None, p_obs=None):
"""
Compute the forward variable with some observations
......@@ -231,16 +227,14 @@ class HSMM(HMM):
elif isinstance(demo, dict):
n_step = demo['x'].shape[0]
nbD = np.round(4* n_step/self.nb_states)
nbD = np.round(4 * n_step // self.nb_states)
if nbD == 0:
nbD = 10
self.Pd = np.zeros((self.nb_states, nbD))
# Precomputation of duration probabilities
for i in range(self.nb_states):
self.Pd[i, :] = multi_variate_normal(np.arange(nbD), self.Mu_Pd[i], self.Sigma_Pd[i], log=False)
self.Pd[i, :] = self.Pd[i, :] / np.sum(self.Pd[i, :])
self.Pd[i, :] = self.Pd[i, :] / (np.sum(self.Pd[i, :])+realmin)
# compute observation marginal probabilities
p_obs, _ = self.obs_likelihood(demo, dep, marginal, n_step)
......@@ -265,7 +259,6 @@ class HSMM(HMM):
bmx = np.zeros((self.nb_states, 1))
Btmp = priors
ALPHA = np.tile(self.init_priors, [nbD, 1]).T * self.Pd
# r = Btmp.T * np.sum(ALPHA, axis=1)
......@@ -273,7 +266,7 @@ class HSMM(HMM):
bmx[:, 0] = Btmp / r
E = bmx * ALPHA[:, [0]]
S = np.dot(self.Trans_Pd.T, E) # use [idx] to keep the dimension
S = np.dot(self.Trans_Pd.T, E) # use [idx] to keep the dimension
return bmx, ALPHA, S, Btmp * np.sum(ALPHA, axis=1)
......@@ -289,8 +282,8 @@ class HSMM(HMM):
Btmp = obs_marginal
ALPHA = np.concatenate((S[:, [-1]] * self.Pd[:, 0:nbD-1] + bmx[:,[-1]] * ALPHA[:, 1:nbD],
S[:, [-1]] * self.Pd[:, [nbD-1]]), axis=1)
ALPHA = np.concatenate((S[:, [-1]] * self.Pd[:, 0:nbD - 1] + bmx[:, [-1]] * ALPHA[:, 1:nbD],
S[:, [-1]] * self.Pd[:, [nbD - 1]]), axis=1)
r = np.dot(Btmp.T, np.sum(ALPHA, axis=1))
bmx = np.concatenate((bmx, Btmp[:, None] / r), axis=1)
......@@ -305,7 +298,6 @@ class HSMM(HMM):
## SANDBOX ABOVE
########################################################################################
@property
def Sigma_Pd(self):
return self.sigma_d
......@@ -349,10 +341,9 @@ class HSMM(HMM):
except:
# print "No task-parametrized transition matrix : normal transition matrix will be used"
self.Trans_Fw = self.Trans_Pd
else: # compute the transition matrix for current parameters
else: # compute the transition matrix for current parameters
self._update_transition_matrix(tp_param)
# nbD = np.round(2 * n_step/self.nb_states)
nbD = np.round(2 * n_step)
......@@ -361,8 +352,8 @@ class HSMM(HMM):
# Precomputation of duration probabilities
for i in range(self.nb_states):
self.Pd[i, :] = multi_variate_normal(np.arange(nbD), self.Mu_Pd[i], self.Sigma_Pd[i])
if np.sum(self.Pd[i,:])< 1e-50:
self.Pd[i,:] = 1.0 / self.Pd[i, :].shape[0]
if np.sum(self.Pd[i, :]) < 1e-50:
self.Pd[i, :] = 1.0 / self.Pd[i, :].shape[0]
else:
self.Pd[i, :] = self.Pd[i, :] / np.sum(self.Pd[i, :])
......@@ -393,11 +384,11 @@ class HSMM(HMM):
try:
# self.Trans_Fw = self.tp_trans.Prior_Trans
self.Trans_Fw = self.Trans_Pd
# print self.Trans_Fw
# print self.Trans_Fw
except:
print("No task-parametrized transition matrix : normal transition matrix will be used")
self.Trans_Fw = self.Trans_Pd
# print self.Trans_Fw
# print self.Trans_Fw
else: # compute the transition matrix for current parameters
self._update_transition_matrix(tp_param)
......@@ -425,7 +416,7 @@ class HSMM(HMM):
priors /= np.sum(priors)
self.ol.bmx, self.ol.ALPHA, self.ol.S, self.ol.h = self._fwd_init_priors(self.ol.nbD, priors,
start_priors=start_priors)
start_priors=start_priors)
# for i in range(1, n_step):
# bmx, ALPHA, S, h[:, [i]] = self._fwd_step_priors(bmx, ALPHA, S, self.ol.nbD, priors)
......@@ -452,7 +443,6 @@ class HSMM(HMM):
# traceback.print_exc(file=sys.stdout)
return None
def online_forward_variable_prob_predict(self, n_step, priors):
"""
Compute prediction for n_step timestep on the current online forward variable.
......@@ -467,7 +457,7 @@ class HSMM(HMM):
priors /= np.sum(priors)
# bmx, ALPHA, S, h[:, [0]] = self._fwd_init_priors(nbD, priors, start_priors=start_priors)
h[:,[0]] = self.ol.h
h[:, [0]] = self.ol.h
bmx = self.ol.bmx
ALPHA = self.ol.ALPHA
S = self.ol.S
......@@ -479,7 +469,7 @@ class HSMM(HMM):
except:
h = np.tile(self.ol.h, (1, n_step))
# traceback.print_exc(file=sys.stdout)
# traceback.print_exc(file=sys.stdout)
h /= np.sum(h, axis=0)
......@@ -499,7 +489,7 @@ class HSMM(HMM):
if tp_param is None:
# self.Trans_Fw = self.tp_trans.Prior_Trans
self.Trans_Fw = self.Trans_Pd
else: # compute the transition matrix for current parameters
else: # compute the transition matrix for current parameters
self._update_transition_matrix(tp_param)
# nbD = np.round(2 * n_step/self.nb_states)
......@@ -518,16 +508,15 @@ class HSMM(HMM):
h = np.zeros((self.nb_states, n_step))
bmx, ALPHA, S, h[:, [0]] = self._fwd_init_hsum(nbD, Data[:,1])
bmx, ALPHA, S, h[:, [0]] = self._fwd_init_hsum(nbD, Data[:, 1])
for i in range(1, n_step):
bmx, ALPHA, S, h[:, [i]] = self._fwd_step_hsum(bmx, ALPHA, S, nbD, Data[:,i])
bmx, ALPHA, S, h[:, [i]] = self._fwd_step_hsum(bmx, ALPHA, S, nbD, Data[:, i])
h /= np.sum(h, axis=0)
return h
def _fwd_init_priors(self, nbD, priors,start_priors=None):
def _fwd_init_priors(self, nbD, priors, start_priors=None):
"""
:param nbD:
......@@ -546,7 +535,7 @@ class HSMM(HMM):
bmx[:, [0]] = Btmp / r
E = bmx * ALPHA[:, [0]]
S = np.dot(self.Trans_Fw.T, E) # use [idx] to keep the dimension
S = np.dot(self.Trans_Fw.T, E) # use [idx] to keep the dimension
return bmx, ALPHA, S, Btmp * colvec(np.sum(ALPHA, axis=1))
......@@ -562,14 +551,15 @@ class HSMM(HMM):
Btmp = priors
ALPHA = np.concatenate((S[:, [-1]] * self.Pd[:, 0:nbD-1] + bmx[:,[-1]] * ALPHA[:, 1:nbD],
S[:, [-1]] * self.Pd[:, [nbD-1]]), axis=1)
ALPHA = np.concatenate((S[:, [-1]] * self.Pd[:, 0:nbD - 1] + bmx[:, [-1]] * ALPHA[:, 1:nbD],
S[:, [-1]] * self.Pd[:, [nbD - 1]]), axis=1)
r = np.dot(Btmp.T, np.sum(ALPHA, axis=1))
bmx = np.concatenate((bmx, Btmp / r), axis=1)
E = bmx[:, [-1]] * ALPHA[:, [0]]
S = np.concatenate((S, np.dot(self.Trans_Fw.T + np.eye(self.nb_states) * trans_diag + trans_reg, ALPHA[:, [0]])), axis=1)
S = np.concatenate(
(S, np.dot(self.Trans_Fw.T + np.eye(self.nb_states) * trans_diag + trans_reg, ALPHA[:, [0]])), axis=1)
alpha = Btmp * colvec(np.sum(ALPHA, axis=1))
alpha /= np.sum(alpha)
return bmx, ALPHA, S, alpha
......@@ -585,7 +575,7 @@ class HSMM(HMM):
Btmp = np.zeros((self.nb_states, 1))
for i in range(self.nb_states):
Btmp[i] = multi_variate_normal(Data.reshape(-1,1), self.Mu[:,i], self.Sigma[:,:,i]) + 1e-12
Btmp[i] = multi_variate_normal(Data.reshape(-1, 1), self.Mu[:, i], self.Sigma[:, :, i]) + 1e-12
Btmp /= np.sum(Btmp)
......@@ -595,7 +585,7 @@ class HSMM(HMM):
bmx[:, [0]] = Btmp / r
E = bmx * ALPHA[:, [0]]
S = np.dot(self.Trans_Fw.T, E) # use [idx] to keep the dimension
S = np.dot(self.Trans_Fw.T, E) # use [idx] to keep the dimension
return bmx, ALPHA, S, Btmp * colvec(np.sum(ALPHA, axis=1))
......@@ -612,12 +602,12 @@ class HSMM(HMM):
Btmp = np.zeros((self.nb_states, 1))
for i in range(self.nb_states):
Btmp[i] = multi_variate_normal(Data.reshape(-1,1), self.Mu[:,i], self.Sigma[:,:,i]) + 1e-12
Btmp[i] = multi_variate_normal(Data.reshape(-1, 1), self.Mu[:, i], self.Sigma[:, :, i]) + 1e-12
Btmp /= np.sum(Btmp)
ALPHA = np.concatenate((S[:, [-1]] * self.Pd[:, 0:nbD-1] + bmx[:,[-1]] * ALPHA[:, 1:nbD],
S[:, [-1]] * self.Pd[:, [nbD-1]]), axis=1)
ALPHA = np.concatenate((S[:, [-1]] * self.Pd[:, 0:nbD - 1] + bmx[:, [-1]] * ALPHA[:, 1:nbD],
S[:, [-1]] * self.Pd[:, [nbD - 1]]), axis=1)
r = np.dot(Btmp.T, np.sum(ALPHA, axis=1))
bmx = np.concatenate((bmx, Btmp / r), axis=1)
......@@ -626,4 +616,4 @@ class HSMM(HMM):
S = np.concatenate((S, np.dot(self.Trans_Fw.T, ALPHA[:, [0]])), axis=1)
alpha = Btmp * colvec(np.sum(ALPHA, axis=1))
alpha /= np.sum(alpha)
return bmx, ALPHA, S, alpha
\ No newline at end of file
return bmx, ALPHA, S, alpha
import numpy as np
from .functions import *
from .utils.gaussian_utils import gaussian_moment_matching
from .plot import plot_gmm
class Model(object):
"""
Basis class for Gaussian mixture model (GMM), Hidden Markov Model (HMM), Hidden semi-Markov
......@@ -14,7 +14,6 @@ class Model(object):
self.nb_dim = nb_dim
self.nb_states = nb_states
self._mu = None
self._sigma = None # covariance matrix
self._sigma_chol = None # covariance matrix, cholesky decomposition
......@@ -181,7 +180,6 @@ class Model(object):
dGrid = np.ix_(dep, dep)
mask[dGrid] = 1.
return mask
def dep_mask(self, deps):
......@@ -220,7 +218,6 @@ class Model(object):
self._mu = np.array([np.zeros(self.nb_dim) for i in range(self.nb_states)])
self._sigma = np.array([np.eye(self.nb_dim) for i in range(self.nb_states)])
def plot(self, *args, **kwargs):
"""
Plot GMM, circle is 1 std
......@@ -239,8 +236,7 @@ class Model(object):
"""
zs = np.array([np.random.multinomial(1, self.priors) for _ in range(size)]).T
xs = [z[:, None] * np.random.multivariate_normal(m, s, size=size)
for z, m, s in zip(zs, self.mu, self.sigma)]
xs = [z[:, None] * np.random.multivariate_normal(m, s, size=size) for z, m, s in zip(zs, self.mu, self.sigma)]
return np.sum(xs, axis=0)
......@@ -251,7 +247,6 @@ class Model(object):
# get conditional distribution of x_out given x_in for each states p(x_out|x_in, k)
_, sigma_in_out = self.get_marginal(dim_in, dim_out)
inv_sigma_in_in = np.linalg.inv(
sigma_in)
inv_sigma_out_in = np.einsum('aji,ajk->aik', sigma_in_out, inv_sigma_in_in)
......@@ -265,7 +260,6 @@ class Model(object):
def condition(self, data_in, dim_in, dim_out, h=None, return_gmm=False):
"""
:param data_in: [np.array([nb_timestep, nb_dim])
:param dim_in:
:param dim_out:
......@@ -274,7 +268,6 @@ class Model(object):
"""
sample_size = data_in.shape[0]
# compute responsabilities
mu_in, sigma_in = self.get_marginal(dim_in)
......@@ -287,10 +280,10 @@ class Model(object):
h += np.log(self.priors)[:, None]
h = np.exp(h).T
h /= np.sum(h, axis=1, keepdims=True)
h /= (np.sum(h, axis=1, keepdims=True) + realmin)
h = h.T
self._h = h
# self._h = h
mu_out, sigma_out = self.get_marginal(dim_out)
mu_est, sigma_est = ([], [])
......@@ -302,15 +295,14 @@ class Model(object):
inv_sigma_in_in += [np.linalg.inv(sigma_in[i])]
inv_sigma_out_in += [sigma_in_out[i].T.dot(inv_sigma_in_in[-1])]
mu_est += [mu_out[i] + np.einsum('ij,aj->ai',
inv_sigma_out_in[-1], data_in - mu_in[i])]
mu_est += [mu_out[i] + np.einsum('ij,aj->ai', inv_sigma_out_in[-1], data_in - mu_in[i])]
sigma_est += [sigma_out[i] - inv_sigma_out_in[-1].dot(sigma_in_out[i])]
mu_est, sigma_est = (np.asarray(mu_est), np.asarray(sigma_est))
if return_gmm:
return h, mu_est, sigma_est
return h, mu_est, sigma_est
# return np.mean(mu_est, axis=0)
else:
......
import numpy as np
from .gmm import GMM, MVN
from .hmm import HMM
from .functions import multi_variate_normal, multi_variate_t
from .utils.gaussian_utils import gaussian_moment_matching
from scipy.special import gamma, gammaln, logsumexp
from scipy.special import logsumexp
from sklearn import mixture
from scipy.stats import wishart
from .model import *
from .utils import gaussian_moment_matching
class MTMM(GMM):
class MTMM(Model):
"""
Multivariate t-distribution mixture
"""
def __init__(self, *args, **kwargs):
self._nu = kwargs.pop('nu', None)
GMM.__init__(self, *args, **kwargs)
def __init__(self, nb_states=1, nb_dim=None, init_zeros=False, mu=None, lmbda=None, sigma=None, priors=None,
nu=None):
"""
:param
"""
if mu is not None:
nb_states = mu.shape[0]
nb_dim = mu.shape[-1]
super().__init__(nb_states, nb_dim)
# flag to indicate that publishing was not init
self.publish_init = False
self._mu = mu
self._lmbda = lmbda
self._sigma = sigma
self._priors = priors
self._nu = nu
self._k = None
self.cond = None
self.aleatoric = None
self.epistemic = None
if init_zeros:
self.init_zeros()
def __add__(self, other):
if isinstance(other, MVN):
......@@ -41,8 +66,15 @@ class MTMM(GMM):
return mtmm
def get_matching_gmm(self):
return GMM(mu=self.mu, sigma=self.sigma * (self.nu/(self.nu-2.))[:, None, None],
priors=self.priors)
if self.mu.ndim == 3:
return GMM(mu=self.mu, sigma=self.sigma * (self.nu / (self.nu - 2.))[:, None, None, None],
priors=self.priors)
else:
return GMM(mu=self.mu, sigma=self.sigma * (self.nu / (self.nu - 2.))[:, None, None], priors=self.priors)
def get_matching_gaussian(self):
gmm = self.get_matching_gmm()
return gaussian_moment_matching(gmm.mu, gmm.sigma, gmm.priors)
@property
def k(self):
......@@ -63,14 +95,12 @@ class MTMM(GMM):
def condition_gmm(self, data_in, dim_in, dim_out):
sample_size = data_in.shape[0]
# compute responsabilities
# compute responsibilities
mu_in, sigma_in = self.get_marginal(dim_in)
h = np.zeros((self.nb_states, sample_size))
for i in range(self.nb_states):
h[i, :] = multi_variate_t(data_in[None], self.nu[i],
mu_in[i],
sigma_in[i])
h[i, :] = multi_variate_t(data_in[None], self.nu[i], mu_in[i], sigma_in[i])
h += np.log(self.priors)[:, None]
h = np.exp(h).T
......@@ -99,7 +129,6 @@ class MTMM(GMM):
mu_est, sigma_est = (np.asarray(mu_est)[:, 0], np.asarray(sigma_est)[:, 0])
gmm_out = MTMM(nb_states=self.nb_states, nb_dim=mu_out.shape[1])
gmm_out.nu = self.nu + gmm_out.nb_dim
gmm_out.mu = mu_est
......@@ -118,25 +147,28 @@ class MTMM(GMM):
# s = np.sum(np.einsum('kij,kai->kaj', self.lmbda, dx) * dx, axis=2) # [nb_states, nb_samples]
# faster
s = np.sum(np.matmul(self.lmbda[:, None], dx[:, :, :, None])[:, :, :, 0] * dx, axis=2) # [nb_states, nb_samples]
s = np.sum(np.matmul(self.lmbda[:, None], dx[:, :, :, None])[:, :, :, 0] * dx,
axis=2) # [nb_states, nb_samples]