Home > manopt > manifolds > ttfixedrank > TTeMPS_1.1 > algorithms > completion > completion_orth.m

completion_orth

PURPOSE ^

RTTC: Riemannian Tensor Train Completion

SYNOPSIS ^

function [xL,cost,test,stats] = completion_orth( A_Omega, Omega, A_Gamma, Gamma, X, opts )

DESCRIPTION ^

 RTTC: Riemannian Tensor Train Completion
 as described in 
   
   Michael Steinlechner, Riemannian optimization for high-dimensional tensor completion,
   Technical report, March 2015, revised December 2015. 
   To appear in SIAM J. Sci. Comput.

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 % RTTC: Riemannian Tensor Train Completion
0002 % as described in
0003 %
0004 %   Michael Steinlechner, Riemannian optimization for high-dimensional tensor completion,
0005 %   Technical report, March 2015, revised December 2015.
0006 %   To appear in SIAM J. Sci. Comput.
0007 %
0008 
0009 %   TTeMPS Toolbox.
0010 %   Michael Steinlechner, 2013-2016
0011 %   Questions and contact: michael.steinlechner@epfl.ch
0012 %   BSD 2-clause license, see LICENSE.txt
0013 function [xL,cost,test,stats] = completion_orth( A_Omega, Omega, A_Gamma, Gamma, X, opts )
0014     
0015     if ~isfield( opts, 'maxiter');  opts.maxiter = 100;     end
0016     if ~isfield( opts, 'cg');       opts.cg = true;         end
0017     if ~isfield( opts, 'tol');      opts.tol = 1e-6;        end
0018     if ~isfield( opts, 'reltol');   opts.reltol = 1e-6;     end
0019     if ~isfield( opts, 'gradtol');  opts.gradtol = 10*eps;  end
0020     if ~isfield( opts, 'verbose');  opts.verbose = false;   end
0021     
0022 
0023     n = X.size;
0024     r = X.rank;
0025     
0026     xL = X;
0027     xR = orthogonalize(X, 1);
0028 
0029     norm_A_Omega = norm( A_Omega );
0030     norm_A_Gamma = norm( A_Gamma );
0031     
0032     cost = zeros(opts.maxiter,1);
0033     test = zeros(opts.maxiter,1);
0034     stats.gradnorm = zeros(opts.maxiter,1);
0035 
0036     t = tic;
0037     stats.time = [0];
0038     stats.conv = false;
0039 
0040     for i = 1:opts.maxiter
0041         grad = euclidgrad(A_Omega, xL, Omega);
0042         xi = TTeMPS_tangent_orth(xL, xR, grad, Omega);
0043         ip_xi_xi = innerprod(xi, xi);
0044         stats.gradnorm(i) = sqrt(abs(ip_xi_xi));
0045 
0046         if sqrt( abs(ip_xi_xi) ) < opts.gradtol 
0047             if cost(i) < opts.tol
0048                 disp(sprintf('CONVERGED AFTER %i STEPS. Gradient is smaller than %0.3g', ...
0049                       i, opts.gradtol))
0050                 stats.conv = true;
0051             else
0052                 disp('No more progress in gradient change, but not converged. Aborting!')
0053                 stats.conv = false;
0054             end
0055             cost = cost(1:i,1);
0056             test = test(1:i,1);
0057             stats.gradnorm = stats.gradnorm(1:i,1);
0058 
0059             stats.time = [stats.time stats.time(end) + toc(t)];
0060             stats.time = stats.time(2:end);
0061             return
0062         end
0063 
0064         if (i == 1) || (~opts.cg) 
0065             eta = -xi;
0066         else
0067             ip_xitrans_xi = innerprod( xi_trans, xi );
0068             theta = ip_xitrans_xi / ip_xi_xi;
0069             if theta >= 0.1
0070                 if opts.verbose, disp('steepest descent step'), end
0071                 eta = -xi;
0072             else
0073                 if opts.verbose, disp('CG step'), end
0074                 beta = ip_xi_xi/ip_xi_xi_old;
0075                 eta = -xi + beta*TTeMPS_tangent_orth( xL, xR, eta );
0076             end
0077         end
0078         
0079         %line search
0080         eta_Omega = at_Omega( eta, Omega );
0081         alpha = -(eta_Omega'*grad) / norm(eta_Omega)^2;
0082         
0083         X = tangentAdd( eta, alpha, true );
0084         xL = orthogonalize( X, X.order );
0085         xR = orthogonalize( X, 1 );
0086         cost(i) = sqrt(2*func(A_Omega, xL, Omega )) / norm_A_Omega;
0087         
0088 
0089         if cost(i) < opts.tol
0090             disp(sprintf('CONVERGED AFTER %i STEPS. Rel. residual smaller than %0.3g', ...
0091                           i, opts.tol))
0092             stats.conv = true;
0093             cost = cost(1:i,1);
0094             stats.gradnorm = stats.gradnorm(1:i,1);
0095             stats.time = [stats.time stats.time(end)+toc(t)];
0096             test(i) = sqrt(2*func(A_Gamma, xL, Gamma )) / norm_A_Gamma;
0097             test = test(1:i,1);
0098             stats.time = stats.time(2:end);
0099             return
0100         end
0101 
0102         if i > 1
0103             reltol = abs(cost(i) - cost(i-1)) / cost(i);
0104             if reltol < opts.reltol
0105                 if cost(i) < opts.tol
0106                     disp(sprintf('CONVERGED AFTER %i STEPS. Relative change is smaller than %0.3g', ...
0107                               i, opts.reltol))
0108                     stats.conv = true;
0109                 else
0110                     disp('No more progress in relative change, but not converged. Aborting!')
0111                     stats.conv = false;
0112                 end
0113 
0114                 cost = cost(1:i,1);
0115                 stats.gradnorm = stats.gradnorm(1:i,1);
0116                 stats.time = [stats.time stats.time(end)+toc(t)];
0117                 test(i) = sqrt(2*func(A_Gamma, xL, Gamma )) / norm_A_Gamma;
0118                 test = test(1:i,1);
0119                 stats.time = stats.time(2:end);
0120                 return
0121             end
0122         end
0123 
0124         ip_xi_xi_old = ip_xi_xi;
0125         xi_trans = TTeMPS_tangent_orth( xL, xR, xi );
0126 
0127         stats.time = [stats.time stats.time(end)+toc(t)];
0128         test(i) = sqrt(2*func(A_Gamma, xL, Gamma )) / norm_A_Gamma;
0129         t = tic;
0130     end
0131     
0132     stats.time = stats.time(2:end);
0133 
0134     
0135 
0136 end
0137 
0138 
0139 function res = func(A_Omega, X, Omega)
0140     res = 0.5*norm( A_Omega - X(Omega) )^2;
0141 end
0142 
0143 function res = euclidgrad(A_Omega, X, Omega)
0144     res = X(Omega) - A_Omega;
0145 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005