Attachment 'ssa_simple.m'

Download

   1 function [ Ps, An, y, converged, iter ] = ssa_simple(X, dd, nreps, max_iter, quiet, nomeans)
   2 %SSA_SIMPLE      Stationary Subspace Analysis.
   3 %
   4 %usage 
   5 %  [Ps, An, y, converged, iter] = ssa_simple(X, dd, { nreps: 5,  max_iter: 100, quiet: false, nomeans: false })
   6 %
   7 %input
   8 %  X          Data in one of two possible formats:
   9 %               1. cell array where each X{i} is a (d x n_i)-dataset
  10 %               2. cell array with two elements: 
  11 %                   X{1} is a (d x n) matrix of epoch means and
  12 %                   X{2} is a (d x d x n) array of epoch covariance matrices 
  13 
  14 %  dd         Dimensionality of stationary subspace
  15 %  nreps      Optional: number of restarts w/ different init (default: none) 
  16 %  max_iter   Optional: maximum number of iterations (default: 100)
  17 %  quiet      Optional: suppress output (default: false) 
  18 %  nomeans    Optional: Perform SSA w.r.t. to the covariance matrix only. 
  19 %
  20 %output
  21 %  Ps         Projection matrix to stationary sources (dd x d)
  22 %  An         Estimated basis of the non-stationary subspace (d x d)
  23 %  y          Minimum objective function value
  24 %  converged  True, if the optimization has converged
  25 %  iter       Number of iterations
  26 %
  27 %author
  28 %  buenau@cs.tu-berlin.de
  29 
  30 % Set default parameters.
  31 if ~exist('max_iter', 'var') || isempty(max_iter), max_iter = 100; end
  32 if ~exist('nreps', 'var'), nreps = []; end
  33 if ~exist('max_iter', 'var'), max_iter = 100; end
  34 if ~exist('quiet', 'var'), quiet = false; end
  35 if ~exist('nomeans', 'var'), nomeans = false; end
  36 
  37 % Loop over repetitions and return the solution with the lowest objective function 
  38 % value. 
  39 if ~isempty(nreps)
  40   r_Ps = cell(1, nreps);
  41   r_An = cell(1, nreps);
  42   r_y = zeros(1, nreps);
  43   r_converged = zeros(1, nreps);
  44   r_iter = zeros(1, nreps);
  45 
  46   for i=1:nreps
  47     [r_Ps{i}, r_An{i}, r_y(i), r_converged(i), r_iter(i) ] = ssa_simple(X, dd, [], max_iter, quiet, nomeans);
  48   end
  49 
  50   [foo, mini] = min(r_y);
  51   Ps = r_Ps{mini};
  52   An = r_An{mini};
  53   y = r_y(mini);
  54   iter = r_iter(mini);
  55   converged = r_converged(mini);
  56   return; 
  57 end
  58 
  59 % Sanity check.
  60 if length(X) < 2, error('X must contain at least two datasets or two matrices: X{1} contains the means and X{2} contains the covariance matrices\n'); end
  61 
  62 % Distinguish the two parametrization variants: data or means+covariance matrices.
  63 X_contains_data = (ndims(X{2}) == 2);
  64 
  65 d = size(X{1}, 1);
  66 
  67 % Sanity check.
  68 if dd > d, error('Dimensionality of stationary subspace (dd) must be less than or equal to the dimensionality of the input data!\n'); end
  69 
  70 % Distinguish different formats. 
  71 
  72 if X_contains_data
  73   % If the input is data, compute means and covariance matrices per epoch 
  74   % and the whitening matrix.
  75   n_X = length(X);
  76   C = zeros(d, d, n_X);
  77   mu = zeros(d, n_X);
  78   all_data = [];
  79   for i=1:n_X
  80     all_data = [ all_data X{i} ];
  81     mu(:,i) = mean(X{i}, 2);
  82     C(:,:,i) = cov(X{i}'); 
  83   end
  84 else
  85   % If the input is means+covmats, compute only the whitening matrix. 
  86   mu = X{1};
  87   C = X{2};
  88   n_X = size(mu, 2);
  89 end
  90 
  91 converged = false;
  92 y_new = [];
  93 
  94 % Parameters for backtracking linesearch.
  95 ls_alpha = 0.5*(0.01+0.3);
  96 ls_beta = 0.4;
  97 
  98 if ~quiet, fprintf('*** iter=[iteration] y=[function value] ||grad||=[norm of gradient] ||step||=[norm of step] ([number of line search iterations]) rel_dec=[relative function value decrease in percent]\n'); end
  99 
 100 % Centering and Whitening. 
 101 W = inv(sqrtm(squeeze(mean(C,3))));
 102 mu = mu - repmat(mean(mu,2), [1 n_X]);
 103 
 104 % Initialization: random rotation.
 105 B = randrot(d)*W;
 106 
 107 % Apply initialization to means and covariance matrices.
 108 mu = B*mu;
 109 C = mult3(C, B);
 110 
 111 % Optimization loop.
 112 for iter=1:max_iter
 113   % Get current function value and gradient.
 114   [y, grad] = objfun(zeros(dd*(d-dd),1), C, mu, d, dd, nomeans);
 115 
 116   % Sanity check.
 117   if ~isempty(y_new) && y_new ~= y, error('Something is utterly wrong.\n'); end
 118 
 119   % Print progress (if not suppressed).
 120   if ~quiet, fprintf('iter=%d y=%.5g ||grad||=%.5g ', iter, y, norm(grad)); end
 121 
 122   % Conjugate gradient: compute search direction.
 123   if iter == 1
 124     alpha = -grad;
 125   else
 126     gamma = grad'*(grad-grad_old)/(grad_old'*grad_old);
 127     alpha = -grad + gamma*alpha_old;
 128   end
 129   grad_old = grad;
 130   alpha_old = alpha;
 131 
 132   % Alpha is the current search direction, so we do a line-search along t*alpha
 133   % Alpha is a vector, that contains the elements of "Z" (see paper)
 134 
 135   % Normalize search direction. 
 136   alpha = alpha ./ (2*norm(alpha));
 137 
 138   % Fill in nonzero values: this means: put the Z into the bigger M so that we 
 139   % we can do expm(M)
 140   M_alpha = reshape(alpha, [dd (d-dd)]);
 141   M_alpha = [ zeros(dd, dd) M_alpha; -M_alpha' zeros(d-dd, d-dd) ];
 142 
 143   % Backtracking line search loop.
 144   t = 1;
 145   for j=1:10
 146     M_new = t*M_alpha;
 147     R = expm(M_new);
 148 
 149     y_new = objfun(zeros(dd*(d-dd), 1), mult3(C, R), R*mu, d, dd, nomeans);
 150 
 151     % Stop if function decrease is sufficient.
 152     if y_new <= (y + ls_alpha*t*grad'*alpha)
 153       break;
 154     end
 155 
 156     t = ls_beta*t;
 157   end
 158 
 159   % Stop if line search failed. 
 160   if y_new >= y
 161     if ~quiet, fprintf('no step found\n'); end
 162     converged = true;
 163     break;
 164   end
 165 
 166   % Stop if relative function decrease is below threshold.
 167   rel_dec = (y-y_new)/y;
 168   rel_dec_thr = 1e-8;
 169   if rel_dec < rel_dec_thr
 170     if ~quiet, fprintf('rel_dec < %f\n', rel_dec_thr); end
 171     converged = true;
 172     break;
 173   end
 174 
 175   % Print progress.
 176   if ~quiet, fprintf('||step||=%.3g (%d) rel_dec=%.3g%% y=%.3g\n', t, j, 100*rel_dec, y); end
 177 
 178   % Rotate basis (= multiplicative update step).
 179   C = mult3(C, R);
 180   mu = R*mu;
 181   
 182   B = R*B;
 183 end
 184 
 185 % Display warning message if algorithm has not converged.
 186 if ~converged
 187   if ~quiet, fprintf('Reached maximum number of iterations\n'); end
 188 end
 189 
 190 % Compute estimated mixing matrix.
 191 A = inv(B);
 192 
 193 % Split estimated de-mixing matrix into two projection matrices.
 194 Ps = B(1:dd,:);
 195 An = A(:,(dd+1):end);
 196 
 197 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 198 function C = mult3(C, R)
 199 % Compute R*C(:,:,i)*R' for all i.
 200 
 201 [d1, d2, d3] = size(C);
 202 
 203 % Multiply from the left with R.
 204 C = reshape(C, [d1 d2*d3]);
 205 C = reshape(R*C, [d1 d2 d3]);
 206 
 207 % Multiply from the right with R'
 208 C = permute(C, [2 1 3]);
 209 C = reshape(C, [d1 d2*d3]);
 210 C = reshape(R*C, [d1 d2 d3]);
 211 
 212 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 213 % Objective function
 214 function [fx, grad] = objfun(M, C, mu, d, dd, nomeans)
 215 
 216 n = size(C, 3);
 217 
 218 % Degrees of freedoms.
 219 dof = n*(dd*(dd+1)/2 + dd);
 220 
 221 % Fill in non-zero values. 
 222 M = reshape(M, [dd (d-dd)]);
 223 M = [ zeros(dd, dd) M; -M' zeros(d-dd, d-dd) ];
 224 
 225 % Compute rotation.
 226 R = expm(M);
 227 
 228 % Projection to stationary signals.
 229 P = R(1:dd,:);
 230 
 231 log_det = zeros(1, n);
 232 inv_pC = zeros(dd, dd, n);
 233 pC = zeros(dd, dd, n);
 234 log_det = zeros(1, n);
 235 pmu = zeros(dd, n);
 236 
 237 opts.UT = true;
 238 
 239 % Prepare some values.
 240 for i=1:n
 241   pC(:,:,i) = P*C(:,:,i)*P';
 242   L = chol(pC(:,:,i));
 243   inv_L = linsolve(L, eye(dd), opts);
 244   inv_pC(:,:,i) = inv_L*inv_L';
 245   log_det(i) = 2*sum(log(diag(L)));
 246   pmu(:,i) = P*mu(:,i);
 247 end
 248 
 249 fx = 0; 
 250 grad = zeros(dd,d);
 251 
 252 % Loop over epochs.
 253 for i=1:n
 254   fx = fx - log_det(i);
 255   if ~nomeans, fx = fx + pmu(:,i)'*pmu(:,i); end
 256 
 257   grad = grad -2*inv_pC(:,:,i)*P*C(:,:,i);
 258   if ~nomeans, grad = grad + 2*P*mu(:,i)*mu(:,i)'; end
 259 end
 260 
 261 grad = [ grad; zeros(d-dd, d) ];
 262 grad = grad*R' - R*grad';
 263 
 264 grad = grad(1:dd,dd+1:end);
 265 grad = grad(:);
 266 
 267 %fx = (fx-dof)/sqrt(2*dof);
 268 
 269 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 270 function [ R, M ] = randrot(d)
 271 %RANDROT        Generate random orthogonal matrix. 
 272 %
 273 %usage
 274 %  [R,M] = randrot(d)
 275 %
 276 %author
 277 %  buenau@cs.tu-berlin.de
 278 
 279 M = 10*(rand(d,d)-0.5);
 280 M = 0.5*(M-M');
 281 R = expm(M);

Attached Files

To refer to attachments on a page, use attachment:filename, as shown below in the list of files. Do NOT use the URL of the [get] link, since this is subject to change and can break easily.
  • [get | view] (2011-06-22 07:45:04, 2873.9 KB) [[attachment:bio_intro.pdf]]
  • [get | view] (2011-05-02 16:51:13, 1341.1 KB) [[attachment:buchner_chap2007.pdf]]
  • [get | view] (2011-04-19 12:05:37, 3719.0 KB) [[attachment:cca_lecture.pdf]]
  • [get | view] (2011-07-20 11:25:53, 1778.9 KB) [[attachment:chemo_lecture.pdf]]
  • [get | view] (2011-05-16 22:58:50, 723.1 KB) [[attachment:deep_lecture.pdf]]
  • [get | view] (2011-04-11 15:08:58, 95.6 KB) [[attachment:full_sheet01.pdf]]
  • [get | view] (2011-04-19 11:33:46, 130.0 KB) [[attachment:full_sheet02.pdf]]
  • [get | view] (2011-05-02 16:53:54, 123.4 KB) [[attachment:full_sheet03.pdf]]
  • [get | view] (2011-05-10 10:03:36, 112.2 KB) [[attachment:full_sheet04.pdf]]
  • [get | view] (2011-05-17 09:43:04, 34.8 KB) [[attachment:full_sheet05.pdf]]
  • [get | view] (2011-05-23 12:18:47, 95.7 KB) [[attachment:full_sheet06.pdf]]
  • [get | view] (2011-05-30 13:35:32, 98.3 KB) [[attachment:full_sheet07.pdf]]
  • [get | view] (2011-06-07 15:23:13, 125.3 KB) [[attachment:full_sheet08.pdf]]
  • [get | view] (2011-06-21 08:43:34, 164.4 KB) [[attachment:full_sheet09.pdf]]
  • [get | view] (2011-07-05 09:39:18, 100.4 KB) [[attachment:full_sheet10.pdf]]
  • [get | view] (2011-05-31 20:33:20, 2656.3 KB) [[attachment:ids_lect.pdf]]
  • [get | view] (2011-05-24 15:25:30, 2121.4 KB) [[attachment:kern_struct.pdf]]
  • [get | view] (2011-06-14 10:50:39, 1169.4 KB) [[attachment:mkl_intro.pdf]]
  • [get | view] (2011-05-04 09:54:42, 512.3 KB) [[attachment:ml2_bss.pdf]]
  • [get | view] (2011-05-17 09:43:15, 7656.4 KB) [[attachment:mnist.mat]]
  • [get | view] (2011-06-07 18:29:45, 192.5 KB) [[attachment:opt_intro.pdf]]
  • [get | view] (2011-05-17 09:43:21, 1.7 KB) [[attachment:rbm.m]]
  • [get | view] (2011-05-02 16:54:43, 0.7 KB) [[attachment:sfft.m]]
  • [get | view] (2011-05-30 13:37:07, 4.3 KB) [[attachment:sheet07.m]]
  • [get | view] (2011-06-21 08:43:43, 1.1 KB) [[attachment:sheet09.m]]
  • [get | view] (2011-05-17 09:43:25, 0.1 KB) [[attachment:sigmoid.m]]
  • [get | view] (2011-06-21 08:43:50, 129.6 KB) [[attachment:splice-test-data.txt]]
  • [get | view] (2011-06-21 08:43:55, 5.4 KB) [[attachment:splice-test-label.txt]]
  • [get | view] (2011-06-21 08:44:00, 59.6 KB) [[attachment:splice-train-data.txt]]
  • [get | view] (2011-06-21 08:44:05, 2.5 KB) [[attachment:splice-train-label.txt]]
  • [get | view] (2011-04-18 07:30:47, 1515.8 KB) [[attachment:ssa_data.mat]]
  • [get | view] (2011-04-18 07:28:13, 585.7 KB) [[attachment:ssa_lecture.pdf]]
  • [get | view] (2011-04-18 07:30:20, 7.4 KB) [[attachment:ssa_simple.m]]
  • [get | view] (2011-07-05 11:05:27, 1202.4 KB) [[attachment:structured_lecture.pdf]]
  • [get | view] (2011-05-30 13:36:38, 1217.5 KB) [[attachment:stud-data.mat.gz]]
  • [get | view] (2011-04-18 07:29:43, 1.0 KB) [[attachment:tkcca_example.m]]
  • [get | view] (2011-04-18 07:29:54, 4.1 KB) [[attachment:tkcca_simple.m]]
  • [get | view] (2011-04-18 07:30:07, 150.9 KB) [[attachment:tkcca_toy_data.mat]]
  • [get | view] (2011-05-02 16:51:28, 4.5 KB) [[attachment:umfbss.m]]
  • [get | view] (2011-05-02 16:03:05, 531.3 KB) [[attachment:x1.wav]]
  • [get | view] (2011-05-02 16:03:12, 531.3 KB) [[attachment:x2.wav]]
 All files | Selected Files: delete move to page copy to page

You are not allowed to attach a file to this page.