Home > manopt > core > getPartialGradient.m

getPartialGradient

PURPOSE ^

Computes the gradient of a subset of terms in the cost function at x.

SYNOPSIS ^

function grad = getPartialGradient(problem, x, I, storedb, key)

DESCRIPTION ^

 Computes the gradient of a subset of terms in the cost function at x.

 function grad = getPartialGradient(problem, x, I)
 function grad = getPartialGradient(problem, x, I, storedb)
 function grad = getPartialGradient(problem, x, I, storedb, key)

 Assume the cost function described in the problem structure is a sum of
 many terms, as

    f(x) = sum_i f_i(x) for i = 1:d,

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function grad = getPartialGradient(problem, x, I, storedb, key)
0002 % Computes the gradient of a subset of terms in the cost function at x.
0003 %
0004 % function grad = getPartialGradient(problem, x, I)
0005 % function grad = getPartialGradient(problem, x, I, storedb)
0006 % function grad = getPartialGradient(problem, x, I, storedb, key)
0007 %
0008 % Assume the cost function described in the problem structure is a sum of
0009 % many terms, as
0010 %
0011 %    f(x) = sum_i f_i(x) for i = 1:d,
0012 
0013 % where d is specified as d = problem.ncostterms.
0014 %
0015 % For a subset I of 1:d, getPartialGradient obtains the gradient of the
0016 % partial cost function
0017 %
0018 %    f_I(x) = sum_i f_i(x) for i = I.
0019 %
0020 % storedb is a StoreDB object, key is the StoreDB key to point x.
0021 %
0022 % See also: getGradient canGetPartialGradient getPartialEuclideanGradient
0023 
0024 % This file is part of Manopt: www.manopt.org.
0025 % Original author: Nicolas Boumal, June 28, 2016
0026 % Contributors:
0027 % Change log:
0028 %
0029 %   Feb. 10, 2020 (NB):
0030 %       Allowing M.egrad2rgrad to take (storedb, key) as extra inputs.
0031 
0032 
0033     % Allow omission of the key, and even of storedb.
0034     if ~exist('key', 'var')
0035         if ~exist('storedb', 'var')
0036             storedb = StoreDB();
0037         end
0038         key = storedb.getNewKey();
0039     end
0040     
0041     
0042     % Make sure I is a row vector, so that it is natural to loop over it
0043     % with " for i = I ".
0044     I = (I(:)).';
0045 
0046     
0047     if isfield(problem, 'partialgrad')
0048     %% Compute the partial gradient using partialgrad.
0049     
0050         % Check whether this function wants to deal with storedb or not.
0051         switch nargin(problem.partialgrad)
0052             case 2
0053                 grad = problem.partialgrad(x, I);
0054             case 3
0055                 % Obtain, pass along, and save the store for x.
0056                 store = storedb.getWithShared(key);
0057                 [grad, store] = problem.partialgrad(x, I, store);
0058                 storedb.setWithShared(store, key);
0059             case 4
0060                 % Pass along the whole storedb (by reference), with key.
0061                 grad = problem.partialgrad(x, I, storedb, key);
0062             otherwise
0063                 up = MException('manopt:getPartialGradient:badpartialgrad', ...
0064                     'partialgrad should accept 2, 3 or 4 inputs.');
0065                 throw(up);
0066         end
0067     
0068     elseif canGetPartialEuclideanGradient(problem)
0069     %% Compute the partial gradient using the Euclidean partial gradient.
0070         
0071         egrad = getPartialEuclideanGradient(problem, x, I, storedb, key);
0072         % Convert to the Riemannian gradient
0073         switch nargin(problem.M.egrad2rgrad)
0074             case 2
0075                 grad = problem.M.egrad2rgrad(x, egrad);
0076             case 4
0077                 grad = problem.M.egrad2rgrad(x, egrad, storedb, key);
0078             otherwise
0079                 up = MException('manopt:getPartialGradient:egrad2rgrad', ...
0080                     'egrad2rgrad should accept 2 or 4 inputs.');
0081                 throw(up);
0082         end
0083 
0084     else
0085     %% Abandon computing the partial gradient.
0086     
0087         up = MException('manopt:getPartialGradient:fail', ...
0088             ['The problem description is not explicit enough to ' ...
0089              'compute the partial gradient of the cost.']);
0090         throw(up);
0091         
0092     end
0093     
0094 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005