Home > manopt > manifolds > ttfixedrank > TTeMPS_1.1 > @TTeMPS_tangent_orth > TTeMPS_tangent_orth.m

TTeMPS_tangent_orth

PURPOSE ^

SYNOPSIS ^

This is a script file.

DESCRIPTION ^

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 classdef TTeMPS_tangent_orth
0002     % TTeMPS_tangent
0003     %
0004     %   A MATLAB class for representing and tangent tensors
0005     %   to the TT/MPS format in the core-by-core orthogonalization
0006     %   presented in the paper
0007     %
0008     %   Michael Steinlechner. Riemannian optimization for high-dimensional tensor completion
0009     %   Technical report, March 2015. Revised December 2015. To appear in SIAM J. Sci. Comput.
0010     %
0011 
0012     %   TTeMPS Toolbox.
0013     %   Michael Steinlechner, 2013-2016
0014     %   Questions and contact: michael.steinlechner@epfl.ch
0015     %   BSD 2-clause license, see LICENSE.txt
0016 
0017     properties (SetAccess = public, GetAccess = public)
0018 
0019         dU
0020         U
0021         V
0022         rank
0023         order
0024         size
0025 
0026     end
0027 
0028     methods (Access = public)
0029 
0030         function Y = TTeMPS_tangent_orth(xL, xR, Z, ind, storedb, key)
0031             %TTEMPS_TANGENT projects the d-dimensional array Z into the tangent space at a TT/MPS tensor X.
0032             %
0033             %   P = TTEMPS_TANGENT(X) projects the d-dimensional array Z into the tangent space of the
0034             %   TT-rank-r manifold at a TT/MPS tensor X.
0035             %
0036 
0037             d = xL.order;
0038             r = xL.rank;
0039             n = xL.size;
0040 
0041             % additional conditional needed for added code for ManOpt
0042             if nargin == 1
0043                 xR = orthogonalize(xL, 1);
0044             end
0045 
0046             Y.order = d;
0047             Y.rank = r;
0048             Y.size = n;
0049 
0050             Y.U = xL.U;
0051 
0052             Y.V = xR.U;
0053 
0054             %%%%%%%%%%%%%%%%%% NEW CODE FOR MANOPT: add constructor for zero vector %%%%%%%%%%%%%%%%%%%%
0055             % Takes the unused spot nargin == 1      %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0056 
0057             if nargin == 1
0058                 % Y.U = xL.U;
0059                 Y.V = xR.U;
0060                 Y.dU = cell(1, d);
0061 
0062                 for k = 1:d
0063                     Y.dU{k} = zeros(size(Y.U{k}));
0064                 end
0065 
0066             %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0067 
0068             % default constructor which just returns a random unit norm tangent vector
0069             elseif nargin == 2
0070                 % Y.U = xL.U;
0071                 % Y.V = xR.U;
0072                 Y.dU = cell(1, d);
0073 
0074                 for i = 1:d
0075                     Y.dU{i} = randn(size(Y.U{i}));
0076                 end
0077 
0078                 Y = TTeMPS_tangent_orth(xL, xR, Y);
0079       
0080                 % count number of nonzero dU cores
0081                 % if manifold has positive dimension, num_nonzero > 0
0082                 % with probability 1
0083                 num_nonzero = 0;
0084                 for k = 1:d
0085                     if norm(Y.dU{k}(:)) > 1e-14
0086                         num_nonzero = num_nonzero + 1;
0087                     end
0088                 end
0089 
0090                 % normalize vec to unit norm based on non-zero dU cores
0091                 for k = 1:d
0092                     if norm(Y.dU{k}(:)) > 1e-14
0093                         Y.dU{k} = Y.dU{k} / (sqrt(num_nonzero) * norm(Y.dU{k}(:)));
0094                     end
0095                 end
0096 
0097             else
0098 
0099                 if ~exist('ind', 'var')
0100                     sampled = false;
0101                 else
0102                     sampled = true;
0103                 end
0104 
0105                 if isa(Z, 'TTeMPS')
0106 
0107                     Y.dU = cell(1, d);
0108 
0109                     right = innerprod(xR, Z, 'RL', 2, true);
0110                     left = innerprod(xL, Z, 'LR', d - 1, true);
0111 
0112                     % contract to first core
0113                     Y.dU{1} = tensorprod_ttemps(Z.U{1}, right{2}, 3);
0114                     % contract to inner cores
0115                     for i = 2:d - 1
0116                         res = tensorprod_ttemps(Z.U{i}, left{i - 1}, 1);
0117                         Y.dU{i} = tensorprod_ttemps(res, right{i + 1}, 3);
0118                     end
0119 
0120                     % contract to last core
0121                     Y.dU{d} = tensorprod_ttemps(Z.U{d}, left{d - 1}, 1);
0122 
0123                     for i = 1:d - 1
0124                         Y.dU{i} = unfold(Y.dU{i}, 'left');
0125                         U = unfold(Y.U{i}, 'left');
0126                         Y.dU{i} = Y.dU{i} - U * (U' * Y.dU{i});
0127                         Y.dU{i} = reshape(Y.dU{i}, [r(i), n(i), r(i + 1)]);
0128                     end
0129 
0130                 elseif isa(Z, 'TTeMPS_tangent_orth')
0131 
0132                     Znew = tangent_to_TTeMPS(Z);
0133                     Y = TTeMPS_tangent_orth(xL, xR, Znew);
0134 
0135                 elseif ~sampled % Z is a full array
0136 
0137                     ZZ = cell(1, d);
0138 
0139                     % right side
0140                     ZZ{d} = Z(:);
0141 
0142                     for i = d - 1:-1:1
0143                         zz = reshape(Z, [prod(n(1:i)), n(i + 1) * r(i + 2)]);
0144                         xx = transpose(unfold(Y.V{i + 1}, 'right'));
0145                         Z = zz * xx;
0146                         ZZ{i} = Z;
0147                     end
0148 
0149                     % left side
0150                     for k = 2:d
0151 
0152                         for i = 1:k - 1
0153                             Z_i = reshape(ZZ{k}, [r(i) * prod(n(i)), prod(n(i + 1:k)) * r(k + 1)]);
0154                             X_i = unfold(Y.U{i}, 'left');
0155                             Z = X_i' * Z_i;
0156                             ZZ{k} = Z;
0157                         end
0158 
0159                         ZZ{k} = reshape(ZZ{k}, [r(k) * n(k), r(k + 1)]);
0160                     end
0161 
0162                     Y.dU = cell(1, d);
0163                     % orth. projection (w/o last core)
0164                     for i = 1:d - 1
0165                         U = unfold(Y.U{i}, 'left');
0166                         ZZ{i} = ZZ{i} - U * (U' * ZZ{i});
0167                         Y.dU{i} = reshape(ZZ{i}, [r(i), n(i), r(i + 1)]);
0168                     end
0169 
0170                     Y.dU{d} = reshape(ZZ{d}, [r(d), n(d), r(d + 1)]);
0171 
0172                 else % Z is a sparse array
0173 
0174                     vals = Z;
0175                     CU = cell(1, d);
0176                     CV = cell(1, d);
0177                     Y.dU = cell(1, d);
0178 
0179                     for i = 1:d
0180                         CU{i} = permute(Y.U{i}, [1 3 2]);
0181                         CV{i} = permute(Y.V{i}, [1 3 2]);
0182                     end
0183 
0184                     res = TTeMPS_tangent_orth.TTeMPS_tangent_orth_omega(n, r, CU, CV, ind.', vals);
0185 
0186                     %%%%%%%%%%%%%%%%%%%%%%% MANOPT ADDED CODE %%%%%%%%%%%%%%%%%%%%%%
0187                     if exist('storedb', 'var')
0188                     % store res to be used for efficient Weingarten approx
0189                         store = storedb.getWithShared(key);
0190                         store.inner_dU_ = res;
0191                         storedb.setWithShared(store, key);
0192                     end
0193 
0194                     %%%%%%%%%%%%%%%%%%%%%%% END MANOPT ADDED CODE %%%%%%%%%%%%%%%%%%
0195 
0196                     for i = 1:d
0197                         res{i} = reshape(res{i}, [r(i), r(i + 1), n(i)]);
0198                         Y.dU{i} = unfold(permute(res{i}, [1 3 2]), 'left');
0199                     end
0200 
0201                     for i = 1:d - 1
0202                         U = unfold(Y.U{i}, 'left');
0203                         Y.dU{i} = Y.dU{i} - U * (U' * Y.dU{i});
0204                         Y.dU{i} = reshape(Y.dU{i}, [r(i), n(i), r(i + 1)]);
0205                     end
0206 
0207                     Y.dU{d} = reshape(Y.dU{d}, [r(d), n(d), r(d + 1)]);
0208 
0209                 end
0210 
0211             end
0212 
0213         end
0214 
0215         function Zfull = full(Z)
0216             %FULL converts tangent tensor to d-dimensional array
0217             %
0218             %   ZFULL = full(Z, X) converts the tangent tensor Z given in factorized form
0219             %   (class TTeMPS_tangent) to a d-dimensional array ZFULL. X is the TTeMPS tensor at
0220             %   which point the tangent space is taken.
0221             %
0222 
0223             Zfull = tangent_to_TTeMPS(Z);
0224             Zfull = full(Zfull);
0225         end
0226 
0227         function res = plus(X, Y)
0228             %PLUS adds two tangent tensors
0229             %
0230             %   RES = plus(X, Y) adds two tangent tensors in factorized form. Both tangent tensors
0231             %   have be elements of the SAME tangent space.
0232             %
0233 
0234             res = X;
0235             res.dU = cellfun(@plus, X.dU, Y.dU, 'UniformOutput', false);
0236         end
0237 
0238         function X = minus(X, Y)
0239             %MINUS substracts two tangent tensors
0240             %
0241             %   RES = minus(X, Y) substracts two tangent tensors in factorized form. Both tangent tensors
0242             %   have be elements of the SAME tangent space.
0243             %
0244 
0245             X.dU = cellfun(@minus, X.dU, Y.dU, 'UniformOutput', false);
0246         end
0247 
0248         function X = mtimes(a, X)
0249             %MTIMES Multiplication of TTeMPS tangent tensor by scalar
0250             %
0251             %   RES = mtimes(a, X) multiplies the TTeMPS tangent tensor X
0252             %   by the scalar a.
0253             %
0254 
0255             X.dU = cellfun(@(x) a * x, X.dU, 'UniformOutput', false);
0256         end
0257 
0258         function X = uminus(X)
0259             %UMINUS Unary minus of TTeMPS tangent tensor.
0260             %
0261             %   RES = uminus(X) negates the TTeMPS tangent tensor X.
0262             %
0263 
0264             X = mtimes(-1, X);
0265         end
0266 
0267         function Xnew = tangentAdd(Z, t, retract)
0268             %TANGENTADD adds a tangent vector to a point on the manifold
0269             %
0270             %   RES = tangentAdd(Z, t ) adds a tangent vector Z to the current point on the rank-r-manifold, scaled by t:
0271             %           res = X + t*Z
0272             %   where the result is stored as a TTeMPS tensor of rank 2*r.
0273             %
0274             %   RES = tangentAdd(Z, t, true) adds a tangent vector Z to the current point X on the rank-r-manifold, scaled by t:
0275             %           res = X + t*Z
0276             %   and retracts the result back to the manifold:
0277             %           res = R_X( X + t*Z )
0278             %   where the result is stored as a right orthogonal TTeMPS tensor of rank r.
0279             %
0280 
0281             if ~exist('retract', 'var')
0282                 retract = false;
0283             end
0284 
0285             d = length(Z.dU);
0286             r = ones(1, d + 1);
0287             C = cell(1, d);
0288 
0289             C{1} = cat(3, t * Z.dU{1}, Z.U{1});
0290 
0291             for i = 2:d - 1
0292                 sz = size(Z.U{i});
0293                 r(i) = sz(1);
0294                 zeroblock = zeros(sz);
0295                 tmp1 = cat(3, Z.V{i}, zeroblock);
0296                 tmp2 = cat(3, t * Z.dU{i}, Z.U{i});
0297                 C{i} = cat(1, tmp1, tmp2);
0298             end
0299 
0300             r(d) = size(Z.U{d}, 1);
0301             C{d} = cat(1, Z.V{d}, Z.U{d} + t * Z.dU{d});
0302             Xnew = TTeMPS(C);
0303 
0304             if retract
0305                 Xnew = truncate(Xnew, r);
0306             end
0307 
0308         end
0309 
0310         function res = innerprod(Z1, Z2)
0311 
0312             % due to left-and right orth., inner prod is just the inner product of the dU
0313             res = 0;
0314 
0315             for i = 1:length(Z1.dU)
0316                 res = res + Z1.dU{i}(:)' * Z2.dU{i}(:);
0317             end
0318 
0319         end
0320 
0321         function n = norm(Z)
0322             Z_tt = tangent_to_TTeMPS(Z);
0323             n = norm(Z_tt);
0324         end
0325 
0326         function res = at_Omega(Z, ind)
0327 
0328             Xnew = tangent_to_TTeMPS(Z);
0329             res = Xnew(ind);
0330         end
0331 
0332         function res = tangent_to_TTeMPS(Z)
0333             d = length(Z.dU);
0334             C = cell(1, d);
0335 
0336             C{1} = cat(3, Z.dU{1}, Z.U{1});
0337 
0338             for i = 2:d - 1
0339                 zeroblock = zeros(size(Z.U{i}));
0340                 tmp1 = cat(3, Z.V{i}, zeroblock);
0341                 tmp2 = cat(3, Z.dU{i}, Z.U{i});
0342                 C{i} = cat(1, tmp1, tmp2);
0343             end
0344 
0345             C{d} = cat(1, Z.V{d}, Z.dU{d});
0346 
0347             res = TTeMPS(C);
0348         end
0349 
0350         function z = vectorize_tangent(Z)
0351             z = cellfun(@(x) x(:), Z.dU, 'UniformOutput', false);
0352             z = cell2mat(z(:));
0353         end
0354 
0355         function Z = fill_with_vectorized(Z, z)
0356             d = length(Z.dU);
0357             k = 1;
0358 
0359             for i = 1:d
0360                 s = size(Z.dU{i});
0361                 Z.dU{i} = reshape(z(k:k + prod(s) - 1), s);
0362                 k = k + prod(s);
0363             end
0364 
0365         end
0366 
0367         function xi = tangent_orth_to_tangent(Z)
0368             % Transform a TTeMPS_tangent_orth to a TTeMPS_tangent.
0369             d = Z.order;
0370             xL = TTeMPS(Z.U);
0371             [xL, xR, G] = gauge_matrices(xL);
0372 
0373             xi = TTeMPS_tangent(xL);
0374 
0375             for ii = 1:d - 1
0376                 xi.dU{ii} = tensorprod_ttemps(Z.dU{ii}, inv(G{ii}'), 3);
0377             end
0378 
0379             % the lc;last one does not need to be changed
0380             xi.dU{d} = Z.dU{d};
0381         end
0382 
0383     end
0384 
0385     methods (Static, Access = private)
0386 
0387         x = TTeMPS_tangent_orth_omega(n, r, CU, CV, ind, vals);
0388         x = TTeMPS_tangent_orth_omega_openmp(n, r, CU, CV, ind, vals);
0389 
0390     end
0391 
0392 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005