Home > manopt > tools > checkdiff.m

checkdiff

PURPOSE ^

Checks the consistency of the cost function and directional derivatives.

SYNOPSIS ^

function checkdiff(problem, x, d, force_gradient)

DESCRIPTION ^

 Checks the consistency of the cost function and directional derivatives.

 function checkdiff(problem)
 function checkdiff(problem, x)
 function checkdiff(problem, x, d)

 checkdiff performs a numerical test to check that the directional
 derivatives defined in the problem structure agree up to first order with
 the cost function at some point x, along some direction d. The test is
 based on a truncated Taylor series (see online Manopt documentation).

 Both x and d are optional and will be sampled at random if omitted.

 See also: checkgradient checkhessian

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function checkdiff(problem, x, d, force_gradient)
0002 % Checks the consistency of the cost function and directional derivatives.
0003 %
0004 % function checkdiff(problem)
0005 % function checkdiff(problem, x)
0006 % function checkdiff(problem, x, d)
0007 %
0008 % checkdiff performs a numerical test to check that the directional
0009 % derivatives defined in the problem structure agree up to first order with
0010 % the cost function at some point x, along some direction d. The test is
0011 % based on a truncated Taylor series (see online Manopt documentation).
0012 %
0013 % Both x and d are optional and will be sampled at random if omitted.
0014 %
0015 % See also: checkgradient checkhessian
0016 
0017 % If force_gradient = true (hidden parameter), then the function will call
0018 % getGradient and infer the directional derivative, rather than call
0019 % getDirectionalDerivative directly. This is used by checkgradient.
0020 
0021 % This file is part of Manopt: www.manopt.org.
0022 % Original author: Nicolas Boumal, Dec. 30, 2012.
0023 % Contributors:
0024 % Change log:
0025 %
0026 %   March 26, 2017 (JB):
0027 %       Detects if the approximated linear model is exact
0028 %       and provides the user with the corresponding feedback.
0029 %
0030 %   April 3, 2015 (NB):
0031 %       Works with the new StoreDB class system.
0032 %
0033 %   Aug. 2, 2018 (NB):
0034 %       Using storedb.remove() to avoid unnecessary cache build-up.
0035 %
0036 %   Sep. 6, 2018 (NB):
0037 %       Now checks whether M.exp() is available; uses retraction otherwise.
0038 
0039     if ~exist('force_gradient', 'var')
0040         force_gradient = false;
0041     end
0042         
0043     % Verify that the problem description is sufficient.
0044     if ~canGetCost(problem)
0045         error('It seems no cost was provided.');
0046     end
0047     if ~force_gradient && ~canGetDirectionalDerivative(problem)
0048         error('It seems no directional derivatives were provided.');
0049     end
0050     if force_gradient && ~canGetGradient(problem)
0051         % Would normally issue a warning, but this function should only be
0052         % called with force_gradient on by checkgradient, which will
0053         % already have issued a warning.
0054     end
0055         
0056     x_isprovided = exist('x', 'var') && ~isempty(x);
0057     d_isprovided = exist('d', 'var') && ~isempty(d);
0058     
0059     if ~x_isprovided && d_isprovided
0060         error('If d is provided, x must be too, since d is tangent at x.');
0061     end
0062     
0063     % If x and / or d are not specified, pick them at random.
0064     if ~x_isprovided
0065         x = problem.M.rand();
0066     end
0067     if ~d_isprovided
0068         d = problem.M.randvec(x);
0069     end
0070 
0071     % Compute the value f0 at f and directional derivative at x along d.
0072     storedb = StoreDB();
0073     xkey = storedb.getNewKey();
0074     f0 = getCost(problem, x, storedb, xkey);
0075     
0076     if ~force_gradient
0077         df0 = getDirectionalDerivative(problem, x, d, storedb, xkey);
0078     else
0079         grad = getGradient(problem, x, storedb, xkey);
0080         df0 = problem.M.inner(x, grad, d);
0081     end
0082     
0083     % Pick a stepping function: exponential or retraction?
0084     if isfield(problem.M, 'exp')
0085         stepper = problem.M.exp;
0086     else
0087         stepper = problem.M.retr;
0088         % No need to issue a warning: to check the gradient, any retraction
0089         % (which is first-order by definition) is appropriate.
0090     end
0091     
0092     % Compute the value of f at points on the geodesic (or approximation
0093     % of it) originating from x, along direction d, for stepsizes in a
0094     % large range given by h.
0095     h = logspace(-8, 0, 51);
0096     value = zeros(size(h));
0097     for k = 1 : length(h)
0098         y = stepper(x, d, h(k));
0099         ykey = storedb.getNewKey();
0100         value(k) = getCost(problem, y, storedb, ykey);
0101         storedb.remove(ykey); % no need to keep it in memory
0102     end
0103     
0104     % Compute the linear approximation of the cost function using f0 and
0105     % df0 at the same points.
0106     model = polyval([df0 f0], h);
0107     
0108     % Compute the approximation error
0109     err = abs(model - value);
0110     
0111     % And plot it.
0112     loglog(h, err);
0113     title(sprintf(['Directional derivative check.\nThe slope of the '...
0114                    'continuous line should match that of the dashed\n'...
0115                    '(reference) line over at least a few orders of '...
0116                    'magnitude for h.']));
0117     xlabel('h');
0118     ylabel('Approximation error');
0119     
0120     line('xdata', [1e-8 1e0], 'ydata', [1e-8 1e8], ...
0121          'color', 'k', 'LineStyle', '--', ...
0122          'YLimInclude', 'off', 'XLimInclude', 'off');
0123     
0124      
0125     if ~all( err < 1e-12 )
0126         % In a numerically reasonable neighborhood, the error should
0127         % decrease as the square of the stepsize, i.e., in loglog scale,
0128         % the error should have a slope of 2.
0129         isModelExact = false;
0130         window_len = 10;
0131         [range, poly] = identify_linear_piece(log10(h), log10(err), window_len);
0132     else
0133         % The 1st order model is exact: all errors are (numerically) zero
0134         % Fit line from all points, use log scale only in h.
0135         isModelExact = true;
0136         range = 1:numel(h);
0137         poly = polyfit(log10(h), err, 1);
0138         % Set mean error in log scale for plot.
0139         poly(end) = log10(poly(end));
0140         % Change title to something more descriptive for this special case.
0141         title(sprintf(...
0142               ['Directional derivative check.\n'...
0143                'It seems the linear model is exact:\n'...
0144                'Model error is numerically zero for all h.']));
0145     end
0146     hold all;
0147     loglog(h(range), 10.^polyval(poly, log10(h(range))), 'LineWidth', 3);
0148     hold off;
0149     
0150     if ~isModelExact
0151         fprintf('The slope should be 2. It appears to be: %g.\n', poly(1));
0152         fprintf(['If it is far from 2, then directional derivatives ' ...
0153                  'might be erroneous.\n']);
0154     else
0155         fprintf(['The linear model appears to be exact ' ...
0156                  '(within numerical precision),\n'...
0157                  'hence the slope computation is irrelevant.\n']);
0158     end
0159     
0160 end

Generated on Mon 10-Sep-2018 11:48:06 by m2html © 2005