Home > manopt > core > getCostGrad.m

getCostGrad

PURPOSE ^

Computes the cost function and the gradient at x in one call if possible.

SYNOPSIS ^

function [cost, grad] = getCostGrad(problem, x, storedb, key)

DESCRIPTION ^

 Computes the cost function and the gradient at x in one call if possible.

 function [cost, grad] = getCostGrad(problem, x)
 function [cost, grad] = getCostGrad(problem, x, storedb)
 function [cost, grad] = getCostGrad(problem, x, storedb, key)

 Returns the value at x of the cost function described in the problem
 structure, as well as the gradient at x.

 storedb is a StoreDB object, key is the StoreDB key to point x.

 See also: canGetCost canGetGradient getCost getGradient

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [cost, grad] = getCostGrad(problem, x, storedb, key)
0002 % Computes the cost function and the gradient at x in one call if possible.
0003 %
0004 % function [cost, grad] = getCostGrad(problem, x)
0005 % function [cost, grad] = getCostGrad(problem, x, storedb)
0006 % function [cost, grad] = getCostGrad(problem, x, storedb, key)
0007 %
0008 % Returns the value at x of the cost function described in the problem
0009 % structure, as well as the gradient at x.
0010 %
0011 % storedb is a StoreDB object, key is the StoreDB key to point x.
0012 %
0013 % See also: canGetCost canGetGradient getCost getGradient
0014 
0015 % This file is part of Manopt: www.manopt.org.
0016 % Original author: Nicolas Boumal, Dec. 30, 2012.
0017 % Contributors:
0018 % Change log:
0019 %
0020 %   April 3, 2015 (NB):
0021 %       Works with the new StoreDB class system.
0022 %
0023 %   Aug. 2, 2018 (NB):
0024 %       The value of the cost function is now always cached.
0025 %
0026 %   Sep. 6, 2018 (NB):
0027 %       The gradient is now also cached.
0028 
0029     % Allow omission of the key, and even of storedb.
0030     if ~exist('key', 'var')
0031         if ~exist('storedb', 'var')
0032             storedb = StoreDB();
0033         end
0034         key = storedb.getNewKey();
0035     end
0036 
0037     % Contrary to most similar functions, here, we get the store by
0038     % default. This is for the caching functionality described below.
0039     store = storedb.getWithShared(key);
0040     store_is_stale = false;
0041     
0042     % Check if the cost or gradient are readily available from the store.
0043     force_grad_caching = true;
0044     if isfield(store, 'cost__')
0045         cost = store.cost__;
0046         if force_grad_caching && isfield(store, 'grad__')
0047             grad = store.grad__;
0048             return;
0049         else
0050             grad = getGradient(problem, x, storedb, key); % caches grad
0051             return;
0052         end
0053     end
0054     % If we get here, the cost was not previously cached, but maybe the
0055     % gradient was?
0056     if force_grad_caching && isfield(store, 'grad__')
0057         grad = store.grad__;
0058         cost = getCost(problem, x, storedb, key); % this call caches cost
0059         return;
0060     end
0061 
0062     % Neither the cost nor the gradient were available: let's compute both.
0063 
0064     if isfield(problem, 'costgrad')
0065     %% Compute the cost/grad pair using costgrad.
0066     
0067         % Check whether this function wants to deal with storedb or not.
0068         switch nargin(problem.costgrad)
0069             case 1
0070                 [cost, grad] = problem.costgrad(x);
0071             case 2
0072                 [cost, grad, store] = problem.costgrad(x, store);
0073             case 3
0074                 % Pass along the whole storedb (by reference), with key.
0075                 [cost, grad] = problem.costgrad(x, storedb, key);
0076                 store_is_stale = true;
0077             otherwise
0078                 up = MException('manopt:getCostGrad:badcostgrad', ...
0079                     'costgrad should accept 1, 2 or 3 inputs.');
0080                 throw(up);
0081         end
0082 
0083     else
0084     %% Revert to calling getCost and getGradient separately
0085     
0086         % The two following calls will already cache cost and grad, then
0087         % the caches will be overwritten at the end of this function, with
0088         % the same values (it is not a problem).
0089         cost = getCost(problem, x, storedb, key);
0090         grad = getGradient(problem, x, storedb, key);
0091         store_is_stale = true;
0092         
0093     end
0094     
0095     if store_is_stale
0096         store = storedb.getWithShared(key);
0097     end
0098     
0099     % Cache here.
0100     store.cost__ = cost;
0101     if force_grad_caching
0102         store.grad__ = grad;
0103     end
0104     
0105     storedb.setWithShared(store, key);
0106     
0107 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005