gaussian_utils.py 2.49 KB
Newer Older
1
2
import numpy as np

Emmanuel PIGNAT's avatar
Emmanuel PIGNAT committed
3
def gaussian_moment_matching(mus, sigmas, h=None):
4
5
6
7
8
9
10
11
12
	"""

	:param mu:			[np.array([nb_states, nb_timestep, nb_dim])]
				or [np.array([nb_states, nb_dim])]
	:param sigma:		[np.array([nb_states, nb_timestep, nb_dim, nb_dim])]
				or [np.array([nb_states, nb_dim, nb_dim])]
	:param h: 			[np.array([nb_timestep, nb_states])]
	:return:
	"""
Emmanuel PIGNAT's avatar
Emmanuel PIGNAT committed
13
14
15
16

	if h is None:
		h = np.ones((mus.shape[1], mus.shape[0]))/ mus.shape[0]

17
18
19
20
21
22
	if h.ndim == 1:
		h = h[None]

	if mus.ndim == 3:
		mu = np.einsum('ak,kai->ai', h, mus)
		dmus = mus - mu[None]  # nb_timesteps, nb_states, nb_dim
23
24
25
26
27
		if sigmas.ndim == 4:
			sigma = np.einsum('ak,kaij->aij', h, sigmas) + \
				 np.einsum('ak,akij->aij', h, np.einsum('kai,kaj->akij', dmus, dmus))
		else:
			sigma = np.einsum('ak,kij->aij', h, sigmas) + \
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
				 np.einsum('ak,akij->aij', h, np.einsum('kai,kaj->akij', dmus, dmus))

		return mu, sigma
	else:
		mu = np.einsum('ak,ki->ai', h, mus)
		dmus = mus[None] - mu[:, None] # nb_timesteps, nb_states, nb_dim
		sigma = np.einsum('ak,kij->aij', h, sigmas) + \
				 np.einsum('ak,akij->aij',h , np.einsum('aki,akj->akij', dmus, dmus))

		return mu, sigma

def gaussian_conditioning(mu, sigma, data_in, dim_in, dim_out, reg=None):
	"""

	:param mu: 			[np.array([nb_timestep, nb_dim])]
	:param sigma: 		[np.array([nb_timestep, nb_dim, nb_dim])]
	:param data_in: 	[np.array([nb_timestep, nb_dim])]
	:param dim_in: 		[slice]
	:param dim_out: 	[slice]
	:return:
	"""
	if sigma.ndim == 2:

51
52
53
54
55
56
57
58
59
60
61
62
		if reg is None:
			inv_sigma_in_in = np.linalg.inv(sigma[dim_in, dim_in])
		else:
			reg = reg * np.eye(dim_in.stop - dim_in.start)
			inv_sigma_in_in = np.linalg.inv(sigma[dim_in, dim_in] + reg)

		inv_sigma_out_in = np.einsum('ji,jk->ik', sigma[dim_in, dim_out], inv_sigma_in_in)
		mu_cond = mu[dim_out] + np.einsum('ij,aj->ai', inv_sigma_out_in,
											 data_in - mu[dim_in])
		sigma_cond = sigma[dim_out, dim_out] - np.einsum('ij,jk->ik', inv_sigma_out_in,
															sigma[dim_in, dim_out])

63
	else:
64
65
66
67
68
69
70
71
72
73
74
75

		if reg is None:
			inv_sigma_in_in = np.linalg.inv(sigma[:, dim_in, dim_in])
		else:
			reg = reg * np.eye(dim_in.stop - dim_in.start)
			inv_sigma_in_in = np.linalg.inv(sigma[:, dim_in, dim_in] + reg)

		inv_sigma_out_in = np.einsum('aji,ajk->aik', sigma[:, dim_in, dim_out], inv_sigma_in_in)
		mu_cond = mu[:, dim_out] + np.einsum('aij,aj->ai', inv_sigma_out_in,
											 data_in - mu[:, dim_in])
		sigma_cond = sigma[:, dim_out, dim_out] - np.einsum('aij,ajk->aik', inv_sigma_out_in,
															sigma[:, dim_in, dim_out])
76
77

	return mu_cond, sigma_cond