Home > manopt > autodiff > manoptAD.m

manoptAD

PURPOSE ^

Preprocess automatic differentiation for a manopt problem structure

SYNOPSIS ^

function problem = manoptAD(problem, flag)

DESCRIPTION ^

 Preprocess automatic differentiation for a manopt problem structure

 function problem = manoptAD(problem)
 function problem = manoptAD(problem, 'nohess')
 function problem = manoptAD(problem, 'hess')

 Given a manopt problem structure with problem.cost and problem.M defined,
 this tool adds the following fields to the problem structure:
   problem.egrad
   problem.costgrad
   problem.ehess

 A field problem.autogradfunc is also created for internal use.

 The fields egrad and ehess correspond to Euclidean gradients and Hessian.
 They are obtained through automatic differentation of the cost function.
 Manopt converts them into Riemannian objects in the usual way via the
 manifold's M.egrad2rgrad and M.ehess2rhess functions, automatically.

 As an optional second input, the user may specify the flag string to be:
   'nohess' -- in which case problem.ehess is not created.
   'hess'   -- which corresponds to the default behavior.
 If problem.egrad is already provided and the Hessian is requested, the
 tool builds problem.ehess based on problem.egrad rather than the cost.
 
 This function requires the following:
   Matlab version R2021a or later.
   Deep Learning Toolbox version 14.2 or later.

 Support for complex variables in automatic differentation is added in
   Matlab version R2021b or later.
 There is also better support for Hessian computations in that version.
 Otherwise, see manoptADhelp and complex_example_AD for a workaround, or
 set the 'nohess' flag to tell Manopt not to compute Hessians with AD.

 If AD fails for some reasons, the original problem structure 
 is returned with a warning trying to hint at what the issue may be.
 Mostly, issues arise because the manoptAD relies on the Deep Learning
 Toolbox, which itself relies on the dlarray data type, and only a subset
 of Matlab functions support dlarrays:
 
   See manoptADhelp for more about limitations and workarounds.
   See
   https://ch.mathworks.com/help/deeplearning/ug/list-of-functions-with-dlarray-support.html
   for an official list of functions that support dlarray.

 In particular, sparse matrices are not supported, as well as certain
 standard functions including trace() which can be replaced by ctrace().

 There are a few limitations pertaining to specific manifolds.
 For example:
   fixedrankembeddedfactory: AD creates grad, not egrad; and no Hessian.
   fixedranktensorembeddedfactory: no AD support.
   fixedTTrankfactory: no AD support.
   euclideansparsefactory: no AD support.

 Importantly, while AD is convenient and efficient in terms of human time,
 it is not efficient in terms of CPU time: it is expected that AD slows
 down gradient computations by a factor of about 5. Moreover, while AD can
 most often compute Hessians as well, it is often more efficient to
 compute Hessians with finite differences (which is the default in Manopt
 when the Hessian is not provided by the user).
 Thus: it is often the case that
   problem = manoptAD(problem, 'nohess');
 leads to better overall runtime than
   problem = manoptAD(problem);
 when calling trustregions(problem).

 Some manifold factories in Manopt support GPUs: automatic differentiation
 should work with them too, as usual. See using_gpu_AD for more details.


 See also: manoptADhelp autograd egradcompute ehesscompute complex_example_AD using_gpu_AD

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function problem = manoptAD(problem, flag) 
0002 % Preprocess automatic differentiation for a manopt problem structure
0003 %
0004 % function problem = manoptAD(problem)
0005 % function problem = manoptAD(problem, 'nohess')
0006 % function problem = manoptAD(problem, 'hess')
0007 %
0008 % Given a manopt problem structure with problem.cost and problem.M defined,
0009 % this tool adds the following fields to the problem structure:
0010 %   problem.egrad
0011 %   problem.costgrad
0012 %   problem.ehess
0013 %
0014 % A field problem.autogradfunc is also created for internal use.
0015 %
0016 % The fields egrad and ehess correspond to Euclidean gradients and Hessian.
0017 % They are obtained through automatic differentation of the cost function.
0018 % Manopt converts them into Riemannian objects in the usual way via the
0019 % manifold's M.egrad2rgrad and M.ehess2rhess functions, automatically.
0020 %
0021 % As an optional second input, the user may specify the flag string to be:
0022 %   'nohess' -- in which case problem.ehess is not created.
0023 %   'hess'   -- which corresponds to the default behavior.
0024 % If problem.egrad is already provided and the Hessian is requested, the
0025 % tool builds problem.ehess based on problem.egrad rather than the cost.
0026 %
0027 % This function requires the following:
0028 %   Matlab version R2021a or later.
0029 %   Deep Learning Toolbox version 14.2 or later.
0030 %
0031 % Support for complex variables in automatic differentation is added in
0032 %   Matlab version R2021b or later.
0033 % There is also better support for Hessian computations in that version.
0034 % Otherwise, see manoptADhelp and complex_example_AD for a workaround, or
0035 % set the 'nohess' flag to tell Manopt not to compute Hessians with AD.
0036 %
0037 % If AD fails for some reasons, the original problem structure
0038 % is returned with a warning trying to hint at what the issue may be.
0039 % Mostly, issues arise because the manoptAD relies on the Deep Learning
0040 % Toolbox, which itself relies on the dlarray data type, and only a subset
0041 % of Matlab functions support dlarrays:
0042 %
0043 %   See manoptADhelp for more about limitations and workarounds.
0044 %   See
0045 %   https://ch.mathworks.com/help/deeplearning/ug/list-of-functions-with-dlarray-support.html
0046 %   for an official list of functions that support dlarray.
0047 %
0048 % In particular, sparse matrices are not supported, as well as certain
0049 % standard functions including trace() which can be replaced by ctrace().
0050 %
0051 % There are a few limitations pertaining to specific manifolds.
0052 % For example:
0053 %   fixedrankembeddedfactory: AD creates grad, not egrad; and no Hessian.
0054 %   fixedranktensorembeddedfactory: no AD support.
0055 %   fixedTTrankfactory: no AD support.
0056 %   euclideansparsefactory: no AD support.
0057 %
0058 % Importantly, while AD is convenient and efficient in terms of human time,
0059 % it is not efficient in terms of CPU time: it is expected that AD slows
0060 % down gradient computations by a factor of about 5. Moreover, while AD can
0061 % most often compute Hessians as well, it is often more efficient to
0062 % compute Hessians with finite differences (which is the default in Manopt
0063 % when the Hessian is not provided by the user).
0064 % Thus: it is often the case that
0065 %   problem = manoptAD(problem, 'nohess');
0066 % leads to better overall runtime than
0067 %   problem = manoptAD(problem);
0068 % when calling trustregions(problem).
0069 %
0070 % Some manifold factories in Manopt support GPUs: automatic differentiation
0071 % should work with them too, as usual. See using_gpu_AD for more details.
0072 %
0073 %
0074 % See also: manoptADhelp autograd egradcompute ehesscompute complex_example_AD using_gpu_AD
0075 
0076 % This file is part of Manopt: www.manopt.org.
0077 % Original author: Xiaowen Jiang, Aug. 31, 2021.
0078 % Contributors: Nicolas Boumal
0079 % Change log:
0080 
0081 % To do: Add AD to fixedTTrankfactory, fixedranktensorembeddedfactory
0082 % and the product manifold which contains fixedrankembeddedfactory
0083 % or anchoredrotationsfactory
0084 
0085 %% Check if AD can be applied to the manifold and the cost function
0086     
0087     % Check availability of the Deep Learning Toolbox.
0088     if ~(exist('dlarray', 'file') == 2)
0089         error('manopt:AD:dl', ...
0090         ['It seems the Deep Learning Toolbox is not installed.\n' ...
0091          'It is needed for automatic differentiation in Manopt.\n' ...
0092          'If possible, install the latest version of that toolbox and ' ...
0093          'ideally also Matlab R2021b or later.']);
0094     end
0095     
0096     % Check for a feature of recent versions of the Deep Learning Toolbox.
0097     if ~(exist('dlaccelerate', 'file') == 2)
0098         warning('manopt:AD:dlaccelerate', ...
0099            ['Function dlaccelerate not available:\n If possible, ' ...
0100             'upgrade to Matlab R2021a or later and use the latest ' ...
0101             'version of the Deep Learning Toolbox.\n' ...
0102             'Automatic differentiation may still work but be a lot ' ...
0103             'slower.\nMoreover, the Hessian is not available in AD.\n' ...
0104             'Setting flag to ''nohess''. '
0105             'To disable this warning: ' ...
0106             'warning(''off'', ''manopt:AD:dlaccelerate'');']);
0107         flag = 'nohess';
0108     end
0109 
0110     % The problem structure must provide a manifold and a cost function.
0111     assert(isfield(problem, 'M') && isfield(problem, 'cost'), ... 
0112               'The problem structure must contain the fields M and cost.');
0113     
0114     % Check the flag value if provided, or set its default value.
0115     if exist('flag', 'var')
0116         assert(strcmp(flag, 'nohess') || strcmp(flag, 'hess'), ...
0117            'The second argument should be either ''nohess'' or ''hess''.');
0118     else
0119         flag = 'hess'; % default behavior
0120     end
0121     
0122     % If the gradient and Hessian information is already provided, return.
0123     if canGetGradient(problem) && canGetHessian(problem)
0124         warning('manopt:AD:alreadydefined', ...
0125           ['Gradient and Hessian already defined, skipping AD.\n' ...
0126            'To disable this warning: ' ...
0127            'warning(''off'', ''manopt:AD:alreadydefined'');']);
0128         return;
0129     end
0130     
0131     % Below, it is convenient for several purposes to have a point on the
0132     % manifold. This makes it possible to investigate its representation.
0133     x = problem.M.rand();
0134     
0135     % AD does not support certain manifolds.
0136     manifold_name = problem.M.name();
0137     if contains(manifold_name, 'sparsity')
0138          error('manopt:AD:sparse', ...
0139               ['Automatic differentiation currently does not support ' ...
0140                'sparse matrices, e.g., euclideansparsefactory.']);
0141     end
0142     if ( startsWith(manifold_name, 'Product manifold') && ...
0143         ((sum(isfield(x, {'U', 'S', 'V'})) == 3) && ...
0144         (contains(manifold_name(), 'rank', 'IgnoreCase', true))) ...
0145        ) || ( ...
0146         exist('tenrand', 'file') == 2 && isfield(x, 'X') && ...
0147         isa(x.X, 'ttensor') ...
0148        ) || ...
0149        isa(x, 'TTeMPS')
0150         error('manopt:AD:fixedrankembedded', ...
0151              ['Automatic differentiation ' ...
0152               'does not support fixedranktensorembeddedfactory,\n'...
0153               'fixedTTrankfactory, and product manifolds containing '...
0154               'fixedrankembeddedfactory.']);
0155     end
0156     
0157     % complexflag is used to detect if both of the following are true:
0158     %   A) the problem variables contain complex numbers, and
0159     %   B) the Matlab version is R2021a or earlier.
0160     % If so, we attempt a workaround.
0161     % If Matlab is R2021b or later, then it is not an issue to have
0162     % complex numbers in the variables.
0163     complexflag = false;
0164     % Check if AD can be applied to the cost function by passing the point
0165     % x we created earlier to problem.cost.
0166     try
0167         dlx = mat2dl(x);
0168         costtestdlx = problem.cost(dlx); %#ok<NASGU>
0169     catch ME
0170         % Detect complex number by looking in error message.
0171         % Note: the error deep:dlarray:ComplexNotSupported is removed
0172         % in Matlab R2021b or later
0173         if (strcmp(ME.identifier, 'deep:dlarray:ComplexNotSupported'))
0174             try
0175                 % Let's try to run AD with 'complex' workaround.
0176                 dlx = mat2dl_complex(x);
0177                 costtestx = problem.cost(x); %#ok<NASGU>
0178                 costtestdlx = problem.cost(dlx); %#ok<NASGU>
0179             catch
0180                 error('manopt:AD:complex', ...
0181                      ['Automatic differentiation failed. ' ...
0182                       'Problem defining the cost function.\n' ...
0183                       'Variables contain complex numbers. ' ...
0184                       'Check your Matlab version and see\n' ...
0185                       'complex_example_AD.m and manoptADhelp.m for ' ...
0186                       'help about how to deal with complex variables.']);
0187             end
0188             % If no error appears, set complexflag to true.
0189             complexflag = true;
0190         else
0191             % If the error is not related to complex numbers, then the
0192             % issue is likely with the cost function definition.
0193             warning('manopt:AD:cost', ...
0194                ['Automatic differentiation failed. '...
0195                 'Problem defining the cost function.\n'...
0196                 '<a href = "https://www.mathworks.ch/help/deeplearning'...
0197                 '/ug/list-of-functions-with-dlarray-support.html">'...
0198                 'Check the list of functions with AD support.</a>'...
0199                 ' and see manoptADhelp for more information.']);
0200             return;
0201         end
0202     end
0203     
0204 %% Keep track of what we create with AD
0205     ADded_gradient = false;
0206     ADded_hessian  = false;
0207     
0208 %% Handle special case of fixedrankembeddedfactory first
0209 
0210     % Check if the manifold struct is fixed-rank matrices
0211     % with an embedded geometry. For fixedrankembeddedfactory,
0212     % only the Riemannian gradient can be computed via AD so far.
0213     fixedrankflag = false;
0214     if (sum(isfield(x, {'U', 'S', 'V'})) == 3) && ...
0215         (contains(manifold_name, 'rank', 'IgnoreCase', true)) && ...
0216         (~startsWith(manifold_name, 'Product manifold'))
0217     
0218         if ~strcmp(flag, 'nohess')
0219             warning('manopt:AD:fixedrank', ...
0220               ['Computating the exact Hessian via AD is not supported ' ...
0221                'for fixedrankembeddedfactory.\n' ...
0222                'Setting flag to ''nohess''.\nTo disable this warning: ' ...
0223                'warning(''off'', ''manopt:AD:fixedrank'');']);
0224             flag = 'nohess';
0225         end
0226         
0227         % Set the fixedrankflag to true to prepare for autgrad.
0228         fixedrankflag = true;
0229         % If no gradient information is provided, compute grad using AD.
0230         % Note that here we define the Riemannian gradient.
0231         if ~canGetGradient(problem)
0232             problem.autogradfunc = autograd(problem, fixedrankflag);
0233             problem.grad = @(x) gradcomputefixedrankembedded(problem, x);
0234             problem.costgrad = @(x) costgradcomputefixedrankembedded(problem, x);
0235             ADded_gradient = true;
0236         end
0237         
0238     end
0239     
0240 %% Compute the euclidean gradient and the euclidean Hessian via AD
0241     
0242     % Provide egrad and (if requested) ehess via AD.
0243     % Manopt converts to Riemannian derivatives via egrad2rgrad and
0244     % ehess2rhess as usual: no need to worry about this here.
0245     if ~fixedrankflag
0246         
0247         if ~canGetGradient(problem)
0248             problem.autogradfunc = autograd(problem);
0249             problem.egrad = @(x) egradcompute(problem, x, complexflag);
0250             problem.costgrad = @(x) costgradcompute(problem, x, complexflag);
0251             ADded_gradient = true;
0252         end
0253         
0254         if ~canGetHessian(problem) && strcmp(flag, 'hess')
0255             problem.ehess = @(x, xdot, store) ...
0256                                      ehesscompute(problem, x, xdot, ...
0257                                                   store, complexflag);
0258             ADded_hessian = true;
0259         end
0260         
0261     end
0262             
0263     
0264 %% Check whether the gradient / Hessian we AD'ded actually work.
0265 
0266     % Some functions are not supported to be differentiated with AD in the
0267     % Deep Learning Toolbox, e.g., cat(3, A, B).
0268     % In this clean-up phase, we check if things actually work, and we
0269     % remove functions if they do not, with a warning.
0270     
0271     if ADded_gradient && ~fixedrankflag
0272         
0273         try 
0274             egrad = problem.egrad(x);
0275         catch
0276             warning('manopt:AD:failgrad', ...
0277                ['Automatic differentiation for gradient failed. '...
0278                 'Problem defining the cost function.\n'...
0279                 '<a href = "https://www.mathworks.ch/help/deeplearning'...
0280                 '/ug/list-of-functions-with-dlarray-support.html">'...
0281                 'Check the list of functions with AD support.</a>'...
0282                 ' and see manoptADhelp for more information.']);
0283             problem = rmfield(problem, 'autogradfunc');
0284             problem = rmfield(problem, 'egrad');
0285             problem = rmfield(problem, 'costgrad');
0286             if ADded_hessian
0287                 problem = rmfield(problem, 'ehess');
0288             end
0289             return;
0290         end
0291         
0292         if isNaNgeneral(egrad)
0293             warning('manopt:AD:NaN', ...
0294                    ['Automatic differentiation for gradient failed. '...
0295                     'Problem defining the cost function.\n'...
0296                     'NaN comes up in the computation of egrad via AD.\n'...
0297                     'Check the example thomson_problem.m for help.']);
0298             problem = rmfield(problem, 'autogradfunc');
0299             problem = rmfield(problem, 'egrad');
0300             problem = rmfield(problem, 'costgrad');
0301             if ADded_hessian
0302                problem = rmfield(problem, 'ehess');
0303             end
0304             return;
0305         end
0306         
0307     end
0308         
0309     
0310     if ADded_hessian
0311         
0312         % Randomly generate a vector in the tangent space at x.
0313         xdot = problem.M.randvec(x);
0314         store = struct();
0315         try 
0316             ehess = problem.ehess(x, xdot, store);
0317         catch
0318             warning('manopt:AD:failhess', ...
0319                    ['Automatic differentiation for Hessian failed. ' ...
0320                     'Problem defining the cost function.\n' ...
0321                     '<a href = "https://www.mathworks.ch/help/deeplearning' ...
0322                     '/ug/list-of-functions-with-dlarray-support.html">' ...
0323                     'Check the list of functions with AD support.</a>' ...
0324                     ' and see manoptADhelp for more information.']);
0325             problem = rmfield(problem, 'ehess');
0326             return;
0327         end
0328         
0329         if isNaNgeneral(ehess)
0330             warning('manopt:AD:NaN', ...
0331                    ['Automatic differentiation for Hessian failed. ' ...
0332                     'Problem defining the cost function.\n' ...
0333                     'NaN comes up in the computation of egrad via AD.\n' ...
0334                     'Check the example thomson_problem.m for help.']);
0335             problem = rmfield(problem, 'ehess');
0336             return;
0337         end
0338         
0339     end
0340         
0341     % Check the case of fixed-rank matrices as embedded submanifold.
0342     if ADded_gradient && fixedrankflag
0343         try 
0344             grad = problem.grad(x);
0345         catch
0346             warning('manopt:AD:costfixedrank', ...
0347                    ['Automatic differentiation for gradient failed. ' ...
0348                     'Problem defining the cost function.\n' ...
0349                     '<a href = "https://www.mathworks.ch/help/deeplearning' ...
0350                     '/ug/list-of-functions-with-dlarray-support.html">' ...
0351                     'Check the list of functions with AD support.</a>' ...
0352                     ' and see manoptADhelp for more information.']);
0353             problem = rmfield(problem, 'autogradfunc');                
0354             problem = rmfield(problem, 'grad');
0355             problem = rmfield(problem, 'costgrad');
0356             return;
0357         end
0358         
0359         if isNaNgeneral(grad)
0360             warning('manopt:AD:NaN', ...
0361                    ['Automatic differentiation for gradient failed. ' ...
0362                     'Problem defining the cost function.\n' ...
0363                     'NaN comes up in the computation of grad via AD.\n' ...
0364                     'Check the example thomson_problem.m for help.']);
0365             problem = rmfield(problem, 'autogradfunc');
0366             problem = rmfield(problem, 'grad');
0367             problem = rmfield(problem, 'costgrad');
0368             return;
0369         end
0370         
0371     end
0372     
0373     
0374 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005