Home > manopt > manifolds > stiefel > stiefelfactory.m

stiefelfactory

PURPOSE ^

Returns a manifold structure to optimize over orthonormal matrices.

SYNOPSIS ^

function M = stiefelfactory(n, p, k, gpuflag)

DESCRIPTION ^

 Returns a manifold structure to optimize over orthonormal matrices.

 function M = stiefelfactory(n, p)
 function M = stiefelfactory(n, p, k)
 function M = stiefelfactory(n, p, k, gpuflag)

 The Stiefel manifold is the set of orthonormal nxp matrices. If k
 is larger than 1, this is the Cartesian product of the Stiefel manifold
 taken k times. The metric is such that the manifold is a Riemannian
 submanifold of R^nxp equipped with the usual trace inner product, that
 is, it is the usual metric.

 Points are represented as matrices X of size n x p x k (or n x p if k=1,
 which is the default) such that each n x p matrix is orthonormal,
 i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for
 i = 1 : k if k > 1. Tangent vectors are represented as matrices the same
 size as points.

 The default retraction is QR-based: it is only a first-order retraction.
 To use the polar retraction (which is second order), run
    M.retr = M.retr_polar;
 after creating M with this factory. This can be reverted with
    M.retr = M.retr_qr;
 If used, you may also want to update M.invretr similarly.

 Set gpuflag = true to have points, tangent vectors and ambient vectors
 stored on the GPU. If so, computations can be done on the GPU directly.

 By default, k = 1 and gpuflag = false.

 See also: grassmannfactory rotationsfactory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function M = stiefelfactory(n, p, k, gpuflag)
