Home > manopt > manifolds > stiefel > stiefelfactory.m

# stiefelfactory

## PURPOSE

Returns a manifold structure to optimize over orthonormal matrices.

## SYNOPSIS

function M = stiefelfactory(n, p, k, gpuflag)

## DESCRIPTION

``` Returns a manifold structure to optimize over orthonormal matrices.

function M = stiefelfactory(n, p)
function M = stiefelfactory(n, p, k)
function M = stiefelfactory(n, p, k, gpuflag)

The Stiefel manifold is the set of orthonormal nxp matrices. If k
is larger than 1, this is the Cartesian product of the Stiefel manifold
taken k times. The metric is such that the manifold is a Riemannian
submanifold of R^nxp equipped with the usual trace inner product, that
is, it is the usual metric.

Points are represented as matrices X of size n x p x k (or n x p if k=1,
which is the default) such that each n x p matrix is orthonormal,
i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for
i = 1 : k if k > 1. Tangent vectors are represented as matrices the same
size as points.

The default retraction is QR-based: it is only a first-order retraction.
To use the polar retraction (which is second order), run
M.retr = M.retr_polar;
after creating M with this factory. This can be reverted with
M.retr = M.retr_qr;
If used, you may also want to update M.invretr similarly.

Set gpuflag = true to have points, tangent vectors and ambient vectors
stored on the GPU. If so, computations can be done on the GPU directly.

By default, k = 1 and gpuflag = false.

## CROSS-REFERENCE INFORMATION

This function calls:
• solve_for_triu Solve the linear matrix equation AX + X'A' = H for X upper triangular.
• norm NORM Norm of a TT/MPS tensor.
• norm NORM Norm of a TT/MPS block-mu tensor.
• factorygpuhelper Returns a manifold struct to optimize over unit-norm vectors or matrices.
• hashmd5 Computes the MD5 hash of input data.
• matrixlincomb Linear combination function for tangent vectors represented as matrices.
• multiprod Matrix multiply 2-D slices of N-D arrays
• multisym Returns the symmetric parts of the matrices in a 3D array
• multitransp Transpose the matrix slices of an N-D array (no complex conjugate)
• qr_unique Thin QR factorization ensuring diagonal of R is real, positive if possible.
• sylvester_nochecks Solve Sylvester equation without input checks.
This function is called by:

## SOURCE CODE

```0001 function M = stiefelfactory(n, p, k, gpuflag)
0002 % Returns a manifold structure to optimize over orthonormal matrices.
0003 %
0004 % function M = stiefelfactory(n, p)
0005 % function M = stiefelfactory(n, p, k)
0006 % function M = stiefelfactory(n, p, k, gpuflag)
0007 %
0008 % The Stiefel manifold is the set of orthonormal nxp matrices. If k
0009 % is larger than 1, this is the Cartesian product of the Stiefel manifold
0010 % taken k times. The metric is such that the manifold is a Riemannian
0011 % submanifold of R^nxp equipped with the usual trace inner product, that
0012 % is, it is the usual metric.
0013 %
0014 % Points are represented as matrices X of size n x p x k (or n x p if k=1,
0015 % which is the default) such that each n x p matrix is orthonormal,
0016 % i.e., X'*X = eye(p) if k = 1, or X(:, :, i)' * X(:, :, i) = eye(p) for
0017 % i = 1 : k if k > 1. Tangent vectors are represented as matrices the same
0018 % size as points.
0019 %
0020 % The default retraction is QR-based: it is only a first-order retraction.
0021 % To use the polar retraction (which is second order), run
0022 %    M.retr = M.retr_polar;
0023 % after creating M with this factory. This can be reverted with
0024 %    M.retr = M.retr_qr;
0025 % If used, you may also want to update M.invretr similarly.
0026 %
0027 % Set gpuflag = true to have points, tangent vectors and ambient vectors
0028 % stored on the GPU. If so, computations can be done on the GPU directly.
0029 %
0030 % By default, k = 1 and gpuflag = false.
0031 %
0033
0034 % This file is part of Manopt: www.manopt.org.
0035 % Original author: Nicolas Boumal, Dec. 30, 2012.
0036 % Contributors:
0037 % Change log:
0038 %  July  5, 2013 (NB) : Added ehess2rhess.
0039 %  Jan. 27, 2014 (BM) : Bug in ehess2rhess corrected.
0040 %  June 24, 2014 (NB) : Added true exponential map and changed the randvec
0041 %                       function so that it now returns a globally
0042 %                       normalized vector, not a vector where each
0043 %                       component is normalized (this only matters if k>1).
0044 %  July 17, 2018 (NB) : Now both QR (default) and polar retractions are
0045 %                       directly accessible, and their inverses are also
0046 %                       implemented.
0047 %  Aug.  2, 2018 (NB) : Added GPU support: just set gpuflag = true.
0048 %  June 18, 2019 (NB) : Using qr_unique for retr and rand.
0049 %  July  9, 2019 (NB) : Added a comment about QR retraction being first
0050 %                       order only.
0051 %  Jan.  8, 2021 (NB) : Added tangent2ambient+tangent2ambient_is_identity.
0052
0053     assert(n >= p, 'The dimension n must be larger than the dimension p.');
0054
0055     if ~exist('k', 'var') || isempty(k)
0056         k = 1;
0057     end
0058     if ~exist('gpuflag', 'var') || isempty(gpuflag)
0059         gpuflag = false;
0060     end
0061
0062     % If gpuflag is active, new arrays (e.g., via rand, randn, zeros, ones)
0063     % are created directly on the GPU; otherwise, they are created in the
0064     % usual way (in double precision).
0065     if gpuflag
0066         array_type = 'gpuArray';
0067     else
0068         array_type = 'double';
0069     end
0070
0071     if k == 1
0072         M.name = @() sprintf('Stiefel manifold St(%d, %d)', n, p);
0073     elseif k > 1
0074         M.name = @() sprintf('Product Stiefel manifold St(%d, %d)^%d', n, p, k);
0075     else
0076         error('k must be an integer no less than 1.');
0077     end
0078
0079     M.dim = @() k*(n*p - .5*p*(p+1));
0080
0081     M.inner = @(x, d1, d2) d1(:).'*d2(:);
0082
0083     M.norm = @(x, d) norm(d(:));
0084
0085     M.dist = @(x, y) error('stiefel.dist not implemented yet.');
0086
0087     M.typicaldist = @() sqrt(p*k);
0088
0089     M.proj = @projection;
0090     function Up = projection(X, U)
0091
0092         XtU = multiprod(multitransp(X), U);
0093         symXtU = multisym(XtU);
0094         Up = U - multiprod(X, symXtU);
0095
0096 % The code above is equivalent to, but faster than, the code below.
0097 %
0098 %     Up = zeros(size(U));
0099 %     function A = sym(A), A = .5*(A+A'); end
0100 %     for i = 1 : k
0101 %         Xi = X(:, :, i);
0102 %         Ui = U(:, :, i);
0103 %         Up(:, :, i) = Ui - Xi*sym(Xi'*Ui);
0104 %     end
0105
0106     end
0107
0108     M.tangent = M.proj;
0109
0110     M.tangent2ambient_is_identity = true;
0111     M.tangent2ambient = @(X, U) U;
0112
0113     % For Riemannian submanifolds, converting a Euclidean gradient into a
0114     % Riemannian gradient amounts to an orthogonal projection.
0116
0117     M.ehess2rhess = @ehess2rhess;
0118     function rhess = ehess2rhess(X, egrad, ehess, H)
0120         symXtG = multisym(XtG);
0121         HsymXtG = multiprod(H, symXtG);
0122         rhess = projection(X, ehess - HsymXtG);
0123     end
0124
0125     M.retr_qr = @retraction_qr;
0126     function Y = retraction_qr(X, U, t)
0127         % It is necessary to call qr_unique rather than simply qr to ensure
0128         % this is a retraction, to avoid spurious column sign flips.
0129         if nargin < 3
0130             Y = qr_unique(X + U);
0131         else
0132             Y = qr_unique(X + t*U);
0133         end
0134     end
0135
0136     M.invretr_qr = @invretr_qr;
0137     function U = invretr_qr(X, Y)
0138         XtY = multiprod(multitransp(X), Y);
0139         R = zeros(p, p, k, array_type);
0140         H = 2*eye(p, array_type);
0141         for kk = 1 : k
0142             % For each slice, assuming the inverse retraction is well
0143             % defined for the given inputs, we have:
0144             %   X + U = YR
0145             % Left multiply with X' to get
0146             %   I + X'U = X'Y M
0147             % Since X'U is skew symmetric for a tangent vector U at X, add
0148             % up this equation with its transpose to get:
0149             %   2I = (X'Y) R + R' (X'Y)'
0150             % Contrary to the polar factorization, here R is not symmetric
0151             % but it is upper triangular. As a result, this is not a
0152             % Sylvester equation and we must solve it differently.
0153             R(:, :, kk) = solve_for_triu(XtY(:, :, kk), H);
0154             % Then,
0155             %   U = YR - X
0156             % which is what we compute below.
0157         end
0158         U = multiprod(Y, R) - X;
0159     end
0160
0161     M.retr_polar = @retraction_polar;
0162     function Y = retraction_polar(X, U, t)
0163         if nargin < 3
0164             Y = X + U;
0165         else
0166             Y = X + t*U;
0167         end
0168         for kk = 1 : k
0169             [u, s, v] = svd(Y(:, :, kk), 'econ'); %#ok
0170             Y(:, :, kk) = u*v';
0171         end
0172     end
0173
0174     M.invretr_polar = @invretr_polar;
0175     function U = invretr_polar(X, Y)
0176         XtY = multiprod(multitransp(X), Y);
0177         MM = zeros(p, p, k, array_type);
0178         H = 2*eye(p, array_type);
0179         for kk = 1 : k
0180             % For each slice, assuming the inverse retraction is well
0181             % defined for the given inputs, we have:
0182             %   X + U = YM
0183             % Left multiply with X' to get
0184             %   I + X'U = X'Y M
0185             % Since X'U is skew symmetric for a tangent vector U at X, add
0186             % up this equation with its transpose to get:
0187             %   2I = (X'Y) M + M' (X'Y)'
0188             %      = (X'Y) M + M (X'Y)'   since M is symmetric.
0189             % Solve for M symmetric with a call to sylvester:
0190             MM(:, :, kk) = sylvester_nochecks(XtY(:, :, kk), XtY(:, :, kk)', H);
0191             % Note that the above is really a Lyapunov equation: it could
0192             % be solved faster by exploiting the fact the same matrix
0193             % appears twice on the left, with one the transpose of the
0194             % other. Then,
0195             %   U = YM - X
0196             % which is what we compute below.
0197         end
0198         U = multiprod(Y, MM) - X;
0199     end
0200
0201     % By default, we use the QR retraction
0202     M.retr = M.retr_qr;
0203     M.invretr = M.invretr_qr;
0204
0205     M.exp = @exponential;
0206     function Y = exponential(X, U, t)
0207         if nargin == 2
0208             tU = U;
0209         else
0210             tU = t*U;
0211         end
0212         Y = zeros(size(X), array_type);
0213         I = eye(p, array_type);
0214         Z = zeros(p, array_type);
0215         for kk = 1 : k
0216             % From a formula by Ross Lippert, Example 5.4.2 in AMS08.
0217             Xkk = X(:, :, kk);
0218             Ukk = tU(:, :, kk);
0219             Y(:, :, kk) = [Xkk Ukk] * ...
0220                          expm([Xkk'*Ukk , -Ukk'*Ukk ; I , Xkk'*Ukk]) * ...
0221                          [ expm(-Xkk'*Ukk) ; Z ];
0222         end
0223
0224     end
0225
0226     M.hash = @(X) ['z' hashmd5(X(:))];
0227
0228     M.rand = @() qr_unique(randn(n, p, k, array_type));
0229
0230     M.randvec = @randomvec;
0231     function U = randomvec(X)
0232         U = projection(X, randn(n, p, k, array_type));
0233         U = U / norm(U(:));
0234     end
0235
0236     M.lincomb = @matrixlincomb;
0237
0238     M.zerovec = @(x) zeros(n, p, k, array_type);
0239
0240     M.transp = @(x1, x2, d) projection(x2, d);
0241
0242     M.vec = @(x, u_mat) u_mat(:);
0243     M.mat = @(x, u_vec) reshape(u_vec, [n, p, k]);
0244     M.vecmatareisometries = @() true;
0245
0246
0247     % Automatically convert a number of tools to support GPU.
0248     if gpuflag
0249         M = factorygpuhelper(M);
0250     end
0251
0252 end```

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005