Computes the Euclidean gradient of the cost function at x via AD. function egrad = egradcompute(autogradfunc, x) function egrad = egradcompute(autogradfunc, x, complexflag) Returns the Euclidean gradient of the cost function described in autogradfunc at the point x. Note: the problem structure must contain the field autogradfunc. autogradfunc should be either an AcceleratedFunction or a function handle which contains dlgradient. x is a point on the target manifold. complexflag is a boolean variable which indicates whether the problem described in problem involves complex numbers and meanwhile the Matlab version installed is R2021a or earlier. See also: manoptAD autograd mat2dl mat2dl_complex dl2mat dl2mat_complex
0001 function egrad = egradcompute(problem, x, complexflag) 0002 % Computes the Euclidean gradient of the cost function at x via AD. 0003 % 0004 % function egrad = egradcompute(autogradfunc, x) 0005 % function egrad = egradcompute(autogradfunc, x, complexflag) 0006 % 0007 % Returns the Euclidean gradient of the cost function described in 0008 % autogradfunc at the point x. 0009 % 0010 % Note: the problem structure must contain the field autogradfunc. 0011 % autogradfunc should be either an AcceleratedFunction or a function handle 0012 % which contains dlgradient. x is a point on the target manifold. 0013 % complexflag is a boolean variable which indicates whether the problem 0014 % described in problem involves complex numbers and meanwhile the Matlab 0015 % version installed is R2021a or earlier. 0016 % 0017 % See also: manoptAD autograd mat2dl mat2dl_complex dl2mat dl2mat_complex 0018 0019 % This file is part of Manopt: www.manopt.org. 0020 % Original author: Xiaowen Jiang, Aug. 31, 2021. 0021 % Contributors: Nicolas Boumal 0022 % Change log: 0023 0024 % To do: Add AD to fixedTTrankfactory, fixedranktensorembeddedfactory 0025 % and the product manifold which contains fixedrankembeddedfactory 0026 % or anchoredrotationsfactory 0027 0028 % check availability 0029 assert(isfield(problem,'autogradfunc'),['the problem structure must'..., 0030 ' contain the field autogradfunc, see autograd.']) 0031 if ~exist('complexflag','var') 0032 complexflag = false; 0033 end 0034 % convert x into dlarrays to prepare for AD 0035 if complexflag == true 0036 dlx = mat2dl_complex(x); 0037 else 0038 dlx = mat2dl(x); 0039 end 0040 0041 % In Matlab R2021b Prerelease, AcceleratedFunction can only accept 0042 % the input with a fixed data structure. If the representation of 0043 % a point on the manifold varies when running a certain algorithm, 0044 % the AcceleratedFunction then fails to work properly. A special case 0045 % is that AcceleratedFunction is sensitive to the order in which the 0046 % fields of the structure have been defined. If a point on a manifold 0047 % is represented as a structure and meanwhile the order of the fields 0048 % defined in the retr and the rand functions in a manifold factory are 0049 % inconsistent, an error will occur. In this case, the old cache should 0050 % be cleared in order to accept the new input. 0051 if isa(problem.autogradfunc,'deep.AcceleratedFunction') 0052 try 0053 % compute egrad according to autogradfunc 0054 [~,egrad] = dlfeval(problem.autogradfunc,dlx); 0055 catch 0056 % clear the old cache 0057 clearCache(problem.autogradfunc); 0058 [~,egrad] = dlfeval(problem.autogradfunc,dlx); 0059 warning('manopt:AD:cachedlaccelerte', ... 0060 ['The representation of points on the manifold is inconsistent.\n'... 0061 'AcceleratedFunction has to clear its old cache to accept the new '... 0062 'representation of the input.\nPlease check the consistency when '... 0063 'writing the manifold factory.\n'... 0064 'To disable this warning: warning(''off'', ''manopt:AD:cachedlaccelerte'')']); 0065 end 0066 else 0067 [~, egrad] = dlfeval(problem.autogradfunc,dlx); 0068 end 0069 0070 % convert egrad back to numeric arrays 0071 if complexflag == true 0072 egrad = dl2mat_complex(egrad); 0073 else 0074 egrad = dl2mat(egrad); 0075 end 0076 0077 end