Home > manopt > autodiff > egradcompute.m

egradcompute

PURPOSE ^

Computes the Euclidean gradient of the cost function at x via AD.

SYNOPSIS ^

function egrad = egradcompute(problem, x, complexflag)

DESCRIPTION ^

 Computes the Euclidean gradient of the cost function at x via AD.

 function egrad = egradcompute(autogradfunc, x)
 function egrad = egradcompute(autogradfunc, x, complexflag)

 Returns the Euclidean gradient of the cost function described in 
 autogradfunc at the point x.

 Note: the problem structure must contain the field autogradfunc.
 autogradfunc should be either an AcceleratedFunction or a function handle  
 which contains dlgradient. x is a point on the target manifold. 
 complexflag is a boolean variable which indicates whether the problem
 described in problem involves complex numbers and meanwhile the Matlab
 version installed is R2021a or earlier.

 See also: manoptAD autograd mat2dl mat2dl_complex dl2mat dl2mat_complex

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function egrad = egradcompute(problem, x, complexflag)
0002 % Computes the Euclidean gradient of the cost function at x via AD.
0003 %
0004 % function egrad = egradcompute(autogradfunc, x)
0005 % function egrad = egradcompute(autogradfunc, x, complexflag)
0006 %
0007 % Returns the Euclidean gradient of the cost function described in
0008 % autogradfunc at the point x.
0009 %
0010 % Note: the problem structure must contain the field autogradfunc.
0011 % autogradfunc should be either an AcceleratedFunction or a function handle
0012 % which contains dlgradient. x is a point on the target manifold.
0013 % complexflag is a boolean variable which indicates whether the problem
0014 % described in problem involves complex numbers and meanwhile the Matlab
0015 % version installed is R2021a or earlier.
0016 %
0017 % See also: manoptAD autograd mat2dl mat2dl_complex dl2mat dl2mat_complex
0018 
0019 % This file is part of Manopt: www.manopt.org.
0020 % Original author: Xiaowen Jiang, Aug. 31, 2021.
0021 % Contributors: Nicolas Boumal
0022 % Change log:
0023 
0024 % To do: Add AD to fixedTTrankfactory, fixedranktensorembeddedfactory
0025 % and the product manifold which contains fixedrankembeddedfactory
0026 % or anchoredrotationsfactory
0027 
0028     % check availability
0029     assert(isfield(problem,'autogradfunc'),['the problem structure must'...,
0030         ' contain the field autogradfunc, see autograd.'])
0031     if ~exist('complexflag','var')
0032         complexflag = false;
0033     end
0034     % convert x into dlarrays to prepare for AD
0035     if complexflag == true
0036         dlx = mat2dl_complex(x);
0037     else
0038         dlx = mat2dl(x);
0039     end
0040 
0041     % In Matlab R2021b Prerelease, AcceleratedFunction can only accept
0042     % the input with a fixed data structure. If the representation of
0043     % a point on the manifold varies when running a certain algorithm,
0044     % the AcceleratedFunction then fails to work properly. A special case
0045     % is that AcceleratedFunction is sensitive to the order in which the
0046     % fields of the structure have been defined. If a point on a manifold
0047     % is represented as a structure and meanwhile the order of the fields
0048     % defined in the retr and the rand functions in a manifold factory are
0049     % inconsistent, an error will occur. In this case, the old cache should
0050     % be cleared in order to accept the new input.
0051     if isa(problem.autogradfunc,'deep.AcceleratedFunction')
0052         try
0053             % compute egrad according to autogradfunc
0054             [~,egrad] = dlfeval(problem.autogradfunc,dlx);
0055         catch
0056             % clear the old cache
0057             clearCache(problem.autogradfunc);
0058             [~,egrad] = dlfeval(problem.autogradfunc,dlx);
0059             warning('manopt:AD:cachedlaccelerte', ...
0060             ['The representation of points on the manifold is inconsistent.\n'...
0061             'AcceleratedFunction has to clear its old cache to accept the new '...
0062             'representation of the input.\nPlease check the consistency when '...
0063             'writing the manifold factory.\n'...
0064             'To disable this warning: warning(''off'', ''manopt:AD:cachedlaccelerte'')']);
0065         end
0066     else
0067         [~, egrad] = dlfeval(problem.autogradfunc,dlx);
0068     end
0069 
0070     % convert egrad back to numeric arrays
0071     if complexflag == true
0072         egrad = dl2mat_complex(egrad);
0073     else
0074         egrad = dl2mat(egrad);
0075     end
0076  
0077 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005