0001
0002
0003
0004
0005
0006 function [X, residual, cost] = amen( L, F, X, opts )
0007
0008
0009 if ~exist( 'opts', 'var'); opts = struct(); end
0010 if ~isfield( opts, 'nSweeps'); opts.nSweeps = 4; end
0011 if ~isfield( opts, 'maxrank'); opts.maxrank = 20; end
0012 if ~isfield( opts, 'maxrankRes'); opts.maxrankRes = 4; end
0013 if ~isfield( opts, 'tolRes'); opts.tolRes = 1e-13; end
0014 if ~isfield( opts, 'tol'); opts.tol = 1e-13; end
0015
0016 d = X.order;
0017 n = X.size;
0018
0019
0020 normF = norm(F);
0021 cost = cost_function( L, X, F );
0022 residual = norm( apply(L, X) - F ) / normF;
0023
0024 for sweep = 1:opts.nSweeps
0025 X = orthogonalize(X, 1);
0026 for mu = 1:d-1
0027 disp( ['Current core: ', num2str(mu)] )
0028
0029
0030 L_mu = contract( L, X, mu );
0031 F_mu = contract( X, F, mu );
0032
0033 U_mu = L_mu \ F_mu(:);
0034 X.U{mu} = reshape( U_mu, size(X.U{mu}) );
0035
0036
0037 res = F - apply(L, X);
0038 residual = [residual; norm( res ) / normF];
0039 cost = [cost; cost_function( L, X, F )];
0040 disp(['Rel. residual: ' num2str(residual(end)) ', Current rank: ' num2str(X.rank) ]);
0041
0042
0043 R = contract( X, res, [mu, mu+1] );
0044 R_combined = unfold(R{1},'left') * unfold(R{2},'right');
0045
0046 [uu,ss,~] = svd( R_combined, 'econ');
0047 s = find( diag(ss) > opts.tolRes*norm(diag(ss)), 1, 'last' );
0048 if opts.maxrankRes ~= 0
0049 s = min( s, opts.maxrankRes );
0050 end
0051 R{1} = reshape( uu(:,1:s)*ss(1:s,1:s), [X.rank(mu), n(mu), s]);
0052
0053
0054 left = cat(3, X.U{mu}, R{1});
0055
0056
0057
0058 [U,S,~] = svd( unfold(left,'left'), 'econ' );
0059 t = find( diag(S) > opts.tol*norm(diag(S)), 1, 'last' );
0060 t = min( t, opts.maxrank );
0061 X.U{mu} = reshape( U(:,1:t), [X.rank(mu), n(mu), t] );
0062
0063 X.U{mu+1} = rand( t, n(mu+1), X.rank(mu+2));
0064
0065 end
0066 for mu = d:-1:2
0067 disp( ['Current core: ', num2str(mu)] )
0068
0069
0070 L_mu = contract( L, X, mu );
0071 F_mu = contract( X, F, mu );
0072
0073 U_mu = L_mu \ F_mu(:);
0074 X.U{mu} = reshape( U_mu, size(X.U{mu}) );
0075
0076
0077 res = F - apply(L, X);
0078 residual = [residual; norm( res ) / normF];
0079 disp(['Rel. residual: ' num2str(residual(end)) ', Current rank: ' num2str(X.rank) ]);
0080 cost = [cost; cost_function( L, X, F )];
0081
0082
0083 R = contract( X, res, [mu-1, mu] );
0084 R_combined = unfold(R{1},'left') * unfold(R{2},'right');
0085
0086 [~,ss,vv] = svd( R_combined, 'econ');
0087 s = find( diag(ss) > opts.tolRes*norm(diag(ss)), 1, 'last' );
0088 if opts.maxrankRes ~= 0
0089 s = min( s, opts.maxrankRes );
0090 end
0091 R{2} = reshape( ss(1:s,1:s)*vv(:,1:s)', [s, n(mu), X.rank(mu+1)]);
0092
0093 right = cat(1, X.U{mu}, R{2});
0094
0095
0096 [~,S,V] = svd( unfold(right,'right'), 'econ' );
0097 t = find( diag(S) > opts.tol*norm(diag(S)), 1, 'last' );
0098 t = min( t, opts.maxrank );
0099 X.U{mu} = reshape( V(:,1:t)', [t, n(mu), X.rank(mu+1)] );
0100
0101 X.U{mu-1} = rand( X.rank(mu-1), n(mu-1), t);
0102
0103
0104
0105
0106
0107
0108
0109
0110
0111
0112
0113
0114 end
0115
0116 end
0117
0118
0119 end
0120
0121 function res = cost_function( L, X, F )
0122 res = 0.5*innerprod( X, apply(L, X) ) - innerprod( X, F );
0123 end
0124
0125 function res = euclid_grad( L, X, F )
0126 res = apply(L, X) - F;
0127 end
0128