Home > manopt > tools > checkdiff.m

checkdiff

PURPOSE ^

Checks the consistency of the cost function and directional derivatives.

SYNOPSIS ^

function checkdiff(problem, x, d, force_gradient)

DESCRIPTION ^

 Checks the consistency of the cost function and directional derivatives.

 function checkdiff(problem)
 function checkdiff(problem, x)
 function checkdiff(problem, x, d)

 checkdiff performs a numerical test to check that the directional
 derivatives defined in the problem structure agree up to first order with
 the cost function at some point x, along some direction d. The test is
 based on a truncated Taylor series (see online Manopt documentation).

 Both x and d are optional and will be sampled at random if omitted.

 See also: checkgradient checkhessian

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function checkdiff(problem, x, d, force_gradient)
0002 % Checks the consistency of the cost function and directional derivatives.
0003 %
0004 % function checkdiff(problem)
0005 % function checkdiff(problem, x)
0006 % function checkdiff(problem, x, d)
0007 %
0008 % checkdiff performs a numerical test to check that the directional
0009 % derivatives defined in the problem structure agree up to first order with
0010 % the cost function at some point x, along some direction d. The test is
0011 % based on a truncated Taylor series (see online Manopt documentation).
0012 %
0013 % Both x and d are optional and will be sampled at random if omitted.
0014 %
0015 % See also: checkgradient checkhessian
0016 
0017 % If force_gradient = true (hidden parameter), then the function will call
0018 % getGradient and infer the directional derivative, rather than call
0019 % getDirectionalDerivative directly. This is used by checkgradient.
0020 
0021 % This file is part of Manopt: www.manopt.org.
0022 % Original author: Nicolas Boumal, Dec. 30, 2012.
0023 % Contributors:
0024 % Change log:
0025 %
0026 %   March 26, 2017 (JB):
0027 %       Detects if the approximated linear model is exact
0028 %       and provides the user with the corresponding feedback.
0029 %
0030 %   April 3, 2015 (NB):
0031 %       Works with the new StoreDB class system.
0032 %
0033 %   Aug. 2, 2018 (NB):
0034 %       Using storedb.remove() to avoid unnecessary cache build-up.
0035 %
0036 %   Sep. 6, 2018 (NB):
0037 %       Now checks whether M.exp() is available; uses retraction otherwise.
0038 %
0039 %   June 18, 2019 (NB):
0040 %       Now issues a warning if the cost function returns complex values.
0041 
0042     if ~exist('force_gradient', 'var')
0043         force_gradient = false;
0044     end
0045         
0046     % Verify that the problem description is sufficient.
0047     if ~canGetCost(problem)
0048         error('It seems no cost was provided.');
0049     end
0050     if ~force_gradient && ~canGetDirectionalDerivative(problem)
0051         error('It seems no directional derivatives were provided.');
0052     end
0053     if force_gradient && ~canGetGradient(problem)
0054         % Would normally issue a warning, but this function should only be
0055         % called with force_gradient on by checkgradient, which will
0056         % already have issued a warning.
0057     end
0058         
0059     x_isprovided = exist('x', 'var') && ~isempty(x);
0060     d_isprovided = exist('d', 'var') && ~isempty(d);
0061     
0062     if ~x_isprovided && d_isprovided
0063         error('If d is provided, x must be too, since d is tangent at x.');
0064     end
0065     
0066     % If x and / or d are not specified, pick them at random.
0067     if ~x_isprovided
0068         x = problem.M.rand();
0069     end
0070     if ~d_isprovided
0071         d = problem.M.randvec(x);
0072     end
0073 
0074     % Compute the value f0 at f and directional derivative at x along d.
0075     storedb = StoreDB();
0076     xkey = storedb.getNewKey();
0077     f0 = getCost(problem, x, storedb, xkey);
0078     
0079     if ~force_gradient
0080         df0 = getDirectionalDerivative(problem, x, d, storedb, xkey);
0081     else
0082         grad = getGradient(problem, x, storedb, xkey);
0083         df0 = problem.M.inner(x, grad, d);
0084     end
0085     
0086     % Pick a stepping function: exponential or retraction?
0087     if isfield(problem.M, 'exp')
0088         stepper = problem.M.exp;
0089     else
0090         stepper = problem.M.retr;
0091         % No need to issue a warning: to check the gradient, any retraction
0092         % (which is first-order by definition) is appropriate.
0093     end
0094     
0095     % Compute the value of f at points on the geodesic (or approximation
0096     % of it) originating from x, along direction d, for stepsizes in a
0097     % large range given by h.
0098     h = logspace(-8, 0, 51);
0099     value = zeros(size(h));
0100     for k = 1 : length(h)
0101         y = stepper(x, d, h(k));
0102         ykey = storedb.getNewKey();
0103         value(k) = getCost(problem, y, storedb, ykey);
0104         storedb.remove(ykey); % no need to keep it in memory
0105     end
0106     
0107     % Compute the linear approximation of the cost function using f0 and
0108     % df0 at the same points.
0109     model = polyval([df0 f0], h);
0110     
0111     % Compute the approximation error
0112     err = abs(model - value);
0113     
0114     % And plot it.
0115     loglog(h, err);
0116     title(sprintf(['Directional derivative check.\nThe slope of the '...
0117                    'continuous line should match that of the dashed\n'...
0118                    '(reference) line over at least a few orders of '...
0119                    'magnitude for h.']));
0120     xlabel('h');
0121     ylabel('Approximation error');
0122     
0123     line('xdata', [1e-8 1e0], 'ydata', [1e-8 1e8], ...
0124          'color', 'k', 'LineStyle', '--', ...
0125          'YLimInclude', 'off', 'XLimInclude', 'off');
0126     
0127      
0128     if ~all( err < 1e-12 )
0129         % In a numerically reasonable neighborhood, the error should
0130         % decrease as the square of the stepsize, i.e., in loglog scale,
0131         % the error should have a slope of 2.
0132         isModelExact = false;
0133         window_len = 10;
0134         [range, poly] = identify_linear_piece(log10(h), log10(err), window_len);
0135     else
0136         % The 1st order model is exact: all errors are (numerically) zero
0137         % Fit line from all points, use log scale only in h.
0138         isModelExact = true;
0139         range = 1:numel(h);
0140         poly = polyfit(log10(h), err, 1);
0141         % Set mean error in log scale for plot.
0142         poly(end) = log10(poly(end));
0143         % Change title to something more descriptive for this special case.
0144         title(sprintf(...
0145               ['Directional derivative check.\n'...
0146                'It seems the linear model is exact:\n'...
0147                'Model error is numerically zero for all h.']));
0148     end
0149     hold all;
0150     loglog(h(range), 10.^polyval(poly, log10(h(range))), 'LineWidth', 3);
0151     hold off;
0152     
0153     if ~isModelExact
0154         fprintf('The slope should be 2. It appears to be: %g.\n', poly(1));
0155         fprintf(['If it is far from 2, then directional derivatives ' ...
0156                  'might be erroneous.\n']);
0157     else
0158         fprintf(['The linear model appears to be exact ' ...
0159                  '(within numerical precision),\n'...
0160                  'hence the slope computation is irrelevant.\n']);
0161     end
0162     
0163     if ~(isreal(value) && isreal(f0))
0164         fprintf(['# The cost function appears to return complex values' ...
0165               '.\n# Please ensure real outputs.\n']);
0166     end
0167     
0168 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005