Home > manopt > manifolds > ttfixedrank > TTeMPS_1.1 > algorithms > completion > completion_orth.m

# completion_orth

## PURPOSE

RTTC: Riemannian Tensor Train Completion

## SYNOPSIS

function [xL,cost,test,stats] = completion_orth( A_Omega, Omega, A_Gamma, Gamma, X, opts )

## DESCRIPTION

``` RTTC: Riemannian Tensor Train Completion
as described in

Michael Steinlechner, Riemannian optimization for high-dimensional tensor completion,
Technical report, March 2015, revised December 2015.
To appear in SIAM J. Sci. Comput.```

## CROSS-REFERENCE INFORMATION

This function calls:
• disp DISP Display TT/MPS tensor.
• innerprod INNERPROD Inner product between two TT/MPS tensors.
• norm NORM Norm of a TT/MPS tensor.
• orthogonalize ORTHOGONALIZE Orthogonalize tensor.
• disp DISP Display TT/MPS block-mu tensor.
• innerprod INNERPROD Inner product between two TT/MPS tensors.
• norm NORM Norm of a TT/MPS block-mu tensor.
• orthogonalize ORTHOGONALIZE Orthogonalize TT/MPS Block-mu tensor.
• disp DISP Display TT/MPS operator.
• disp DISP Display TT/MPS operator.
• TTeMPS_tangent_orth
• orthogonalize Orthonormalizes a basis of tangent vectors in the Manopt framework.
This function is called by:

## SOURCE CODE

```0001 % RTTC: Riemannian Tensor Train Completion
0002 % as described in
0003 %
0004 %   Michael Steinlechner, Riemannian optimization for high-dimensional tensor completion,
0005 %   Technical report, March 2015, revised December 2015.
0006 %   To appear in SIAM J. Sci. Comput.
0007 %
0008
0009 %   TTeMPS Toolbox.
0010 %   Michael Steinlechner, 2013-2016
0011 %   Questions and contact: michael.steinlechner@epfl.ch
0013 function [xL,cost,test,stats] = completion_orth( A_Omega, Omega, A_Gamma, Gamma, X, opts )
0014
0015     if ~isfield( opts, 'maxiter');  opts.maxiter = 100;     end
0016     if ~isfield( opts, 'cg');       opts.cg = true;         end
0017     if ~isfield( opts, 'tol');      opts.tol = 1e-6;        end
0018     if ~isfield( opts, 'reltol');   opts.reltol = 1e-6;     end
0020     if ~isfield( opts, 'verbose');  opts.verbose = false;   end
0021
0022
0023     n = X.size;
0024     r = X.rank;
0025
0026     xL = X;
0027     xR = orthogonalize(X, 1);
0028
0029     norm_A_Omega = norm( A_Omega );
0030     norm_A_Gamma = norm( A_Gamma );
0031
0032     cost = zeros(opts.maxiter,1);
0033     test = zeros(opts.maxiter,1);
0035
0036     t = tic;
0037     stats.time = [0];
0038     stats.conv = false;
0039
0040     for i = 1:opts.maxiter
0042         xi = TTeMPS_tangent_orth(xL, xR, grad, Omega);
0043         ip_xi_xi = innerprod(xi, xi);
0045
0046         if sqrt( abs(ip_xi_xi) ) < opts.gradtol
0047             if cost(i) < opts.tol
0048                 disp(sprintf('CONVERGED AFTER %i STEPS. Gradient is smaller than %0.3g', ...
0050                 stats.conv = true;
0051             else
0052                 disp('No more progress in gradient change, but not converged. Aborting!')
0053                 stats.conv = false;
0054             end
0055             cost = cost(1:i,1);
0056             test = test(1:i,1);
0058
0059             stats.time = [stats.time stats.time(end) + toc(t)];
0060             stats.time = stats.time(2:end);
0061             return
0062         end
0063
0064         if (i == 1) || (~opts.cg)
0065             eta = -xi;
0066         else
0067             ip_xitrans_xi = innerprod( xi_trans, xi );
0068             theta = ip_xitrans_xi / ip_xi_xi;
0069             if theta >= 0.1
0070                 if opts.verbose, disp('steepest descent step'), end
0071                 eta = -xi;
0072             else
0073                 if opts.verbose, disp('CG step'), end
0074                 beta = ip_xi_xi/ip_xi_xi_old;
0075                 eta = -xi + beta*TTeMPS_tangent_orth( xL, xR, eta );
0076             end
0077         end
0078
0079         %line search
0080         eta_Omega = at_Omega( eta, Omega );
0081         alpha = -(eta_Omega'*grad) / norm(eta_Omega)^2;
0082
0083         X = tangentAdd( eta, alpha, true );
0084         xL = orthogonalize( X, X.order );
0085         xR = orthogonalize( X, 1 );
0086         cost(i) = sqrt(2*func(A_Omega, xL, Omega )) / norm_A_Omega;
0087
0088
0089         if cost(i) < opts.tol
0090             disp(sprintf('CONVERGED AFTER %i STEPS. Rel. residual smaller than %0.3g', ...
0091                           i, opts.tol))
0092             stats.conv = true;
0093             cost = cost(1:i,1);
0095             stats.time = [stats.time stats.time(end)+toc(t)];
0096             test(i) = sqrt(2*func(A_Gamma, xL, Gamma )) / norm_A_Gamma;
0097             test = test(1:i,1);
0098             stats.time = stats.time(2:end);
0099             return
0100         end
0101
0102         if i > 1
0103             reltol = abs(cost(i) - cost(i-1)) / cost(i);
0104             if reltol < opts.reltol
0105                 if cost(i) < opts.tol
0106                     disp(sprintf('CONVERGED AFTER %i STEPS. Relative change is smaller than %0.3g', ...
0107                               i, opts.reltol))
0108                     stats.conv = true;
0109                 else
0110                     disp('No more progress in relative change, but not converged. Aborting!')
0111                     stats.conv = false;
0112                 end
0113
0114                 cost = cost(1:i,1);
0116                 stats.time = [stats.time stats.time(end)+toc(t)];
0117                 test(i) = sqrt(2*func(A_Gamma, xL, Gamma )) / norm_A_Gamma;
0118                 test = test(1:i,1);
0119                 stats.time = stats.time(2:end);
0120                 return
0121             end
0122         end
0123
0124         ip_xi_xi_old = ip_xi_xi;
0125         xi_trans = TTeMPS_tangent_orth( xL, xR, xi );
0126
0127         stats.time = [stats.time stats.time(end)+toc(t)];
0128         test(i) = sqrt(2*func(A_Gamma, xL, Gamma )) / norm_A_Gamma;
0129         t = tic;
0130     end
0131
0132     stats.time = stats.time(2:end);
0133
0134
0135
0136 end
0137
0138
0139 function res = func(A_Omega, X, Omega)
0140     res = 0.5*norm( A_Omega - X(Omega) )^2;
0141 end
0142
0143 function res = euclidgrad(A_Omega, X, Omega)
0144     res = X(Omega) - A_Omega;
0145 end```

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005