Skip to content
Snippets Groups Projects
demo_proMP01.cpp 14.39 KiB
/*
 * demo_proMP01.cpp
 *
 * Conditioning on trajectory distributions with ProMP
 *
 * @incollection{Paraschos13,
 *   title = {Probabilistic Movement Primitives},
 *   author = {Paraschos, A. and Daniel, C. and Peters, J. and Neumann, G.},
 *   booktitle = NIPS,
 *   pages = {2616--2624},
 *   year = {2013},
 *   publisher = {Curran Associates, Inc.}
 * }
 * @inproceedings{Rueckert15,
 *   author = "Rueckert, E. and  Mundo, J. and  Paraschos, A. and  Peters, J. and  Neumann, G.",
 *   title = "Extracting Low-Dimensional Control Variables for Movement Primitives",
 *   booktitle =	ICRA,
 *   year = "2015",
 *   pages = "1511--1518",
 *   address = "Seattle, WA, USA"
 * }
 *
 * Authors: Sylvain Calinon, Philip Abbet
 */


#include <stdio.h>
#include <armadillo>
#include <pbdlib/demonstration.h>
#include <pbdlib/mvn.h>

#include <gfx3.h>
#include <gfx_ui.h>
#include <GLFW/glfw3.h>
#include <imgui.h>
#include <imgui_impl_glfw_gl3.h>


using namespace pbdlib;


/*********************************** TYPES ***********************************/

struct parameters_t {
    int   nb_states;    // Number of components in the GMM
    int   nb_var;       // Dimension of position data (here: x1, x2)
    int   nb_data;      // Number of datapoints in a trajectory
    float dt;           // Time step (without rescaling, large values such
                        // as 1 has the advantage of creating clusers based
                        // on position information)
};


/******************************** CONSTANTS **********************************/

const arma::mat COLORS({
    { 0.00000, 0.00000, 1.00000},
    { 0.00000, 0.50000, 0.00000},
    { 1.00000, 0.00000, 0.00000},
    { 0.00000, 0.75000, 0.75000},
    { 0.75000, 0.00000, 0.75000},
    { 0.75000, 0.75000, 0.00000},
    { 0.25000, 0.25000, 0.25000},
    { 0.00000, 0.00000, 1.00000},
    { 0.00000, 0.50000, 0.00000},
    { 1.00000, 0.00000, 0.00000},
});


/****************************** HELPER FUNCTIONS *****************************/

static void error_callback(int error, const char* description) {
    fprintf(stderr, "Error %d: %s\n", error, description);
}

//-----------------------------------------------

// Create a demonstration (with a length of 'nb_samples') from a trajectory
// (of any length)
Demonstration sample_trajectory(const std::vector<arma::vec>& trajectory, int nb_samples) {

    // Resampling of the trajectory
    arma::vec x(trajectory.size());
    arma::vec y(trajectory.size());
    arma::vec x2(trajectory.size());
    arma::vec y2(trajectory.size());

    for (size_t i = 0; i < trajectory.size(); ++i) {
        x(i) = trajectory[i](0);
        y(i) = trajectory[i](1);
    }

    arma::vec from_indices = arma::linspace<arma::vec>(0, trajectory.size() - 1, trajectory.size());
    arma::vec to_indices = arma::linspace<arma::vec>(0, trajectory.size() - 1, nb_samples);

    interp1(from_indices, x, to_indices, x2, "*linear");
    interp1(from_indices, y, to_indices, y2, "*linear");

    // Create the demonstration
    Demonstration demo = Demonstration(2, nb_samples);
    mat& datapoints = demo.getDatapoints().getData();
    for (int i = 0; i < nb_samples; ++i) {
        datapoints(0, i) = x2[i];
        datapoints(1, i) = y2[i];
    }

    return demo;
}


/********************************* FUNCTIONS *********************************/

