Home > manopt > solvers > arc > arc_conjugate_gradient.m

arc_conjugate_gradient

PURPOSE ^

Subproblem solver for ARC based on a nonlinear conjugate gradient method.

SYNOPSIS ^

function [eta, Heta, hesscalls, stop_str, stats] = arc_conjugate_gradient(problem, x, grad, gradnorm, sigma, options, storedb, key)

DESCRIPTION ^

 Subproblem solver for ARC based on a nonlinear conjugate gradient method.

 [eta, Heta, hesscalls, stop_str, stats] = 
     arc_conjugate_gradient(problem, x, grad, gradnorm, sigma, options, storedb, key)

 This routine approximately solves the following problem:

   min_{eta in T_x M}  m(eta),  where

       m(eta) = <eta, g> + .5 <eta, H[eta]> + (sigma/3) ||eta||^3

 where eta is a tangent vector at x on the manifold given by problem.M,
 g = grad is a tangent vector at x, H[eta] is the result of applying the
 Hessian of the problem at x along eta, and the inner product and norm
 are those from the Riemannian structure on the tangent space T_x M.

 The solve is approximate in the sense that the returned eta only ought
 to satisfy the following conditions:

   ||gradient of m at eta|| <= theta*||eta||^2   and   m(eta) <= m(0),

 where theta is specified in options.theta (see below for default value.)
 Since the gradient of the model at 0 is g, if it is zero, then eta = 0
 is returned. This is the only scenario where eta = 0 is returned.

 Numerical errors can perturb the described expected behavior.

 Inputs:
     problem: Manopt optimization problem structure
     x: point on the manifold problem.M
     grad: gradient of the cost function of the problem at x
     gradnorm: norm of the gradient, often available to the caller
     sigma: cubic regularization parameter (positive scalar)
     options: structure containing options for the subproblem solver
     storedb, key: caching data for problem at x

 Options specific to this subproblem solver:
   theta (0.25)
     Stopping criterion parameter for subproblem solver: the gradient of
     the model at the returned step should have norm no more than theta
     times the squared norm of the step.
   maxinner (the manifold's dimension)
     Maximum number of iterations of the conjugate gradient algorithm.
   beta_type ('P-R')
     The update rule for calculating beta:
     'F-R' for Fletcher-Reeves, 'P-R' for Polak-Ribiere, and 'H-S' for
     Hestenes-Stiefel.

 Outputs:
     eta: approximate solution to the cubic regularized subproblem at x
     Heta: Hess f(x)[eta] -- this is necessary in the outer loop, and it
           is often naturally available to the subproblem solver at the
           end of execution, so that it may be cheaper to return it here.
     hesscalls: number of Hessian calls during execution
     stop_str: string describing why the subsolver stopped
     stats: a structure specifying some statistics about inner work - 
            we record the model cost value and model gradient norm at each
            inner iteration.

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [eta, Heta, hesscalls, stop_str, stats] = arc_conjugate_gradient(problem, x, grad, gradnorm, sigma, options, storedb, key)
0002 % Subproblem solver for ARC based on a nonlinear conjugate gradient method.
0003 %
0004 % [eta, Heta, hesscalls, stop_str, stats] =
0005 %     arc_conjugate_gradient(problem, x, grad, gradnorm, sigma, options, storedb, key)
0006 %
0007 % This routine approximately solves the following problem:
0008 %
0009 %   min_{eta in T_x M}  m(eta),  where
0010 %
0011 %       m(eta) = <eta, g> + .5 <eta, H[eta]> + (sigma/3) ||eta||^3
0012 %
0013 % where eta is a tangent vector at x on the manifold given by problem.M,
0014 % g = grad is a tangent vector at x, H[eta] is the result of applying the
0015 % Hessian of the problem at x along eta, and the inner product and norm
0016 % are those from the Riemannian structure on the tangent space T_x M.
0017 %
0018 % The solve is approximate in the sense that the returned eta only ought
0019 % to satisfy the following conditions:
0020 %
0021 %   ||gradient of m at eta|| <= theta*||eta||^2   and   m(eta) <= m(0),
0022 %
0023 % where theta is specified in options.theta (see below for default value.)
0024 % Since the gradient of the model at 0 is g, if it is zero, then eta = 0
0025 % is returned. This is the only scenario where eta = 0 is returned.
0026 %
0027 % Numerical errors can perturb the described expected behavior.
0028 %
0029 % Inputs:
0030 %     problem: Manopt optimization problem structure
0031 %     x: point on the manifold problem.M
0032 %     grad: gradient of the cost function of the problem at x
0033 %     gradnorm: norm of the gradient, often available to the caller
0034 %     sigma: cubic regularization parameter (positive scalar)
0035 %     options: structure containing options for the subproblem solver
0036 %     storedb, key: caching data for problem at x
0037 %
0038 % Options specific to this subproblem solver:
0039 %   theta (0.25)
0040 %     Stopping criterion parameter for subproblem solver: the gradient of
0041 %     the model at the returned step should have norm no more than theta
0042 %     times the squared norm of the step.
0043 %   maxinner (the manifold's dimension)
0044 %     Maximum number of iterations of the conjugate gradient algorithm.
0045 %   beta_type ('P-R')
0046 %     The update rule for calculating beta:
0047 %     'F-R' for Fletcher-Reeves, 'P-R' for Polak-Ribiere, and 'H-S' for
0048 %     Hestenes-Stiefel.
0049 %
0050 % Outputs:
0051 %     eta: approximate solution to the cubic regularized subproblem at x
0052 %     Heta: Hess f(x)[eta] -- this is necessary in the outer loop, and it
0053 %           is often naturally available to the subproblem solver at the
0054 %           end of execution, so that it may be cheaper to return it here.
0055 %     hesscalls: number of Hessian calls during execution
0056 %     stop_str: string describing why the subsolver stopped
0057 %     stats: a structure specifying some statistics about inner work -
0058 %            we record the model cost value and model gradient norm at each
0059 %            inner iteration.
0060 
0061 % This file is part of Manopt: www.manopt.org.
0062 % Original authors: May 2, 2019,
0063 %    Bryan Zhu, Nicolas Boumal.
0064 % Contributors:
0065 % Change log:
0066 %
0067 %   Aug. 19, 2019 (NB):
0068 %       Option maxiter_cg renamed to maxinner to match trustregions.
0069 
0070 % TODO: Support preconditioning through getPrecon().
0071 
0072     % Some shortcuts
0073     M = problem.M;
0074     n = M.dim();
0075     inner = @(u, v) M.inner(x, u, v);
0076     rnorm = @(u) M.norm(x, u);
0077     tangent = @(u) problem.M.tangent(x, u);
0078     Hess = @(u) getHessian(problem, x, u, storedb, key);
0079     
0080     % Counter for Hessian calls issued
0081     hesscalls = 0;
0082     
0083     % If the gradient has norm zero, return a zero step
0084     if gradnorm == 0
0085         eta = M.zerovec(x);
0086         Heta = eta;
0087         stop_str = 'Cost gradient is zero';
0088         stats = struct('gradnorms', 0, 'func_values', 0);
0089         return;
0090     end
0091     
0092     % Set local defaults here
0093     localdefaults.theta = 0.25;
0094     localdefaults.maxinner = n;
0095     localdefaults.beta_type = 'P-R';
0096     localdefaults.subproblemstop = 'sqrule';
0097     
0098     % Merge local defaults with user options, if any
0099     if ~exist('options', 'var') || isempty(options)
0100         options = struct();
0101     end
0102     options = mergeOptions(localdefaults, options);
0103     
0104     % Calculate the Cauchy point as our initial step
0105     hess_grad = Hess(grad);
0106     hesscalls = hesscalls + 1;
0107     temp = inner(grad, hess_grad) / (2 * sigma * gradnorm * gradnorm);
0108     R_c = -temp + sqrt(temp * temp + gradnorm / sigma);
0109     eta = M.lincomb(x, -R_c / gradnorm, grad);
0110     Heta = M.lincomb(x, -R_c / gradnorm, hess_grad);
0111     
0112     % Initialize variables needed for calculation of conjugate direction
0113     prev_grad = M.lincomb(x, -1, grad);
0114     prev_conj = prev_grad;
0115     Hp_conj = M.lincomb(x, -1, hess_grad);
0116     
0117     % Main conjugate gradients iteration
0118     maxiter = min(options.maxinner, 3*n);
0119     gradnorms = zeros(maxiter, 1);
0120     func_values = zeros(maxiter, 1);
0121     gradnorm_reached = false;
0122     j = 1;
0123     while j < maxiter
0124         % Calculate the gradient of the model
0125         eta_norm = rnorm(eta);
0126         new_grad = M.lincomb(x, 1, Heta, 1, grad);
0127         new_grad = M.lincomb(x, -1, new_grad, -sigma * eta_norm, eta);
0128         new_grad = tangent(new_grad);
0129         
0130         % Compute some statistics
0131         gradnorms(j) = rnorm(new_grad);
0132         func_values(j) = inner(grad, eta) + 0.5 * inner(eta, Heta) + (sigma/3) * eta_norm^3;
0133         
0134         if options.verbosity >= 4
0135             fprintf('\nModel grad norm: %.16e, Iterate norm: %.16e', gradnorms(j), eta_norm);
0136         end
0137 
0138         % Check termination condition
0139         % TODO -- factor this out, as it is the same for all subsolvers.
0140         % TODO -- I haven't found a scenario where sqrule doens't win.
0141         % TODO -- 1e-4 might be too small (g, s, ss seem equivalent).
0142         switch lower(options.subproblemstop)
0143             case 'sqrule'
0144                 stop = (gradnorms(j) <= options.theta * eta_norm^2);
0145             case 'grule'
0146                 stop = (gradnorms(j) <= min(1e-4, sqrt(gradnorms(1)))*gradnorms(1));
0147             case 'srule'
0148                 stop = (gradnorms(j) <= min(1e-4, eta_norm)*gradnorms(1));
0149             case 'ssrule'
0150                 stop = (gradnorms(j) <= min(1e-4, eta_norm/max(1, sigma))*gradnorms(1));
0151             otherwise
0152                 error(['Unknown value for options.subproblemstop.\n' ...
0153                        'Possible values: ''sqrule'', ''grule'', ' ...
0154                        '''srule'', ''ssrule''.\n']); % ...
0155                        % 'Current value: ', options.subproblemstop, '\n']);
0156         end
0157         if stop
0158             stop_str = 'Model grad norm condition satisfied';
0159             gradnorm_reached = true;
0160             break;
0161         end
0162         
0163         % Calculate the conjugate direction using the selected beta rule
0164         delta = M.lincomb(x, 1, new_grad, -1, prev_grad);
0165         switch upper(options.beta_type)
0166             case 'F-R'
0167                 beta = inner(new_grad, new_grad) / inner(prev_grad, prev_grad);
0168             case 'P-R'
0169                 beta = max(0, inner(new_grad, delta) / inner(prev_grad, prev_grad));
0170             case 'H-S'
0171                 beta = max(0, -inner(new_grad, delta) / inner(prev_conj, delta));
0172             otherwise
0173                 error('Unknown options.beta_type. Should be F-R, P-R, or H-S.');
0174         end
0175         new_conj = M.lincomb(x, 1, new_grad, beta, prev_conj);
0176         Hn_grad = Hess(new_grad);
0177         hesscalls = hesscalls + 1;
0178         Hn_conj = M.lincomb(x, 1, Hn_grad, beta, Hp_conj);
0179         
0180         % Find the optimal step in the conjugate direction
0181         alpha = solve_along_line(M, x, eta, new_conj, grad, Hn_conj, sigma);
0182         if alpha == 0
0183             stop_str = 'Unable to make further progress in search direction';
0184             gradnorm_reached = true;
0185             break;
0186         end
0187         eta = M.lincomb(x, 1, eta, alpha, new_conj);
0188         Heta = M.lincomb(x, 1, Heta, alpha, Hn_conj);
0189         prev_grad = new_grad;
0190         prev_conj = new_conj;
0191         Hp_conj = Hn_conj;
0192         j = j + 1;
0193     end
0194     
0195     % Check why we stopped iterating
0196     if ~gradnorm_reached
0197         stop_str = sprintf(['Reached max number of conjugate gradient iterations ' ...
0198                '(options.maxinner = %d)'], options.maxinner);
0199         j = j - 1;
0200     end
0201     
0202     % Return the point we ended on
0203     eta = tangent(eta);
0204     stats = struct('gradnorms', gradnorms(1:j), 'func_values', func_values(1:j));
0205     if options.verbosity >= 4
0206         fprintf('\n');
0207     end
0208         
0209 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005