Home > manopt > manifolds > ttfixedrank > TTeMPS_1.1 > algorithms > completion > completion_als.m

completion_als

PURPOSE ^

ALS Completion

SYNOPSIS ^

function [X,cost,test,stats] = completion_als( A_Omega, Omega, A_Gamma, Gamma, X, opts )

DESCRIPTION ^

 ALS Completion
 as described in 
   
   Michael Steinlechner, Riemannian optimization for high-dimensional tensor completion,
   Technical report, March 2015, revised December 2015. 
   To appear in SIAM J. Sci. Comput.

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 % ALS Completion
0002 % as described in
0003 %
0004 %   Michael Steinlechner, Riemannian optimization for high-dimensional tensor completion,
0005 %   Technical report, March 2015, revised December 2015.
0006 %   To appear in SIAM J. Sci. Comput.
0007 %
0008 
0009 %   TTeMPS Toolbox.
0010 %   Michael Steinlechner, 2013-2016
0011 %   Questions and contact: michael.steinlechner@epfl.ch
0012 %   BSD 2-clause license, see LICENSE.txt
0013 function [X,cost,test,stats] = completion_als( A_Omega, Omega, A_Gamma, Gamma, X, opts )
0014     
0015     if ~isfield( opts, 'maxiter');  opts.maxiter = 100;     end
0016     if ~isfield( opts, 'tol');      opts.tol = 1e-6;        end
0017     if ~isfield( opts, 'reltol');   opts.reltol = 1e-6;     end
0018 
0019     n = X.size;
0020     r = X.rank;
0021     d = X.order;
0022     
0023     cost = zeros(2*opts.maxiter,1);
0024     test = zeros(2*opts.maxiter,1);
0025 
0026     norm_A_Omega = norm( A_Omega );
0027     norm_A_Gamma = norm( A_Gamma );
0028 
0029     X = orthogonalize( X, 1 );
0030 
0031     t = tic;
0032     stats.time = [0];
0033     stats.conv = false;
0034 
0035     for i = 1:opts.maxiter
0036         
0037         % ===================
0038         % FORWARD SWEEP:
0039         % ===================
0040         fprintf(1,'Currently optimizing core: ')
0041         for mu = 1:d-1
0042             fprintf(1,'%i ', mu)
0043             X.U{mu} = solve_least_squares( A_Omega, Omega, X, mu );
0044             X = orth_at( X, mu, 'left' );
0045         end
0046         cost(2*i-1) = sqrt(2*func(A_Omega, X, Omega)) / norm_A_Omega;
0047         
0048 
0049         if cost(2*i-1) < opts.tol 
0050             disp(sprintf('CONVERGED AFTER %i HALF-SWEEPS. Rel. residual smaller than %0.3g', ...
0051                           2*i-1, opts.tol))
0052             stats.conv = true;
0053             cost = cost(1:2*i-1,1);
0054             stats.time = [stats.time stats.time(end)+toc(t)];
0055             test(2*i-1) = sqrt(2*func(A_Gamma, X, Gamma)) / norm_A_Gamma;
0056             test = test(1:2*i-1,1);
0057             break
0058         end
0059 
0060         if i > 1
0061             reltol = abs(cost(2*i-1) - cost(2*i-2)) / cost(2*i-1);
0062             if reltol < opts.reltol
0063                 disp(sprintf('No more progress in gradient change, but not converged after %i half-sweeps. ABORTING!. \nRelative change is smaller than %0.3g', ...
0064                               i, opts.reltol))
0065                 stats.conv = false;
0066                 cost = cost(1:2*i-1,1);
0067                 stats.time = [stats.time stats.time(end)+toc(t)];
0068                 test(2*i-1) = sqrt(2*func(A_Gamma, X, Gamma)) / norm_A_Gamma;
0069                 test = test(1:2*i-1,1);
0070                 break
0071             end
0072         end
0073 
0074         stats.time = [stats.time stats.time(end)+toc(t)];
0075         test(2*i-1) = sqrt(2*func(A_Gamma, X, Gamma)) / norm_A_Gamma;
0076         t = tic;
0077 
0078         fprintf(1,'\nFinished forward sweep.\n    Cost: %e\n    Test: %e\n', cost(2*i-1), test(2*i-1) );
0079         % ===================
0080         % BACKWARD SWEEP:
0081         % ===================
0082         fprintf(1,'Currently optimizing core: ')
0083         for mu = d:-1:2
0084             fprintf(1,'%i ', mu)
0085             X.U{mu} = solve_least_squares( A_Omega, Omega, X, mu );
0086             X = orth_at( X, mu, 'right' );
0087         end
0088 
0089         cost(2*i) = sqrt(2*func(A_Omega, X, Omega)) / norm_A_Omega;
0090         
0091 
0092         if cost(2*i) < opts.tol
0093             disp(sprintf('CONVERGED AFTER %i HALF-SWEEPS. Rel. residual smaller than %0.3g', ...
0094                           2*i, opts.tol))
0095             stats.conv = true;
0096             cost = cost(1:2*i,1);
0097             stats.time = [stats.time stats.time(end)+toc(t)];
0098             test(2*i) = sqrt(2*func(A_Gamma, X, Gamma)) / norm_A_Gamma;
0099             test = test(1:2*i,1);
0100             break
0101         end
0102         
0103         if i > 1
0104             reltol = abs(cost(2*i) - cost(2*i-1)) / cost(2*i);
0105             if reltol < opts.reltol
0106                 disp(sprintf('No more progress in gradient change, but not converged after %i half-sweeps. ABORTING!. \nRelative change is smaller than %0.3g', ...
0107                               2*i, opts.reltol))
0108                 stats.conv = false;
0109                 cost = cost(1:2*i,1);
0110                 stats.time = [stats.time stats.time(end)+toc(t)];
0111                 test(2*i) = sqrt(2*func(A_Gamma, X, Gamma)) / norm_A_Gamma;
0112                 test = test(1:2*i,1);
0113                 break
0114             end
0115         end
0116 
0117         stats.time = [stats.time stats.time(end)+toc(t)];
0118         test(2*i) = sqrt(2*func(A_Gamma, X, Gamma)) / norm_A_Gamma;
0119         t = tic;
0120         fprintf(1,'\nFinished backward sweep.\n    Cost: %e\n    Test: %e\n', cost(2*i), test(2*i) );
0121         
0122         
0123         disp('_______________________________________________________________')
0124     end
0125 
0126     % This is to match original shape of stats.time, since we artificially start w/ [0]
0127     % for consistency in how we count time
0128     stats.time = stats.time(2:end);
0129 
0130 end
0131 
0132 
0133 function res = func(A_Omega, X, Omega)
0134     res = 0.5*norm( A_Omega - X(Omega) )^2;
0135 end
0136 
0137 
0138 function res = solve_least_squares( A_Omega, Omega, X, mu )
0139 
0140     n = X.size;
0141     d = X.order;
0142     r = X.rank;
0143     
0144     [jmu,idx] = sort(Omega(:,mu),'ascend');
0145     Omega = Omega(idx,:);
0146     A_Omega = A_Omega(idx);
0147     
0148     C = cell(1,d);
0149     for i=1:d
0150         C{i} = permute( X.U{i}, [1 3 2]);
0151     end
0152     res = zeros( size(C{mu}) );
0153 
0154     %B = zeros(size(Omega,1), r(mu)*r(mu+1));
0155 
0156     %imu = 1;
0157     %for sample = 1:size(Omega,1)
0158 
0159     %    L = 1;
0160     %    for i = 1:mu-1
0161     %        L = L * C{i}(:,:,Omega(sample,i));
0162     %    end
0163 
0164     %    R = 1;
0165     %    for i = d:-1:mu+1
0166     %        R = C{i}(:,:,Omega(sample,i)) * R;
0167     %    end
0168     %
0169     %    %B(sample,:) = kron(R',L);
0170     %    B(sample,:) = reshape( L'*R', 1, r(mu)*r(mu+1) );
0171     %end
0172 
0173     B = als_solve_mex( n, r, C, Omega', mu)';
0174 
0175     for i = 1:X.size(mu)
0176         idx = find(jmu == i);
0177 
0178         if isempty(idx) 
0179             error('No samples for this slice!')
0180         end
0181         res(:,:,i) = reshape(B(idx,:)\A_Omega(idx), r(mu), r(mu+1));
0182     end
0183       
0184    
0185     res = permute( res, [1 3 2] );
0186 
0187 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005