Skip to content
Snippets Groups Projects
demo_Riemannian_SPD_interp02.cpp 8.48 KiB
/*
 * demo_Riemannian_cov_interp02.cpp
 *
 * Covariance interpolation on Riemannian manifold from a GMM with augmented
 * covariances (Implementation based on Pennec, Fillard and Ayache (2006)
 * "A Riemannian Framework For Tensor Computing")
 *
 * Authors: Sylvain Calinon, Philip Abbet
 */


#include <stdio.h>
#include <armadillo>

#include <imgui.h>
#include <imgui_impl_glfw_gl2.h>
#include <gfx2.h>
#include <gfx_ui.h>
#include <GLFW/glfw3.h>

using namespace arma;


/***************************** ALGORITHM SECTION *****************************/

void trans2d_to_gauss(const ui::Trans2d& gaussian_transforms,
					  const gfx2::window_size_t& window_size,
					  arma::vec &mu, arma::mat &sigma) {

	mu = gfx2::ui2fb_centered(vec({ gaussian_transforms.pos.x, gaussian_transforms.pos.y }),
							  window_size);

	vec t_x({
		gaussian_transforms.x.x * window_size.scale_x(),
		gaussian_transforms.x.y * window_size.scale_y()
	});

	vec t_y({
		gaussian_transforms.y.x * window_size.scale_x(),
		gaussian_transforms.y.y * window_size.scale_y()
	});

	mat RG = {
		{ t_x(0), t_y(0) },
		{ -t_x(1), -t_y(1) }
	};

	sigma = RG * RG.t();
}

//---------------------------------------------------------

arma::mat expmap(const arma::mat& U, const arma::mat& S) {
	return real(sqrtmat(S) * expmat(inv(sqrtmat(S)) * U * inv(sqrtmat(S))) * sqrtmat(S));
}

//---------------------------------------------------------

arma::mat logmap(const arma::mat& X, const arma::mat& S) {
	return real(sqrtmat(S) * logmat(inv(sqrtmat(S)) * X * inv(sqrtmat(S))) * sqrtmat(S));
}

//---------------------------------------------------------

void interpolate(const std::vector<ui::Trans2d>& gaussian_transforms,
				 int nb_data,
				 std::vector<arma::vec> &interpolated_mu,
				 std::vector<arma::mat> &interpolated_sigma,
				 const gfx2::window_size_t& window_size) {
	const int nb_var = 2;	// Number of variables (fixed, since we use a
							// ui::Trans2d to define a gaussian)
	const int nb_var2 = nb_var + 1;

	// Transformation to Gaussians with augmented covariances centered on zero
	std::vector<arma::mat> augmented_sigma;

	for (size_t i = 0; i < gaussian_transforms.size(); ++i) {

		arma::vec current_mu;
		arma::mat sigma;

		trans2d_to_gauss(gaussian_transforms[i], window_size, current_mu, sigma);

		arma::mat current_sigma(nb_var2, nb_var2);

		current_sigma(0, 0, arma::size(nb_var, nb_var)) = sigma + current_mu * current_mu.t();
		current_sigma(0, nb_var, arma::size(nb_var, 1)) = current_mu;
		current_sigma(nb_var, 0, arma::size(1, nb_var)) = current_mu.t();
		current_sigma(nb_var, nb_var) = 1;

		augmented_sigma.push_back(current_sigma);
	}

	// Geodesic interpolation
	arma::vec w = arma::linspace(0, 1, nb_data);

	for (size_t i = 1; i < augmented_sigma.size(); ++i) {
		for (size_t t = 0; t < nb_data; ++t) {
			// Interpolation between two covariances can be computed in closed form
			arma::mat sigma = expmap(w(t) * logmap(augmented_sigma[i], augmented_sigma[i-1]),
									 augmented_sigma[i-1]);

			double beta = sigma(sigma.n_elem - 1);
			arma::vec mu = sigma(sigma.n_rows - 1, 0, arma::size(1, sigma.n_cols-1)).t() / beta;

			interpolated_mu.push_back(mu);
			interpolated_sigma.push_back(sigma(0, 0, arma::size(sigma.n_rows-1, sigma.n_cols-1)) - beta * mu * mu.t());
		}
	}
}


/*************************** DEMONSTRATION SECTION ***************************/

static void error_callback(int error, const char* description) {
	fprintf(stderr, "Error %d: %s\n", error, description);
}

// -----------------------------------------------------------------------------
// Render a 2d gaussian from its parameters (mu and sigma)
// -----------------------------------------------------------------------------
void render_gaussian(const arma::vec& mu, const arma::mat& sigma,
					 const arma::vec& color,
					 const gfx2::window_size_t& window_size) {

	// Rendering of the Gaussian
	gfx2::draw_gaussian(conv_to<fvec>::from(color), mu, sigma);

	glClear(GL_DEPTH_BUFFER_BIT);

	// Rendering of the Gaussian position
	fvec position({ (float) mu(0), (float) mu(1), 0.0f });

	fvec darker_color = conv_to<fvec>::from(color(span(0, 2))) * 0.5f;

	gfx2::draw_rectangle(darker_color, 4.0f * window_size.scale_x(),
						 4.0f * window_size.scale_y(), position);

	glClear(GL_DEPTH_BUFFER_BIT);
}