// The actual computation
//
// Inputs: demos, parameters
// Outputs: mu_mat, mu2_mat, H
void process(std::vector<Demonstration>& demos, const parameters_t& parameters,
             arma::mat &mu_mat, arma::mat &mu2_mat, arma::mat &H) {

    // Create basis functions (GMM with components equally split in time)
    arma::vec timesteps = arma::linspace<arma::vec>(
        0, parameters.nb_data - 1, parameters.nb_data
    ) * parameters.dt;

    arma::vec mu = arma::linspace<arma::vec>(
        timesteps(0), timesteps(timesteps.n_rows - 1), parameters.nb_states
    );

    arma::vec sigma = arma::ones<arma::vec>(parameters.nb_states) * 0.01;

    // Compute basis functions activation
    H = arma::mat(parameters.nb_states, parameters.nb_data);

    for (unsigned int i = 0; i < parameters.nb_states; ++i) {
        arma::colvec mu_({ mu(i) });
        arma::mat sigma_({ sigma(i) });

        GaussianDistribution gauss(mu_, sigma_);

        H.row(i) = gauss.getPDFValue(timesteps.t()).t();
    }

    H = H / repmat(sum(H, 0), parameters.nb_states, 1);


    //_____ ProMP __________

    arma::mat psi(parameters.nb_var * parameters.nb_data,
                  parameters.nb_var * parameters.nb_states);

    for (unsigned int i = 0; i < parameters.nb_data; ++i) {
        psi.rows(i * parameters.nb_var, (i + 1) * parameters.nb_var - 1) =
            arma::kron(H.col(i).t(), arma::eye(parameters.nb_var, parameters.nb_var));
    }

    arma::mat w(parameters.nb_var * parameters.nb_states, demos.size());
    for (size_t i = 0; i < demos.size(); ++i) {
        w.col(i) = pinv(psi) * (arma::mat) arma::vectorise(demos[i].getDatapoints().getData());
    }

    // Distribution in parameter space
    arma::mat mu_w = arma::mean(w, 1);

    //-- First regularization term
    arma::mat sigma_w = arma::eye(parameters.nb_var * parameters.nb_states,
                                  parameters.nb_var * parameters.nb_states);

    if (w.n_cols == 1) {
        sigma_w += arma::rowvec(arma::cov(w.t()))[0];
    }
    else {
        sigma_w += arma::cov(w.t());
    }

    // Trajectory distribution
    mu = psi * mu_w;

    arma::mat sigma2 = psi * sigma_w * psi.t() +    // Second regularization term
                       0.01 * arma::eye(parameters.nb_var * parameters.nb_data,
                                        parameters.nb_var * parameters.nb_data);


    //_____ Conditioning on trajectory distribution __________
    // (reconstruction from partial data)

    arma::uvec in_out = arma::linspace<arma::uvec>(
        0, parameters.nb_data * parameters.nb_var - 1, parameters.nb_data * parameters.nb_var
    );

    arma::uvec in(2 * parameters.nb_var);

    in.head(parameters.nb_var) = in_out.head(parameters.nb_var);
    in.tail(parameters.nb_var) = in_out.tail(parameters.nb_var);

    arma::uvec out = in_out.subvec(parameters.nb_var, in_out.n_elem - parameters.nb_var - 1);

    // Input data
    arma::vec mu_in = arma::reshape(
        demos[0].getDatapoints().getData().cols(arma::uvec({0, (unsigned) parameters.nb_data - 1})) +
        arma::repmat((arma::randu(parameters.nb_var, 1) - 0.5) * 10.0, 1, 2),
        parameters.nb_var * 2, 1
    );

    // Gaussian conditioning with trajectory distribution
    arma::vec mu2(in_out.n_elem, fill::zeros);

    mu2.rows(in) = mu_in;

    arma::mat A = sigma_w * psi.rows(in).t();                   // Used to simplify the
    arma::mat B = psi.rows(in) * sigma_w * psi.rows(in).t();    // following operation

    arma::mat mu_w_tmp = mu_w +
                        (arma::inv(B.t()) * A.t()).t() *        // == A / B (not
                        (mu2.rows(in) - psi.rows(in) * mu_w);   // supported by armadillo)

    mu2.rows(out) = psi.rows(out) * mu_w_tmp;


    //_____ Results __________

    mu_mat = arma::conv_to<arma::mat>::from(mu);
    mu_mat.reshape(parameters.nb_var, parameters.nb_data);

    mu2_mat = arma::conv_to<arma::mat>::from(mu2);
    mu2_mat.reshape(parameters.nb_var, parameters.nb_data);
}

