Home > manopt > core > getGradient.m

getGradient

PURPOSE ^

Computes the gradient of the cost function at x.

SYNOPSIS ^

function grad = getGradient(problem, x, storedb, key)

DESCRIPTION ^

 Computes the gradient of the cost function at x.

 function grad = getGradient(problem, x)
 function grad = getGradient(problem, x, storedb)
 function grad = getGradient(problem, x, storedb, key)

 Returns the gradient at x of the cost function described in the problem
 structure.

 storedb is a StoreDB object, key is the StoreDB key to point x.

 See also: getDirectionalDerivative canGetGradient

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function grad = getGradient(problem, x, storedb, key)
0002 % Computes the gradient of the cost function at x.
0003 %
0004 % function grad = getGradient(problem, x)
0005 % function grad = getGradient(problem, x, storedb)
0006 % function grad = getGradient(problem, x, storedb, key)
0007 %
0008 % Returns the gradient at x of the cost function described in the problem
0009 % structure.
0010 %
0011 % storedb is a StoreDB object, key is the StoreDB key to point x.
0012 %
0013 % See also: getDirectionalDerivative canGetGradient
0014 
0015 % This file is part of Manopt: www.manopt.org.
0016 % Original author: Nicolas Boumal, Dec. 30, 2012.
0017 % Contributors:
0018 % Change log:
0019 %
0020 %   April 3, 2015 (NB):
0021 %       Works with the new StoreDB class system.
0022 %
0023 %  June 28, 2016 (NB):
0024 %       Works with getPartialGradient.
0025 %
0026 %   Nov. 1, 2016 (NB):
0027 %       Added support for gradient from directional derivatives.
0028 %       Last resort is call to getApproxGradient instead of an exception.
0029 %
0030 %   Sep. 6, 2018 (NB):
0031 %       The gradient is now cached by default. This is made practical by
0032 %       the new storedb 'remove' functionalities that keep the number of
0033 %       cached points down to a minimum. If the gradient is obtained via
0034 %       costgrad, the cost is also cached.
0035 %
0036 %   Feb. 10, 2020 (NB):
0037 %       Allowing M.egrad2rgrad to take (storedb, key) as extra inputs.
0038 
0039     % Allow omission of the key, and even of storedb.
0040     if ~exist('key', 'var')
0041         if ~exist('storedb', 'var')
0042             storedb = StoreDB();
0043         end
0044         key = storedb.getNewKey();
0045     end
0046 
0047     % Contrary to most similar functions, here, we get the store by
0048     % default. This is for the caching functionality described below.
0049     store = storedb.getWithShared(key);
0050     store_is_stale = false;
0051 
0052     % If the gradient has been computed before at this point (and its
0053     % memory is still in storedb), then we just look up the value.
0054     force_grad_caching = true;
0055     if force_grad_caching && isfield(store, 'grad__')
0056         grad = store.grad__;
0057         return;
0058     end
0059     
0060     % We don't normally compute the cost value, but if we get it as a side
0061     % result, then we may as well take note of it for caching.
0062     cost_computed = false;
0063     
0064     
0065     if isfield(problem, 'grad')
0066     %% Compute the gradient using grad.
0067     
0068         % Check whether this function wants to deal with storedb or not.
0069         switch nargin(problem.grad)
0070             case 1
0071                 grad = problem.grad(x);
0072             case 2
0073                 [grad, store] = problem.grad(x, store);
0074             case 3
0075                 % Pass along the whole storedb (by reference), with key.
0076                 grad = problem.grad(x, storedb, key);
0077                 % The store structure in storedb might have been modified
0078                 % (since it is passed by reference), so before caching
0079                 % we'll have to update (see below).
0080                 store_is_stale = true;
0081             otherwise
0082                 up = MException('manopt:getGradient:badgrad', ...
0083                     'grad should accept 1, 2 or 3 inputs.');
0084                 throw(up);
0085         end
0086     
0087     elseif isfield(problem, 'costgrad')
0088     %% Compute the gradient using costgrad.
0089     
0090         % Check whether this function wants to deal with storedb or not.
0091         switch nargin(problem.costgrad)
0092             case 1
0093                 [cost, grad] = problem.costgrad(x);
0094             case 2
0095                 [cost, grad, store] = problem.costgrad(x, store);
0096             case 3
0097                 % Pass along the whole storedb (by reference), with key.
0098                 [cost, grad] = problem.costgrad(x, storedb, key);
0099                 store_is_stale = true;
0100             otherwise
0101                 up = MException('manopt:getGradient:badcostgrad', ...
0102                     'costgrad should accept 1, 2 or 3 inputs.');
0103                 throw(up);
0104         end
0105         
0106         cost_computed = true;
0107     
0108     elseif canGetEuclideanGradient(problem)
0109     %% Compute the Riemannian gradient using the Euclidean gradient.
0110         
0111         egrad = getEuclideanGradient(problem, x, storedb, key);
0112         % Convert to the Riemannian gradient
0113         switch nargin(problem.M.egrad2rgrad)
0114             case 2
0115                 grad = problem.M.egrad2rgrad(x, egrad);
0116             case 4
0117                 grad = problem.M.egrad2rgrad(x, egrad, storedb, key);
0118             otherwise
0119                 up = MException('manopt:getGradient:egrad2rgrad', ...
0120                     'egrad2rgrad should accept 2 or 4 inputs.');
0121                 throw(up);
0122         end
0123         store_is_stale = true;
0124     
0125     elseif canGetPartialGradient(problem)
0126     %% Compute the gradient using a full partial gradient.
0127         
0128         d = problem.ncostterms;
0129         grad = getPartialGradient(problem, x, 1:d, storedb, key);
0130         store_is_stale = true;
0131         
0132     elseif canGetDirectionalDerivative(problem)
0133     %% Compute gradient based on directional derivatives; expensive!
0134     
0135         B = tangentorthobasis(problem.M, x);
0136         df = zeros(size(B));
0137         for k = 1 : numel(B)
0138             df(k) = getDirectionalDerivative(problem, x, B{k}, storedb, key);
0139         end
0140         grad = lincomb(problem.M, x, B, df);
0141         store_is_stale = true;
0142 
0143     else
0144     %% Attempt the computation of an approximation of the gradient.
0145         
0146         grad = getApproxGradient(problem, x, storedb, key);
0147         store_is_stale = true;
0148         
0149     end
0150 
0151     % If we are not sure that the store structure is up to date, update.
0152     if store_is_stale
0153         store = storedb.getWithShared(key);
0154     end
0155     
0156     % Cache here.
0157     if force_grad_caching
0158         store.grad__ = grad; 
0159     end
0160     % If we got the gradient via costgrad, then the cost has also been
0161     % computed and we can cache it.
0162     if cost_computed
0163         store.cost__ = cost;
0164     end
0165 
0166     storedb.setWithShared(store, key);
0167     
0168 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005