EM_tensorGMM.m 3.13 KB
Newer Older
Milad Malekzadeh's avatar
Milad Malekzadeh committed
1
function model = EM_tensorGMM(Data, model)
Milad Malekzadeh's avatar
Milad Malekzadeh committed
2 3 4
% Training of a task-parameterized Gaussian mixture model (GMM) with an expectation-maximization (EM) algorithm.
% The approach allows the modulation of the centers and covariance matrices of the Gaussians with respect to
% external parameters represented in the form of candidate coordinate systems.
Milad Malekzadeh's avatar
Milad Malekzadeh committed
5 6 7 8
%
% Author:	Sylvain Calinon, 2014
%         http://programming-by-demonstration.org/SylvainCalinon
%
Milad Malekzadeh's avatar
Milad Malekzadeh committed
9 10
% This source code is given for free! In exchange, I would be grateful if you cite
% the following reference in any academic publication that uses this code or part of it:
Milad Malekzadeh's avatar
Milad Malekzadeh committed
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
%
% @inproceedings{Calinon14ICRA,
%   author="Calinon, S. and Bruno, D. and Caldwell, D. G.",
%   title="A task-parameterized probabilistic model with minimal intervention control",
%   booktitle="Proc. {IEEE} Intl Conf. on Robotics and Automation ({ICRA})",
%   year="2014",
%   month="May-June",
%   address="Hong Kong, China",
%   pages="3339--3344"
% }

%Parameters of the EM algorithm
nbMinSteps = 5; %Minimum number of iterations allowed
nbMaxSteps = 100; %Maximum number of iterations allowed
maxDiffLL = 1E-4; %Likelihood increase threshold to stop the algorithm
nbData = size(Data,3);

%diagRegularizationFactor = 1E-2;
diagRegularizationFactor = 1E-4;

for nbIter=1:nbMaxSteps
	fprintf('.');
	%E-step
	[L, GAMMA, GAMMA0] = computeGamma(Data, model); %See 'computeGamma' function below and Eq. (2.0.5) in doc/TechnicalReport.pdf
	GAMMA2 = GAMMA ./ repmat(sum(GAMMA,2),1,nbData);
	%M-step
Milad Malekzadeh's avatar
Milad Malekzadeh committed
37
	for i=1:model.nbStates
Milad Malekzadeh's avatar
Milad Malekzadeh committed
38 39 40 41 42
		%Update Priors
		model.Priors(i) = sum(sum(GAMMA(i,:))) / nbData; %See Eq. (2.0.6) in doc/TechnicalReport.pdf
		for m=1:model.nbFrames
			%Matricization/flattening of tensor
			DataMat(:,:) = Data(:,m,:);
Milad Malekzadeh's avatar
Milad Malekzadeh committed
43
			%Update Mu
Milad Malekzadeh's avatar
Milad Malekzadeh committed
44
			model.Mu(:,m,i) = DataMat * GAMMA2(i,:)'; %See Eq. (2.0.7) in doc/TechnicalReport.pdf
Milad Malekzadeh's avatar
Milad Malekzadeh committed
45
			%Update Sigma (regularization term is optional)
Milad Malekzadeh's avatar
Milad Malekzadeh committed
46 47 48 49
			DataTmp = DataMat - repmat(model.Mu(:,m,i),1,nbData);
			model.Sigma(:,:,m,i) = DataTmp * diag(GAMMA2(i,:)) * DataTmp' + eye(model.nbVar) * diagRegularizationFactor; %See Eq. (2.0.8) and (2.1.2) in doc/TechnicalReport.pdf
		end
	end
Milad Malekzadeh's avatar
Milad Malekzadeh committed
50
	%Compute average log-likelihood
Milad Malekzadeh's avatar
Milad Malekzadeh committed
51 52 53 54
	LL(nbIter) = sum(log(sum(L,1))) / size(L,2); %See Eq. (2.0.4) in doc/TechnicalReport.pdf
	%Stop the algorithm if EM converged (small change of LL)
	if nbIter>nbMinSteps
		if LL(nbIter)-LL(nbIter-1)<maxDiffLL || nbIter==nbMaxSteps-1
Milad Malekzadeh's avatar
Milad Malekzadeh committed
55
			disp(['EM converged after ' num2str(nbIter) ' iterations.']);
Milad Malekzadeh's avatar
Milad Malekzadeh committed
56 57 58 59
			return;
		end
	end
end
Milad Malekzadeh's avatar
Milad Malekzadeh committed
60
disp(['The maximum number of ' num2str(nbMaxSteps) ' EM iterations has been reached.']);
Milad Malekzadeh's avatar
Milad Malekzadeh committed
61 62 63 64
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [L, GAMMA, GAMMA0] = computeGamma(Data, model)
Milad Malekzadeh's avatar
Milad Malekzadeh committed
65 66 67 68 69 70 71 72 73
%See Eq. (2.0.5) in doc/TechnicalReport.pdf
nbData = size(Data, 3);
L = ones(model.nbStates, nbData);
GAMMA0 = zeros(model.nbStates, model.nbFrames, nbData);
for m=1:model.nbFrames
	DataMat(:,:) = Data(:,m,:); %Matricization/flattening of tensor
	for i=1:model.nbStates
		GAMMA0(i,m,:) = model.Priors(i) * gaussPDF(DataMat, model.Mu(:,m,i), model.Sigma(:,:,m,i));
		L(i,:) = L(i,:) .* squeeze(GAMMA0(i,m,:))';
Milad Malekzadeh's avatar
Milad Malekzadeh committed
74
	end
Milad Malekzadeh's avatar
Milad Malekzadeh committed
75
end
Milad Malekzadeh's avatar
Milad Malekzadeh committed
76 77
%Normalization
GAMMA = L ./ repmat(sum(L,1)+realmin,size(L,1),1);
Milad Malekzadeh's avatar
Milad Malekzadeh committed
78
end