Home > manopt > manifolds > ttfixedrank > TTeMPS_1.1 > algorithms > completion > completion_orth_lambda.m

completion_orth_lambda

PURPOSE ^

RTTC: Riemannian Tensor Train Completion

SYNOPSIS ^

function [xL,cost,test,stats] = completion_orth_lambda( A_Omega, Omega, A_Gamma, Gamma, X, opts, lambda )

DESCRIPTION ^

 RTTC: Riemannian Tensor Train Completion
 as described in 
   
   Michael Steinlechner, Riemannian optimization for high-dimensional tensor completion,
   Technical report, March 2015, revised December 2015. 
   To appear in SIAM J. Sci. Comput.

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 % RTTC: Riemannian Tensor Train Completion
0002 % as described in
0003 %
0004 %   Michael Steinlechner, Riemannian optimization for high-dimensional tensor completion,
0005 %   Technical report, March 2015, revised December 2015.
0006 %   To appear in SIAM J. Sci. Comput.
0007 %
0008 
0009 %   TTeMPS Toolbox.
0010 %   Michael Steinlechner, 2013-2016
0011 %   Questions and contact: michael.steinlechner@epfl.ch
0012 %   BSD 2-clause license, see LICENSE.txt
0013 function [xL,cost,test,stats] = completion_orth_lambda( A_Omega, Omega, A_Gamma, Gamma, X, opts, lambda )
0014     
0015     if ~isfield( opts, 'maxiter');  opts.maxiter = 100;     end
0016     if ~isfield( opts, 'cg');       opts.cg = true;         end
0017     if ~isfield( opts, 'tol');      opts.tol = 1e-6;        end
0018     if ~isfield( opts, 'reltol');   opts.reltol = 1e-6;     end
0019     if ~isfield( opts, 'gradtol');  opts.gradtol = 10*eps;  end
0020     if ~isfield( opts, 'verbose');  opts.verbose = false;   end
0021     
0022 
0023     n = X.size;
0024     nn = prod(n);
0025     sizeOmega = numel(A_Omega);
0026     r = X.rank;
0027     
0028     xL = X;
0029     xR = orthogonalize(X, 1);
0030 
0031     norm_A_Omega = norm( A_Omega );
0032     norm_A_Gamma = norm( A_Gamma );
0033     
0034     cost = zeros(opts.maxiter,1);
0035     test = zeros(opts.maxiter,1);
0036 
0037     t = tic;
0038     stats.time = [];
0039     stats.conv = false;
0040 
0041     for i = 1:opts.maxiter
0042         grad = euclidgrad(A_Omega, xL, Omega);
0043         xi = TTeMPS_tangent_orth(xL, xR, grad, Omega);
0044         
0045         % This is where we add normalization
0046         d = xL.order;
0047         xi.dU{d} = xi.dU{d} + (lambda * sizeOmega / nn) * xL.U{d};
0048         ip_xi_xi = innerprod(xi, xi);
0049 
0050         if sqrt( abs(ip_xi_xi) ) < opts.gradtol 
0051             if cost(i) < opts.tol
0052                 disp(sprintf('CONVERGED AFTER %i STEPS. Gradient is smaller than %0.3g', ...
0053                       i, opts.gradtol))
0054                 stats.conv = true;
0055             else
0056                 disp('No more progress in gradient change, but not converged. Aborting!')
0057                 stats.conv = false;
0058             end
0059             cost = cost(1:i,1);
0060             test = test(1:i,1);
0061             stats.time = [stats.time toc(t)];
0062             return
0063         end
0064 
0065         if (i == 1) || (~opts.cg) 
0066             eta = -xi;
0067         else
0068             ip_xitrans_xi = innerprod( xi_trans, xi );
0069             theta = ip_xitrans_xi / ip_xi_xi;
0070             if theta >= 0.1
0071                 if opts.verbose, disp('steepest descent step'), end
0072                 eta = -xi;
0073             else
0074                 if opts.verbose, disp('CG step'), end
0075                 beta = ip_xi_xi/ip_xi_xi_old;
0076                 eta = -xi + beta*TTeMPS_tangent_orth( xL, xR, eta );
0077             end
0078         end
0079         
0080         %line search
0081         % Note we have to include normalization for this explicit line search
0082         eta_Omega = at_Omega( eta, Omega );
0083         alpha = -(eta_Omega'*grad + (lambda * sizeOmega / nn) * xL.U{d}(:)' * eta.dU{d}(:)) / (norm(eta_Omega)^2 + (lambda * sizeOmega / nn) * innerprod(eta, eta));
0084         
0085         X = tangentAdd( eta, alpha, true );
0086         xL = orthogonalize( X, X.order );
0087         xR = orthogonalize( X, 1 );
0088         cost(i) = sqrt(2*func(A_Omega, xL, Omega )) / norm_A_Omega;
0089         test(i) = sqrt(2*func(A_Gamma, xL, Gamma )) / norm_A_Gamma;
0090 
0091         if cost(i) < opts.tol
0092             disp(sprintf('CONVERGED AFTER %i STEPS. Rel. residual smaller than %0.3g', ...
0093                           i, opts.tol))
0094             stats.conv = true;
0095             cost = cost(1:i,1);
0096             test = test(1:i,1);
0097             stats.time = [stats.time toc(t)];
0098             return
0099         end
0100 
0101         if i > 1
0102             reltol = abs(cost(i) - cost(i-1)) / cost(i);
0103             if reltol < opts.reltol
0104                 if cost(i) < opts.tol
0105                     disp(sprintf('CONVERGED AFTER %i STEPS. Relative change is smaller than %0.3g', ...
0106                               i, opts.reltol))
0107                     stats.conv = true;
0108                 else
0109                     disp('No more progress in relative change, but not converged. Aborting!')
0110                     stats.conv = false;
0111                 end
0112 
0113                 cost = cost(1:i,1);
0114                 test = test(1:i,1);
0115                 stats.time = [stats.time toc(t)];
0116                 return
0117             end
0118         end
0119 
0120         ip_xi_xi_old = ip_xi_xi;
0121         xi_trans = TTeMPS_tangent_orth( xL, xR, xi );
0122 
0123         stats.time = [stats.time toc(t)];
0124     end
0125 
0126     
0127 
0128 end
0129 
0130 
0131 function res = func(A_Omega, X, Omega)
0132     res = 0.5*norm( A_Omega - X(Omega) )^2;
0133 end
0134 
0135 function res = euclidgrad(A_Omega, X, Omega)
0136     res = X(Omega) - A_Omega;
0137 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005