/*
 * demo_online_gmm.cpp
 *
 *	Online gmm learning and lqr-based trajectory generation.
 *
 *     Author: Sylvain Calinon
 */

#include <imgui.h>
#include "imgui_impl_glfw.h"
#include <stdio.h>
#include <GLFW/glfw3.h>

#include "armadillo"
#include <pbdlib/gmm.h>
#include <pbdlib/gmr.h>
#include <pbdlib/lqr.h>

using namespace std;
using namespace pbdlib;
using namespace arma;

static void error_callback(int error, const char* description){
	fprintf(stderr, "Error %d: %s\n", error, description);
}

int main(int argc, char **argv){

	//Setup GMM
	GMM_Model gmm(1,2);
  float minSigma = 2E2;
  float lambda = 50.0f;
  mat minSIGMA = eye(2,2) * minSigma;
  vector<GaussianDistribution> comps; 
  vector<Demonstration> demos;
  vector<mat> repros;
  
  //Setup LQR
  mat A(4,4), B(4,2);
	float dt = 0.01f;
	
	int iFactor = -8;  	
	mat R = eye(2,2) * pow(10.0f,iFactor);
	std::vector<mat> Q;
	mat A1d; A1d << 0 << 1 << endr << 0 << 0 << endr;
	mat B1d; B1d << 0 << endr << 1 << endr;
	A = kron(A1d, eye(2,2)); //See Eq. (5.1.1) in doc/TechnicalReport.pdf
	B = kron(B1d, eye(2,2)); //See Eq. (5.1.1) in doc/TechnicalReport.pdf
	LQR lqr(A,B,dt);



	//Setup GUI
	glfwSetErrorCallback(error_callback);
	if (!glfwInit())
		exit(1);
	GLFWwindow* window = glfwCreateWindow(1280, 520, "PbdLib GUI", NULL, NULL);
	glfwMakeContextCurrent(window);

	// Setup ImGui binding
	ImGui_ImplGlfw_Init(window, true);	
	ImVec4 clear_color = ImColor(114, 144, 154);

	ImVector<ImVec2> points;
  bool adding_line = false;
  bool dispGMM = false;
  int nbPts = 0;  

	while (!glfwWindowShouldClose(window)){

		glfwPollEvents();
		ImGui_ImplGlfw_NewFrame();
		
		//Control panel GUI
		ImGui::SetNextWindowPos(ImVec2(2,2));
		//ImGui::SetNextWindowSize(ImVec2(250,200));
		//ImGui::Begin("Control Panel");
		
		ImGui::Begin("Control Panel", NULL, ImVec2(350,160), 1.0f, 
			ImGuiWindowFlags_NoTitleBar|ImGuiWindowFlags_NoResize|
			ImGuiWindowFlags_NoMove|ImGuiWindowFlags_NoSavedSettings);
		
		//cout<<ImGui::IsItemHovered()<<endl;
    
    if (ImGui::Button("Clear")){
    	demos.clear();
    	repros.clear();
    	gmm.clear();
    	dispGMM = false;
    	//cout<<gmm.getNumSTATES()<<endl;
    }
    //ImGui::SameLine(); 
    //if (ImGui::Button("Train")) { 
		//	cout << "\n Number of EM iterations: " << gmm.EM_learn();
		//	dispGMM = true;
		//}
    ImGui::Text("Left-click to collect demonstrations");
    ImGui::Text("nbDemos: %d, nbPoints: %d, nbStates: %d", (int)demos.size(), (int)points.size(), (int)gmm.getNumSTATES());
    ImGui::SliderFloat("minSigma", &minSigma, 1E1, 2E2);
		ImGui::SliderFloat("lambda", &lambda, 10.0f, 100.0f);
		if (ImGui::SliderInt("rFactor", &iFactor, -9, -6)){
			R = eye(2,2) * pow(10.0f,iFactor);
		}
		
		//Get data
    ImVec2 mouse_pos_in_canvas = ImVec2(ImGui::GetIO().MousePos.x, 
    	ImGui::GetIO().DisplaySize.y - ImGui::GetIO().MousePos.y);
    //ImGui::GetCursorScreenPos()
    	
    //ImGuiWindow* hw = FindHoveredWindow(ImGui::GetIO().MousePos, false);
    
    if ((ImGui::GetIO().MousePos.x<352 && ImGui::GetIO().MousePos.y<162)==0){ //Is outside gui?
		  if (!adding_line && ImGui::GetIO().MouseClicked[0]){ //Button pushed
		    adding_line = true;
				if (!dispGMM){
					colvec p(2);
				  p(0) = mouse_pos_in_canvas.x; 
				  p(1) = mouse_pos_in_canvas.y; 
					GaussianDistribution componentTmp(p,minSIGMA);	
					comps.push_back(componentTmp);
					gmm.setCOMPONENTS(comps);
					comps.clear();
					dispGMM = true;
				}
		  }
		  if (adding_line){ //Trajectory recording
		    points.push_back(mouse_pos_in_canvas);
		    nbPts++;
		    vec p(2); 
		    p(0) = mouse_pos_in_canvas.x; 
		    p(1) = mouse_pos_in_canvas.y; 
		    gmm.onlineEMDP(nbPts, p, lambda, minSigma);
		    if (!ImGui::GetIO().MouseDown[0]){ //Button released
		      adding_line = false;
		      //Add demonstration
			    Demonstration demo = Demonstration(2,points.size());
			    for (int t=0; t<(int)points.size(); t++){
			    	demo.getDatapoints().getData()(0,t) = points[t].x;
			    	demo.getDatapoints().getData()(1,t) = points[t].y;
			    }
			    demos.push_back(demo);
			    //gmm.addDemo(demo);
				
					//Compute sequence of states
					mat h(gmm.getNumSTATES(), points.size());
					for (int i=0; i<(int)gmm.getNumSTATES(); i++){
						h.row(i) = trans(gmm.getCOMPONENTS(i).getPDFValue(demo.getDatapoints().getData()));
					}
					uword imax[points.size()];
					for (int t=0; t<points.size(); t++){
						vec vTmp = h.col(t);
						vTmp.max(imax[t]);
						//cout<<imax[t]<<"-";
					}
					//cout<<endl<<endl;
					
					//LQR
					vec vTmp(4,fill::zeros);
					mat QTmp(4,4,fill::zeros);
					mat Target(4,points.size(),fill::zeros);
					for (int t=0; t<points.size(); t++){
						QTmp.submat(0,0,1,1) = inv(gmm.getSIGMA(imax[t]));
						Q.push_back(QTmp);
						vTmp.subvec(0,1) = gmm.getMU(imax[t]);
						Target.col(t) = vTmp;
					}
					lqr.setProblem(R,Q,Target);
					mat S(4,4,fill::zeros);
					vec d(4,fill::zeros);
					//lqr.evaluate_gains_finiteHorizon(S,d);
					lqr.evaluate_gains_infiniteHorizon();
				
					//Retrieve data
					mat rData(2,points.size());
					//vec x = demo.getDatapoints().getData().col(0);
					//vec dx(2,fill::zeros), ddx(2); 
					//for (int t=0; t<points.size(); t++){
						//rData.col(t) = x;
						//ddx = 100.0f * (Target.submat(0,t,1,t)-x) - sqrtf(2*100.0f) * dx; 
						//ddx = lqr.getGains().at(t).cols(0,1) * (Target.submat(0,t,1,t)-x) - lqr.getGains().at(t).cols(2,3) * dx;
						//ddx += lqr.getFF().at(t);
						//dx += ddx * dt;
						//x += dx * dt;
					//}
					vec u(2);
					vec X = join_cols(demo.getDatapoints().getData().col(0), zeros<vec>(2));
					for (int t=0; t<points.size(); t++){
						rData.col(t) = X.rows(0,1);
						u = lqr.getGains().at(t) * (Target.col(t)-X);
						X += (A*X+B*u) * dt;
					}
					repros.push_back(rData);
				
					//Clean up
					points.clear();
					Q.clear();
		  	}	  	
			}    
		}
    ImGui::End();
    
		
		//GUI rendering
		glViewport(0, 0, (int)ImGui::GetIO().DisplaySize.x, (int)ImGui::GetIO().DisplaySize.y);
		glClearColor(clear_color.x, clear_color.y, clear_color.z, clear_color.w);
		glClear(GL_COLOR_BUFFER_BIT);
		ImGui::Render();
		
		//PbDlib rendering
		glPushMatrix();
		glTranslatef(-1.0f,-1.0f,0);
		glScalef(2.0f/(float)ImGui::GetIO().DisplaySize.x, 2.0f/(float)ImGui::GetIO().DisplaySize.y, 1.0f);
		glLineWidth(2.0f);
				
		//Draw current demo
		glColor3f(0.0f, 0.0f, 0.0f);
		glBegin(GL_LINE_STRIP);
		for (int t=0; t<(int)points.size(); t++){
			glVertex2f(points[t].x, points[t].y);
		}
		glEnd();
		
		//Draw demos
		glColor3f(0.3f, 0.3f, 0.3f);
		for (int n=0; n<(int)demos.size(); n++){
			glBegin(GL_LINE_STRIP);
			for (int t=0; t<(int)demos[n].getDatapoints().getNumPOINTS(); t++){
				glVertex2f(demos[n].getDatapoints().getData()(0,t), demos[n].getDatapoints().getData()(1,t));
			}
			glEnd();
		}
		
		//Draw repros
		glColor3f(0.0f, 0.8f, 0.0f);
		for (int n=0; n<(int)repros.size(); n++){
			glBegin(GL_LINE_STRIP);
			for (int t=0; t<(int)repros[n].n_cols; t++){
				glVertex2f(repros[n](0,t), repros[n](1,t));
			}
			glEnd();
		}
		
		//Draw Gaussians
    if (dispGMM){
		  vec d(2);
			mat V(2,2), R(2,2), pts(2,30);
			mat pts0(2,30);
			pts0 = join_cols(cos(linspace<rowvec>(0,2*PI,30)), sin(linspace<rowvec>(0,2*PI,30)));
			glColor3f(0.8f, 0.0f, 0.0f);
			for (int i=0; i<(int)gmm.getNumSTATES(); i++){
				eig_sym(d, V, gmm.getSIGMA(i));
				R = V * sqrt(diagmat(d)); 
				pts = R * pts0;
				glBegin(GL_LINE_STRIP);
				for (int t=0; t<(int)pts.n_cols; t++){
					glVertex2f((pts(0,t)+gmm.getMU(i)(0)), (pts(1,t)+gmm.getMU(i)(1)));
				}
				glEnd();
				glBegin(GL_POINTS);
				for (int t=0; t<(int)pts.n_cols; t++){
					glVertex2f(gmm.getMU(i)(0), gmm.getMU(i)(1));
				}
				glEnd();
		  }
		}
		
		glPopMatrix();
		glfwSwapBuffers(window);
	} 
	
	//Cleanup
	ImGui_ImplGlfw_Shutdown();
	glfwTerminate();
	return 0;
}