//-----------------------------------------------

int main(int argc, char **argv) {
    arma::arma_rng::set_seed_random();

    // Parameters
    parameters_t parameters;
    parameters.nb_states = 6;
    parameters.nb_var    = 2;
    parameters.dt        = 0.01;
    parameters.nb_data   = 100;


    // Initialise GLFW
    glfwSetErrorCallback(error_callback);

    if (!glfwInit())
        return -1;

    glfwWindowHint(GLFW_SAMPLES, 4);
    glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 3);
    glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 3);
    glfwWindowHint(GLFW_OPENGL_FORWARD_COMPAT, GL_TRUE);
    glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE);

    // Open a window and create its OpenGL context
    GLFWwindow* window = glfwCreateWindow(
        800, 800, "Demo Conditioning on trajectory distributions", NULL, NULL
    );

    glfwMakeContextCurrent(window);

    // Take 4k screens into account (framebuffer size != window size)
    int win_width = ImGui::GetIO().DisplaySize.x;
    int win_height = ImGui::GetIO().DisplaySize.y;

    int fb_width, fb_height;
    glfwGetFramebufferSize(window, &fb_width, &fb_height);

    // Setup GLSL
    gfx3::init();
    glEnable(GL_DEPTH_TEST);
    glEnable(GL_CULL_FACE);
    glEnable(GL_LINE_SMOOTH);
    glDepthFunc(GL_LESS);

    // Setup ImGui
    ImGui_ImplGlfwGL3_Init(window, true);

    // Creation of the Vertex Array Object (VAO)
    GLuint vertexArrayID;
    glGenVertexArrays(1, &vertexArrayID);
    glBindVertexArray(vertexArrayID);


    // Projection matrix
    arma::fmat projection = gfx3::orthographic(
        (float) fb_width, (float) fb_height, 0.1f, 10.0f);

    // Camera matrix
    arma::fmat view = gfx3::lookAt(
        arma::fvec({0, 0, 3}),  // Position of the camera
        arma::fvec({0, 0, 0}),  // Look at the origin
        arma::fvec({0, 1, 0})   // Head is up
    );


    // Loading of the shaders
    gfx3::shader_t colored_shader = gfx3::loadShader(gfx3::VERTEX_SHADER_COLORED,
                                                     gfx3::FRAGMENT_SHADER_COLORED);


    // Main loop
    bool adding_line = false;
    std::vector<arma::vec> current_trajectory;
    std::vector< std::vector<arma::vec> > original_trajectories;
    std::vector<Demonstration> demos;
    arma::mat mu;
    arma::mat mu2;
    arma::mat H;
    int current_nb_data = parameters.nb_data;
    int current_nb_states = parameters.nb_states;

    while (!glfwWindowShouldClose(window)) {
        glfwPollEvents();

        // Detect when the window was resized
        if ((ImGui::GetIO().DisplaySize.x != win_width) ||
            (ImGui::GetIO().DisplaySize.y != win_height)) {

            win_width = ImGui::GetIO().DisplaySize.x;
            win_height = ImGui::GetIO().DisplaySize.y;

            glfwGetFramebufferSize(window, &fb_width, &fb_height);

            // Update the projection matrix
            projection = gfx3::orthographic(
                (float) fb_width, (float) fb_height, 0.1f, 10.0f);
        }


        // If the parameters changed, recompute
        if ((parameters.nb_data != current_nb_data) ||
            (parameters.nb_states != current_nb_states)) {

            demos.clear();

            for (size_t i = 0; i < original_trajectories.size(); ++i)
                demos.push_back(sample_trajectory(original_trajectories[i], parameters.nb_data));

            process(demos, parameters, mu, mu2, H);

            current_nb_data = parameters.nb_data;
            current_nb_states = parameters.nb_states;
        }


        // Start the rendering
        ImGui_ImplGlfwGL3_NewFrame();
        glViewport(0, 0, fb_width, fb_height);
        glClearColor(0.6f, 0.6f, 0.6f, 0.0f);
        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);

        // Draw the currently created demonstration (if any)
        if (current_trajectory.size() > 1) {
            gfx3::model_t line = gfx3::create_line(
                colored_shader, arma::fvec({0.33f, 0.97f, 0.33f}), current_trajectory);

            gfx3::draw(line, view, projection);

            gfx3::destroy(line);
        }


        // Draw the demonstrations (if any)
        for (auto iter = demos.begin(); iter != demos.end(); ++iter) {
            Datapoints& datapoints = iter->getDatapoints();

            for (uint i = 0; i < datapoints.getNumPOINTS(); ++i) {
                arma::fvec color = arma::conv_to<arma::fvec>::from(H.col(i).t() * COLORS.rows(0, parameters.nb_states - 1));

                gfx3::model_t point = gfx3::create_square(colored_shader, color, 10.0f);

                point.transforms.position = arma::conv_to<arma::fvec>::from(datapoints.getData(i));
                point.transforms.position.reshape(3, point.transforms.position.n_cols);

                gfx3::draw(point, view, projection);

                gfx3::destroy(point);
            }
        }


        // Draw the computed trajectories
        if (mu.n_elem > 0) {
            gfx3::model_t line = gfx3::create_line(
                colored_shader, arma::fvec({0.0f, 0.0f, 0.0f}), mu);

            gfx3::draw(line, view, projection);

            gfx3::destroy(line);

            line = gfx3::create_line(
                colored_shader, arma::fvec({0.8f, 0.0f, 0.0f}), mu2);

            gfx3::draw(line, view, projection);

            gfx3::destroy(line);
        }


        // Parameter window
        ImGui::SetNextWindowSize(ImVec2(400, 110));
        ImGui::Begin("Parameters", NULL,
                     ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoSavedSettings |
                     ImGuiWindowFlags_NoMove
        );

        ImGui::SliderInt("Nb states", &parameters.nb_states, 2, 10);
        ImGui::SliderInt("Nb data", &parameters.nb_data, 20, 300);

        if (ImGui::Button("Clear")){
            demos.clear();
            original_trajectories.clear();
            mu = arma::mat();
            mu2 = arma::mat();
            H = arma::mat();
        }

        ImGui::Text("Draw several similar trajectories");
        ImGui::End();

        // GUI rendering
        ImGui::Render();

        // Swap buffers
        glfwSwapBuffers(window);

        // Keyboard input
        if (ImGui::IsKeyPressed(GLFW_KEY_ESCAPE))
            break;


        // Left click: start a new demonstration
        if (!adding_line && ImGui::IsMouseClicked(GLFW_MOUSE_BUTTON_1)) {
            // Only if not on the UI
            if (!ImGui::GetIO().WantCaptureMouse)
                adding_line = true;
        }

        if (adding_line) {
            double mouse_x, mouse_y;
            glfwGetCursorPos(window, &mouse_x, &mouse_y);

            current_trajectory.push_back(gfx3::ui2fb({ mouse_x, mouse_y },
                                                     win_width, win_height,
                                                     fb_width, fb_height));
        }

        // Left mouse button release: end the demonstration creation
        if (adding_line && !ImGui::IsMouseDown(GLFW_MOUSE_BUTTON_1)) {
            adding_line = false;

            demos.push_back(sample_trajectory(current_trajectory, parameters.nb_data));

            original_trajectories.push_back(current_trajectory);
            current_trajectory.clear();

            process(demos, parameters, mu, mu2, H);
        }
    }


    // Cleanup
    ImGui_ImplGlfwGL3_Shutdown();
    glfwTerminate();

    return 0;
}