Home > manopt > autodiff > costgradcompute.m

costgradcompute

PURPOSE ^

Computes the cost and the gradient at x via AD in one call

SYNOPSIS ^

function [cost, grad] = costgradcompute(problem, x, complexflag)

DESCRIPTION ^

 Computes the cost and the gradient at x via AD in one call 

 function [cost, egrad] = costgradcompute(problem, x, complexflag)

 Returns the cost and the gradient of the cost function described in 
 the problem structure at the point x.

 Note: the problem structure must contain the field autogradfunc.
 autogradfunc should be either an AcceleratedFunction or a function handle  
 which contains dlgradient. x is a point on the target manifold. 
 complexflag is bool variable which indicates whether or not the problem 
 described in autogradfunc involves complex numbers and meanwhile the
 Matlab version is R2021a or earlier.

 See also: manoptAD autograd mat2dl mat2dl_complex dl2mat dl2mat_complex

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [cost, grad] = costgradcompute(problem, x, complexflag)
0002 % Computes the cost and the gradient at x via AD in one call
0003 %
0004 % function [cost, egrad] = costgradcompute(problem, x, complexflag)
0005 %
0006 % Returns the cost and the gradient of the cost function described in
0007 % the problem structure at the point x.
0008 %
0009 % Note: the problem structure must contain the field autogradfunc.
0010 % autogradfunc should be either an AcceleratedFunction or a function handle
0011 % which contains dlgradient. x is a point on the target manifold.
0012 % complexflag is bool variable which indicates whether or not the problem
0013 % described in autogradfunc involves complex numbers and meanwhile the
0014 % Matlab version is R2021a or earlier.
0015 %
0016 % See also: manoptAD autograd mat2dl mat2dl_complex dl2mat dl2mat_complex
0017     
0018 % This file is part of Manopt: www.manopt.org.
0019 % Original author: Xiaowen Jiang, Aug. 31, 2021.
0020 % Contributors: Nicolas Boumal
0021 % Change log:
0022 %
0023 % To do: Add AD to fixedTTrankfactory, fixedranktensorembeddedfactory
0024 % and the product manifold which contains fixedrankembeddedfactory
0025 % or anchoredrotationsfactory
0026 
0027     assert(isfield(problem,'autogradfunc'),['the problem structure must'...,
0028         ' contain the field autogradfunc, see autograd.'])
0029     % convert x into dlarrays to prepare for AD
0030     if complexflag == true
0031         dlx = mat2dl_complex(x);
0032     else
0033         dlx = mat2dl(x);
0034     end
0035 
0036     % In Matlab R2021b Prerelease, AcceleratedFunction can only accept
0037     % the input with a fixed data structure. If the representation of
0038     % a point on the manifold varies when running a certain algorithm,
0039     % the AcceleratedFunction then fails to work properly. A special case
0040     % is that AcceleratedFunction is sensitive to the order in which the
0041     % fields of the structure have been defined. If a point on a manifold
0042     % is represented as a structure and meanwhile the order of the fields
0043     % defined in the retr and the rand functions in a manifold factory are
0044     % inconsistent, an error will occur. In this case, the old cache should
0045     % be cleared in order to accept the new input.
0046     if isa(problem.autogradfunc,'deep.AcceleratedFunction')
0047         try
0048             % compute egrad according to autogradfunc
0049             [cost, egrad] = dlfeval(problem.autogradfunc, dlx);
0050         catch
0051             % clear the old cache
0052             clearCache(problem.autogradfunc);
0053             [cost, egrad] = dlfeval(problem.autogradfunc, dlx);
0054             warning('manopt:AD:cachedlaccelerate', ...
0055             ['The representation of points on the manifold is inconsistent.\n'...
0056             'AcceleratedFunction has to clear its old cache to accept the new '...
0057             'representation of the input.\nPlease check the consistency when '...
0058             'writing the manifold factory.\n'...
0059             'To disable this warning: warning(''off'', ''manopt:AD:cachedlaccelerte'')']);            
0060         end
0061      else
0062         [cost, egrad] = dlfeval(problem.autogradfunc, dlx);
0063      end
0064 
0065     % convert egrad back to numerical arrays
0066     if complexflag == true
0067         egrad = dl2mat_complex(egrad);
0068     else
0069         egrad = dl2mat(egrad);
0070     end
0071     
0072     % convert egrad to rgrad
0073     grad = problem.M.egrad2rgrad(x, egrad);
0074     % convert cost back to a numeric format
0075     cost = dl2mat(cost);
0076     
0077 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005