Preprocess automatic differentiation for a manopt problem structure function problem = manoptAD(problem) function problem = manoptAD(problem, 'nohess') function problem = manoptAD(problem, 'hess') Given a manopt problem structure with problem.cost and problem.M defined, this tool adds the following fields to the problem structure: problem.egrad problem.costgrad problem.ehess A field problem.autogradfunc is also created for internal use. The fields egrad and ehess correspond to Euclidean gradients and Hessian. They are obtained through automatic differentation of the cost function. Manopt converts them into Riemannian objects in the usual way via the manifold's M.egrad2rgrad and M.ehess2rhess functions, automatically. As an optional second input, the user may specify the flag string to be: 'nohess' -- in which case problem.ehess is not created. 'hess' -- which corresponds to the default behavior. If problem.egrad is already provided and the Hessian is requested, the tool builds problem.ehess based on problem.egrad rather than the cost. This function requires the following: Matlab version R2021a or later. Deep Learning Toolbox version 14.2 or later. Support for complex variables in automatic differentation is added in Matlab version R2021b or later. There is also better support for Hessian computations in that version. Otherwise, see manoptADhelp and complex_example_AD for a workaround, or set the 'nohess' flag to tell Manopt not to compute Hessians with AD. If AD fails for some reasons, the original problem structure is returned with a warning trying to hint at what the issue may be. Mostly, issues arise because the manoptAD relies on the Deep Learning Toolbox, which itself relies on the dlarray data type, and only a subset of Matlab functions support dlarrays: See manoptADhelp for more about limitations and workarounds. See https://ch.mathworks.com/help/deeplearning/ug/list-of-functions-with-dlarray-support.html for an official list of functions that support dlarray. In particular, sparse matrices are not supported, as well as certain standard functions including trace() which can be replaced by ctrace(). There are a few limitations pertaining to specific manifolds. For example: fixedrankembeddedfactory: AD creates grad, not egrad; and no Hessian. fixedranktensorembeddedfactory: no AD support. fixedTTrankfactory: no AD support. euclideansparsefactory: no AD support. Importantly, while AD is convenient and efficient in terms of human time, it is not efficient in terms of CPU time: it is expected that AD slows down gradient computations by a factor of about 5. Moreover, while AD can most often compute Hessians as well, it is often more efficient to compute Hessians with finite differences (which is the default in Manopt when the Hessian is not provided by the user). Thus: it is often the case that problem = manoptAD(problem, 'nohess'); leads to better overall runtime than problem = manoptAD(problem); when calling trustregions(problem). Some manifold factories in Manopt support GPUs: automatic differentiation should work with them too, as usual. See using_gpu_AD for more details. See also: manoptADhelp autograd egradcompute ehesscompute complex_example_AD using_gpu_AD
0001 function problem = manoptAD(problem, flag) 0002 % Preprocess automatic differentiation for a manopt problem structure 0003 % 0004 % function problem = manoptAD(problem) 0005 % function problem = manoptAD(problem, 'nohess') 0006 % function problem = manoptAD(problem, 'hess') 0007 % 0008 % Given a manopt problem structure with problem.cost and problem.M defined, 0009 % this tool adds the following fields to the problem structure: 0010 % problem.egrad 0011 % problem.costgrad 0012 % problem.ehess 0013 % 0014 % A field problem.autogradfunc is also created for internal use. 0015 % 0016 % The fields egrad and ehess correspond to Euclidean gradients and Hessian. 0017 % They are obtained through automatic differentation of the cost function. 0018 % Manopt converts them into Riemannian objects in the usual way via the 0019 % manifold's M.egrad2rgrad and M.ehess2rhess functions, automatically. 0020 % 0021 % As an optional second input, the user may specify the flag string to be: 0022 % 'nohess' -- in which case problem.ehess is not created. 0023 % 'hess' -- which corresponds to the default behavior. 0024 % If problem.egrad is already provided and the Hessian is requested, the 0025 % tool builds problem.ehess based on problem.egrad rather than the cost. 0026 % 0027 % This function requires the following: 0028 % Matlab version R2021a or later. 0029 % Deep Learning Toolbox version 14.2 or later. 0030 % 0031 % Support for complex variables in automatic differentation is added in 0032 % Matlab version R2021b or later. 0033 % There is also better support for Hessian computations in that version. 0034 % Otherwise, see manoptADhelp and complex_example_AD for a workaround, or 0035 % set the 'nohess' flag to tell Manopt not to compute Hessians with AD. 0036 % 0037 % If AD fails for some reasons, the original problem structure 0038 % is returned with a warning trying to hint at what the issue may be. 0039 % Mostly, issues arise because the manoptAD relies on the Deep Learning 0040 % Toolbox, which itself relies on the dlarray data type, and only a subset 0041 % of Matlab functions support dlarrays: 0042 % 0043 % See manoptADhelp for more about limitations and workarounds. 0044 % See 0045 % https://ch.mathworks.com/help/deeplearning/ug/list-of-functions-with-dlarray-support.html 0046 % for an official list of functions that support dlarray. 0047 % 0048 % In particular, sparse matrices are not supported, as well as certain 0049 % standard functions including trace() which can be replaced by ctrace(). 0050 % 0051 % There are a few limitations pertaining to specific manifolds. 0052 % For example: 0053 % fixedrankembeddedfactory: AD creates grad, not egrad; and no Hessian. 0054 % fixedranktensorembeddedfactory: no AD support. 0055 % fixedTTrankfactory: no AD support. 0056 % euclideansparsefactory: no AD support. 0057 % 0058 % Importantly, while AD is convenient and efficient in terms of human time, 0059 % it is not efficient in terms of CPU time: it is expected that AD slows 0060 % down gradient computations by a factor of about 5. Moreover, while AD can 0061 % most often compute Hessians as well, it is often more efficient to 0062 % compute Hessians with finite differences (which is the default in Manopt 0063 % when the Hessian is not provided by the user). 0064 % Thus: it is often the case that 0065 % problem = manoptAD(problem, 'nohess'); 0066 % leads to better overall runtime than 0067 % problem = manoptAD(problem); 0068 % when calling trustregions(problem). 0069 % 0070 % Some manifold factories in Manopt support GPUs: automatic differentiation 0071 % should work with them too, as usual. See using_gpu_AD for more details. 0072 % 0073 % 0074 % See also: manoptADhelp autograd egradcompute ehesscompute complex_example_AD using_gpu_AD 0075 0076 % This file is part of Manopt: www.manopt.org. 0077 % Original author: Xiaowen Jiang, Aug. 31, 2021. 0078 % Contributors: Nicolas Boumal 0079 % Change log: 0080 0081 % To do: Add AD to fixedTTrankfactory, fixedranktensorembeddedfactory 0082 % and the product manifold which contains fixedrankembeddedfactory 0083 % or anchoredrotationsfactory 0084 0085 %% Check if AD can be applied to the manifold and the cost function 0086 0087 % Check availability of the Deep Learning Toolbox. 0088 if ~(exist('dlarray', 'file') == 2) 0089 error('manopt:AD:dl', ... 0090 ['It seems the Deep Learning Toolbox is not installed.\n' ... 0091 'It is needed for automatic differentiation in Manopt.\n' ... 0092 'If possible, install the latest version of that toolbox and ' ... 0093 'ideally also Matlab R2021b or later.']); 0094 end 0095 0096 % Check for a feature of recent versions of the Deep Learning Toolbox. 0097 if ~(exist('dlaccelerate', 'file') == 2) 0098 warning('manopt:AD:dlaccelerate', ... 0099 ['Function dlaccelerate not available:\n If possible, ' ... 0100 'upgrade to Matlab R2021a or later and use the latest ' ... 0101 'version of the Deep Learning Toolbox.\n' ... 0102 'Automatic differentiation may still work but be a lot ' ... 0103 'slower.\nMoreover, the Hessian is not available in AD.\n' ... 0104 'Setting flag to ''nohess''. ' 0105 'To disable this warning: ' ... 0106 'warning(''off'', ''manopt:AD:dlaccelerate'');']); 0107 flag = 'nohess'; 0108 end 0109 0110 % The problem structure must provide a manifold and a cost function. 0111 assert(isfield(problem, 'M') && isfield(problem, 'cost'), ... 0112 'The problem structure must contain the fields M and cost.'); 0113 0114 % Check the flag value if provided, or set its default value. 0115 if exist('flag', 'var') 0116 assert(strcmp(flag, 'nohess') || strcmp(flag, 'hess'), ... 0117 'The second argument should be either ''nohess'' or ''hess''.'); 0118 else 0119 flag = 'hess'; % default behavior 0120 end 0121 0122 % If the gradient and Hessian information is already provided, return. 0123 if canGetGradient(problem) && canGetHessian(problem) 0124 warning('manopt:AD:alreadydefined', ... 0125 ['Gradient and Hessian already defined, skipping AD.\n' ... 0126 'To disable this warning: ' ... 0127 'warning(''off'', ''manopt:AD:alreadydefined'');']); 0128 return; 0129 end 0130 0131 % Below, it is convenient for several purposes to have a point on the 0132 % manifold. This makes it possible to investigate its representation. 0133 x = problem.M.rand(); 0134 0135 % AD does not support certain manifolds. 0136 manifold_name = problem.M.name(); 0137 if contains(manifold_name, 'sparsity') 0138 error('manopt:AD:sparse', ... 0139 ['Automatic differentiation currently does not support ' ... 0140 'sparse matrices, e.g., euclideansparsefactory.']); 0141 end 0142 if ( startsWith(manifold_name, 'Product manifold') && ... 0143 ((sum(isfield(x, {'U', 'S', 'V'})) == 3) && ... 0144 (contains(manifold_name(), 'rank', 'IgnoreCase', true))) ... 0145 ) || ( ... 0146 exist('tenrand', 'file') == 2 && isfield(x, 'X') && ... 0147 isa(x.X, 'ttensor') ... 0148 ) || ... 0149 isa(x, 'TTeMPS') 0150 error('manopt:AD:fixedrankembedded', ... 0151 ['Automatic differentiation ' ... 0152 'does not support fixedranktensorembeddedfactory,\n'... 0153 'fixedTTrankfactory, and product manifolds containing '... 0154 'fixedrankembeddedfactory.']); 0155 end 0156 0157 % complexflag is used to detect if both of the following are true: 0158 % A) the problem variables contain complex numbers, and 0159 % B) the Matlab version is R2021a or earlier. 0160 % If so, we attempt a workaround. 0161 % If Matlab is R2021b or later, then it is not an issue to have 0162 % complex numbers in the variables. 0163 complexflag = false; 0164 % Check if AD can be applied to the cost function by passing the point 0165 % x we created earlier to problem.cost. 0166 try 0167 dlx = mat2dl(x); 0168 costtestdlx = problem.cost(dlx); %#ok<NASGU> 0169 catch ME 0170 % Detect complex number by looking in error message. 0171 % Note: the error deep:dlarray:ComplexNotSupported is removed 0172 % in Matlab R2021b or later 0173 if (strcmp(ME.identifier, 'deep:dlarray:ComplexNotSupported')) 0174 try 0175 % Let's try to run AD with 'complex' workaround. 0176 dlx = mat2dl_complex(x); 0177 costtestx = problem.cost(x); %#ok<NASGU> 0178 costtestdlx = problem.cost(dlx); %#ok<NASGU> 0179 catch 0180 error('manopt:AD:complex', ... 0181 ['Automatic differentiation failed. ' ... 0182 'Problem defining the cost function.\n' ... 0183 'Variables contain complex numbers. ' ... 0184 'Check your Matlab version and see\n' ... 0185 'complex_example_AD.m and manoptADhelp.m for ' ... 0186 'help about how to deal with complex variables.']); 0187 end 0188 % If no error appears, set complexflag to true. 0189 complexflag = true; 0190 else 0191 % If the error is not related to complex numbers, then the 0192 % issue is likely with the cost function definition. 0193 warning('manopt:AD:cost', ... 0194 ['Automatic differentiation failed. '... 0195 'Problem defining the cost function.\n'... 0196 '<a href = "https://www.mathworks.ch/help/deeplearning'... 0197 '/ug/list-of-functions-with-dlarray-support.html">'... 0198 'Check the list of functions with AD support.</a>'... 0199 ' and see manoptADhelp for more information.']); 0200 return; 0201 end 0202 end 0203 0204 %% Keep track of what we create with AD 0205 ADded_gradient = false; 0206 ADded_hessian = false; 0207 0208 %% Handle special case of fixedrankembeddedfactory first 0209 0210 % Check if the manifold struct is fixed-rank matrices 0211 % with an embedded geometry. For fixedrankembeddedfactory, 0212 % only the Riemannian gradient can be computed via AD so far. 0213 fixedrankflag = false; 0214 if (sum(isfield(x, {'U', 'S', 'V'})) == 3) && ... 0215 (contains(manifold_name, 'rank', 'IgnoreCase', true)) && ... 0216 (~startsWith(manifold_name, 'Product manifold')) 0217 0218 if ~strcmp(flag, 'nohess') 0219 warning('manopt:AD:fixedrank', ... 0220 ['Computating the exact Hessian via AD is not supported ' ... 0221 'for fixedrankembeddedfactory.\n' ... 0222 'Setting flag to ''nohess''.\nTo disable this warning: ' ... 0223 'warning(''off'', ''manopt:AD:fixedrank'');']); 0224 flag = 'nohess'; 0225 end 0226 0227 % Set the fixedrankflag to true to prepare for autgrad. 0228 fixedrankflag = true; 0229 % If no gradient information is provided, compute grad using AD. 0230 % Note that here we define the Riemannian gradient. 0231 if ~canGetGradient(problem) 0232 problem.autogradfunc = autograd(problem, fixedrankflag); 0233 problem.grad = @(x) gradcomputefixedrankembedded(problem, x); 0234 problem.costgrad = @(x) costgradcomputefixedrankembedded(problem, x); 0235 ADded_gradient = true; 0236 end 0237 0238 end 0239 0240 %% Compute the euclidean gradient and the euclidean Hessian via AD 0241 0242 % Provide egrad and (if requested) ehess via AD. 0243 % Manopt converts to Riemannian derivatives via egrad2rgrad and 0244 % ehess2rhess as usual: no need to worry about this here. 0245 if ~fixedrankflag 0246 0247 if ~canGetGradient(problem) 0248 problem.autogradfunc = autograd(problem); 0249 problem.egrad = @(x) egradcompute(problem, x, complexflag); 0250 problem.costgrad = @(x) costgradcompute(problem, x, complexflag); 0251 ADded_gradient = true; 0252 end 0253 0254 if ~canGetHessian(problem) && strcmp(flag, 'hess') 0255 problem.ehess = @(x, xdot, store) ... 0256 ehesscompute(problem, x, xdot, ... 0257 store, complexflag); 0258 ADded_hessian = true; 0259 end 0260 0261 end 0262 0263 0264 %% Check whether the gradient / Hessian we AD'ded actually work. 0265 0266 % Some functions are not supported to be differentiated with AD in the 0267 % Deep Learning Toolbox, e.g., cat(3, A, B). 0268 % In this clean-up phase, we check if things actually work, and we 0269 % remove functions if they do not, with a warning. 0270 0271 if ADded_gradient && ~fixedrankflag 0272 0273 try 0274 egrad = problem.egrad(x); 0275 catch 0276 warning('manopt:AD:failgrad', ... 0277 ['Automatic differentiation for gradient failed. '... 0278 'Problem defining the cost function.\n'... 0279 '<a href = "https://www.mathworks.ch/help/deeplearning'... 0280 '/ug/list-of-functions-with-dlarray-support.html">'... 0281 'Check the list of functions with AD support.</a>'... 0282 ' and see manoptADhelp for more information.']); 0283 problem = rmfield(problem, 'autogradfunc'); 0284 problem = rmfield(problem, 'egrad'); 0285 problem = rmfield(problem, 'costgrad'); 0286 if ADded_hessian 0287 problem = rmfield(problem, 'ehess'); 0288 end 0289 return; 0290 end 0291 0292 if isNaNgeneral(egrad) 0293 warning('manopt:AD:NaN', ... 0294 ['Automatic differentiation for gradient failed. '... 0295 'Problem defining the cost function.\n'... 0296 'NaN comes up in the computation of egrad via AD.\n'... 0297 'Check the example thomson_problem.m for help.']); 0298 problem = rmfield(problem, 'autogradfunc'); 0299 problem = rmfield(problem, 'egrad'); 0300 problem = rmfield(problem, 'costgrad'); 0301 if ADded_hessian 0302 problem = rmfield(problem, 'ehess'); 0303 end 0304 return; 0305 end 0306 0307 end 0308 0309 0310 if ADded_hessian 0311 0312 % Randomly generate a vector in the tangent space at x. 0313 xdot = problem.M.randvec(x); 0314 store = struct(); 0315 try 0316 ehess = problem.ehess(x, xdot, store); 0317 catch 0318 warning('manopt:AD:failhess', ... 0319 ['Automatic differentiation for Hessian failed. ' ... 0320 'Problem defining the cost function.\n' ... 0321 '<a href = "https://www.mathworks.ch/help/deeplearning' ... 0322 '/ug/list-of-functions-with-dlarray-support.html">' ... 0323 'Check the list of functions with AD support.</a>' ... 0324 ' and see manoptADhelp for more information.']); 0325 problem = rmfield(problem, 'ehess'); 0326 return; 0327 end 0328 0329 if isNaNgeneral(ehess) 0330 warning('manopt:AD:NaN', ... 0331 ['Automatic differentiation for Hessian failed. ' ... 0332 'Problem defining the cost function.\n' ... 0333 'NaN comes up in the computation of egrad via AD.\n' ... 0334 'Check the example thomson_problem.m for help.']); 0335 problem = rmfield(problem, 'ehess'); 0336 return; 0337 end 0338 0339 end 0340 0341 % Check the case of fixed-rank matrices as embedded submanifold. 0342 if ADded_gradient && fixedrankflag 0343 try 0344 grad = problem.grad(x); 0345 catch 0346 warning('manopt:AD:costfixedrank', ... 0347 ['Automatic differentiation for gradient failed. ' ... 0348 'Problem defining the cost function.\n' ... 0349 '<a href = "https://www.mathworks.ch/help/deeplearning' ... 0350 '/ug/list-of-functions-with-dlarray-support.html">' ... 0351 'Check the list of functions with AD support.</a>' ... 0352 ' and see manoptADhelp for more information.']); 0353 problem = rmfield(problem, 'autogradfunc'); 0354 problem = rmfield(problem, 'grad'); 0355 problem = rmfield(problem, 'costgrad'); 0356 return; 0357 end 0358 0359 if isNaNgeneral(grad) 0360 warning('manopt:AD:NaN', ... 0361 ['Automatic differentiation for gradient failed. ' ... 0362 'Problem defining the cost function.\n' ... 0363 'NaN comes up in the computation of grad via AD.\n' ... 0364 'Check the example thomson_problem.m for help.']); 0365 problem = rmfield(problem, 'autogradfunc'); 0366 problem = rmfield(problem, 'grad'); 0367 problem = rmfield(problem, 'costgrad'); 0368 return; 0369 end 0370 0371 end 0372 0373 0374 end