Home > manopt > manifolds > grassmann > grassmannfactory.m

grassmannfactory

PURPOSE ^

Returns a manifold struct to optimize over the space of vector subspaces.

SYNOPSIS ^

function M = grassmannfactory(n, p, k, gpuflag)

DESCRIPTION ^

 Returns a manifold struct to optimize over the space of vector subspaces.

 function M = grassmannfactory(n, p)
 function M = grassmannfactory(n, p, k)
 function M = grassmannfactory(n, p, k, gpuflag)

 Grassmann manifold: each point on this manifold is a collection of k
 vector subspaces of dimension p embedded in R^n.

 The metric is obtained by making the Grassmannian a Riemannian quotient
 manifold of the Stiefel manifold, i.e., the manifold of orthonormal
 matrices, itself endowed with a metric by making it a Riemannian
 submanifold of the Euclidean space, endowed with the usual inner product.
 In short: it is the usual metric used in most cases.
 
 This structure deals with matrices X of size n x p x k (or n x p if
 k = 1, which is the default) such that each n x p matrix is orthonormal,
 i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for
 i = 1 : k if k > 1. Each n x p matrix is a numerical representation of
 the vector subspace its columns span.

 Set gpuflag = true to have points, tangent vectors and ambient vectors
 stored on the GPU. If so, computations can be done on the GPU directly.

 By default, k = 1 and gpuflag = false.

 See also: stiefelfactory grassmanncomplexfactory grassmanngeneralizedfactory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function M = grassmannfactory(n, p, k, gpuflag)