0002 % Returns a manifold structure to optimize over orthonormal matrices.
0003 %
0004 % function M = stiefelfactory(n, p)
0005 % function M = stiefelfactory(n, p, k)
0006 % function M = stiefelfactory(n, p, k, gpuflag)
0007 %
0008 % The Stiefel manifold is the set of orthonormal nxp matrices. If k
0009 % is larger than 1, this is the Cartesian product of the Stiefel manifold
0010 % taken k times. The metric is such that the manifold is a Riemannian
0011 % submanifold of R^nxp equipped with the usual trace inner product, that
0012 % is, it is the usual metric.
0013 %
0014 % Points are represented as matrices X of size n x p x k (or n x p if k=1,
0015 % which is the default) such that each n x p matrix is orthonormal,
0016 % i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for
0017 % i = 1 : k if k > 1. Tangent vectors are represented as matrices the same
0018 % size as points.
0019 %
0020 % The default retraction is QR-based: it is only a first-order retraction.
0021 % To use the polar retraction (which is second order), run
0022 %    M.retr = M.retr_polar;
0023 % after creating M with this factory. This can be reverted with
0024 %    M.retr = M.retr_qr;
0025 % If used, you may also want to update M.invretr similarly.
0026 %
0027 % Set gpuflag = true to have points, tangent vectors and ambient vectors
0028 % stored on the GPU. If so, computations can be done on the GPU directly.
0029 %
0030 % By default, k = 1 and gpuflag = false.
0031 %
0032 % See also: grassmannfactory rotationsfactory
0033 
0034 % This file is part of Manopt: www.manopt.org.
0035 % Original author: Nicolas Boumal, Dec. 30, 2012.
0036 % Contributors:
0037 % Change log:
0038 %  July  5, 2013 (NB) : Added ehess2rhess.
0039 %  Jan. 27, 2014 (BM) : Bug in ehess2rhess corrected.
0040 %  June 24, 2014 (NB) : Added true exponential map and changed the randvec
0041 %                       function so that it now returns a globally
0042 %                       normalized vector, not a vector where each
0043 %                       component is normalized (this only matters if k>1).
0044 %  July 17, 2018 (NB) : Now both QR (default) and polar retractions are
0045 %                       directly accessible, and their inverses are also
0046 %                       implemented.
0047 %  Aug.  2, 2018 (NB) : Added GPU support: just set gpuflag = true.
0048 %  June 18, 2019 (NB) : Using qr_unique for retr and rand.
0049 %  July  9, 2019 (NB) : Added a comment about QR retraction being first
0050 %                       order only.
0051 %  Jan.  8, 2021 (NB) : Added tangent2ambient+tangent2ambient_is_identity.
0052 
0053     assert(n >= p, 'The dimension n must be larger than the dimension p.');
0054     
0055     if ~exist('k', 'var') || isempty(k)
0056         k = 1;
0057     end
0058     if ~exist('gpuflag', 'var') || isempty(gpuflag)
0059         gpuflag = false;
0060     end
0061     
0062     % If gpuflag is active, new arrays (e.g., via rand, randn, zeros, ones)
0063     % are created directly on the GPU; otherwise, they are created in the
0064     % usual way (in double precision).
0065     if gpuflag
0066         array_type = 'gpuArray';
0067     else
0068         array_type = 'double';
0069     end
0070     
0071     if k == 1
0072         M.name = @() sprintf('Stiefel manifold St(%d, %d)', n, p);
0073     elseif k > 1
0074         M.name = @() sprintf('Product Stiefel manifold St(%d, %d)^%d', n, p, k);
0075     else
0076         error('k must be an integer no less than 1.');
0077     end
0078     
0079     M.dim = @() k*(n*p - .5*p*(p+1));
0080     
0081     M.inner = @(x, d1, d2) d1(:).'*d2(:);
0082     
0083     M.norm = @(x, d) norm(d(:));
0084     
0085     M.dist = @(x, y) error('stiefel.dist not implemented yet.');
0086     
0087     M.typicaldist = @() sqrt(p*k);
0088     
0089     M.proj = @projection;
0090     function Up = projection(X, U)
0091         
0092         XtU = multiprod(multitransp(X), U);
0093         symXtU = multisym(XtU);
0094         Up = U - multiprod(X, symXtU);
0095         
0096 % The code above is equivalent to, but faster than, the code below.
0097 %
0098 %     Up = zeros(size(U));
0099 %     function A = sym(A), A = .5*(A+A'); end
0100 %     for i = 1 : k
0101 %         Xi = X(:, :, i);
0102 %         Ui = U(:, :, i);
0103 %         Up(:, :, i) = Ui - Xi*sym(Xi'*Ui);
0104 %     end
0105 
0106     end
0107     
0108     M.tangent = M.proj;
0109     
0110     M.tangent2ambient_is_identity = true;
0111     M.tangent2ambient = @(X, U) U;
0112     
0113     % For Riemannian submanifolds, converting a Euclidean gradient into a
0114     % Riemannian gradient amounts to an orthogonal projection.
0115     M.egrad2rgrad = M.proj;
0116     
0117     M.ehess2rhess = @ehess2rhess;
0118     function rhess = ehess2rhess(X, egrad, ehess, H)
0119         XtG = multiprod(multitransp(X), egrad);
0120         symXtG = multisym(XtG);
0121         HsymXtG = multiprod(H, symXtG);
0122         rhess = projection(X, ehess - HsymXtG);
0123     end
0124     
0125     M.retr_qr = @retraction_qr;
0126     function Y = retraction_qr(X, U, t)
0127         % It is necessary to call qr_unique rather than simply qr to ensure
0128         % this is a retraction, to avoid spurious column sign flips.
0129         if nargin < 3
0130             Y = qr_unique(X + U);
0131         else
0132             Y = qr_unique(X + t*U);
0133         end
0134     end
0135 
0136     M.invretr_qr = @invretr_qr;
0137     function U = invretr_qr(X, Y)
0138         XtY = multiprod(multitransp(X), Y);
0139         R = zeros(p, p, k, array_type);
0140         H = 2*eye(p, array_type);
0141         for kk = 1 : k
0142             % For each slice, assuming the inverse retraction is well
0143             % defined for the given inputs, we have:
0144             %   X + U = YR
0145             % Left multiply with X' to get
0146             %   I + X'U = X'Y M
0147             % Since X'U is skew symmetric for a tangent vector U at X, add
0148             % up this equation with its transpose to get:
0149             %   2I = (X'Y) R + R' (X'Y)'
0150             % Contrary to the polar factorization, here R is not symmetric
0151             % but it is upper triangular. As a result, this is not a
0152             % Sylvester equation and we must solve it differently.
0153             R(:, :, kk) = solve_for_triu(XtY(:, :, kk), H);
0154             % Then,
0155             %   U = YR - X
0156             % which is what we compute below.
0157         end
0158         U = multiprod(Y, R) - X;
0159     end
0160     
0161     M.retr_polar = @retraction_polar;
0162     function Y = retraction_polar(X, U, t)
0163         if nargin < 3
0164             Y = X + U;
0165         else
0166             Y = X + t*U;
0167         end
0168         for kk = 1 : k
0169             [u, s, v] = svd(Y(:, :, kk), 'econ'); %#ok
0170             Y(:, :, kk) = u*v';
0171         end
0172     end
0173     
0174     M.invretr_polar = @invretr_polar;
0175     function U = invretr_polar(X, Y)
0176         XtY = multiprod(multitransp(X), Y);
0177         MM = zeros(p, p, k, array_type);
0178         H = 2*eye(p, array_type);
0179         for kk = 1 : k
0180             % For each slice, assuming the inverse retraction is well
0181             % defined for the given inputs, we have:
0182             %   X + U = YM
0183             % Left multiply with X' to get
0184             %   I + X'U = X'Y M
0185             % Since X'U is skew symmetric for a tangent vector U at X, add
0186             % up this equation with its transpose to get:
0187             %   2I = (X'Y) M + M' (X'Y)'
0188             %      = (X'Y) M + M (X'Y)'   since M is symmetric.
0189             % Solve for M symmetric with a call to sylvester:
0190             MM(:, :, kk) = sylvester_nochecks(XtY(:, :, kk), XtY(:, :, kk)', H);
0191             % Note that the above is really a Lyapunov equation: it could
0192             % be solved faster by exploiting the fact the same matrix
0193             % appears twice on the left, with one the transpose of the
0194             % other. Then,
0195             %   U = YM - X
0196             % which is what we compute below.
0197         end
0198         U = multiprod(Y, MM) - X;
0199     end
0200     
0201     % By default, we use the QR retraction
0202     M.retr = M.retr_qr;
0203     M.invretr = M.invretr_qr;
0204 
0205     M.exp = @exponential;
0206     function Y = exponential(X, U, t)
0207         if nargin == 2
0208             tU = U;
0209         else
0210             tU = t*U;
0211         end
0212         Y = zeros(size(X), array_type);
0213         I = eye(p, array_type);
0214         Z = zeros(p, array_type);
0215         for kk = 1 : k
0216             % From a formula by Ross Lippert, Example 5.4.2 in AMS08.
0217             Xkk = X(:, :, kk);
0218             Ukk = tU(:, :, kk);
0219             Y(:, :, kk) = [Xkk Ukk] * ...
0220                          expm([Xkk'*Ukk , -Ukk'*Ukk ; I , Xkk'*Ukk]) * ...
0221                          [ expm(-Xkk'*Ukk) ; Z ];
0222         end
0223         
0224     end
0225 
0226     M.hash = @(X) ['z' hashmd5(X(:))];
0227     
0228     M.rand = @() qr_unique(randn(n, p, k, array_type));
0229     
0230     M.randvec = @randomvec;
0231     function U = randomvec(X)
0232         U = projection(X, randn(n, p, k, array_type));
0233         U = U / norm(U(:));
0234     end
0235     
0236     M.lincomb = @matrixlincomb;
0237     
0238     M.zerovec = @(x) zeros(n, p, k, array_type);
0239     
0240     M.transp = @(x1, x2, d) projection(x2, d);
0241     
0242     M.vec = @(x, u_mat) u_mat(:);
0243     M.mat = @(x, u_vec) reshape(u_vec, [n, p, k]);
0244     M.vecmatareisometries = @() true;
0245 
0246     
0247     % Automatically convert a number of tools to support GPU.
0248     if gpuflag
0249         M = factorygpuhelper(M);
0250     end
0251     
0252 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005