Home > manopt > manifolds > ttfixedrank > TTeMPS_1.1 > algorithms > linearsystem > amen_fast.m

amen_fast

PURPOSE ^

TTeMPS Toolbox.

SYNOPSIS ^

function [X, residual, cost, times] = amen_fast( L, F, X, opts )

DESCRIPTION ^

   TTeMPS Toolbox. 
   Michael Steinlechner, 2013-2016
   Questions and contact: michael.steinlechner@epfl.ch
   BSD 2-clause license, see LICENSE.txt

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 %   TTeMPS Toolbox.
0002 %   Michael Steinlechner, 2013-2016
0003 %   Questions and contact: michael.steinlechner@epfl.ch
0004 %   BSD 2-clause license, see LICENSE.txt
0005 
0006 function [X, residual, cost, times] = amen_fast( L, F, X, opts )
0007 
0008 t_start = tic();
0009 % set default opts
0010 if ~exist( 'opts', 'var');        opts = struct();     end
0011 if ~isfield( opts, 'nSweeps');    opts.nSweeps = 4;    end
0012 if ~isfield( opts, 'maxrank');    opts.maxrank = 20;   end
0013 if ~isfield( opts, 'maxrankRes'); opts.maxrankRes = 4; end
0014 if ~isfield( opts, 'tolRes');     opts.tolRes = 1e-13;  end
0015 if ~isfield( opts, 'tol');        opts.tol = 1e-13;     end
0016 if ~isfield( opts, 'solver');     opts.solver = 'direct';   end
0017 if ~isfield( opts, 'prec');       opts.prec = true;   end
0018     
0019 
0020 d = X.order;
0021 n = X.size;
0022 
0023 normF = norm(F);
0024 cost = cost_function( L, X, F );
0025 residual = norm( apply(L, X) - F ) / normF;
0026 times = toc(t_start);
0027 
0028 for sweep = 1:opts.nSweeps
0029     X = orthogonalize(X, 1);
0030     for mu = 1:d-1
0031         disp( ['Current core: ', num2str(mu)] )
0032 
0033         % STEP 1: Solve mu-th core opimization
0034         F_mu = contract( X, F, mu );
0035         sz = [X.rank(mu), X.size(mu), X.rank(mu+1)];
0036         
0037         if strcmpi( opts.solver, 'direct' ) 
0038             % if system very small
0039             L_mu = contract( L, X, mu );
0040             U_mu = L_mu \ F_mu(:);
0041             X.U{mu} = reshape( U_mu, sz );
0042         elseif strcmpi( opts.solver, 'pcg' )
0043             [left, right] = Afun_prepare( L, X, mu );
0044             [B2, V, E] =  prepare_precond( L.L0, X, mu );
0045 
0046             U_mu = pcg( @(y) Afun( L, y, mu, sz, left, right), ...
0047                      F_mu(:), ...
0048                      1e-10, 1000, ...
0049                      @(y) apply_precond( B2, V, E, y, sz ), [],...
0050                      X.U{mu}(:) ); 
0051             X.U{mu} = reshape( U_mu, sz );
0052         else
0053             error( 'Unknown opts.solver type. Use either ''direct'' (default) or ''pcg''.' )
0054         end
0055         
0056         % STEP 2: Calculate current residual and cost function
0057         res =  F - apply(L, X);
0058         residual = [residual; norm( res ) / normF];
0059         cost = [cost; cost_function( L, X, F )];
0060         disp(['Rel. residual: ' num2str(residual(end)) ', Current rank: ' num2str(X.rank) ]);
0061 
0062         % STEP 3: Augment mu-th and (mu+1)-th core with (truncated) residual
0063         R = contract( X, res, [mu, mu+1] );
0064         R_combined = unfold(R{1},'left') * unfold(R{2},'right'); 
0065         if opts.prec
0066             R_combined = precond_residual( L.L0, X, R_combined, mu );
0067         end
0068         [uu,ss,~] = svd( R_combined, 'econ');  
0069         s = find( diag(ss) > opts.tolRes*norm(diag(ss)), 1, 'last' );
0070         if opts.maxrankRes ~= 0 
0071             s = min( s, opts.maxrankRes );
0072         end
0073         R{1} = reshape( uu(:,1:s)*ss(1:s,1:s), [X.rank(mu), n(mu), s]);
0074         %R{2} = reshape( vv(:,1:s)', [s, n(mu+1), X.rank(mu+2)]);
0075 
0076         left = cat(3, X.U{mu}, R{1});
0077         %right = cat(1, X.U{mu+1}, R{2});
0078 
0079         % STEP 4: Move orthogonality to (mu+1)-th core while performing rank truncation
0080         [U,S,~] = svd( unfold(left,'left'), 'econ' );
0081         t = find( diag(S) > opts.tol*norm(diag(S)), 1, 'last' );
0082         t = min( t, opts.maxrank );
0083         X.U{mu} = reshape( U(:,1:t), [X.rank(mu), n(mu), t] );
0084         X.U{mu+1} = rand( t, n(mu+1), X.rank(mu+2));
0085 
0086         times = [times; toc(t_start)];
0087     end
0088     for mu = d:-1:2
0089         disp( ['Current core: ', num2str(mu)] )
0090 
0091         % STEP 1: Solve mu-th core opimization
0092         F_mu = contract( X, F, mu );
0093         sz = [X.rank(mu), X.size(mu), X.rank(mu+1)];
0094     
0095         if strcmpi( opts.solver, 'direct' ) 
0096             L_mu = contract( L, X, mu );
0097             U_mu = L_mu \ F_mu(:);
0098             X.U{mu} = reshape( U_mu, size(X.U{mu}) );
0099         elseif strcmpi( opts.solver, 'pcg' )
0100             [left, right] = Afun_prepare( L, X, mu );
0101             [B2, V, E] =  prepare_precond( L.L0, X, mu );
0102 
0103             U_mu = pcg( @(y) Afun( L, y, mu, sz, left, right), ...
0104                      F_mu(:), ...
0105                      1e-10, 1000, ...
0106                      @(y) apply_precond( B2, V, E, y, sz ), [],...
0107                      X.U{mu}(:) ); 
0108             X.U{mu} = reshape( U_mu, sz );
0109         else
0110             error( 'Unknown opts.solver type. Use either ''direct'' (default) or ''diag''.' )
0111         end
0112         
0113         % STEP 2: Calculate current residual and cost function
0114         res =  F - apply(L, X);
0115         residual = [residual; norm( res ) / normF];
0116         disp(['Rel. residual: ' num2str(residual(end)) ', Current rank: ' num2str(X.rank) ]);
0117         cost = [cost; cost_function( L, X, F )];
0118 
0119         % STEP 3: Augment mu-th and (mu+1)-th core with (truncated) residual
0120         R = contract( X, res, [mu-1, mu] );
0121         R_combined = unfold(R{1},'left') * unfold(R{2},'right'); 
0122         if opts.prec
0123             R_combined = precond_residual( L.L0, X, R_combined, mu-1 );
0124         end
0125         [~,ss,vv] = svd( R_combined, 'econ');  
0126         s = find( diag(ss) > opts.tolRes*norm(diag(ss)), 1, 'last' );
0127         if opts.maxrankRes ~= 0 
0128             s = min( s, opts.maxrankRes );
0129         end
0130         R{2} = reshape( ss(1:s,1:s)*vv(:,1:s)', [s, n(mu), X.rank(mu+1)]);
0131 
0132         right = cat(1, X.U{mu}, R{2});
0133 
0134         % STEP 4: Move orthogonality to (mu+1)-th core while performing rank truncation
0135         [~,S,V] = svd( unfold(right,'right'), 'econ' );
0136         t = find( diag(S) > opts.tol*norm(diag(S)), 1, 'last' );
0137         t = min( t, opts.maxrank );
0138         X.U{mu} = reshape( V(:,1:t)', [t, n(mu), X.rank(mu+1)] );
0139         X.U{mu-1} = rand( X.rank(mu-1), n(mu-1), t);
0140         
0141         times = [times; toc(t_start)];
0142     end
0143 
0144 end
0145 
0146 
0147 end
0148 
0149 function res = cost_function( L, X, F )
0150 res = 0.5*innerprod( X, apply(L, X) ) - innerprod( X, F );
0151 end
0152 
0153 function res = euclid_grad( L, X, F )
0154 res = apply(L, X) - F;
0155 end
0156 
0157 function res = precond_residual( L0, X, R_combined, idx )
0158     n = size(L0, 1);
0159     rl = X.rank(idx);
0160     rr = X.rank(idx+2);
0161 
0162     B1 = zeros( rl );
0163     % calculate B1 part:
0164     for i = 1:idx-1
0165         % apply L to the i'th core
0166         tmp = X;
0167         tmp.U{i} = tensorprod_ttemps( tmp.U{i}, L0, 2 );
0168         B1 = B1 + innerprod( X, tmp, 'LR', idx-1);
0169     end
0170 
0171     % calculate B2 part:
0172     B2 = kron( L0, speye(n) ) + kron( speye(n), L0 );
0173 
0174     B3 = zeros( rr );
0175     % calculate B3 part:
0176     for i = idx+2:X.order
0177         tmp = X;
0178         tmp.U{i} = tensorprod_ttemps( tmp.U{i}, L0, 2 );
0179         B3 = B3 + innerprod( X, tmp, 'RL', idx+2);
0180     end
0181 
0182     [V,E] = eig( kron( eye(rr), B1 ) + kron( B3, eye(rl) ) );
0183     E = diag(E);
0184 
0185     R_combined = reshape( R_combined, [rl, n*n, rr] );
0186     rhs = matricize( R_combined, 2 ) * V;
0187     Y = zeros(size(rhs));
0188     for i=1:length(E)
0189         Y(:,i) = (B2 + E(i)*speye(n*n)) \ rhs(:,i);
0190     end
0191     res = tensorize( Y*V', 2, [rl, n*n, rr] );
0192     res = reshape( res, [rl*n, n*rr] );
0193 end
0194 
0195 function [left, right] = Afun_prepare( A, x, idx )
0196     y = A.apply(x); 
0197     if idx == 1
0198         right = innerprod( x, y, 'RL', idx+1 );
0199         left = [];
0200     elseif idx == x.order
0201         left = innerprod( x, y, 'LR', idx-1 );
0202         right = [];
0203     else
0204         left = innerprod( x, y, 'LR', idx-1 );
0205         right = innerprod( x, y, 'RL', idx+1 ); 
0206     end
0207 end
0208 
0209 function res = Afun( A, U, idx, sz, left, right )
0210 
0211     V = reshape( U, sz );
0212     V = A.apply( V, idx );
0213     
0214     if idx == 1
0215         tmp = tensorprod_ttemps( V, right, 3 );
0216     elseif idx == A.order
0217         tmp = tensorprod_ttemps( V, left, 1 );
0218     else
0219         tmp = tensorprod_ttemps( V, right, 3);
0220         tmp = tensorprod_ttemps( tmp, left, 1);
0221     end
0222 
0223     res = tmp(:);
0224 end
0225 function [B2, V, E] = prepare_precond( L0, X, idx )
0226     n = size(L0, 1);
0227     rl = X.rank(idx);
0228     rr = X.rank(idx+1);
0229 
0230     B1 = zeros( rl );
0231     % calculate B1 part:
0232     for i = 1:idx-1
0233         % apply L to the i'th core
0234         tmp = X;
0235         tmp.U{i} = tensorprod_ttemps( tmp.U{i}, L0, 2 );
0236         B1 = B1 + innerprod( X, tmp, 'LR', idx-1);
0237     end
0238 
0239     % calculate B2 part:
0240     B2 = L0;
0241 
0242     B3 = zeros( rr );
0243     % calculate B3 part:
0244     for i = idx+1:X.order
0245         tmp = X;
0246         tmp.U{i} = tensorprod_ttemps( tmp.U{i}, L0, 2 );
0247         B3 = B3 + innerprod( X, tmp, 'RL', idx+1);
0248     end
0249 
0250     [V,E] = eig( kron( eye(rr), B1 ) + kron( B3, eye(rl) ) );
0251     E = diag(E);
0252 end
0253 function res = apply_precond( B2, V, E, rhs, sz )
0254     n = size(B2, 1);
0255     rhs = reshape( rhs, sz );
0256     rhs = matricize( rhs, 2 ) * V;
0257     Y = zeros(size(rhs));
0258     for i=1:length(E)
0259         Y(:,i) = (B2 + E(i)*speye(n)) \ rhs(:,i);
0260     end
0261     res = tensorize( Y*V', 2, sz );
0262     res = res(:);
0263 end
0264 
0265

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005