Home > manopt > autodiff > ehesscompute.m

# ehesscompute

## PURPOSE

Computes the Euclidean Hessian of the cost function at x along xdot via AD.

## SYNOPSIS

function [ehess,store] = ehesscompute(problem, x, xdot, store, complexflag)

## DESCRIPTION

``` Computes the Euclidean Hessian of the cost function at x along xdot via AD.

function [ehess, store] = ehesscompute(problem, x, xdot)
function [ehess, store] = ehesscompute(problem, x, xdot, store)
function [ehess, store] = ehesscompute(problem, x, xdot, store, complexflag)

This file requires Matlab R2021a or later.

Returns the Euclidean Hessian of the cost function described in the
problem structure at the point x along xdot. Returns store structure
which stores the Euclidean gradient and AD trace in order to avoid
redundant computation of hessian by-vector product at the same point x.

complexflag is bool variable which indicates whether or not the cost
function and the manifold described in the problem structure involves
complex numbers and meanwhile the Matlab version is R2021a or earlier.

Note: the Euclidean hessian by-vector product is computed through
differentiating the inner product between egrad and xdot, thus the
result is valid only when second-order partial derivatives commute.
When the egrad function has already been specified by the user, the
according to the cost function.

## CROSS-REFERENCE INFORMATION

This function calls:
• dl2mat Convert the data type of x from dlarray into double
• dl2mat_complex Convert dlx which stores complex numbers in a structure into double
• findA_anchors Find the indices of the anchors for the anchoredrotationsfactory
• cinnerprodgeneral Computes the Euclidean inner product between x and y in the complex case
• creal Extracts the real part of x
• iscstruct
• innerprodgeneral Compute the Euclidean inner product between x and y
• mat2dl Convert the data type of x from numeric into dlarray
• mat2dl_complex Convert x into a particular data structure to store complex numbers
This function is called by:
• manoptAD Preprocess automatic differentiation for a manopt problem structure

## SOURCE CODE

```0001 function [ehess,store] = ehesscompute(problem, x, xdot, store, complexflag)
0002 % Computes the Euclidean Hessian of the cost function at x along xdot via AD.
0003 %
0004 % function [ehess, store] = ehesscompute(problem, x, xdot)
0005 % function [ehess, store] = ehesscompute(problem, x, xdot, store)
0006 % function [ehess, store] = ehesscompute(problem, x, xdot, store, complexflag)
0007 %
0008 % This file requires Matlab R2021a or later.
0009 %
0010 % Returns the Euclidean Hessian of the cost function described in the
0011 % problem structure at the point x along xdot. Returns store structure
0012 % which stores the Euclidean gradient and AD trace in order to avoid
0013 % redundant computation of hessian by-vector product at the same point x.
0014 %
0015 % complexflag is bool variable which indicates whether or not the cost
0016 % function and the manifold described in the problem structure involves
0017 % complex numbers and meanwhile the Matlab version is R2021a or earlier.
0018 %
0019 % Note: the Euclidean hessian by-vector product is computed through
0020 % differentiating the inner product between egrad and xdot, thus the
0021 % result is valid only when second-order partial derivatives commute.
0022 % When the egrad function has already been specified by the user, the
0023 % euclidean gradient is computed according to the egrad and otherwise
0024 % according to the cost function.
0025 %
0027
0028 % This file is part of Manopt: www.manopt.org.
0029 % Original author: Xiaowen Jiang, Aug. 31, 2021.
0030 % Contributors: Nicolas Boumal
0031 % Change log:
0032
0034
0035     % check availability
0036     assert(isfield(problem,'M') && isfield(problem,'cost'),...,
0037     'problem structure must contain fields M and cost.');
0038     assert(exist('dlarray', 'file') == 2, ['Deep learning tool box is '...
0039     'needed for automatic differentiation']);
0040     assert(exist('dlaccelerate', 'file') == 2, ['AD failed when computing'...
0042
0047     end
0048
0049     % check the Matlab version and the complex number
0050     if ~exist('complexflag','var')
0051         complexflag = false;
0052     end
0053     % obtain cost funtion via problem
0054     costfunction = problem.cost;
0055
0056     % prepare euclidean gradient if not yet
0058
0059         % create a tape and start recording the trace that records the
0060         % computation of the Euclidean gradient. the destruction of record
0061         % object cleans up the tape, which is done at the same time when
0062         % the store is renewed after each iteration,
0063         tm = deep.internal.recording.TapeManager();
0064         record = deep.internal.startTracingAndSetupCleanup(tm);
0065
0066         % compute the euclidean gradient of the cost function at x
0068
0069         % store the trace, euclidean gradient and the point dlx
0071         store.dlx = dlx;
0072         store.tm = tm;
0073         store.record = record;
0074
0075     end
0076
0079
0080         % convert x into dlarrays to prepare for AD
0081         if complexflag == true
0082             dlx = mat2dl_complex(x);
0083         else
0084             dlx = mat2dl(x);
0085         end
0086
0087         % convert dlx into recording arrays
0088         dlx = deep.internal.recording.recordContainer(dlx);
0089
0090         % if the user has defined the egrad, compute the Euclidean gradient
0091         % and keep the trace
0093             try
0095             catch
0097             end
0098         end
0099
0100         % otherwise, compute the Euclidean gradient from the cost function
0102             y = costfunction(dlx);
0103             % in case that the user forgot to take the real part of the cost
0104             % when dealing with complex problems and meanwhile the Matlab
0105             % version is R2021a or earlier, take the real part for AD
0106             if iscstruct(y)
0107                 y = creal(y);
0108             end
0110             % trace the backward pass in order to compute higher order
0111             % derivatives in the further steps
0113         end
0114     end
0115
0116     %% compute the Euclidean Hessian of the cost function at x along xdot
0117
0118     % prepare ingredients
0119     tm = store.tm; %#ok<NASGU>
0120     record = store.record; %#ok<NASGU>
0122     dlx = store.dlx;
0123
0124     % To compute Euclidean Hessian vector product, rotations manifold,
0125     % unitary manifold and essential manifold requires first converting
0126     % the representation of the tangent vector into the ambient space.
0127     % if the problem is a product manifold, in addition to the above
0128     % manifolds, the xdot of the other manifolds remain the same
0129     if contains(problem.M.name(),'Rotations manifold SO','IgnoreCase',true)...,
0130             ||  contains(problem.M.name(),'Unitary manifold','IgnoreCase',true)...,
0131             || (contains(problem.M.name(),'Product rotations manifold','IgnoreCase',true) &&...,
0132             contains(problem.M.name(),'anchors'))...,
0133             || contains(problem.M.name(),'essential','IgnoreCase',true)
0134         xdot = problem.M.tangent2ambient(x, xdot);
0135     end
0136
0137     % compute the inner product between the Euclidean gradient and xdot
0138     if complexflag == true
0140     else
0142     end
0143
0144     % compute derivatives of the inner product w.r.t. dlx
0146
0147     % obtain the numerical representation
0148     if complexflag == true
0149         ehess = dl2mat_complex(ehess);
0150     else
0151         ehess = dl2mat(ehess);
0152     end
0153
0154
0155     % in case that the user is optimizing over anchoredrotationsfactory
0156     % ehess of anchors with indices in A should be zero
0157     if (contains(problem.M.name(),'Product rotations manifold') &&...,
0158             contains(problem.M.name(),'anchors'))
0159         A = findA_anchors(problem);
0160         ehess(:, :, A) = 0;
0161     end
0162
0163 end```

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005