Commit 30929683 by Sylvain CALINON

Opetions for EM added

parent b0cfc8c5
......@@ -6,11 +6,16 @@ function [model, GAMMA2, LL] = EM_GMM(Data, model)
% of the algorithms, please reward the authors by citing the related publications,
% and consider making your own research available in this way.
%
% @article{Calinon15,
% @article{Calinon16JIST,
% author="Calinon, S.",
% title="A Tutorial on Task-Parameterized Movement Learning and Retrieval",
% journal="Intelligent Service Robotics",
% year="2015"
% publisher="Springer Berlin Heidelberg",
% doi="10.1007/s11370-015-0187-9",
% year="2016",
% volume="9",
% number="1",
% pages="1--29"
% }
%
% Copyright (c) 2015 Idiap Research Institute, http://idiap.ch/
......@@ -43,11 +48,10 @@ if ~isfield(model,'params_maxDiffLL')
model.params_maxDiffLL = 1E-4; %Likelihood increase threshold to stop the algorithm
end
if ~isfield(model,'params_diagRegFact')
%model.params.diagRegFact = 1E-8; %Regularization term is optional
model.params_diagRegFact = 1E-4; %Regularization term is optional
end
if ~isfield(model,'params_updateComp')
model.params_updateComp = ones(3,1);
model.params_updateComp = ones(3,1); %pi,Mu,Sigma
end
for nbIter=1:model.params_nbMaxSteps
......
function [model, GAMMA] = EM_HMM(s, model)
function [model, GAMMA2, LL] = EM_HMM(s, model)
% Estimation of HMM parameters with an EM algorithm.
%
% Writing code takes time. Polishing it and making it available to others takes longer!
......@@ -37,10 +37,22 @@ function [model, GAMMA] = EM_HMM(s, model)
%Parameters of the EM algorithm
nbMinSteps = 5; %Minimum number of iterations allowed
nbMaxSteps = 50; %MaZETAmum number of iterations allowed
maxDiffLL = 1E-4; %Likelihood increase threshold to stop the algorithm
diagRegularizationFactor = 1E-8; %Optional regularization term
if ~isfield(model,'params_nbMinSteps')
model.params_nbMinSteps = 5; %Minimum number of iterations allowed
end
if ~isfield(model,'params_nbMaxSteps')
model.params_nbMaxSteps = 50; %Maximum number of iterations allowed
end
if ~isfield(model,'params_maxDiffLL')
model.params_maxDiffLL = 1E-4; %Likelihood increase threshold to stop the algorithm
end
if ~isfield(model,'params_diagRegFact')
model.params_diagRegFact = 1E-8; %Regularization term is optional
end
if ~isfield(model,'params_updateComp') || length(model.params_updateComp)<4
model.params_updateComp = ones(4,1); %Mu,Sigma,Pi,A
end
%Initialization of the parameters
nbSamples = length(s);
......@@ -51,7 +63,7 @@ for n=1:nbSamples
end
[nbVar, nbData] = size(Data);
for nbIter=1:nbMaxSteps
for nbIter=1:model.params_nbMaxSteps
fprintf('.');
%E-step
......@@ -105,23 +117,28 @@ for nbIter=1:nbMaxSteps
%M-step
for i=1:model.nbStates
%Update the centers
model.Mu(:,i) = Data * GAMMA2(i,:)';
if model.params_updateComp(1)
model.Mu(:,i) = Data * GAMMA2(i,:)';
end
%Update the covariance matrices
Data_tmp = Data - repmat(model.Mu(:,i),1,nbData);
model.Sigma(:,:,i) = Data_tmp * diag(GAMMA2(i,:)) * Data_tmp'; %Eq. (54) Rabiner
%Optional regularization term
model.Sigma(:,:,i) = model.Sigma(:,:,i) + eye(nbVar) * diagRegularizationFactor;
if model.params_updateComp(2)
Data_tmp = Data - repmat(model.Mu(:,i),1,nbData);
model.Sigma(:,:,i) = Data_tmp * diag(GAMMA2(i,:)) * Data_tmp'; %Eq. (54) Rabiner
%Optional regularization term
model.Sigma(:,:,i) = model.Sigma(:,:,i) + eye(nbVar) * model.params_diagRegFact;
end
end
%Update initial state probability vector
model.StatesPriors = mean(GAMMA_INIT,2);
if model.params_updateComp(3)
model.StatesPriors = mean(GAMMA_INIT,2);
end
%Update transition probabilities
model.Trans = sum(ZETA,3)./ repmat(sum(GAMMA_TRK,2)+realmin, 1, model.nbStates);
if model.params_updateComp(4)
model.Trans = sum(ZETA,3)./ repmat(sum(GAMMA_TRK,2)+realmin, 1, model.nbStates);
end
%Compute the average log-likelihood through the ALPHA scaling factors
LL(nbIter)=0;
......@@ -130,14 +147,14 @@ for nbIter=1:nbMaxSteps
end
LL(nbIter) = LL(nbIter)/nbSamples;
%Stop the algorithm if EM converged
if nbIter>nbMinSteps
if LL(nbIter)-LL(nbIter-1)<maxDiffLL
if nbIter>model.params_nbMinSteps
if LL(nbIter)-LL(nbIter-1)<model.params_maxDiffLL
disp(['EM converged after ' num2str(nbIter) ' iterations.']);
return;
end
end
end
disp(['The maximum number of ' num2str(nbMaxSteps) ' EM iterations has been reached.']);
disp(['The maximum number of ' num2str(model.params_nbMaxSteps) ' EM iterations has been reached.']);
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment