Home > manopt > autodiff > autograd.m

autograd

PURPOSE ^

Apply automatic differentiation to computing the Euclidean gradient

SYNOPSIS ^

function autogradfunc = autograd(problem, fixedrankflag)

DESCRIPTION ^

 Apply automatic differentiation to computing the Euclidean gradient

 function autogradfunc = autograd(problem)
 function autogradfunc = autograd(problem, fixedrankflag)

 Returns an AcceleratedFunction or a function handle which can be used to 
 compute Euclidean gradients. See https://ch.mathworks.com/help/
 deeplearning/ref/deep.acceleratedfunction.html for more descriptions 
 about AcceleratedFunction.

 Note: to evaluate the Euclidean gradient of a certain point x(x should be
 of type dlarray), call dfeval(autogradfunc,x) instead of autogradfunc(x).

 See also: manoptAD, egradcompute, costgradcompute

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function autogradfunc = autograd(problem, fixedrankflag)
0002 % Apply automatic differentiation to computing the Euclidean gradient
0003 %
0004 % function autogradfunc = autograd(problem)
0005 % function autogradfunc = autograd(problem, fixedrankflag)
0006 %
0007 % Returns an AcceleratedFunction or a function handle which can be used to
0008 % compute Euclidean gradients. See https://ch.mathworks.com/help/
0009 % deeplearning/ref/deep.acceleratedfunction.html for more descriptions
0010 % about AcceleratedFunction.
0011 %
0012 % Note: to evaluate the Euclidean gradient of a certain point x(x should be
0013 % of type dlarray), call dfeval(autogradfunc,x) instead of autogradfunc(x).
0014 %
0015 % See also: manoptAD, egradcompute, costgradcompute
0016 
0017 % This file is part of Manopt: www.manopt.org.
0018 % Original author: Xiaowen Jiang, Aug. 31, 2021.
0019 % Contributors: Nicolas Boumal
0020 % Change log:
0021 %
0022 % To do: Add AD to fixedTTrankfactory, fixedranktensorembeddedfactory
0023 % and the product manifold which contains fixedrankembeddedfactory
0024 % or anchoredrotationsfactory
0025     
0026     % Check availability
0027     assert(isfield(problem,'M') && isfield(problem,'cost'),...
0028     'problem structure must contain the fields M and cost.');
0029     assert(exist('dlarray', 'file') == 2, ['Deep learning tool box is '... 
0030     'needed for automatic differentiation'])
0031     
0032     % Set fixedrankflag to false if the manifold struct is not
0033     % fixed(multilinear)-rank matrices or tensors with an embedded geometry
0034     % or tensors of fixed Tensor Train (TT) rank
0035     if ~exist('fixedrankflag', 'var')|| isempty(fixedrankflag)
0036         fixedrankflag = false;
0037     end
0038 
0039     % Obtain the euclidean gradient function via AD
0040     costfunction = problem.cost;
0041     % Set fixedrankflag to true if the manifold is fixed-rank matrices with
0042     % an embedded geometry. The other two cases are not implemented yet.
0043     if fixedrankflag
0044         % AcceleratedFunction can lead to a slow down in this case
0045         autogradfunc = @(x,A,B) autogradfuncinternelfixedrankembedded(x,A,B);
0046     else
0047         func = @autogradfuncinternal;
0048         % accelerate
0049         try
0050             autogradfunc = dlaccelerate(func); % Introduced in Matlab 2021a
0051             clearCache(autogradfunc);
0052         catch
0053             warning('manopt:dlaccelerate', ...
0054                     ['Function dlaccelerate is not available:\nPlease ' ...
0055                      'upgrade to Matlab 2021a or later and the latest deep\nlearning ' ...
0056                      'toolbox version if possible.\nMeanwhile, auto-diff ' ...
0057                      'may be somewhat slower.\n The hessian is not available as well.\n' ...
0058                      'To disable this warning: warning(''off'', ''manopt:dlaccelerate'')']);
0059             autogradfunc = func;
0060         end
0061     end
0062     
0063     % define Euclidean gradient function
0064     function [y, egrad] = autogradfuncinternal(x)
0065             
0066         y = costfunction(x);
0067         % In case that the user forgot to take the real part of the cost
0068         % when dealing with complex problems with Matlab R2021a or earlier,
0069         % take the real part for AD
0070         if iscstruct(y)
0071             y = creal(y);
0072         end
0073         
0074         % Call dlgradient to compute the Euclidean gradient. by default,
0075         % 'RetainData' and 'EnableHigherDerivatives' are set to false
0076         egrad = dlgradient(y, x);
0077         
0078         % in case that the user is optimizing over anchoredrotationsfactory
0079         % egrad of anchors with indices in A should be zero
0080         problem_name = problem.M.name();
0081         if (contains(problem_name,'Product rotations manifold') &&..., 
0082             contains(problem_name,'anchors') &&...,
0083             ~startsWith(problem_name,'Product manifold'))
0084             A = findA_anchors(problem);
0085             egrad(:, :, A) = 0;
0086         end
0087     end
0088     
0089     % fixedrankembeddedfactory part
0090     % obtain the product of egrad and V and the product of egrad
0091     % transpose and U by differentiating g1 and g2 w.r.t A and B
0092     function [g1, egrad] = autogradfuncinternelfixedrankembedded(x, A, B)
0093         X1.U = A; X1.S = eye(size(x.S,1)); X1.V = x.V;
0094         X2.U = x.U; X2.S = eye(size(x.S,1)); X2.V = B;
0095         g1 = costfunction(X1); g2 = costfunction(X2);
0096         egrad.A = dlgradient(g1,A);  egrad.B = dlgradient(g2,B);
0097     end
0098     
0099 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005