Attachment 'ssa_simple.m'
Download 1 function [ Ps, An, y, converged, iter ] = ssa_simple(X, dd, nreps, max_iter, quiet, nomeans)
2 %SSA_SIMPLE Stationary Subspace Analysis.
3 %
4 %usage
5 % [Ps, An, y, converged, iter] = ssa_simple(X, dd, { nreps: 5, max_iter: 100, quiet: false, nomeans: false })
6 %
7 %input
8 % X Data in one of two possible formats:
9 % 1. cell array where each X{i} is a (d x n_i)-dataset
10 % 2. cell array with two elements:
11 % X{1} is a (d x n) matrix of epoch means and
12 % X{2} is a (d x d x n) array of epoch covariance matrices
13
14 % dd Dimensionality of stationary subspace
15 % nreps Optional: number of restarts w/ different init (default: none)
16 % max_iter Optional: maximum number of iterations (default: 100)
17 % quiet Optional: suppress output (default: false)
18 % nomeans Optional: Perform SSA w.r.t. to the covariance matrix only.
19 %
20 %output
21 % Ps Projection matrix to stationary sources (dd x d)
22 % An Estimated basis of the non-stationary subspace (d x d)
23 % y Minimum objective function value
24 % converged True, if the optimization has converged
25 % iter Number of iterations
26 %
27 %author
28 % buenau@cs.tu-berlin.de
29
30 % Set default parameters.
31 if ~exist('max_iter', 'var') || isempty(max_iter), max_iter = 100; end
32 if ~exist('nreps', 'var'), nreps = []; end
33 if ~exist('max_iter', 'var'), max_iter = 100; end
34 if ~exist('quiet', 'var'), quiet = false; end
35 if ~exist('nomeans', 'var'), nomeans = false; end
36
37 % Loop over repetitions and return the solution with the lowest objective function
38 % value.
39 if ~isempty(nreps)
40 r_Ps = cell(1, nreps);
41 r_An = cell(1, nreps);
42 r_y = zeros(1, nreps);
43 r_converged = zeros(1, nreps);
44 r_iter = zeros(1, nreps);
45
46 for i=1:nreps
47 [r_Ps{i}, r_An{i}, r_y(i), r_converged(i), r_iter(i) ] = ssa_simple(X, dd, [], max_iter, quiet, nomeans);
48 end
49
50 [foo, mini] = min(r_y);
51 Ps = r_Ps{mini};
52 An = r_An{mini};
53 y = r_y(mini);
54 iter = r_iter(mini);
55 converged = r_converged(mini);
56 return;
57 end
58
59 % Sanity check.
60 if length(X) < 2, error('X must contain at least two datasets or two matrices: X{1} contains the means and X{2} contains the covariance matrices\n'); end
61
62 % Distinguish the two parametrization variants: data or means+covariance matrices.
63 X_contains_data = (ndims(X{2}) == 2);
64
65 d = size(X{1}, 1);
66
67 % Sanity check.
68 if dd > d, error('Dimensionality of stationary subspace (dd) must be less than or equal to the dimensionality of the input data!\n'); end
69
70 % Distinguish different formats.
71
72 if X_contains_data
73 % If the input is data, compute means and covariance matrices per epoch
74 % and the whitening matrix.
75 n_X = length(X);
76 C = zeros(d, d, n_X);
77 mu = zeros(d, n_X);
78 all_data = [];
79 for i=1:n_X
80 all_data = [ all_data X{i} ];
81 mu(:,i) = mean(X{i}, 2);
82 C(:,:,i) = cov(X{i}');
83 end
84 else
85 % If the input is means+covmats, compute only the whitening matrix.
86 mu = X{1};
87 C = X{2};
88 n_X = size(mu, 2);
89 end
90
91 converged = false;
92 y_new = [];
93
94 % Parameters for backtracking linesearch.
95 ls_alpha = 0.5*(0.01+0.3);
96 ls_beta = 0.4;
97
98 if ~quiet, fprintf('*** iter=[iteration] y=[function value] ||grad||=[norm of gradient] ||step||=[norm of step] ([number of line search iterations]) rel_dec=[relative function value decrease in percent]\n'); end
99
100 % Centering and Whitening.
101 W = inv(sqrtm(squeeze(mean(C,3))));
102 mu = mu - repmat(mean(mu,2), [1 n_X]);
103
104 % Initialization: random rotation.
105 B = randrot(d)*W;
106
107 % Apply initialization to means and covariance matrices.
108 mu = B*mu;
109 C = mult3(C, B);
110
111 % Optimization loop.
112 for iter=1:max_iter
113 % Get current function value and gradient.
114 [y, grad] = objfun(zeros(dd*(d-dd),1), C, mu, d, dd, nomeans);
115
116 % Sanity check.
117 if ~isempty(y_new) && y_new ~= y, error('Something is utterly wrong.\n'); end
118
119 % Print progress (if not suppressed).
120 if ~quiet, fprintf('iter=%d y=%.5g ||grad||=%.5g ', iter, y, norm(grad)); end
121
122 % Conjugate gradient: compute search direction.
123 if iter == 1
124 alpha = -grad;
125 else
126 gamma = grad'*(grad-grad_old)/(grad_old'*grad_old);
127 alpha = -grad + gamma*alpha_old;
128 end
129 grad_old = grad;
130 alpha_old = alpha;
131
132 % Alpha is the current search direction, so we do a line-search along t*alpha
133 % Alpha is a vector, that contains the elements of "Z" (see paper)
134
135 % Normalize search direction.
136 alpha = alpha ./ (2*norm(alpha));
137
138 % Fill in nonzero values: this means: put the Z into the bigger M so that we
139 % we can do expm(M)
140 M_alpha = reshape(alpha, [dd (d-dd)]);
141 M_alpha = [ zeros(dd, dd) M_alpha; -M_alpha' zeros(d-dd, d-dd) ];
142
143 % Backtracking line search loop.
144 t = 1;
145 for j=1:10
146 M_new = t*M_alpha;
147 R = expm(M_new);
148
149 y_new = objfun(zeros(dd*(d-dd), 1), mult3(C, R), R*mu, d, dd, nomeans);
150
151 % Stop if function decrease is sufficient.
152 if y_new <= (y + ls_alpha*t*grad'*alpha)
153 break;
154 end
155
156 t = ls_beta*t;
157 end
158
159 % Stop if line search failed.
160 if y_new >= y
161 if ~quiet, fprintf('no step found\n'); end
162 converged = true;
163 break;
164 end
165
166 % Stop if relative function decrease is below threshold.
167 rel_dec = (y-y_new)/y;
168 rel_dec_thr = 1e-8;
169 if rel_dec < rel_dec_thr
170 if ~quiet, fprintf('rel_dec < %f\n', rel_dec_thr); end
171 converged = true;
172 break;
173 end
174
175 % Print progress.
176 if ~quiet, fprintf('||step||=%.3g (%d) rel_dec=%.3g%% y=%.3g\n', t, j, 100*rel_dec, y); end
177
178 % Rotate basis (= multiplicative update step).
179 C = mult3(C, R);
180 mu = R*mu;
181
182 B = R*B;
183 end
184
185 % Display warning message if algorithm has not converged.
186 if ~converged
187 if ~quiet, fprintf('Reached maximum number of iterations\n'); end
188 end
189
190 % Compute estimated mixing matrix.
191 A = inv(B);
192
193 % Split estimated de-mixing matrix into two projection matrices.
194 Ps = B(1:dd,:);
195 An = A(:,(dd+1):end);
196
197 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
198 function C = mult3(C, R)
199 % Compute R*C(:,:,i)*R' for all i.
200
201 [d1, d2, d3] = size(C);
202
203 % Multiply from the left with R.
204 C = reshape(C, [d1 d2*d3]);
205 C = reshape(R*C, [d1 d2 d3]);
206
207 % Multiply from the right with R'
208 C = permute(C, [2 1 3]);
209 C = reshape(C, [d1 d2*d3]);
210 C = reshape(R*C, [d1 d2 d3]);
211
212 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
213 % Objective function
214 function [fx, grad] = objfun(M, C, mu, d, dd, nomeans)
215
216 n = size(C, 3);
217
218 % Degrees of freedoms.
219 dof = n*(dd*(dd+1)/2 + dd);
220
221 % Fill in non-zero values.
222 M = reshape(M, [dd (d-dd)]);
223 M = [ zeros(dd, dd) M; -M' zeros(d-dd, d-dd) ];
224
225 % Compute rotation.
226 R = expm(M);
227
228 % Projection to stationary signals.
229 P = R(1:dd,:);
230
231 log_det = zeros(1, n);
232 inv_pC = zeros(dd, dd, n);
233 pC = zeros(dd, dd, n);
234 log_det = zeros(1, n);
235 pmu = zeros(dd, n);
236
237 opts.UT = true;
238
239 % Prepare some values.
240 for i=1:n
241 pC(:,:,i) = P*C(:,:,i)*P';
242 L = chol(pC(:,:,i));
243 inv_L = linsolve(L, eye(dd), opts);
244 inv_pC(:,:,i) = inv_L*inv_L';
245 log_det(i) = 2*sum(log(diag(L)));
246 pmu(:,i) = P*mu(:,i);
247 end
248
249 fx = 0;
250 grad = zeros(dd,d);
251
252 % Loop over epochs.
253 for i=1:n
254 fx = fx - log_det(i);
255 if ~nomeans, fx = fx + pmu(:,i)'*pmu(:,i); end
256
257 grad = grad -2*inv_pC(:,:,i)*P*C(:,:,i);
258 if ~nomeans, grad = grad + 2*P*mu(:,i)*mu(:,i)'; end
259 end
260
261 grad = [ grad; zeros(d-dd, d) ];
262 grad = grad*R' - R*grad';
263
264 grad = grad(1:dd,dd+1:end);
265 grad = grad(:);
266
267 %fx = (fx-dof)/sqrt(2*dof);
268
269 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
270 function [ R, M ] = randrot(d)
271 %RANDROT Generate random orthogonal matrix.
272 %
273 %usage
274 % [R,M] = randrot(d)
275 %
276 %author
277 % buenau@cs.tu-berlin.de
278
279 M = 10*(rand(d,d)-0.5);
280 M = 0.5*(M-M');
281 R = expm(M);
Attached Files
To refer to attachments on a page, use attachment:filename, as shown below in the list of files. Do NOT use the URL of the [get] link, since this is subject to change and can break easily.You are not allowed to attach a file to this page.