Computes the Euclidean Hessian of the cost function at x along xdot via AD. function [ehess, store] = ehesscompute(problem, x, xdot) function [ehess, store] = ehesscompute(problem, x, xdot, store) function [ehess, store] = ehesscompute(problem, x, xdot, store, complexflag) This file requires Matlab R2021a or later. Returns the Euclidean Hessian of the cost function described in the problem structure at the point x along xdot. Returns store structure which stores the Euclidean gradient and AD trace in order to avoid redundant computation of hessian by-vector product at the same point x. complexflag is bool variable which indicates whether or not the cost function and the manifold described in the problem structure involves complex numbers and meanwhile the Matlab version is R2021a or earlier. Note: the Euclidean hessian by-vector product is computed through differentiating the inner product between egrad and xdot, thus the result is valid only when second-order partial derivatives commute. When the egrad function has already been specified by the user, the euclidean gradient is computed according to the egrad and otherwise according to the cost function. See also: manoptAD mat2dl dl2mat dl2mat_complex mat2dl_complex innerprodgeneral cinnerprodgeneral
0001 function [ehess,store] = ehesscompute(problem, x, xdot, store, complexflag) 0002 % Computes the Euclidean Hessian of the cost function at x along xdot via AD. 0003 % 0004 % function [ehess, store] = ehesscompute(problem, x, xdot) 0005 % function [ehess, store] = ehesscompute(problem, x, xdot, store) 0006 % function [ehess, store] = ehesscompute(problem, x, xdot, store, complexflag) 0007 % 0008 % This file requires Matlab R2021a or later. 0009 % 0010 % Returns the Euclidean Hessian of the cost function described in the 0011 % problem structure at the point x along xdot. Returns store structure 0012 % which stores the Euclidean gradient and AD trace in order to avoid 0013 % redundant computation of hessian by-vector product at the same point x. 0014 % 0015 % complexflag is bool variable which indicates whether or not the cost 0016 % function and the manifold described in the problem structure involves 0017 % complex numbers and meanwhile the Matlab version is R2021a or earlier. 0018 % 0019 % Note: the Euclidean hessian by-vector product is computed through 0020 % differentiating the inner product between egrad and xdot, thus the 0021 % result is valid only when second-order partial derivatives commute. 0022 % When the egrad function has already been specified by the user, the 0023 % euclidean gradient is computed according to the egrad and otherwise 0024 % according to the cost function. 0025 % 0026 % See also: manoptAD mat2dl dl2mat dl2mat_complex mat2dl_complex innerprodgeneral cinnerprodgeneral 0027 0028 % This file is part of Manopt: www.manopt.org. 0029 % Original author: Xiaowen Jiang, Aug. 31, 2021. 0030 % Contributors: Nicolas Boumal 0031 % Change log: 0032 0033 %% Prepare Euclidean gradient 0034 0035 % check availability 0036 assert(isfield(problem,'M') && isfield(problem,'cost'),..., 0037 'problem structure must contain fields M and cost.'); 0038 assert(exist('dlarray', 'file') == 2, ['Deep learning tool box is '... 0039 'needed for automatic differentiation']); 0040 assert(exist('dlaccelerate', 'file') == 2, ['AD failed when computing'... 0041 'the hessian. Please upgrade to Matlab R2021a or later.']) 0042 0043 % check whether the user has specified the egrad already 0044 egradflag = false; 0045 if isfield(problem,'egrad') && ~isfield(problem,'autogradfunc') 0046 egradflag = true; 0047 end 0048 0049 % check the Matlab version and the complex number 0050 if ~exist('complexflag','var') 0051 complexflag = false; 0052 end 0053 % obtain cost funtion via problem 0054 costfunction = problem.cost; 0055 0056 % prepare euclidean gradient if not yet 0057 if (~exist('store','var') || ~isfield(store,'dlegrad')) 0058 0059 % create a tape and start recording the trace that records the 0060 % computation of the Euclidean gradient. the destruction of record 0061 % object cleans up the tape, which is done at the same time when 0062 % the store is renewed after each iteration, 0063 tm = deep.internal.recording.TapeManager(); 0064 record = deep.internal.startTracingAndSetupCleanup(tm); 0065 0066 % compute the euclidean gradient of the cost function at x 0067 [dlx,dlegrad] = subautograd(costfunction,complexflag,x); 0068 0069 % store the trace, euclidean gradient and the point dlx 0070 store.dlegrad = dlegrad; 0071 store.dlx = dlx; 0072 store.tm = tm; 0073 store.record = record; 0074 0075 end 0076 0077 % define gradient computation function which is similar to autograd 0078 function [dlx,dlegrad] = subautograd(costfunction,complexflag,x) 0079 0080 % convert x into dlarrays to prepare for AD 0081 if complexflag == true 0082 dlx = mat2dl_complex(x); 0083 else 0084 dlx = mat2dl(x); 0085 end 0086 0087 % convert dlx into recording arrays 0088 dlx = deep.internal.recording.recordContainer(dlx); 0089 0090 % if the user has defined the egrad, compute the Euclidean gradient 0091 % and keep the trace 0092 if egradflag 0093 try 0094 dlegrad = problem.egrad(dlx); 0095 catch 0096 egradflag = false; 0097 end 0098 end 0099 0100 % otherwise, compute the Euclidean gradient from the cost function 0101 if ~egradflag 0102 y = costfunction(dlx); 0103 % in case that the user forgot to take the real part of the cost 0104 % when dealing with complex problems and meanwhile the Matlab 0105 % version is R2021a or earlier, take the real part for AD 0106 if iscstruct(y) 0107 y = creal(y); 0108 end 0109 % call dlgradient to compute the Euclidean gradient 0110 % trace the backward pass in order to compute higher order 0111 % derivatives in the further steps 0112 dlegrad = dlgradient(y,dlx,'RetainData',true,'EnableHigherDerivatives',true); 0113 end 0114 end 0115 0116 %% compute the Euclidean Hessian of the cost function at x along xdot 0117 0118 % prepare ingredients 0119 tm = store.tm; %#ok<NASGU> 0120 record = store.record; %#ok<NASGU> 0121 dlegrad = store.dlegrad; 0122 dlx = store.dlx; 0123 0124 % To compute Euclidean Hessian vector product, rotations manifold, 0125 % unitary manifold and essential manifold requires first converting 0126 % the representation of the tangent vector into the ambient space. 0127 % if the problem is a product manifold, in addition to the above 0128 % manifolds, the xdot of the other manifolds remain the same 0129 if contains(problem.M.name(),'Rotations manifold SO','IgnoreCase',true)..., 0130 || contains(problem.M.name(),'Unitary manifold','IgnoreCase',true)..., 0131 || (contains(problem.M.name(),'Product rotations manifold','IgnoreCase',true) &&..., 0132 contains(problem.M.name(),'anchors'))..., 0133 || contains(problem.M.name(),'essential','IgnoreCase',true) 0134 xdot = problem.M.tangent2ambient(x, xdot); 0135 end 0136 0137 % compute the inner product between the Euclidean gradient and xdot 0138 if complexflag == true 0139 z = cinnerprodgeneral(dlegrad, xdot); 0140 else 0141 z = innerprodgeneral(dlegrad, xdot); 0142 end 0143 0144 % compute derivatives of the inner product w.r.t. dlx 0145 ehess = dlgradient(z,dlx,'RetainData',false,'EnableHigherDerivatives',false); 0146 0147 % obtain the numerical representation 0148 if complexflag == true 0149 ehess = dl2mat_complex(ehess); 0150 else 0151 ehess = dl2mat(ehess); 0152 end 0153 0154 0155 % in case that the user is optimizing over anchoredrotationsfactory 0156 % ehess of anchors with indices in A should be zero 0157 if (contains(problem.M.name(),'Product rotations manifold') &&..., 0158 contains(problem.M.name(),'anchors')) 0159 A = findA_anchors(problem); 0160 ehess(:, :, A) = 0; 0161 end 0162 0163 end