0002 % Returns a manifold struct to optimize over the space of vector subspaces.
0003 %
0004 % function M = grassmannfactory(n, p)
0005 % function M = grassmannfactory(n, p, k)
0006 % function M = grassmannfactory(n, p, k, gpuflag)
0007 %
0008 % Grassmann manifold: each point on this manifold is a collection of k
0009 % vector subspaces of dimension p embedded in R^n.
0010 %
0011 % The metric is obtained by making the Grassmannian a Riemannian quotient
0012 % manifold of the Stiefel manifold, i.e., the manifold of orthonormal
0013 % matrices, itself endowed with a metric by making it a Riemannian
0014 % submanifold of the Euclidean space, endowed with the usual inner product.
0015 % In short: it is the usual metric used in most cases.
0016 %
0017 % This structure deals with matrices X of size n x p x k (or n x p if
0018 % k = 1, which is the default) such that each n x p matrix is orthonormal,
0019 % i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for
0020 % i = 1 : k if k > 1. Each n x p matrix is a numerical representation of
0021 % the vector subspace its columns span.
0022 %
0023 % Set gpuflag = true to have points, tangent vectors and ambient vectors
0024 % stored on the GPU. If so, computations can be done on the GPU directly.
0025 %
0026 % By default, k = 1 and gpuflag = false.
0027 %
0028 % See also: stiefelfactory grassmanncomplexfactory grassmanngeneralizedfactory
0029 
0030 % This file is part of Manopt: www.manopt.org.
0031 % Original author: Nicolas Boumal, Dec. 30, 2012.
0032 % Contributors:
0033 % Change log:
0034 %   March 22, 2013 (NB):
0035 %       Implemented geodesic distance.
0036 %
0037 %   April 17, 2013 (NB):
0038 %       Retraction changed to the polar decomposition, so that the vector
0039 %       transport is now correct, in the sense that it is compatible with
0040 %       the retraction, i.e., transporting a tangent vector G from U to V
0041 %       where V = Retr(U, H) will give Z, and transporting GQ from UQ to VQ
0042 %       will give ZQ: there is no dependence on the representation, which
0043 %       is as it should be. Notice that the polar factorization requires an
0044 %       SVD whereas the qfactor retraction requires a QR decomposition,
0045 %       which is cheaper. Hence, if the retraction happens to be a
0046 %       bottleneck in your application and you are not using vector
0047 %       transports, you may want to replace the retraction with a qfactor.
0048 %
0049 %   July  4, 2013 (NB):
0050 %       Added support for the logarithmic map 'log'.
0051 %
0052 %   July  5, 2013 (NB):
0053 %       Added support for ehess2rhess.
0054 %
0055 %   June 24, 2014 (NB):
0056 %       Small bug fix in the retraction, and added final
0057 %       re-orthonormalization at the end of the exponential map. This
0058 %       follows discussions on the forum where it appeared there is a
0059 %       significant loss in orthonormality without that extra step. Also
0060 %       changed the randvec function so that it now returns a globally
0061 %       normalized vector, not a vector where each component is normalized
0062 %       (this only matters if k>1).
0063 %
0064 %   July 8, 2018 (NB):
0065 %       Inverse retraction implemented.
0066 %
0067 %   Aug. 3, 2018 (NB):
0068 %       Added GPU support: just set gpuflag = true.
0069 
0070     assert(n >= p, ...
0071            ['The dimension n of the ambient space must be larger ' ...
0072             'than the dimension p of the subspaces.']);
0073     
0074     if ~exist('k', 'var') || isempty(k)
0075         k = 1;
0076     end
0077     if ~exist('gpuflag', 'var') || isempty(gpuflag)
0078         gpuflag = false;
0079     end
0080     
0081     % If gpuflag is active, new arrays (e.g., via rand, randn, zeros, ones)
0082     % are created directly on the GPU; otherwise, they are created in the
0083     % usual way (in double precision).
0084     if gpuflag
0085         array_type = 'gpuArray';
0086     else
0087         array_type = 'double';
0088     end
0089     
0090     if k == 1
0091         M.name = @() sprintf('Grassmann manifold Gr(%d, %d)', n, p);
0092     elseif k > 1
0093         M.name = @() sprintf('Multi Grassmann manifold Gr(%d, %d)^%d', ...
0094                              n, p, k);
0095     else
0096         error('k must be an integer no less than 1.');
0097     end
0098     
0099     M.dim = @() k*p*(n-p);
0100     
0101     M.inner = @(x, d1, d2) d1(:).'*d2(:);
0102     
0103     M.norm = @(x, d) norm(d(:));
0104     
0105     M.dist = @distance;
0106     function d = distance(x, y)
0107         square_d = 0;
0108         XtY = multiprod(multitransp(x), y);
0109         for kk = 1 : k
0110             cos_princ_angle = svd(XtY(:, :, kk));
0111             square_d = square_d + sum(real(acos(cos_princ_angle)).^2);
0112         end
0113         d = sqrt(square_d);
0114     end
0115     
0116     M.typicaldist = @() sqrt(p*k);
0117     
0118     % Orthogonal projection of an ambient vector U to the horizontal space
0119     % at X.
0120     M.proj = @projection;
0121     function Up = projection(X, U)
0122         
0123         XtU = multiprod(multitransp(X), U);
0124         Up = U - multiprod(X, XtU);
0125 
0126     end
0127     
0128     M.tangent = M.proj;
0129     
0130     M.egrad2rgrad = M.proj;
0131     
0132     M.ehess2rhess = @ehess2rhess;
0133     function rhess = ehess2rhess(X, egrad, ehess, H)
0134         PXehess = projection(X, ehess);
0135         XtG = multiprod(multitransp(X), egrad);
0136         HXtG = multiprod(H, XtG);
0137         rhess = PXehess - HXtG;
0138     end
0139     
0140     M.retr = @retraction;
0141     function Y = retraction(X, U, t)
0142         if nargin < 3
0143             Y = X + U;
0144         else
0145             Y = X + t*U;
0146         end
0147         for kk = 1 : k
0148         
0149             % Compute the polar factorization of Y = X+tU
0150             [u, s, v] = svd(Y(:, :, kk), 'econ'); %#ok
0151             Y(:, :, kk) = u*v';
0152             
0153             % Another popular retraction uses QR instead of SVD.
0154             % As compared with the Stiefel factory, we do not need to
0155             % worry about flipping signs of columns here, since only
0156             % the column space is important, not the actual columns.
0157             % [Q, unused] = qr(Y(:, :, kk), 0); %#ok
0158             % Y(:, :, kk) = Q;
0159             
0160         end
0161     end
0162     
0163     % This inverse retraction is valid for both the QR retraction and the
0164     % polar retraction.
0165     M.invretr = @invretr;
0166     function U = invretr(X, Y)
0167         XtY = multiprod(multitransp(X), Y);
0168         U = zeros(n, p, k, array_type);
0169         for kk = 1 : k
0170             U(:, :, kk) = Y(:, :, kk) / XtY(:, :, kk);
0171         end
0172         U = U - X;
0173     end
0174     
0175     % See Eq. (2.65) in Edelman, Arias and Smith 1998.
0176     M.exp = @exponential;
0177     function Y = exponential(X, U, t)
0178         if nargin == 3
0179             tU = t*U;
0180         else
0181             tU = U;
0182         end
0183         Y = zeros(size(X), array_type);
0184         for kk = 1 : k
0185             [u, s, v] = svd(tU(:, :, kk), 0);
0186             cos_s = diag(cos(diag(s)));
0187             sin_s = diag(sin(diag(s)));
0188             Y(:, :, kk) = X(:, :, kk)*v*cos_s*v' + u*sin_s*v';
0189             % From numerical experiments, it seems necessary to
0190             % re-orthonormalize. This is overall quite expensive.
0191             [q, unused] = qr(Y(:, :, kk), 0); %#ok
0192             Y(:, :, kk) = q;
0193         end
0194     end
0195 
0196     % Test code for the logarithm:
0197     % Gr = grassmannfactory(5, 2, 3);
0198     % x = Gr.rand()
0199     % y = Gr.rand()
0200     % u = Gr.log(x, y)
0201     % Gr.dist(x, y) % These two numbers should
0202     % Gr.norm(x, u) % be the same.
0203     % z = Gr.exp(x, u) % z needs not be the same matrix as y, but it should
0204     % v = Gr.log(x, z) % be the same point as y on Grassmann: dist almost 0.
0205     M.log = @logarithm;
0206     function U = logarithm(X, Y)
0207         U = zeros(n, p, k, array_type);
0208         for kk = 1 : k
0209             x = X(:, :, kk);
0210             y = Y(:, :, kk);
0211             ytx = y.'*x;
0212             At = y.'-ytx*x.';
0213             Bt = ytx\At;
0214             [u, s, v] = svd(Bt.', 'econ');
0215 
0216             u = u(:, 1:p);
0217             s = diag(s);
0218             s = s(1:p);
0219             v = v(:, 1:p);
0220 
0221             U(:, :, kk) = u*diag(atan(s))*v.';
0222         end
0223     end
0224 
0225     M.hash = @(X) ['z' hashmd5(X(:))];
0226     
0227     M.rand = @random;
0228     function X = random()
0229         X = randn(n, p, k, array_type);
0230         for kk = 1 : k
0231             [Q, unused] = qr(X(:, :, kk), 0); %#ok
0232             X(:, :, kk) = Q;
0233         end
0234     end
0235     
0236     M.randvec = @randomvec;
0237     function U = randomvec(X)
0238         U = projection(X, randn(n, p, k, array_type));
0239         U = U / norm(U(:));
0240     end
0241     
0242     M.lincomb = @matrixlincomb;
0243     
0244     M.zerovec = @(x) zeros(n, p, k, array_type);
0245     
0246     % This transport is compatible with the polar retraction.
0247     M.transp = @(x1, x2, d) projection(x2, d);
0248     
0249     M.vec = @(x, u_mat) u_mat(:);
0250     M.mat = @(x, u_vec) reshape(u_vec, [n, p, k]);
0251     M.vecmatareisometries = @() true;
0252 
0253     
0254     % Automatically convert a number of tools to support GPU.
0255     if gpuflag
0256         M = factorygpuhelper(M);
0257     end
0258 
0259 end

Generated on Mon 10-Sep-2018 11:48:06 by m2html © 2005