Commit 197035b7 authored by Sylvain Calinon's avatar Sylvain Calinon

Cleaned version of semi-tied GMM example

parent cada04c3
No preview for this file type
......@@ -39,54 +39,44 @@ function demo_semitiedGMM01
addpath('./m_fcts/');
%% Dataset and parameters
%% Parameters
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
load('data/Zshape3D.mat');
model.nbVar = size(Data,1);
model.nbStates = 3;
model.time_dim = false;
model.nbSamples = 1;
%Algorithm parameters
model.params_alpha = 1.0;
model.params_Bsf = 5E-2;
model.nbStates = 3; %Number of states in the GMM
model.nbVar = 3; %Number of variables [x1,x2,x3]
model.nbSamples = 5; %Number of demonstrations
model.params_Bsf = 5E-2; %Initial variance of B in semi-tied GMM
nbData = 300; %Length of each trajectory
%% Learning
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
model = init_GMM_timeBased(Data, model);
if isfield(model,'time_dim')
if ~model.time_dim
model.Mu = model.Mu(2:model.nbVar,:);
model.Sigma = model.Sigma(2:model.nbVar,2:model.nbVar,:);
Data = Data(2:model.nbVar,:);
model.nbVar= model.nbVar - 1;
end
end
load('data/Zshape3D.mat'); %Load 'Data'
model = init_GMM_kbins(Data, model, nbData);
model = EM_semitiedGMM(Data, model);
%% Plot
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
figure('color',[1 1 1],'Position',[10 10 800 650]); hold on; axis off; box off;
xx = round(linspace(1,64,9)); xx2 = round(linspace(1,64,3));
clrmap = colormap('jet'); clrmap = min(clrmap(xx,:),.95);
clrmap2 = colormap('jet'); clrmap2 = min(clrmap2(xx2,:),.95);
for i=1:5
plot3(Data(1,(i-1)*300+1:300*i), Data(2,(i-1)*300 + 1:300*i),Data(3,(i-1)*300+1:300*i),'-','linewidth',1.5,'color',[.5 .5 .5]);
figure('color',[1 1 1],'Position',[10 10 700 650]); hold on; axis off; box off;
clrmap = lines(model.nbVar);
for n=1:model.nbSamples
plot3(Data(1,(n-1)*nbData+1:n*nbData), Data(2,(n-1)*nbData+1:n*nbData), Data(3,(n-1)*nbData+1:n*nbData),'-','linewidth',1.5,'color',[.7 .7 .7]);
end
clrlist1 = [2,3,1];
for i=1:model.nbVar
mArrow3(zeros(model.nbVar,1),model.H(:,i),'color',clrmap2(clrlist1(i),:),'stemWidth',0.75, 'tipWidth',1.0, 'facealpha',0.75);
mArrow3(zeros(model.nbVar,1), model.H(:,i), 'color',clrmap(i,:),'stemWidth',0.75, 'tipWidth',1.0, 'facealpha',0.75);
end
clrlist = [2,7,4];
plotGMM3D(model.Mu, model.Sigma+repmat(eye(model.nbVar)*2E0,[1,1,model.nbStates]), [0 .6 0], .4, 2);
for i=1:model.nbStates
plotGMM3D(model.Mu(:,i), model.Sigma(:,:,i), clrmap(clrlist(i),:), .5);
for j=1:model.nbVar
mArrow3(model.Mu(:,i), model.Mu(:,i) + model.H(:,j).*(model.SigmaDiag(j,j,i).^0.5),'color',clrmap2(clrlist1(j),:),'stemWidth',0.75, 'tipWidth',1.25, 'facealpha',1);
w = model.SigmaDiag(j,j,i).^0.5;
if w>5E-1
mArrow3(model.Mu(:,i), model.Mu(:,i)+model.H(:,j)*w, 'color',clrmap(j,:),'stemWidth',0.75, 'tipWidth',1.25, 'facealpha',1);
end
end
end
view(-40,6); axis equal;
%print('-dpng','graphs/demo_semitiedGMM01.png');
pause;
close all;
......
......@@ -35,6 +35,7 @@ function [model, LL] = EM_semitiedGMM(Data, model)
% You should have received a copy of the GNU General Public License
% along with PbDlib. If not, see <http://www.gnu.org/licenses/>.
%Parameters of the EM algorithm
nbData = size(Data,2);
if ~isfield(model,'params_nbMinSteps')
......@@ -55,20 +56,19 @@ end
if ~isfield(model,'params_nbVariationSteps')
model.params_nbVariationSteps = 50;
end
if ~isfield(model,'params_alpha')
model.params_alpha = 0.99;
end
if ~isfield(model,'B')
model.B = eye(model.nbVar) * model.params_Bsf;
model.InitH = pinv(model.B) + eye(model.nbVar) * model.params_diagRegFact;
for i=1:model.nbStates
%model.InitSigmaDiag(:,:,i) = diag(diag(model.B*squeeze(model.Sigma(:,:,i))*model.B'));
[~,model.InitSigmaDiag(:,:,i)] = eig(squeeze(model.Sigma(:,:,i)));
%model.InitSigmaDiag(:,:,i) = diag(diag(model.B*model.Sigma(:,:,i)*model.B'));
[~,model.InitSigmaDiag(:,:,i)] = eig(model.Sigma(:,:,i));
end
end
for nbIter=1:model.params_nbMaxSteps
fprintf('.');
%E-step
[L, GAMMA] = computeGamma(Data, model); %See 'computeGamma' function below
GAMMA2 = GAMMA ./ repmat(sum(GAMMA,2),1,nbData);
......@@ -86,20 +86,21 @@ for nbIter=1:model.params_nbMaxSteps
%Update A matrix
for lp=1:model.params_nbVariationSteps
for i=1:model.nbStates
model.SigmaDiag(:,:,i) = diag(diag(model.B*squeeze(model.S(:,:,i))*model.B'));
model.SigmaDiag(:,:,i) = diag(diag(model.B * model.S(:,:,i) * model.B')); %Eq.(9)
end
for k=1:model.nbVar
C = pinv(model.B') * det(model.B); %C=cof(model.B);
G = sum(reshape(kron(squeeze(1/(model.SigmaDiag(k,k,:)))',ones(model.nbVar)) .* ...
reshape(model.S, [model.nbVar model.nbVar*model.nbStates]) .* ...
kron(sum(GAMMA,2)', ones(model.nbVar)), [model.nbVar model.nbVar model.nbStates]), 3);
model.B(k,:) = C(k,:) * pinv(G) * (sqrt(sum(sum(GAMMA,2) / (C(k,:) * pinv(G) * C(k,:)'))));
C = pinv(model.B') * det(model.B); %Or C=cof(model.B), Eq.(6)
G = zeros(model.nbVar);
for i=1:model.nbStates
G = G + model.S(:,:,i) * sum(GAMMA(i,:),2) / model.SigmaDiag(k,k,i); %Eq.(7)
end
model.B(k,:) = C(k,:) * pinv(G) * (sqrt(sum(sum(GAMMA,2) / (C(k,:) * pinv(G) * C(k,:)')))); %Eq.(5)
end
end
%Update Sigma
model.H = pinv(model.B) + eye(model.nbVar) * model.params_diagRegFact;
for i=1:model.nbStates
model.Sigma(:,:,i) = model.params_alpha * (model.H * model.SigmaDiag(:,:,i) * model.H') + (1-model.params_alpha) * model.S(:,:,i);
model.Sigma(:,:,i) = model.H * model.SigmaDiag(:,:,i) * model.H'; %Eq.(3)
end
%Compute average log-likelihood
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment