Home > manopt > manifolds > stiefel > stiefelfactory.m

stiefelfactory

PURPOSE ^

Returns a manifold structure to optimize over orthonormal matrices.

SYNOPSIS ^

function M = stiefelfactory(n, p, k, gpuflag)

DESCRIPTION ^

 Returns a manifold structure to optimize over orthonormal matrices.

 function M = stiefelfactory(n, p)
 function M = stiefelfactory(n, p, k)
 function M = stiefelfactory(n, p, k, gpuflag)

 The Stiefel manifold is the set of orthonormal nxp matrices. If k
 is larger than 1, this is the Cartesian product of the Stiefel manifold
 taken k times. The metric is such that the manifold is a Riemannian
 submanifold of R^nxp equipped with the usual trace inner product, that
 is, it is the usual metric.

 Points are represented as matrices X of size n x p x k (or n x p if k=1,
 which is the default) such that each n x p matrix is orthonormal,
 i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for
 i = 1 : k if k > 1. Tangent vectors are represented as matrices the same
 size as points.

 Set gpuflag = true to have points, tangent vectors and ambient vectors
 stored on the GPU. If so, computations can be done on the GPU directly.

 By default, k = 1 and gpuflag = false.

 See also: grassmannfactory rotationsfactory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function M = stiefelfactory(n, p, k, gpuflag)
0002 % Returns a manifold structure to optimize over orthonormal matrices.
0003 %
0004 % function M = stiefelfactory(n, p)
0005 % function M = stiefelfactory(n, p, k)
0006 % function M = stiefelfactory(n, p, k, gpuflag)
0007 %
0008 % The Stiefel manifold is the set of orthonormal nxp matrices. If k
0009 % is larger than 1, this is the Cartesian product of the Stiefel manifold
0010 % taken k times. The metric is such that the manifold is a Riemannian
0011 % submanifold of R^nxp equipped with the usual trace inner product, that
0012 % is, it is the usual metric.
0013 %
0014 % Points are represented as matrices X of size n x p x k (or n x p if k=1,
0015 % which is the default) such that each n x p matrix is orthonormal,
0016 % i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for
0017 % i = 1 : k if k > 1. Tangent vectors are represented as matrices the same
0018 % size as points.
0019 %
0020 % Set gpuflag = true to have points, tangent vectors and ambient vectors
0021 % stored on the GPU. If so, computations can be done on the GPU directly.
0022 %
0023 % By default, k = 1 and gpuflag = false.
0024 %
0025 % See also: grassmannfactory rotationsfactory
0026 
0027 % This file is part of Manopt: www.manopt.org.
0028 % Original author: Nicolas Boumal, Dec. 30, 2012.
0029 % Contributors:
0030 % Change log:
0031 %  July  5, 2013 (NB) : Added ehess2rhess.
0032 %  Jan. 27, 2014 (BM) : Bug in ehess2rhess corrected.
0033 %  June 24, 2014 (NB) : Added true exponential map and changed the randvec
0034 %                       function so that it now returns a globally
0035 %                       normalized vector, not a vector where each
0036 %                       component is normalized (this only matters if k>1).
0037 %  July 17, 2018 (NB) : Now both QR (default) and polar retractions are
0038 %                       directly accessible, and their inverses are also
0039 %                       implemented.
0040 %  Aug.  2, 2018 (NB) : Added GPU support: just set gpuflag = true.
0041 
0042     assert(n >= p, 'The dimension n must be larger than the dimension p.');
0043     
0044     if ~exist('k', 'var') || isempty(k)
0045         k = 1;
0046     end
0047     if ~exist('gpuflag', 'var') || isempty(gpuflag)
0048         gpuflag = false;
0049     end
0050     
0051     % If gpuflag is active, new arrays (e.g., via rand, randn, zeros, ones)
0052     % are created directly on the GPU; otherwise, they are created in the
0053     % usual way (in double precision).
0054     if gpuflag
0055         array_type = 'gpuArray';
0056     else
0057         array_type = 'double';
0058     end
0059     
0060     if k == 1
0061         M.name = @() sprintf('Stiefel manifold St(%d, %d)', n, p);
0062     elseif k > 1
0063         M.name = @() sprintf('Product Stiefel manifold St(%d, %d)^%d', n, p, k);
0064     else
0065         error('k must be an integer no less than 1.');
0066     end
0067     
0068     M.dim = @() k*(n*p - .5*p*(p+1));
0069     
0070     M.inner = @(x, d1, d2) d1(:).'*d2(:);
0071     
0072     M.norm = @(x, d) norm(d(:));
0073     
0074     M.dist = @(x, y) error('stiefel.dist not implemented yet.');
0075     
0076     M.typicaldist = @() sqrt(p*k);
0077     
0078     M.proj = @projection;
0079     function Up = projection(X, U)
0080         
0081         XtU = multiprod(multitransp(X), U);
0082         symXtU = multisym(XtU);
0083         Up = U - multiprod(X, symXtU);
0084         
0085 % The code above is equivalent to, but much faster than, the code below.
0086 %
0087 %     Up = zeros(size(U));
0088 %     function A = sym(A), A = .5*(A+A'); end
0089 %     for i = 1 : k
0090 %         Xi = X(:, :, i);
0091 %         Ui = U(:, :, i);
0092 %         Up(:, :, i) = Ui - Xi*sym(Xi'*Ui);
0093 %     end
0094 
0095     end
0096     
0097     M.tangent = M.proj;
0098     
0099     % For Riemannian submanifolds, converting a Euclidean gradient into a
0100     % Riemannian gradient amounts to an orthogonal projection.
0101     M.egrad2rgrad = M.proj;
0102     
0103     M.ehess2rhess = @ehess2rhess;
0104     function rhess = ehess2rhess(X, egrad, ehess, H)
0105         XtG = multiprod(multitransp(X), egrad);
0106         symXtG = multisym(XtG);
0107         HsymXtG = multiprod(H, symXtG);
0108         rhess = projection(X, ehess - HsymXtG);
0109     end
0110     
0111     M.retr_qr = @retraction_qr;
0112     function Y = retraction_qr(X, U, t)
0113         if nargin < 3
0114             Y = X + U;
0115         else
0116             Y = X + t*U;
0117         end
0118         for kk = 1 : k
0119             [Q, R] = qr(Y(:, :, kk), 0);
0120             % The instruction with R ensures we are not flipping signs
0121             % of some columns, which should never happen in modern Matlab
0122             % versions but may be an issue with older versions.
0123             Y(:, :, kk) = Q * diag(sign(sign(diag(R))+.5));
0124         end
0125     end
0126 
0127     % This inverse retraction is valid for both the QR retraction and the
0128     % polar retraction.
0129     M.invretr_qr = @invretr_qr;
0130     function U = invretr_qr(X, Y)
0131         XtY = multiprod(multitransp(X), Y);
0132         R = zeros(p, p, k, array_type);
0133         H = 2*eye(p, array_type);
0134         for kk = 1 : k
0135             % For each slice, assuming the inverse retraction is well
0136             % defined for the given inputs, we have:
0137             %   X + U = YR
0138             % Left multiply with X' to get
0139             %   I + X'U = X'Y M
0140             % Since X'U is skew symmetric for a tangent vector U at X, add
0141             % up this equation with its transpose to get:
0142             %   2I = (X'Y) R + R' (X'Y)'
0143             % Contrary to the polar factorization, here R is not symmetric
0144             % but it is upper triangular. As a result, this is not a
0145             % Sylvester equation and we must solve it differently.
0146             R(:, :, kk) = solve_for_triu(XtY(:, :, kk), H);
0147             % Then,
0148             %   U = YR - X
0149             % which is what we compute below.
0150         end
0151         U = multiprod(Y, R) - X;
0152     end
0153     
0154     M.retr_polar = @retraction_polar;
0155     function Y = retraction_polar(X, U, t)
0156         if nargin < 3
0157             Y = X + U;
0158         else
0159             Y = X + t*U;
0160         end
0161         for kk = 1 : k
0162             [u, s, v] = svd(Y(:, :, kk), 'econ'); %#ok
0163             Y(:, :, kk) = u*v';
0164         end
0165     end
0166     
0167     % This inverse retraction is valid for both the QR retraction and the
0168     % polar retraction.
0169     M.invretr_polar = @invretr_polar;
0170     function U = invretr_polar(X, Y)
0171         XtY = multiprod(multitransp(X), Y);
0172         MM = zeros(p, p, k, array_type);
0173         H = 2*eye(p, array_type);
0174         for kk = 1 : k
0175             % For each slice, assuming the inverse retraction is well
0176             % defined for the given inputs, we have:
0177             %   X + U = YM
0178             % Left multiply with X' to get
0179             %   I + X'U = X'Y M
0180             % Since X'U is skew symmetric for a tangent vector U at X, add
0181             % up this equation with its transpose to get:
0182             %   2I = (X'Y) M + M' (X'Y)'
0183             %      = (X'Y) M + M (X'Y)'   since M is symmetric.
0184             % Solve for M symmetric with a call to sylvester:
0185             MM(:, :, kk) = sylvester_nochecks(XtY(:, :, kk), XtY(:, :, kk)', H);
0186             % Note that the above is really a Lyapunov equation: it could
0187             % be solved faster by exploiting the fact the same matrix
0188             % appears twice on the left, with one the transpose of the
0189             % other. Then,
0190             %   U = YM - X
0191             % which is what we compute below.
0192         end
0193         U = multiprod(Y, MM) - X;
0194     end
0195     
0196     % By default, we use the QR retraction
0197     M.retr = M.retr_qr;
0198     M.invretr = M.invretr_qr;
0199 
0200     M.exp = @exponential;
0201     function Y = exponential(X, U, t)
0202         if nargin == 2
0203             tU = U;
0204         else
0205             tU = t*U;
0206         end
0207         Y = zeros(size(X), array_type);
0208         I = eye(p, array_type);
0209         Z = zeros(p, array_type);
0210         for kk = 1 : k
0211             % From a formula by Ross Lippert, Example 5.4.2 in AMS08.
0212             Xkk = X(:, :, kk);
0213             Ukk = tU(:, :, kk);
0214             Y(:, :, kk) = [Xkk Ukk] * ...
0215                          expm([Xkk'*Ukk , -Ukk'*Ukk ; I , Xkk'*Ukk]) * ...
0216                          [ expm(-Xkk'*Ukk) ; Z ];
0217         end
0218         
0219     end
0220 
0221     M.hash = @(X) ['z' hashmd5(X(:))];
0222     
0223     M.rand = @random;
0224     function X = random()
0225         X = randn(n, p, k, array_type);
0226         for kk = 1 : k
0227             [Q, unused] = qr(X(:, :, kk), 0);  %#ok<ASGLU>
0228             X(:, :, kk) = Q;
0229         end
0230     end
0231     
0232     M.randvec = @randomvec;
0233     function U = randomvec(X)
0234         U = projection(X, randn(n, p, k, array_type));
0235         U = U / norm(U(:));
0236     end
0237     
0238     M.lincomb = @matrixlincomb;
0239     
0240     M.zerovec = @(x) zeros(n, p, k, array_type);
0241     
0242     M.transp = @(x1, x2, d) projection(x2, d);
0243     
0244     M.vec = @(x, u_mat) u_mat(:);
0245     M.mat = @(x, u_vec) reshape(u_vec, [n, p, k]);
0246     M.vecmatareisometries = @() true;
0247 
0248     
0249     % Automatically convert a number of tools to support GPU.
0250     if gpuflag
0251         M = factorygpuhelper(M);
0252     end
0253     
0254 end

Generated on Mon 10-Sep-2018 11:48:06 by m2html © 2005