Home > examples > generalized_procrustes.m

generalized_procrustes

PURPOSE ^

Rotationally align clouds of points (generalized Procrustes problem)

SYNOPSIS ^

function [A, R] = generalized_procrustes(A_measure)

DESCRIPTION ^

 Rotationally align clouds of points (generalized Procrustes problem)

 function X = generalized_procrustes(A_measure)

 The input is a 3D matrix A_measure of size nxmxN. Each of the N slices
 A_measure(:, :, i) is a cloud of m points in R^n. These clouds are
 assumed to be (noisy) rotated versions of a reference cloud Atrue.
 This algorithm tries to find the optimal rotations to apply to the
 individual clouds such that they will match each other as much as
 possible following a least-squares cost.

 The output A is an estimate of the cloud Atrue (up to rotation). The
 output R is a 3D matrix of size nxnxN containing the rotation matrices
 such that R(:, :, i) * A is approximately equal to A_measure(:, :, i).

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [A, R] = generalized_procrustes(A_measure)
0002 % Rotationally align clouds of points (generalized Procrustes problem)
0003 %
0004 % function X = generalized_procrustes(A_measure)
0005 %
0006 % The input is a 3D matrix A_measure of size nxmxN. Each of the N slices
0007 % A_measure(:, :, i) is a cloud of m points in R^n. These clouds are
0008 % assumed to be (noisy) rotated versions of a reference cloud Atrue.
0009 % This algorithm tries to find the optimal rotations to apply to the
0010 % individual clouds such that they will match each other as much as
0011 % possible following a least-squares cost.
0012 %
0013 % The output A is an estimate of the cloud Atrue (up to rotation). The
0014 % output R is a 3D matrix of size nxnxN containing the rotation matrices
0015 % such that R(:, :, i) * A is approximately equal to A_measure(:, :, i).
0016 
0017 % This file is part of Manopt and is copyrighted. See the license file.
0018 %
0019 % Main author: Nicolas Boumal, July 8, 2013
0020 % Contributors:
0021 %
0022 % Change log:
0023 %
0024 %    Xiaowen Jiang Aug. 20, 2021
0025 %       Added AD to compute the grad and the hess
0026 
0027     if ~exist('A_measure', 'var')
0028         % Generate random data to test the method.
0029         % There are N clouds of m points in R^n. Each of them is a noisy,
0030         % rotated version of a reference cloud A. Rotations are uniformly
0031         % random and noise on each rotated cloud is iid normal with
0032         % standard deviation sigma.
0033         n = 3;
0034         m = 10;
0035         N = 50;
0036         % The reference cloud
0037         Atrue = randn(n, m);
0038         % A 3D matrix containing the N measured clouds
0039         sigma = .3;
0040         A_measure = multiprod(randrot(n, N), Atrue) + sigma*randn(n, m, N);
0041     else
0042         [n, m, N] = size(A_measure);
0043     end
0044     
0045     % Construct a manifold structure representing the product of groups of
0046     % rotations with the Euclidean space for A. We optimize simultaneously
0047     % for the reference cloud and for the rotations that affect each of the
0048     % measured clouds. Notice that there is a group invariance because
0049     % there is no way of telling which orientation the reference cloud
0050     % should be in.
0051     tuple.R = rotationsfactory(n, N);
0052     tuple.A = euclideanfactory(n, m);
0053     M = productmanifold(tuple);
0054 
0055     % Define the cost function here. Points on the manifold M are
0056     % structures with fields X.A and X.R, containing matrices of sizes
0057     % respectively nxm and nxnxN. The store structure (the caching system)
0058     % is used to keep the residue matrix E in memory, as it is also used in
0059     % the computation of the gradient and of the Hessian. This way, we
0060     % prevent redundant computations.
0061     function [f, store] = cost(X, store)
0062         if ~isfield(store, 'E')
0063             R = X.R;
0064             A = X.A;
0065             store.E = multiprod(R, A) - A_measure;
0066         end
0067         E = store.E;
0068         f = (E(:)'*E(:))/(2*N);
0069     end
0070 
0071     % Riemannian gradient of the cost function.
0072     function [g, store] = grad(X, store)
0073         R = X.R;
0074         A = X.A;
0075         if ~isfield(store, 'E')
0076             [~, store] = cost(X, store);
0077         end
0078         E = store.E;
0079         % Compute the Euclidean gradient of the cost wrt the rotations R
0080         % and wrt the cloud A,
0081         egrad.R = multiprod(E, A'/N);
0082         egrad.A = A - mean(multiprod(multitransp(R), A_measure), 3);
0083         % then transform this Euclidean gradient into the Riemannian
0084         % gradient.
0085         g = M.egrad2rgrad(X, egrad);
0086         store.egrad = egrad;
0087     end
0088 
0089     % It is not necessary to define the Hessian of the cost. We do it
0090     % mostly to illustrate how to do it and to study the spectrum of the
0091     % Hessian at the solution (see further down).
0092     function [h, store] = hess(X, Xdot, store)
0093         R = X.R;
0094         A = X.A;
0095         % Careful: tangent vectors on the rotation group are represented as
0096         % skew symmetric matrices. To obtain the corresponding vectors in
0097         % the ambient space, we need a little transformation. This
0098         % transformation is typically not needed when we compute the
0099         % formulas for the gradient and the Hessian directly in Riemannian
0100         % form instead of resorting the egrad2rgrad and ehess2rhess. These
0101         % latter tools are convenient for prototyping but are not always
0102         % the most efficient form to execute the computations.
0103         Rdot = tuple.R.tangent2ambient(R, Xdot.R);
0104         Adot = Xdot.A;
0105         if ~isfield(store, 'egrad')
0106             [~, store] = grad(X, store);
0107         end
0108         E = store.E;
0109         egrad = store.egrad;
0110         
0111         ehess.R = multiprod(multiprod(Rdot, A) + multiprod(R, Adot), A') + ...
0112                   multiprod(E, Adot');
0113         ehess.R = ehess.R / N;
0114         ehess.A = Adot-mean(multiprod(multitransp(Rdot), A_measure), 3);
0115         
0116         h = M.ehess2rhess(X, egrad, ehess, Xdot);
0117     end
0118 
0119     % Setup the problem structure with manifold M and cost+grad functions.
0120     problem.M = M;
0121     problem.cost = @cost;
0122     problem.grad = @grad;
0123     problem.hess = @hess;
0124 
0125     % An alternative way to compute the gradient and the hessian is to use
0126     % automatic differentiation provided in the deep learning toolbox (slower)
0127     % problem.cost = @cost_AD;
0128     %    function f = cost_AD(X)
0129     %        R = X.R;
0130     %        A = X.A;
0131     %        E = multiprod(R, A) - A_measure;
0132     %        f = (E(:)'*E(:))/(2*N);
0133     %    end
0134     % call manoptAD to prepare AD for the problem structure
0135     % problem = manoptAD(problem);
0136     
0137     % For debugging, it's always nice to check the gradient a few times.
0138     % checkgradient(problem);
0139     % pause;
0140     % checkhessian(problem);
0141     % pause;
0142     
0143     % Call a solver on our problem. This can probably be much improved if a
0144     % clever initial guess is used instead of a random one.
0145     X = trustregions(problem);
0146     A = X.A;
0147     R = X.R;
0148     
0149     % To evaluate the performance of the algorithm, see how well Atrue (the
0150     % reference cloud) matches A (the found cloud). Since the recovery is
0151     % up to rotation, apply Kabsch algorithm (or standard Procrustes),
0152     % i.e., compute the polar factorization to best align Atrue and A.
0153     if exist('Atrue', 'var')
0154         [U, ~, V] = svd(Atrue*A');
0155         Ahat = (U*V')*A;
0156         fprintf('Registration error: %g.\n', norm(Atrue-Ahat, 'fro'));
0157     end
0158     
0159     % Plot the spectrum of the Hessian at the solution found.
0160     % Notice that the invariance of f under a rotation yields dim SO(n),
0161     % that is, n*(n-1)/2 zero eigenvalues in the Hessian spectrum at the
0162     % solution. This indicates that critical points are not isolated and
0163     % can theoretically prevent quadratic convergence. One solution to
0164     % circumvent this would be to fix one rotation arbitrarily. Another
0165     % solution would be to work on a quotient manifold. Both can be
0166     % achieved in Manopt: they simply require a little more work on the
0167     % manifold description side.
0168     if M.dim() <= 512
0169         stairs(sort(hessianspectrum(problem, X)));
0170         title('Spectrum of the Hessian at the solution found.');
0171         xlabel('Eigenvalue number (sorted)');
0172         ylabel('Value of the eigenvalue');
0173     end
0174     
0175 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005