Subproblem solver for ARC based on a nonlinear conjugate gradient method. [eta, Heta, hesscalls, stop_str, stats] = arc_conjugate_gradient(problem, x, grad, gradnorm, sigma, options, storedb, key) This routine approximately solves the following problem: min_{eta in T_x M} m(eta), where m(eta) = <eta, g> + .5 <eta, H[eta]> + (sigma/3) ||eta||^3 where eta is a tangent vector at x on the manifold given by problem.M, g = grad is a tangent vector at x, H[eta] is the result of applying the Hessian of the problem at x along eta, and the inner product and norm are those from the Riemannian structure on the tangent space T_x M. The solve is approximate in the sense that the returned eta only ought to satisfy the following conditions: ||gradient of m at eta|| <= theta*||eta||^2 and m(eta) <= m(0), where theta is specified in options.theta (see below for default value.) Since the gradient of the model at 0 is g, if it is zero, then eta = 0 is returned. This is the only scenario where eta = 0 is returned. Numerical errors can perturb the described expected behavior. Inputs: problem: Manopt optimization problem structure x: point on the manifold problem.M grad: gradient of the cost function of the problem at x gradnorm: norm of the gradient, often available to the caller sigma: cubic regularization parameter (positive scalar) options: structure containing options for the subproblem solver storedb, key: caching data for problem at x Options specific to this subproblem solver: theta (0.25) Stopping criterion parameter for subproblem solver: the gradient of the model at the returned step should have norm no more than theta times the squared norm of the step. maxinner (the manifold's dimension) Maximum number of iterations of the conjugate gradient algorithm. beta_type ('P-R') The update rule for calculating beta: 'F-R' for Fletcher-Reeves, 'P-R' for Polak-Ribiere, and 'H-S' for Hestenes-Stiefel. Outputs: eta: approximate solution to the cubic regularized subproblem at x Heta: Hess f(x)[eta] -- this is necessary in the outer loop, and it is often naturally available to the subproblem solver at the end of execution, so that it may be cheaper to return it here. hesscalls: number of Hessian calls during execution stop_str: string describing why the subsolver stopped stats: a structure specifying some statistics about inner work - we record the model cost value and model gradient norm at each inner iteration.
0001 function [eta, Heta, hesscalls, stop_str, stats] = arc_conjugate_gradient(problem, x, grad, gradnorm, sigma, options, storedb, key) 0002 % Subproblem solver for ARC based on a nonlinear conjugate gradient method. 0003 % 0004 % [eta, Heta, hesscalls, stop_str, stats] = 0005 % arc_conjugate_gradient(problem, x, grad, gradnorm, sigma, options, storedb, key) 0006 % 0007 % This routine approximately solves the following problem: 0008 % 0009 % min_{eta in T_x M} m(eta), where 0010 % 0011 % m(eta) = <eta, g> + .5 <eta, H[eta]> + (sigma/3) ||eta||^3 0012 % 0013 % where eta is a tangent vector at x on the manifold given by problem.M, 0014 % g = grad is a tangent vector at x, H[eta] is the result of applying the 0015 % Hessian of the problem at x along eta, and the inner product and norm 0016 % are those from the Riemannian structure on the tangent space T_x M. 0017 % 0018 % The solve is approximate in the sense that the returned eta only ought 0019 % to satisfy the following conditions: 0020 % 0021 % ||gradient of m at eta|| <= theta*||eta||^2 and m(eta) <= m(0), 0022 % 0023 % where theta is specified in options.theta (see below for default value.) 0024 % Since the gradient of the model at 0 is g, if it is zero, then eta = 0 0025 % is returned. This is the only scenario where eta = 0 is returned. 0026 % 0027 % Numerical errors can perturb the described expected behavior. 0028 % 0029 % Inputs: 0030 % problem: Manopt optimization problem structure 0031 % x: point on the manifold problem.M 0032 % grad: gradient of the cost function of the problem at x 0033 % gradnorm: norm of the gradient, often available to the caller 0034 % sigma: cubic regularization parameter (positive scalar) 0035 % options: structure containing options for the subproblem solver 0036 % storedb, key: caching data for problem at x 0037 % 0038 % Options specific to this subproblem solver: 0039 % theta (0.25) 0040 % Stopping criterion parameter for subproblem solver: the gradient of 0041 % the model at the returned step should have norm no more than theta 0042 % times the squared norm of the step. 0043 % maxinner (the manifold's dimension) 0044 % Maximum number of iterations of the conjugate gradient algorithm. 0045 % beta_type ('P-R') 0046 % The update rule for calculating beta: 0047 % 'F-R' for Fletcher-Reeves, 'P-R' for Polak-Ribiere, and 'H-S' for 0048 % Hestenes-Stiefel. 0049 % 0050 % Outputs: 0051 % eta: approximate solution to the cubic regularized subproblem at x 0052 % Heta: Hess f(x)[eta] -- this is necessary in the outer loop, and it 0053 % is often naturally available to the subproblem solver at the 0054 % end of execution, so that it may be cheaper to return it here. 0055 % hesscalls: number of Hessian calls during execution 0056 % stop_str: string describing why the subsolver stopped 0057 % stats: a structure specifying some statistics about inner work - 0058 % we record the model cost value and model gradient norm at each 0059 % inner iteration. 0060 0061 % This file is part of Manopt: www.manopt.org. 0062 % Original authors: May 2, 2019, 0063 % Bryan Zhu, Nicolas Boumal. 0064 % Contributors: 0065 % Change log: 0066 % 0067 % Aug. 19, 2019 (NB): 0068 % Option maxiter_cg renamed to maxinner to match trustregions. 0069 0070 % TODO: Support preconditioning through getPrecon(). 0071 0072 % Some shortcuts 0073 M = problem.M; 0074 n = M.dim(); 0075 inner = @(u, v) M.inner(x, u, v); 0076 rnorm = @(u) M.norm(x, u); 0077 tangent = @(u) problem.M.tangent(x, u); 0078 Hess = @(u) getHessian(problem, x, u, storedb, key); 0079 0080 % Counter for Hessian calls issued 0081 hesscalls = 0; 0082 0083 % If the gradient has norm zero, return a zero step 0084 if gradnorm == 0 0085 eta = M.zerovec(x); 0086 Heta = eta; 0087 stop_str = 'Cost gradient is zero'; 0088 stats = struct('gradnorms', 0, 'func_values', 0); 0089 return; 0090 end 0091 0092 % Set local defaults here 0093 localdefaults.theta = 0.25; 0094 localdefaults.maxinner = n; 0095 localdefaults.beta_type = 'P-R'; 0096 localdefaults.subproblemstop = 'sqrule'; 0097 0098 % Merge local defaults with user options, if any 0099 if ~exist('options', 'var') || isempty(options) 0100 options = struct(); 0101 end 0102 options = mergeOptions(localdefaults, options); 0103 0104 % Calculate the Cauchy point as our initial step 0105 hess_grad = Hess(grad); 0106 hesscalls = hesscalls + 1; 0107 temp = inner(grad, hess_grad) / (2 * sigma * gradnorm * gradnorm); 0108 R_c = -temp + sqrt(temp * temp + gradnorm / sigma); 0109 eta = M.lincomb(x, -R_c / gradnorm, grad); 0110 Heta = M.lincomb(x, -R_c / gradnorm, hess_grad); 0111 0112 % Initialize variables needed for calculation of conjugate direction 0113 prev_grad = M.lincomb(x, -1, grad); 0114 prev_conj = prev_grad; 0115 Hp_conj = M.lincomb(x, -1, hess_grad); 0116 0117 % Main conjugate gradients iteration 0118 maxiter = min(options.maxinner, 3*n); 0119 gradnorms = zeros(maxiter, 1); 0120 func_values = zeros(maxiter, 1); 0121 gradnorm_reached = false; 0122 j = 1; 0123 while j < maxiter 0124 % Calculate the gradient of the model 0125 eta_norm = rnorm(eta); 0126 new_grad = M.lincomb(x, 1, Heta, 1, grad); 0127 new_grad = M.lincomb(x, -1, new_grad, -sigma * eta_norm, eta); 0128 new_grad = tangent(new_grad); 0129 0130 % Compute some statistics 0131 gradnorms(j) = rnorm(new_grad); 0132 func_values(j) = inner(grad, eta) + 0.5 * inner(eta, Heta) + (sigma/3) * eta_norm^3; 0133 0134 if options.verbosity >= 4 0135 fprintf('\nModel grad norm: %.16e, Iterate norm: %.16e', gradnorms(j), eta_norm); 0136 end 0137 0138 % Check termination condition 0139 % TODO -- factor this out, as it is the same for all subsolvers. 0140 % TODO -- I haven't found a scenario where sqrule doens't win. 0141 % TODO -- 1e-4 might be too small (g, s, ss seem equivalent). 0142 switch lower(options.subproblemstop) 0143 case 'sqrule' 0144 stop = (gradnorms(j) <= options.theta * eta_norm^2); 0145 case 'grule' 0146 stop = (gradnorms(j) <= min(1e-4, sqrt(gradnorms(1)))*gradnorms(1)); 0147 case 'srule' 0148 stop = (gradnorms(j) <= min(1e-4, eta_norm)*gradnorms(1)); 0149 case 'ssrule' 0150 stop = (gradnorms(j) <= min(1e-4, eta_norm/max(1, sigma))*gradnorms(1)); 0151 otherwise 0152 error(['Unknown value for options.subproblemstop.\n' ... 0153 'Possible values: ''sqrule'', ''grule'', ' ... 0154 '''srule'', ''ssrule''.\n']); % ... 0155 % 'Current value: ', options.subproblemstop, '\n']); 0156 end 0157 if stop 0158 stop_str = 'Model grad norm condition satisfied'; 0159 gradnorm_reached = true; 0160 break; 0161 end 0162 0163 % Calculate the conjugate direction using the selected beta rule 0164 delta = M.lincomb(x, 1, new_grad, -1, prev_grad); 0165 switch upper(options.beta_type) 0166 case 'F-R' 0167 beta = inner(new_grad, new_grad) / inner(prev_grad, prev_grad); 0168 case 'P-R' 0169 beta = max(0, inner(new_grad, delta) / inner(prev_grad, prev_grad)); 0170 case 'H-S' 0171 beta = max(0, -inner(new_grad, delta) / inner(prev_conj, delta)); 0172 otherwise 0173 error('Unknown options.beta_type. Should be F-R, P-R, or H-S.'); 0174 end 0175 new_conj = M.lincomb(x, 1, new_grad, beta, prev_conj); 0176 Hn_grad = Hess(new_grad); 0177 hesscalls = hesscalls + 1; 0178 Hn_conj = M.lincomb(x, 1, Hn_grad, beta, Hp_conj); 0179 0180 % Find the optimal step in the conjugate direction 0181 alpha = solve_along_line(M, x, eta, new_conj, grad, Hn_conj, sigma); 0182 if alpha == 0 0183 stop_str = 'Unable to make further progress in search direction'; 0184 gradnorm_reached = true; 0185 break; 0186 end 0187 eta = M.lincomb(x, 1, eta, alpha, new_conj); 0188 Heta = M.lincomb(x, 1, Heta, alpha, Hn_conj); 0189 prev_grad = new_grad; 0190 prev_conj = new_conj; 0191 Hp_conj = Hn_conj; 0192 j = j + 1; 0193 end 0194 0195 % Check why we stopped iterating 0196 if ~gradnorm_reached 0197 stop_str = sprintf(['Reached max number of conjugate gradient iterations ' ... 0198 '(options.maxinner = %d)'], options.maxinner); 0199 j = j - 1; 0200 end 0201 0202 % Return the point we ended on 0203 eta = tangent(eta); 0204 stats = struct('gradnorms', gradnorms(1:j), 'func_values', func_values(1:j)); 0205 if options.verbosity >= 4 0206 fprintf('\n'); 0207 end 0208 0209 end