/******************************* MAIN FUNCTION *******************************/

int main(int argc, char **argv) {
	arma_rng::set_seed_random();

	// Parameters
	int nb_states = 2; // Number of states in the GMM
	int nb_data = 20;  // Length of each trajectory


	// Take 4k screens into account (framebuffer size != window size)
	gfx2::window_size_t window_size;
	window_size.win_width = 600;
	window_size.win_height = 600;
	window_size.fb_width = -1;	// Will be known later
	window_size.fb_height = -1;

	// Setup GUI
	glfwSetErrorCallback(error_callback);
	if (!glfwInit())
		exit(1);

	glfwWindowHint(GLFW_SAMPLES, 4);
	glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 2);
	glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 1);

	// Open a window and create its OpenGL context
	GLFWwindow* window = gfx2::create_window_at_optimal_size(
		"Covariance interpolation",
		window_size.win_width, window_size.win_height
	);

	glfwMakeContextCurrent(window);

	// Setup OpenGL
	gfx2::init();
	glEnable(GL_DEPTH_TEST);
	glEnable(GL_CULL_FACE);
	glEnable(GL_LINE_SMOOTH);
	// glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA);

	// Setup ImGui
	ImGui::CreateContext();
	ImGui_ImplGlfwGL2_Init(window, true);
	ImVec4 clear_color = ImColor(255, 255, 255);


	// Main loop
	std::vector<ui::Trans2d> gaussian_transforms;

	while (!glfwWindowShouldClose(window)) {
		glfwPollEvents();

		// Handling of the resizing of the window
		gfx2::window_result_t window_result =
			gfx2::handle_window_resizing(window, &window_size);

		if (window_result == gfx2::INVALID_SIZE)
			continue;

		// Start of rendering
		ImGui_ImplGlfwGL2_NewFrame();
		glViewport(0, 0, window_size.fb_width, window_size.fb_height);
		glClearColor(clear_color.x, clear_color.y, clear_color.z, clear_color.w);
		glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);

		glMatrixMode(GL_PROJECTION);
		glLoadIdentity();
		glOrtho(-window_size.fb_width / 2, window_size.fb_width / 2,
				-window_size.fb_height / 2, window_size.fb_height / 2,
				-1.0f, 10.0f);

		glMatrixMode(GL_MODELVIEW);
		glLoadIdentity();

		glPushMatrix();

		// Ensure that the number of desired states hasn't changed
		if (nb_states > gaussian_transforms.size()) {
			for (int i = gaussian_transforms.size(); i < nb_states; ++i) {
				arma::vec mu = arma::randu(2);
				mu(0) = mu(0) * (window_size.win_width - 200) + 100;
				mu(1) = mu(1) * (window_size.win_height - 200) + 100;

				arma::vec xy = arma::randu(2);
				xy(0) = (xy(0) * window_size.win_width / 6 + 20);
				xy(1) = (xy(1) * window_size.win_height / 6 + 20);

				gaussian_transforms.push_back(ui::Trans2d(ImVec2((int) xy(0), 0),
														  ImVec2(0, (int) xy(1)),
														  ImVec2((int) mu(0), (int) mu(1))));
			}
		}
		else if (nb_states < gaussian_transforms.size()) {
			gaussian_transforms.resize(nb_states);
		}

		// Interpolation between the gaussians
		std::vector<arma::vec> interpolated_mu;
		std::vector<arma::mat> interpolated_sigma;

		interpolate(gaussian_transforms, nb_data, interpolated_mu, interpolated_sigma,
					window_size);

		// Rendering of the gaussians
		for (size_t i = 0; i < interpolated_sigma.size(); ++i) {
			render_gaussian(interpolated_mu[i], interpolated_sigma[i],
							arma::vec({ 0.8, 0.0, 0.0, 0.02 }), window_size);
		}

		for (size_t i = 0; i < gaussian_transforms.size(); ++i) {
			arma::vec mu;
			arma::mat sigma;

			trans2d_to_gauss(gaussian_transforms[i], window_size, mu, sigma);

			render_gaussian(mu, sigma, arma::vec({ 0.5, 0.5, 0.5, 0.5 }), window_size);
		}

		gfx2::draw_line(arma::fvec({ 0.8f, 0.0f, 0.0f }), interpolated_mu);

		// Gaussian UI widgets
		ui::begin("Gaussians");
		for (size_t i = 0; i < gaussian_transforms.size(); ++i)
			gaussian_transforms[i] = ui::affineSimple(i, gaussian_transforms[i]);
		ui::end();

		// Parameter window
		ImGui::Begin("Parameters", NULL, ImVec2(250, 80), 0.5f,
					 ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoSavedSettings);
		ImGui::SliderInt("Nb states", &nb_states, 2, 5);
		ImGui::SliderInt("Nb data", &nb_data, 5, 30);
		ImGui::End();

		// GUI rendering
		ImGui::Render();
		ImGui_ImplGlfwGL2_RenderDrawData(ImGui::GetDrawData());

		// End of rendering
		glPopMatrix();
		glfwSwapBuffers(window);

		// Keyboard input
		if (ImGui::IsKeyPressed(GLFW_KEY_ESCAPE))
			break;
	}

	// Cleanup
	ImGui_ImplGlfwGL2_Shutdown();
	glfwTerminate();

	return 0;
}