Home > manopt > manifolds > ttfixedrank > TTeMPS_1.1 > @TTeMPS_tangent > TTeMPS_tangent.m

TTeMPS_tangent

PURPOSE ^

SYNOPSIS ^

This is a script file.

DESCRIPTION ^

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 classdef TTeMPS_tangent
0002 % TTeMPS_tangent
0003 %
0004 %   A MATLAB class for representing and tangent tensors
0005 %   to the TT/MPS format.
0006 %
0007 
0008 %   TTeMPS Toolbox.
0009 %   Michael Steinlechner, 2013-2016
0010 %   Questions and contact: michael.steinlechner@epfl.ch
0011 %   BSD 2-clause license, see LICENSE.txt
0012 
0013 
0014 properties( SetAccess = public, GetAccess = public )
0015 
0016     dU
0017     
0018 end
0019 
0020 methods( Access = public )
0021 
0022     function Y = TTeMPS_tangent(X, Z, ind, vec)
0023     %TTEMPS_TANGENT projects the d-dimensional array Z into the tangent space at a TT/MPS tensor X.
0024     %
0025     %   P = TTEMPS_TANGENT(X) projects the d-dimensional array Z into the tangent space of the
0026     %    TT-rank-r manifold at a TT/MPS tensor X.
0027     %
0028     %   Important: X has to be left-orthogonalized! (use X = orthogonalize(X, X.order) beforehand)
0029     %
0030         if nargin == 1
0031             % Create a zero tangent tensor
0032             Y.dU = cellfun( @(x) zeros(size(x)), X.U, 'UniformOutput', false );
0033             return
0034         end
0035 
0036 
0037         if ~exist('ind','var')
0038             sampled = false;
0039         else
0040             sampled = true;
0041             if ~exist('vec','var')
0042                 vec = false;
0043             end
0044         end
0045         
0046         d = X.order;
0047         r = X.rank;
0048         n = X.size;
0049     
0050         invXtX = cell(1,d);
0051         tmp = conj(unfold( X.U{d}, 'right')) * unfold( X.U{d}, 'right').';
0052         invXtX{d} = pinv(tmp,1e-8);
0053         for i = d-1:-1:2     
0054             tmp = tensorprod_ttemps( X.U{i}, tmp', 3);
0055             tmp = conj(unfold( tmp, 'right')) * unfold( X.U{i}, 'right').';
0056             invXtX{i} = pinv(tmp,1e-8);
0057         end
0058 
0059         if isa(Z, 'TTeMPS')
0060             
0061             Y.dU = cell(1, d);
0062             
0063             for i = 1:d
0064                 Y.dU{i} = contract( X, Z, i);
0065             end
0066             
0067             for i = 1:d-1
0068                 Y.dU{i} = unfold( Y.dU{i}, 'left' );
0069                 Y.dU{i} = Y.dU{i} * invXtX{i+1};
0070                 
0071                 %U = orth(unfold(X.U{i},'left')); % orth unnecessary
0072                 U = unfold(X.U{i},'left');
0073                 Y.dU{i} = Y.dU{i} - U * ( U' * Y.dU{i});
0074                 Y.dU{i} = reshape( Y.dU{i}, [r(i), n(i), r(i+1)] );
0075             end
0076             
0077         elseif isa(Z, 'TTeMPS_tangent')
0078             
0079             Znew = tangent_to_TTeMPS( Z, X );
0080             Y = TTeMPS_tangent( X, Znew );
0081         
0082         elseif ~sampled % Z is a full array
0083             
0084             ZZ = cell(1,d);
0085 
0086             % right side
0087             ZZ{d} = Z(:);
0088             for i = d-1:-1:1
0089                 zz = reshape( Z, [prod(n(1:i)), n(i+1)*r(i+2)] );
0090                 xx = transpose( unfold( X.U{i+1}, 'right') );
0091                 Z = zz*xx;
0092                 ZZ{i} = Z * invXtX{i+1};
0093             end
0094 
0095             % left side
0096             for k = 2:d
0097                 for i = 1:k-1
0098                     Z_i = reshape( ZZ{k}, [r(i)*prod(n(i)), prod(n(i+1:k))*r(k+1)] );
0099                     X_i = unfold( X.U{i}, 'left');
0100                     Z = X_i' * Z_i;
0101                     ZZ{k} = Z;
0102                 end
0103                 ZZ{k} = reshape( ZZ{k}, [r(k)*n(k), r(k+1)] );
0104             end
0105 
0106             Y.dU = cell(1,d);
0107             % orth. projection (w/o last core)
0108             for i = 1:d-1
0109                 U = orth(unfold(X.U{i},'left'));
0110                 ZZ{i} = ZZ{i} - U * ( U' * ZZ{i});
0111                 Y.dU{i} = reshape( ZZ{i}, [r(i), n(i), r(i+1)] );
0112             end
0113             Y.dU{d} = reshape( ZZ{d}, [r(d), n(d), r(d+1)] );
0114         
0115         else % Z is a sparse array
0116             
0117             vals = Z;
0118             C = cell(1,d);
0119             Y.dU = cell(1,d);
0120             for i=1:d
0121                 C{i} = permute( X.U{i}, [1 3 2]);
0122                 %Y.dU{i} = zeros(size(C{i}));
0123             end
0124             res = TTeMPS_tangent_omega( n, r, C, ind.', vals);
0125             
0126             for i = 1:d
0127                 res{i} = reshape( res{i}, [r(i), r(i+1), n(i)] );
0128                 Y.dU{i} = unfold( permute( res{i}, [1 3 2]), 'left');
0129             end
0130             
0131             for i=1:d-1
0132                 Y.dU{i} = Y.dU{i} * invXtX{i+1};    
0133             end
0134             
0135             for i=1:d-1
0136                 %U = orth(unfold(X.U{i},'left'));
0137                 U = unfold(X.U{i},'left');
0138                 Y.dU{i} = Y.dU{i} - U * ( U' * Y.dU{i});    
0139                 Y.dU{i} = reshape( Y.dU{i}, [r(i), n(i), r(i+1)] );        
0140             end
0141             Y.dU{d} = reshape( Y.dU{d}, [r(d), n(d), r(d+1)] );
0142             
0143         end
0144     
0145         
0146     end
0147     
0148     function Zfull = full( Z, X )
0149     %FULL converts tangent tensor to d-dimensional array
0150     %
0151     %   ZFULL = full(Z, X) converts the tangent tensor Z given in factorized form
0152     %   (class TTeMPS_tangent) to a d-dimensional array ZFULL. X is the TTeMPS tensor at
0153     %    which point the tangent space is taken.
0154     %
0155         d = X.order;
0156         
0157         C = cell(1, d);
0158 
0159         C{1} = cat( 3, Z.dU{1}, X.U{1} );
0160         for i = 2:d-1
0161             zeroblock = zeros( size(X.U{i}) );
0162             tmp1 = cat( 3, X.U{i}, zeroblock );
0163             tmp2 = cat( 3, Z.dU{i}, X.U{i} );
0164             C{i} = cat( 1, tmp1, tmp2 );
0165         end
0166         C{d} = cat( 1, X.U{d}, Z.dU{d} );
0167         
0168         Xnew = TTeMPS( C );
0169         Zfull = full( Xnew );
0170     end
0171     
0172     function res = plus( X, Y)
0173         %PLUS adds two tangent tensors
0174         %
0175         %   RES = plus(X, Y) adds two tangent tensors in factorized form. Both tangent tensors
0176         %   have be elements of the SAME tangent space.
0177         %
0178     
0179         res = X;
0180         res.dU = cellfun(@plus, X.dU, Y.dU, 'UniformOutput', false);
0181     end
0182     
0183     function X = minus( X, Y )
0184         %MINUS substracts two tangent tensors
0185         %
0186         %   RES = minus(X, Y) substracts two tangent tensors in factorized form. Both tangent tensors
0187         %   have be elements of the SAME tangent space.
0188         %
0189     
0190         X.dU = cellfun(@plus, X.dU, Y.dU, 'UniformOutput', false);
0191     end
0192     
0193     function X = mtimes( a, X )
0194         %MTIMES Multiplication of TTeMPS tangent tensor by scalar
0195         %
0196         %   RES = mtimes(a, X) multiplies the TTeMPS tangent tensor X
0197         %    by the scalar a.
0198         %
0199         
0200         X.dU = cellfun(@(x) a*x, X.dU, 'UniformOutput', false);    
0201     end
0202     function X = uminus( X )
0203         %UMINUS Unary minus of TTeMPS tangent tensor.
0204         %
0205         %   RES = uminus(X) negates the TTeMPS tangent tensor X.
0206         %
0207         
0208         X = mtimes( -1, X );
0209     end
0210     
0211     function Xnew = tangentAdd( Z, t, X, retract )
0212     %TANGENTADD adds a tangent vector to a point on the manifold
0213     %
0214     %   RES = tangentAdd(Z, t, X) adds a tangent vector Z to a point X on the rank-r-manifold, scaled by t:
0215     %             res = X + t*Z
0216     %     where the result is stored as a TTeMPS tensor of rank 2*r.
0217     %
0218     %   RES = tangentAdd(Z, t, X, true) adds a tangent vector Z to a point X on the rank-r-manifold, scaled by t:
0219     %             res = X + t*Z
0220     %    and retracts the result back to the manifold:
0221     %            res = R_X( X + t*Z )
0222     %    where the result is stored as a TTeMPS tensor of rank r.
0223     %
0224     
0225         if ~exist( 'retract', 'var' )
0226             retract = false;
0227         end
0228 
0229         d = X.order;
0230         C = cell(1, d);
0231 
0232         C{1} = cat( 3, t*Z.dU{1}, X.U{1} );
0233         for i = 2:d-1
0234             zeroblock = zeros( size(X.U{i}) );
0235             tmp1 = cat( 3, X.U{i}, zeroblock );
0236             tmp2 = cat( 3, t*Z.dU{i}, X.U{i} );
0237             C{i} = cat( 1, tmp1, tmp2 );
0238         end
0239         C{d} = cat( 1, X.U{d}, t*Z.dU{d} + X.U{d} );
0240         
0241         Xnew = TTeMPS( C );
0242     
0243         if retract
0244             Xnew = truncate( Xnew, X.rank );
0245         end
0246         
0247     end
0248     
0249     function res = innerprod( Z1, Z2, X )
0250         
0251         X1 = tangent_to_TTeMPS( Z1, X );
0252         X2 = tangent_to_TTeMPS( Z2, X );
0253         
0254         res = innerprod( X1, X2 );
0255     end
0256     
0257     function res = at_Omega( Z, ind, X)
0258         
0259         Xnew = tangent_to_TTeMPS( Z, X );
0260         res = Xnew(ind);
0261     end
0262     
0263     function res = tangent_to_TTeMPS( Z, X)
0264         d = X.order;
0265         C = cell(1, d);
0266 
0267         C{1} = cat( 3, Z.dU{1}, X.U{1} );
0268         for i = 2:d-1
0269             zeroblock = zeros( size(X.U{i}) );
0270             tmp1 = cat( 3, X.U{i}, zeroblock );
0271             tmp2 = cat( 3, Z.dU{i}, X.U{i} );
0272             C{i} = cat( 1, tmp1, tmp2 );
0273         end
0274         C{d} = cat( 1, X.U{d}, Z.dU{d});
0275         
0276         res = TTeMPS( C );
0277     end
0278 end
0279     
0280     
0281     
0282     
0283 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005