0001
0002
0003
0004
0005
0006 function [X, residual, cost, times] = amen_fast( L, F, X, opts )
0007
0008 t_start = tic();
0009
0010 if ~exist( 'opts', 'var'); opts = struct(); end
0011 if ~isfield( opts, 'nSweeps'); opts.nSweeps = 4; end
0012 if ~isfield( opts, 'maxrank'); opts.maxrank = 20; end
0013 if ~isfield( opts, 'maxrankRes'); opts.maxrankRes = 4; end
0014 if ~isfield( opts, 'tolRes'); opts.tolRes = 1e-13; end
0015 if ~isfield( opts, 'tol'); opts.tol = 1e-13; end
0016 if ~isfield( opts, 'solver'); opts.solver = 'direct'; end
0017 if ~isfield( opts, 'prec'); opts.prec = true; end
0018
0019
0020 d = X.order;
0021 n = X.size;
0022
0023 normF = norm(F);
0024 cost = cost_function( L, X, F );
0025 residual = norm( apply(L, X) - F ) / normF;
0026 times = toc(t_start);
0027
0028 for sweep = 1:opts.nSweeps
0029 X = orthogonalize(X, 1);
0030 for mu = 1:d-1
0031 disp( ['Current core: ', num2str(mu)] )
0032
0033
0034 F_mu = contract( X, F, mu );
0035 sz = [X.rank(mu), X.size(mu), X.rank(mu+1)];
0036
0037 if strcmpi( opts.solver, 'direct' )
0038
0039 L_mu = contract( L, X, mu );
0040 U_mu = L_mu \ F_mu(:);
0041 X.U{mu} = reshape( U_mu, sz );
0042 elseif strcmpi( opts.solver, 'pcg' )
0043 [left, right] = Afun_prepare( L, X, mu );
0044 [B2, V, E] = prepare_precond( L.L0, X, mu );
0045
0046 U_mu = pcg( @(y) Afun( L, y, mu, sz, left, right), ...
0047 F_mu(:), ...
0048 1e-10, 1000, ...
0049 @(y) apply_precond( B2, V, E, y, sz ), [],...
0050 X.U{mu}(:) );
0051 X.U{mu} = reshape( U_mu, sz );
0052 else
0053 error( 'Unknown opts.solver type. Use either ''direct'' (default) or ''pcg''.' )
0054 end
0055
0056
0057 res = F - apply(L, X);
0058 residual = [residual; norm( res ) / normF];
0059 cost = [cost; cost_function( L, X, F )];
0060 disp(['Rel. residual: ' num2str(residual(end)) ', Current rank: ' num2str(X.rank) ]);
0061
0062
0063 R = contract( X, res, [mu, mu+1] );
0064 R_combined = unfold(R{1},'left') * unfold(R{2},'right');
0065 if opts.prec
0066 R_combined = precond_residual( L.L0, X, R_combined, mu );
0067 end
0068 [uu,ss,~] = svd( R_combined, 'econ');
0069 s = find( diag(ss) > opts.tolRes*norm(diag(ss)), 1, 'last' );
0070 if opts.maxrankRes ~= 0
0071 s = min( s, opts.maxrankRes );
0072 end
0073 R{1} = reshape( uu(:,1:s)*ss(1:s,1:s), [X.rank(mu), n(mu), s]);
0074
0075
0076 left = cat(3, X.U{mu}, R{1});
0077
0078
0079
0080 [U,S,~] = svd( unfold(left,'left'), 'econ' );
0081 t = find( diag(S) > opts.tol*norm(diag(S)), 1, 'last' );
0082 t = min( t, opts.maxrank );
0083 X.U{mu} = reshape( U(:,1:t), [X.rank(mu), n(mu), t] );
0084 X.U{mu+1} = rand( t, n(mu+1), X.rank(mu+2));
0085
0086 times = [times; toc(t_start)];
0087 end
0088 for mu = d:-1:2
0089 disp( ['Current core: ', num2str(mu)] )
0090
0091
0092 F_mu = contract( X, F, mu );
0093 sz = [X.rank(mu), X.size(mu), X.rank(mu+1)];
0094
0095 if strcmpi( opts.solver, 'direct' )
0096 L_mu = contract( L, X, mu );
0097 U_mu = L_mu \ F_mu(:);
0098 X.U{mu} = reshape( U_mu, size(X.U{mu}) );
0099 elseif strcmpi( opts.solver, 'pcg' )
0100 [left, right] = Afun_prepare( L, X, mu );
0101 [B2, V, E] = prepare_precond( L.L0, X, mu );
0102
0103 U_mu = pcg( @(y) Afun( L, y, mu, sz, left, right), ...
0104 F_mu(:), ...
0105 1e-10, 1000, ...
0106 @(y) apply_precond( B2, V, E, y, sz ), [],...
0107 X.U{mu}(:) );
0108 X.U{mu} = reshape( U_mu, sz );
0109 else
0110 error( 'Unknown opts.solver type. Use either ''direct'' (default) or ''diag''.' )
0111 end
0112
0113
0114 res = F - apply(L, X);
0115 residual = [residual; norm( res ) / normF];
0116 disp(['Rel. residual: ' num2str(residual(end)) ', Current rank: ' num2str(X.rank) ]);
0117 cost = [cost; cost_function( L, X, F )];
0118
0119
0120 R = contract( X, res, [mu-1, mu] );
0121 R_combined = unfold(R{1},'left') * unfold(R{2},'right');
0122 if opts.prec
0123 R_combined = precond_residual( L.L0, X, R_combined, mu-1 );
0124 end
0125 [~,ss,vv] = svd( R_combined, 'econ');
0126 s = find( diag(ss) > opts.tolRes*norm(diag(ss)), 1, 'last' );
0127 if opts.maxrankRes ~= 0
0128 s = min( s, opts.maxrankRes );
0129 end
0130 R{2} = reshape( ss(1:s,1:s)*vv(:,1:s)', [s, n(mu), X.rank(mu+1)]);
0131
0132 right = cat(1, X.U{mu}, R{2});
0133
0134
0135 [~,S,V] = svd( unfold(right,'right'), 'econ' );
0136 t = find( diag(S) > opts.tol*norm(diag(S)), 1, 'last' );
0137 t = min( t, opts.maxrank );
0138 X.U{mu} = reshape( V(:,1:t)', [t, n(mu), X.rank(mu+1)] );
0139 X.U{mu-1} = rand( X.rank(mu-1), n(mu-1), t);
0140
0141 times = [times; toc(t_start)];
0142 end
0143
0144 end
0145
0146
0147 end
0148
0149 function res = cost_function( L, X, F )
0150 res = 0.5*innerprod( X, apply(L, X) ) - innerprod( X, F );
0151 end
0152
0153 function res = euclid_grad( L, X, F )
0154 res = apply(L, X) - F;
0155 end
0156
0157 function res = precond_residual( L0, X, R_combined, idx )
0158 n = size(L0, 1);
0159 rl = X.rank(idx);
0160 rr = X.rank(idx+2);
0161
0162 B1 = zeros( rl );
0163
0164 for i = 1:idx-1
0165
0166 tmp = X;
0167 tmp.U{i} = tensorprod_ttemps( tmp.U{i}, L0, 2 );
0168 B1 = B1 + innerprod( X, tmp, 'LR', idx-1);
0169 end
0170
0171
0172 B2 = kron( L0, speye(n) ) + kron( speye(n), L0 );
0173
0174 B3 = zeros( rr );
0175
0176 for i = idx+2:X.order
0177 tmp = X;
0178 tmp.U{i} = tensorprod_ttemps( tmp.U{i}, L0, 2 );
0179 B3 = B3 + innerprod( X, tmp, 'RL', idx+2);
0180 end
0181
0182 [V,E] = eig( kron( eye(rr), B1 ) + kron( B3, eye(rl) ) );
0183 E = diag(E);
0184
0185 R_combined = reshape( R_combined, [rl, n*n, rr] );
0186 rhs = matricize( R_combined, 2 ) * V;
0187 Y = zeros(size(rhs));
0188 for i=1:length(E)
0189 Y(:,i) = (B2 + E(i)*speye(n*n)) \ rhs(:,i);
0190 end
0191 res = tensorize( Y*V', 2, [rl, n*n, rr] );
0192 res = reshape( res, [rl*n, n*rr] );
0193 end
0194
0195 function [left, right] = Afun_prepare( A, x, idx )
0196 y = A.apply(x);
0197 if idx == 1
0198 right = innerprod( x, y, 'RL', idx+1 );
0199 left = [];
0200 elseif idx == x.order
0201 left = innerprod( x, y, 'LR', idx-1 );
0202 right = [];
0203 else
0204 left = innerprod( x, y, 'LR', idx-1 );
0205 right = innerprod( x, y, 'RL', idx+1 );
0206 end
0207 end
0208
0209 function res = Afun( A, U, idx, sz, left, right )
0210
0211 V = reshape( U, sz );
0212 V = A.apply( V, idx );
0213
0214 if idx == 1
0215 tmp = tensorprod_ttemps( V, right, 3 );
0216 elseif idx == A.order
0217 tmp = tensorprod_ttemps( V, left, 1 );
0218 else
0219 tmp = tensorprod_ttemps( V, right, 3);
0220 tmp = tensorprod_ttemps( tmp, left, 1);
0221 end
0222
0223 res = tmp(:);
0224 end
0225 function [B2, V, E] = prepare_precond( L0, X, idx )
0226 n = size(L0, 1);
0227 rl = X.rank(idx);
0228 rr = X.rank(idx+1);
0229
0230 B1 = zeros( rl );
0231
0232 for i = 1:idx-1
0233
0234 tmp = X;
0235 tmp.U{i} = tensorprod_ttemps( tmp.U{i}, L0, 2 );
0236 B1 = B1 + innerprod( X, tmp, 'LR', idx-1);
0237 end
0238
0239
0240 B2 = L0;
0241
0242 B3 = zeros( rr );
0243
0244 for i = idx+1:X.order
0245 tmp = X;
0246 tmp.U{i} = tensorprod_ttemps( tmp.U{i}, L0, 2 );
0247 B3 = B3 + innerprod( X, tmp, 'RL', idx+1);
0248 end
0249
0250 [V,E] = eig( kron( eye(rr), B1 ) + kron( B3, eye(rl) ) );
0251 E = diag(E);
0252 end
0253 function res = apply_precond( B2, V, E, rhs, sz )
0254 n = size(B2, 1);
0255 rhs = reshape( rhs, sz );
0256 rhs = matricize( rhs, 2 ) * V;
0257 Y = zeros(size(rhs));
0258 for i=1:length(E)
0259 Y(:,i) = (B2 + E(i)*speye(n)) \ rhs(:,i);
0260 end
0261 res = tensorize( Y*V', 2, sz );
0262 res = res(:);
0263 end
0264
0265