-
Sylvain CALINON authoredSylvain CALINON authored
demo_proMP01.cpp 14.39 KiB
/*
* demo_proMP01.cpp
*
* Conditioning on trajectory distributions with ProMP
*
* @incollection{Paraschos13,
* title = {Probabilistic Movement Primitives},
* author = {Paraschos, A. and Daniel, C. and Peters, J. and Neumann, G.},
* booktitle = NIPS,
* pages = {2616--2624},
* year = {2013},
* publisher = {Curran Associates, Inc.}
* }
* @inproceedings{Rueckert15,
* author = "Rueckert, E. and Mundo, J. and Paraschos, A. and Peters, J. and Neumann, G.",
* title = "Extracting Low-Dimensional Control Variables for Movement Primitives",
* booktitle = ICRA,
* year = "2015",
* pages = "1511--1518",
* address = "Seattle, WA, USA"
* }
*
* Authors: Sylvain Calinon, Philip Abbet
*/
#include <stdio.h>
#include <armadillo>
#include <pbdlib/demonstration.h>
#include <pbdlib/mvn.h>
#include <gfx3.h>
#include <gfx_ui.h>
#include <GLFW/glfw3.h>
#include <imgui.h>
#include <imgui_impl_glfw_gl3.h>
using namespace pbdlib;
/*********************************** TYPES ***********************************/
struct parameters_t {
int nb_states; // Number of components in the GMM
int nb_var; // Dimension of position data (here: x1, x2)
int nb_data; // Number of datapoints in a trajectory
float dt; // Time step (without rescaling, large values such
// as 1 has the advantage of creating clusers based
// on position information)
};
/******************************** CONSTANTS **********************************/
const arma::mat COLORS({
{ 0.00000, 0.00000, 1.00000},
{ 0.00000, 0.50000, 0.00000},
{ 1.00000, 0.00000, 0.00000},
{ 0.00000, 0.75000, 0.75000},
{ 0.75000, 0.00000, 0.75000},
{ 0.75000, 0.75000, 0.00000},
{ 0.25000, 0.25000, 0.25000},
{ 0.00000, 0.00000, 1.00000},
{ 0.00000, 0.50000, 0.00000},
{ 1.00000, 0.00000, 0.00000},
});
/****************************** HELPER FUNCTIONS *****************************/
static void error_callback(int error, const char* description) {
fprintf(stderr, "Error %d: %s\n", error, description);
}
//-----------------------------------------------
// Create a demonstration (with a length of 'nb_samples') from a trajectory
// (of any length)
Demonstration sample_trajectory(const std::vector<arma::vec>& trajectory, int nb_samples) {
// Resampling of the trajectory
arma::vec x(trajectory.size());
arma::vec y(trajectory.size());
arma::vec x2(trajectory.size());
arma::vec y2(trajectory.size());
for (size_t i = 0; i < trajectory.size(); ++i) {
x(i) = trajectory[i](0);
y(i) = trajectory[i](1);
}
arma::vec from_indices = arma::linspace<arma::vec>(0, trajectory.size() - 1, trajectory.size());
arma::vec to_indices = arma::linspace<arma::vec>(0, trajectory.size() - 1, nb_samples);
interp1(from_indices, x, to_indices, x2, "*linear");
interp1(from_indices, y, to_indices, y2, "*linear");
// Create the demonstration
Demonstration demo = Demonstration(2, nb_samples);
mat& datapoints = demo.getDatapoints().getData();
for (int i = 0; i < nb_samples; ++i) {
datapoints(0, i) = x2[i];
datapoints(1, i) = y2[i];
}
return demo;
}
/********************************* FUNCTIONS *********************************/
// The actual computation
//
// Inputs: demos, parameters
// Outputs: mu_mat, mu2_mat, H
void process(std::vector<Demonstration>& demos, const parameters_t& parameters,
arma::mat &mu_mat, arma::mat &mu2_mat, arma::mat &H) {
// Create basis functions (GMM with components equally split in time)
arma::vec timesteps = arma::linspace<arma::vec>(
0, parameters.nb_data - 1, parameters.nb_data
) * parameters.dt;
arma::vec mu = arma::linspace<arma::vec>(
timesteps(0), timesteps(timesteps.n_rows - 1), parameters.nb_states
);
arma::vec sigma = arma::ones<arma::vec>(parameters.nb_states) * 0.01;
// Compute basis functions activation
H = arma::mat(parameters.nb_states, parameters.nb_data);
for (unsigned int i = 0; i < parameters.nb_states; ++i) {
arma::colvec mu_({ mu(i) });
arma::mat sigma_({ sigma(i) });
GaussianDistribution gauss(mu_, sigma_);
H.row(i) = gauss.getPDFValue(timesteps.t()).t();
}
H = H / repmat(sum(H, 0), parameters.nb_states, 1);
//_____ ProMP __________
arma::mat psi(parameters.nb_var * parameters.nb_data,
parameters.nb_var * parameters.nb_states);
for (unsigned int i = 0; i < parameters.nb_data; ++i) {
psi.rows(i * parameters.nb_var, (i + 1) * parameters.nb_var - 1) =
arma::kron(H.col(i).t(), arma::eye(parameters.nb_var, parameters.nb_var));
}
arma::mat w(parameters.nb_var * parameters.nb_states, demos.size());
for (size_t i = 0; i < demos.size(); ++i) {
w.col(i) = pinv(psi) * (arma::mat) arma::vectorise(demos[i].getDatapoints().getData());
}
// Distribution in parameter space
arma::mat mu_w = arma::mean(w, 1);
//-- First regularization term
arma::mat sigma_w = arma::eye(parameters.nb_var * parameters.nb_states,
parameters.nb_var * parameters.nb_states);
if (w.n_cols == 1) {
sigma_w += arma::rowvec(arma::cov(w.t()))[0];
}
else {
sigma_w += arma::cov(w.t());
}
// Trajectory distribution
mu = psi * mu_w;
arma::mat sigma2 = psi * sigma_w * psi.t() + // Second regularization term
0.01 * arma::eye(parameters.nb_var * parameters.nb_data,
parameters.nb_var * parameters.nb_data);
//_____ Conditioning on trajectory distribution __________
// (reconstruction from partial data)
arma::uvec in_out = arma::linspace<arma::uvec>(
0, parameters.nb_data * parameters.nb_var - 1, parameters.nb_data * parameters.nb_var
);
arma::uvec in(2 * parameters.nb_var);
in.head(parameters.nb_var) = in_out.head(parameters.nb_var);
in.tail(parameters.nb_var) = in_out.tail(parameters.nb_var);
arma::uvec out = in_out.subvec(parameters.nb_var, in_out.n_elem - parameters.nb_var - 1);
// Input data
arma::vec mu_in = arma::reshape(
demos[0].getDatapoints().getData().cols(arma::uvec({0, (unsigned) parameters.nb_data - 1})) +
arma::repmat((arma::randu(parameters.nb_var, 1) - 0.5) * 10.0, 1, 2),
parameters.nb_var * 2, 1
);
// Gaussian conditioning with trajectory distribution
arma::vec mu2(in_out.n_elem, fill::zeros);
mu2.rows(in) = mu_in;
arma::mat A = sigma_w * psi.rows(in).t(); // Used to simplify the
arma::mat B = psi.rows(in) * sigma_w * psi.rows(in).t(); // following operation
arma::mat mu_w_tmp = mu_w +
(arma::inv(B.t()) * A.t()).t() * // == A / B (not
(mu2.rows(in) - psi.rows(in) * mu_w); // supported by armadillo)
mu2.rows(out) = psi.rows(out) * mu_w_tmp;
//_____ Results __________
mu_mat = arma::conv_to<arma::mat>::from(mu);
mu_mat.reshape(parameters.nb_var, parameters.nb_data);
mu2_mat = arma::conv_to<arma::mat>::from(mu2);
mu2_mat.reshape(parameters.nb_var, parameters.nb_data);
}
//-----------------------------------------------
int main(int argc, char **argv) {
arma::arma_rng::set_seed_random();
// Parameters
parameters_t parameters;
parameters.nb_states = 6;
parameters.nb_var = 2;
parameters.dt = 0.01;
parameters.nb_data = 100;
// Initialise GLFW
glfwSetErrorCallback(error_callback);
if (!glfwInit())
return -1;
glfwWindowHint(GLFW_SAMPLES, 4);
glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 3);
glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 3);
glfwWindowHint(GLFW_OPENGL_FORWARD_COMPAT, GL_TRUE);
glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE);
// Open a window and create its OpenGL context
GLFWwindow* window = glfwCreateWindow(
800, 800, "Demo Conditioning on trajectory distributions", NULL, NULL
);
glfwMakeContextCurrent(window);
// Take 4k screens into account (framebuffer size != window size)
int win_width = ImGui::GetIO().DisplaySize.x;
int win_height = ImGui::GetIO().DisplaySize.y;
int fb_width, fb_height;
glfwGetFramebufferSize(window, &fb_width, &fb_height);
// Setup GLSL
gfx3::init();
glEnable(GL_DEPTH_TEST);
glEnable(GL_CULL_FACE);
glEnable(GL_LINE_SMOOTH);
glDepthFunc(GL_LESS);
// Setup ImGui
ImGui_ImplGlfwGL3_Init(window, true);
// Creation of the Vertex Array Object (VAO)
GLuint vertexArrayID;
glGenVertexArrays(1, &vertexArrayID);
glBindVertexArray(vertexArrayID);
// Projection matrix
arma::fmat projection = gfx3::orthographic(
(float) fb_width, (float) fb_height, 0.1f, 10.0f);
// Camera matrix
arma::fmat view = gfx3::lookAt(
arma::fvec({0, 0, 3}), // Position of the camera
arma::fvec({0, 0, 0}), // Look at the origin
arma::fvec({0, 1, 0}) // Head is up
);
// Loading of the shaders
gfx3::shader_t colored_shader = gfx3::loadShader(gfx3::VERTEX_SHADER_COLORED,
gfx3::FRAGMENT_SHADER_COLORED);
// Main loop
bool adding_line = false;
std::vector<arma::vec> current_trajectory;
std::vector< std::vector<arma::vec> > original_trajectories;
std::vector<Demonstration> demos;
arma::mat mu;
arma::mat mu2;
arma::mat H;
int current_nb_data = parameters.nb_data;
int current_nb_states = parameters.nb_states;
while (!glfwWindowShouldClose(window)) {
glfwPollEvents();
// Detect when the window was resized
if ((ImGui::GetIO().DisplaySize.x != win_width) ||
(ImGui::GetIO().DisplaySize.y != win_height)) {
win_width = ImGui::GetIO().DisplaySize.x;
win_height = ImGui::GetIO().DisplaySize.y;
glfwGetFramebufferSize(window, &fb_width, &fb_height);
// Update the projection matrix
projection = gfx3::orthographic(
(float) fb_width, (float) fb_height, 0.1f, 10.0f);
}
// If the parameters changed, recompute
if ((parameters.nb_data != current_nb_data) ||
(parameters.nb_states != current_nb_states)) {
demos.clear();
for (size_t i = 0; i < original_trajectories.size(); ++i)
demos.push_back(sample_trajectory(original_trajectories[i], parameters.nb_data));
process(demos, parameters, mu, mu2, H);
current_nb_data = parameters.nb_data;
current_nb_states = parameters.nb_states;
}
// Start the rendering
ImGui_ImplGlfwGL3_NewFrame();
glViewport(0, 0, fb_width, fb_height);
glClearColor(0.6f, 0.6f, 0.6f, 0.0f);
glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
// Draw the currently created demonstration (if any)
if (current_trajectory.size() > 1) {
gfx3::model_t line = gfx3::create_line(
colored_shader, arma::fvec({0.33f, 0.97f, 0.33f}), current_trajectory);
gfx3::draw(line, view, projection);
gfx3::destroy(line);
}
// Draw the demonstrations (if any)
for (auto iter = demos.begin(); iter != demos.end(); ++iter) {
Datapoints& datapoints = iter->getDatapoints();
for (uint i = 0; i < datapoints.getNumPOINTS(); ++i) {
arma::fvec color = arma::conv_to<arma::fvec>::from(H.col(i).t() * COLORS.rows(0, parameters.nb_states - 1));
gfx3::model_t point = gfx3::create_square(colored_shader, color, 10.0f);
point.transforms.position = arma::conv_to<arma::fvec>::from(datapoints.getData(i));
point.transforms.position.reshape(3, point.transforms.position.n_cols);
gfx3::draw(point, view, projection);
gfx3::destroy(point);
}
}
// Draw the computed trajectories
if (mu.n_elem > 0) {
gfx3::model_t line = gfx3::create_line(
colored_shader, arma::fvec({0.0f, 0.0f, 0.0f}), mu);
gfx3::draw(line, view, projection);
gfx3::destroy(line);
line = gfx3::create_line(
colored_shader, arma::fvec({0.8f, 0.0f, 0.0f}), mu2);
gfx3::draw(line, view, projection);
gfx3::destroy(line);
}
// Parameter window
ImGui::SetNextWindowSize(ImVec2(400, 110));
ImGui::Begin("Parameters", NULL,
ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoSavedSettings |
ImGuiWindowFlags_NoMove
);
ImGui::SliderInt("Nb states", ¶meters.nb_states, 2, 10);
ImGui::SliderInt("Nb data", ¶meters.nb_data, 20, 300);
if (ImGui::Button("Clear")){
demos.clear();
original_trajectories.clear();
mu = arma::mat();
mu2 = arma::mat();
H = arma::mat();
}
ImGui::Text("Draw several similar trajectories");
ImGui::End();
// GUI rendering
ImGui::Render();
// Swap buffers
glfwSwapBuffers(window);
// Keyboard input
if (ImGui::IsKeyPressed(GLFW_KEY_ESCAPE))
break;
// Left click: start a new demonstration
if (!adding_line && ImGui::IsMouseClicked(GLFW_MOUSE_BUTTON_1)) {
// Only if not on the UI
if (!ImGui::GetIO().WantCaptureMouse)
adding_line = true;
}
if (adding_line) {
double mouse_x, mouse_y;
glfwGetCursorPos(window, &mouse_x, &mouse_y);
current_trajectory.push_back(gfx3::ui2fb({ mouse_x, mouse_y },
win_width, win_height,
fb_width, fb_height));
}
// Left mouse button release: end the demonstration creation
if (adding_line && !ImGui::IsMouseDown(GLFW_MOUSE_BUTTON_1)) {
adding_line = false;
demos.push_back(sample_trajectory(current_trajectory, parameters.nb_data));
original_trajectories.push_back(current_trajectory);
current_trajectory.clear();
process(demos, parameters, mu, mu2, H);
}
}
// Cleanup
ImGui_ImplGlfwGL3_Shutdown();
glfwTerminate();
return 0;
}