Apply automatic differentiation to computing the Euclidean gradient function autogradfunc = autograd(problem) function autogradfunc = autograd(problem, fixedrankflag) Returns an AcceleratedFunction or a function handle which can be used to compute Euclidean gradients. See https://ch.mathworks.com/help/ deeplearning/ref/deep.acceleratedfunction.html for more descriptions about AcceleratedFunction. Note: to evaluate the Euclidean gradient of a certain point x(x should be of type dlarray), call dfeval(autogradfunc,x) instead of autogradfunc(x). See also: manoptAD, egradcompute, costgradcompute
0001 function autogradfunc = autograd(problem, fixedrankflag) 0002 % Apply automatic differentiation to computing the Euclidean gradient 0003 % 0004 % function autogradfunc = autograd(problem) 0005 % function autogradfunc = autograd(problem, fixedrankflag) 0006 % 0007 % Returns an AcceleratedFunction or a function handle which can be used to 0008 % compute Euclidean gradients. See https://ch.mathworks.com/help/ 0009 % deeplearning/ref/deep.acceleratedfunction.html for more descriptions 0010 % about AcceleratedFunction. 0011 % 0012 % Note: to evaluate the Euclidean gradient of a certain point x(x should be 0013 % of type dlarray), call dfeval(autogradfunc,x) instead of autogradfunc(x). 0014 % 0015 % See also: manoptAD, egradcompute, costgradcompute 0016 0017 % This file is part of Manopt: www.manopt.org. 0018 % Original author: Xiaowen Jiang, Aug. 31, 2021. 0019 % Contributors: Nicolas Boumal 0020 % Change log: 0021 % 0022 % To do: Add AD to fixedTTrankfactory, fixedranktensorembeddedfactory 0023 % and the product manifold which contains fixedrankembeddedfactory 0024 % or anchoredrotationsfactory 0025 0026 % Check availability 0027 assert(isfield(problem,'M') && isfield(problem,'cost'),... 0028 'problem structure must contain the fields M and cost.'); 0029 assert(exist('dlarray', 'file') == 2, ['Deep learning tool box is '... 0030 'needed for automatic differentiation']) 0031 0032 % Set fixedrankflag to false if the manifold struct is not 0033 % fixed(multilinear)-rank matrices or tensors with an embedded geometry 0034 % or tensors of fixed Tensor Train (TT) rank 0035 if ~exist('fixedrankflag', 'var')|| isempty(fixedrankflag) 0036 fixedrankflag = false; 0037 end 0038 0039 % Obtain the euclidean gradient function via AD 0040 costfunction = problem.cost; 0041 % Set fixedrankflag to true if the manifold is fixed-rank matrices with 0042 % an embedded geometry. The other two cases are not implemented yet. 0043 if fixedrankflag 0044 % AcceleratedFunction can lead to a slow down in this case 0045 autogradfunc = @(x,A,B) autogradfuncinternelfixedrankembedded(x,A,B); 0046 else 0047 func = @autogradfuncinternal; 0048 % accelerate 0049 try 0050 autogradfunc = dlaccelerate(func); % Introduced in Matlab 2021a 0051 clearCache(autogradfunc); 0052 catch 0053 warning('manopt:dlaccelerate', ... 0054 ['Function dlaccelerate is not available:\nPlease ' ... 0055 'upgrade to Matlab 2021a or later and the latest deep\nlearning ' ... 0056 'toolbox version if possible.\nMeanwhile, auto-diff ' ... 0057 'may be somewhat slower.\n The hessian is not available as well.\n' ... 0058 'To disable this warning: warning(''off'', ''manopt:dlaccelerate'')']); 0059 autogradfunc = func; 0060 end 0061 end 0062 0063 % define Euclidean gradient function 0064 function [y, egrad] = autogradfuncinternal(x) 0065 0066 y = costfunction(x); 0067 % In case that the user forgot to take the real part of the cost 0068 % when dealing with complex problems with Matlab R2021a or earlier, 0069 % take the real part for AD 0070 if iscstruct(y) 0071 y = creal(y); 0072 end 0073 0074 % Call dlgradient to compute the Euclidean gradient. by default, 0075 % 'RetainData' and 'EnableHigherDerivatives' are set to false 0076 egrad = dlgradient(y, x); 0077 0078 % in case that the user is optimizing over anchoredrotationsfactory 0079 % egrad of anchors with indices in A should be zero 0080 problem_name = problem.M.name(); 0081 if (contains(problem_name,'Product rotations manifold') &&..., 0082 contains(problem_name,'anchors') &&..., 0083 ~startsWith(problem_name,'Product manifold')) 0084 A = findA_anchors(problem); 0085 egrad(:, :, A) = 0; 0086 end 0087 end 0088 0089 % fixedrankembeddedfactory part 0090 % obtain the product of egrad and V and the product of egrad 0091 % transpose and U by differentiating g1 and g2 w.r.t A and B 0092 function [g1, egrad] = autogradfuncinternelfixedrankembedded(x, A, B) 0093 X1.U = A; X1.S = eye(size(x.S,1)); X1.V = x.V; 0094 X2.U = x.U; X2.S = eye(size(x.S,1)); X2.V = B; 0095 g1 = costfunction(X1); g2 = costfunction(X2); 0096 egrad.A = dlgradient(g1,A); egrad.B = dlgradient(g2,B); 0097 end 0098 0099 end