Home > manopt > autodiff > ehesscompute.m

ehesscompute

PURPOSE ^

Computes the Euclidean Hessian of the cost function at x along xdot via AD.

SYNOPSIS ^

function [ehess,store] = ehesscompute(problem, x, xdot, store, complexflag)

DESCRIPTION ^

 Computes the Euclidean Hessian of the cost function at x along xdot via AD.

 function [ehess, store] = ehesscompute(problem, x, xdot)
 function [ehess, store] = ehesscompute(problem, x, xdot, store)
 function [ehess, store] = ehesscompute(problem, x, xdot, store, complexflag)

 This file requires Matlab R2021a or later.

 Returns the Euclidean Hessian of the cost function described in the
 problem structure at the point x along xdot. Returns store structure 
 which stores the Euclidean gradient and AD trace in order to avoid
 redundant computation of hessian by-vector product at the same point x.

 complexflag is bool variable which indicates whether or not the cost  
 function and the manifold described in the problem structure involves 
 complex numbers and meanwhile the Matlab version is R2021a or earlier.

 Note: the Euclidean hessian by-vector product is computed through
 differentiating the inner product between egrad and xdot, thus the 
 result is valid only when second-order partial derivatives commute. 
 When the egrad function has already been specified by the user, the
 euclidean gradient is computed according to the egrad and otherwise 
 according to the cost function.

 See also: manoptAD mat2dl dl2mat dl2mat_complex mat2dl_complex innerprodgeneral cinnerprodgeneral

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [ehess,store] = ehesscompute(problem, x, xdot, store, complexflag)
0002 % Computes the Euclidean Hessian of the cost function at x along xdot via AD.
0003 %
0004 % function [ehess, store] = ehesscompute(problem, x, xdot)
0005 % function [ehess, store] = ehesscompute(problem, x, xdot, store)
0006 % function [ehess, store] = ehesscompute(problem, x, xdot, store, complexflag)
0007 %
0008 % This file requires Matlab R2021a or later.
0009 %
0010 % Returns the Euclidean Hessian of the cost function described in the
0011 % problem structure at the point x along xdot. Returns store structure
0012 % which stores the Euclidean gradient and AD trace in order to avoid
0013 % redundant computation of hessian by-vector product at the same point x.
0014 %
0015 % complexflag is bool variable which indicates whether or not the cost
0016 % function and the manifold described in the problem structure involves
0017 % complex numbers and meanwhile the Matlab version is R2021a or earlier.
0018 %
0019 % Note: the Euclidean hessian by-vector product is computed through
0020 % differentiating the inner product between egrad and xdot, thus the
0021 % result is valid only when second-order partial derivatives commute.
0022 % When the egrad function has already been specified by the user, the
0023 % euclidean gradient is computed according to the egrad and otherwise
0024 % according to the cost function.
0025 %
0026 % See also: manoptAD mat2dl dl2mat dl2mat_complex mat2dl_complex innerprodgeneral cinnerprodgeneral
0027 
0028 % This file is part of Manopt: www.manopt.org.
0029 % Original author: Xiaowen Jiang, Aug. 31, 2021.
0030 % Contributors: Nicolas Boumal
0031 % Change log:
0032 
0033     %% Prepare Euclidean gradient
0034    
0035     % check availability
0036     assert(isfield(problem,'M') && isfield(problem,'cost'),...,
0037     'problem structure must contain fields M and cost.');
0038     assert(exist('dlarray', 'file') == 2, ['Deep learning tool box is '... 
0039     'needed for automatic differentiation']);
0040     assert(exist('dlaccelerate', 'file') == 2, ['AD failed when computing'...
0041         'the hessian. Please upgrade to Matlab R2021a or later.'])
0042 
0043     % check whether the user has specified the egrad already
0044     egradflag = false;
0045     if isfield(problem,'egrad') && ~isfield(problem,'autogradfunc')
0046         egradflag = true;
0047     end
0048 
0049     % check the Matlab version and the complex number
0050     if ~exist('complexflag','var')
0051         complexflag = false;
0052     end
0053     % obtain cost funtion via problem
0054     costfunction = problem.cost;
0055     
0056     % prepare euclidean gradient if not yet
0057     if (~exist('store','var') || ~isfield(store,'dlegrad')) 
0058         
0059         % create a tape and start recording the trace that records the
0060         % computation of the Euclidean gradient. the destruction of record
0061         % object cleans up the tape, which is done at the same time when
0062         % the store is renewed after each iteration,
0063         tm = deep.internal.recording.TapeManager();
0064         record = deep.internal.startTracingAndSetupCleanup(tm);
0065         
0066         % compute the euclidean gradient of the cost function at x
0067         [dlx,dlegrad] = subautograd(costfunction,complexflag,x);
0068         
0069         % store the trace, euclidean gradient and the point dlx
0070         store.dlegrad = dlegrad;
0071         store.dlx = dlx;
0072         store.tm = tm;
0073         store.record = record;
0074        
0075     end
0076     
0077     % define gradient computation function which is similar to autograd
0078     function [dlx,dlegrad] = subautograd(costfunction,complexflag,x)
0079         
0080         % convert x into dlarrays to prepare for AD
0081         if complexflag == true
0082             dlx = mat2dl_complex(x);
0083         else
0084             dlx = mat2dl(x);
0085         end
0086         
0087         % convert dlx into recording arrays
0088         dlx = deep.internal.recording.recordContainer(dlx);
0089         
0090         % if the user has defined the egrad, compute the Euclidean gradient
0091         % and keep the trace
0092         if egradflag
0093             try 
0094                 dlegrad = problem.egrad(dlx);
0095             catch
0096                 egradflag = false;
0097             end
0098         end
0099 
0100         % otherwise, compute the Euclidean gradient from the cost function
0101         if ~egradflag
0102             y = costfunction(dlx);
0103             % in case that the user forgot to take the real part of the cost
0104             % when dealing with complex problems and meanwhile the Matlab
0105             % version is R2021a or earlier, take the real part for AD
0106             if iscstruct(y)
0107                 y = creal(y);
0108             end
0109             % call dlgradient to compute the Euclidean gradient
0110             % trace the backward pass in order to compute higher order
0111             % derivatives in the further steps
0112             dlegrad = dlgradient(y,dlx,'RetainData',true,'EnableHigherDerivatives',true);
0113         end
0114     end
0115     
0116     %% compute the Euclidean Hessian of the cost function at x along xdot
0117     
0118     % prepare ingredients
0119     tm = store.tm; %#ok<NASGU>
0120     record = store.record; %#ok<NASGU>
0121     dlegrad = store.dlegrad;
0122     dlx = store.dlx;
0123     
0124     % To compute Euclidean Hessian vector product, rotations manifold,
0125     % unitary manifold and essential manifold requires first converting
0126     % the representation of the tangent vector into the ambient space.
0127     % if the problem is a product manifold, in addition to the above
0128     % manifolds, the xdot of the other manifolds remain the same
0129     if contains(problem.M.name(),'Rotations manifold SO','IgnoreCase',true)..., 
0130             ||  contains(problem.M.name(),'Unitary manifold','IgnoreCase',true)...,
0131             || (contains(problem.M.name(),'Product rotations manifold','IgnoreCase',true) &&..., 
0132             contains(problem.M.name(),'anchors'))...,
0133             || contains(problem.M.name(),'essential','IgnoreCase',true)
0134         xdot = problem.M.tangent2ambient(x, xdot);
0135     end 
0136     
0137     % compute the inner product between the Euclidean gradient and xdot
0138     if complexflag == true
0139         z = cinnerprodgeneral(dlegrad, xdot);
0140     else
0141         z = innerprodgeneral(dlegrad, xdot);
0142     end
0143     
0144     % compute derivatives of the inner product w.r.t. dlx
0145     ehess = dlgradient(z,dlx,'RetainData',false,'EnableHigherDerivatives',false);
0146     
0147     % obtain the numerical representation
0148     if complexflag == true
0149         ehess = dl2mat_complex(ehess);
0150     else
0151         ehess = dl2mat(ehess);
0152     end
0153     
0154     
0155     % in case that the user is optimizing over anchoredrotationsfactory
0156     % ehess of anchors with indices in A should be zero
0157     if (contains(problem.M.name(),'Product rotations manifold') &&..., 
0158             contains(problem.M.name(),'anchors'))
0159         A = findA_anchors(problem);
0160         ehess(:, :, A) = 0;
0161     end
0162     
0163 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005