Commit a8e6efea authored by Emmanuel PIGNAT's avatar Emmanuel PIGNAT
Browse files

some utils for memory LQR

parent 2515737e
import numpy as np
def vec(x):
"""
:param x: [batch_shape, d1, d2]
:return: [batch_shape, d1 * d2]
"""
d1, d2 = [s.value for s in x.shape[1:]]
return tf.reshape(x, (-1, d1 * d2))
def log_normalize(x, axis=0):
return x - tf.reduce_logsumexp(x, axis=axis)
def matmatmul(a=None, b=None, transpose_a=False, transpose_b=False):
"""
:param a: ...ij
:param b: ...jk
:param transpose_a:
:param transpose_b:
:return: ...ik
"""
if a is None:
return b
if b is None:
return a
if a.shape.ndims == b.shape.ndims:
return tf.matmul(a, b, transpose_a=transpose_a, transpose_b=transpose_b)
else:
idx_a = 'ij,' if not transpose_a else 'ji,'
idx_b = 'jk->' if not transpose_b else 'kj->'
idx_c = 'ik'
b_a = ['', 'a', 'ab', 'abc'][a.shape.ndims-2]
b_b = ['', 'a', 'ab', 'abc'][b.shape.ndims-2]
b_c = ['', 'a', 'ab', 'abc'][max([a.shape.ndims, b.shape.ndims])-2]
return tf.einsum(
b_a+idx_a+b_b+idx_b+b_c+idx_c, a, b
)
def matvecmul(m=None, v=None, transpose_m=False):
if m is None:
return v
if m.shape.ndims == 2 and v.shape.ndims == 1:
return tf.matmul(m, v[:, None], transpose_a=transpose_m)[:, 0]
elif m.shape.ndims == 2 and v.shape.ndims == 2:
return tf.transpose(tf.matmul(m, v, transpose_b=True, transpose_a=transpose_m))
else:
idx_a = 'ij,' if not transpose_m else 'ji,'
idx_b = 'j->'
idx_c = 'i'
b_a = ['', 'a', 'ab', 'abc'][m.shape.ndims-2]
b_b = ['', 'a', 'ab', 'abc'][v.shape.ndims-1]
b_c = ['', 'a', 'ab', 'abc'][max([m.shape.ndims-2, v.shape.ndims-1])]
return tf.einsum(
b_a+idx_a+b_b+idx_b+b_c+idx_c, m, v
)
def vecvecadd(v1=None, v2=None, opposite_a=False, opposite_b=False):
if v1 is None:
return v2
if v2 is None:
return v1
if opposite_a: v1 = -v1
if opposite_b: v2 = -v2
if v1.shape.ndims == v2.shape.ndims:
return v1 + v2
elif v1.shape.ndims == 2 and v2.shape.ndims == 1 or \
v1.shape.ndims == 3 and v2.shape.ndims == 2:
return v1 + v2[None, :]
elif v1.shape.ndims == 1 and v2.shape.ndims == 2 or \
v1.shape.ndims == 2 and v2.shape.ndims == 3:
return v1[None, :] + v2
else:
raise NotImplementedError
......@@ -23,6 +23,68 @@ def get_canonical(nb_dim, nb_deriv=2, dt=0.01):
return np.kron(A1d, np.eye(nb_dim)), np.kron(B1d, np.eye(nb_dim))
def multi_timestep_matrix(A, B, nb_step=4):
xi_dim, u_dim = A.shape[0], B.shape[1]
_A = np.zeros((xi_dim * nb_step, xi_dim * nb_step))
_B = np.zeros((xi_dim * nb_step, u_dim))
_A[:xi_dim, :xi_dim] = A
for i in range(1, nb_step):
_A[xi_dim * i:xi_dim * (i + 1), xi_dim * (i - 1):xi_dim * i] = np.eye(xi_dim)
_B[:xi_dim, :u_dim] = B
return _A, _B
def fd_transform(d, xi_dim, nb_past, dt=0.1):
"""
Finite difference transform matrix
:param d:
:param xi_dim:
:param nb_past:
:param dt:
:return:
"""
T_1 = np.zeros((xi_dim * nb_past, xi_dim * (nb_past - d)))
for i in range(nb_past - d):
T_1[xi_dim * i:xi_dim * (i + 1), xi_dim * (i):xi_dim * (i + 1)] = np.eye(
xi_dim) * dt ** d
nb = [[1],
[1, -1],
[1., -2, 1],
[1., -3, 3, -1],
[1., -4., 6., -4., 1.]]
for j in range(d):
T_1[xi_dim * (i + 1 + j):xi_dim * (i + 2 + j), xi_dim * i:xi_dim * (i + 1)] = \
nb[d][j + 1] * np.eye(xi_dim) * dt ** d
return T_1
def multi_timestep_fd_q(rs, xi_dim, dt):
"""
:param rs: list of std deviations of derivatives
:param xi_dim:
:param nb_past:
:param dt:
:return:
"""
nb_past = len(rs)
Qs = []
for i in range(nb_past):
T = fd_transform(i + 1, xi_dim, nb_past, dt)
Q = np.eye((xi_dim * (nb_past - i - 1))) * rs[i] ** -2
Qs += [T.dot(Q).dot(T.T)]
return np.sum(Qs, axis=0)
def lifted_noise_matrix(A=None, B=None, nb_dim=3, dt=0.01, horizon=50):
r"""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment