Computes the cost and the gradient at x via AD in one call function [cost, egrad] = costgradcompute(problem, x, complexflag) Returns the cost and the gradient of the cost function described in the problem structure at the point x. Note: the problem structure must contain the field autogradfunc. autogradfunc should be either an AcceleratedFunction or a function handle which contains dlgradient. x is a point on the target manifold. complexflag is bool variable which indicates whether or not the problem described in autogradfunc involves complex numbers and meanwhile the Matlab version is R2021a or earlier. See also: manoptAD autograd mat2dl mat2dl_complex dl2mat dl2mat_complex
0001 function [cost, grad] = costgradcompute(problem, x, complexflag) 0002 % Computes the cost and the gradient at x via AD in one call 0003 % 0004 % function [cost, egrad] = costgradcompute(problem, x, complexflag) 0005 % 0006 % Returns the cost and the gradient of the cost function described in 0007 % the problem structure at the point x. 0008 % 0009 % Note: the problem structure must contain the field autogradfunc. 0010 % autogradfunc should be either an AcceleratedFunction or a function handle 0011 % which contains dlgradient. x is a point on the target manifold. 0012 % complexflag is bool variable which indicates whether or not the problem 0013 % described in autogradfunc involves complex numbers and meanwhile the 0014 % Matlab version is R2021a or earlier. 0015 % 0016 % See also: manoptAD autograd mat2dl mat2dl_complex dl2mat dl2mat_complex 0017 0018 % This file is part of Manopt: www.manopt.org. 0019 % Original author: Xiaowen Jiang, Aug. 31, 2021. 0020 % Contributors: Nicolas Boumal 0021 % Change log: 0022 % 0023 % To do: Add AD to fixedTTrankfactory, fixedranktensorembeddedfactory 0024 % and the product manifold which contains fixedrankembeddedfactory 0025 % or anchoredrotationsfactory 0026 0027 assert(isfield(problem,'autogradfunc'),['the problem structure must'..., 0028 ' contain the field autogradfunc, see autograd.']) 0029 % convert x into dlarrays to prepare for AD 0030 if complexflag == true 0031 dlx = mat2dl_complex(x); 0032 else 0033 dlx = mat2dl(x); 0034 end 0035 0036 % In Matlab R2021b Prerelease, AcceleratedFunction can only accept 0037 % the input with a fixed data structure. If the representation of 0038 % a point on the manifold varies when running a certain algorithm, 0039 % the AcceleratedFunction then fails to work properly. A special case 0040 % is that AcceleratedFunction is sensitive to the order in which the 0041 % fields of the structure have been defined. If a point on a manifold 0042 % is represented as a structure and meanwhile the order of the fields 0043 % defined in the retr and the rand functions in a manifold factory are 0044 % inconsistent, an error will occur. In this case, the old cache should 0045 % be cleared in order to accept the new input. 0046 if isa(problem.autogradfunc,'deep.AcceleratedFunction') 0047 try 0048 % compute egrad according to autogradfunc 0049 [cost, egrad] = dlfeval(problem.autogradfunc, dlx); 0050 catch 0051 % clear the old cache 0052 clearCache(problem.autogradfunc); 0053 [cost, egrad] = dlfeval(problem.autogradfunc, dlx); 0054 warning('manopt:AD:cachedlaccelerate', ... 0055 ['The representation of points on the manifold is inconsistent.\n'... 0056 'AcceleratedFunction has to clear its old cache to accept the new '... 0057 'representation of the input.\nPlease check the consistency when '... 0058 'writing the manifold factory.\n'... 0059 'To disable this warning: warning(''off'', ''manopt:AD:cachedlaccelerte'')']); 0060 end 0061 else 0062 [cost, egrad] = dlfeval(problem.autogradfunc, dlx); 0063 end 0064 0065 % convert egrad back to numerical arrays 0066 if complexflag == true 0067 egrad = dl2mat_complex(egrad); 0068 else 0069 egrad = dl2mat(egrad); 0070 end 0071 0072 % convert egrad to rgrad 0073 grad = problem.M.egrad2rgrad(x, egrad); 0074 % convert cost back to a numeric format 0075 cost = dl2mat(cost); 0076 0077 end