0001 function checkdiff(problem, x, d, force_gradient)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042 if ~exist('force_gradient', 'var')
0043 force_gradient = false;
0044 end
0045
0046
0047 if ~canGetCost(problem)
0048 error('It seems no cost was provided.');
0049 end
0050 if ~force_gradient && ~canGetDirectionalDerivative(problem)
0051 error('It seems no directional derivatives were provided.');
0052 end
0053 if force_gradient && ~canGetGradient(problem)
0054
0055
0056
0057 end
0058
0059 x_isprovided = exist('x', 'var') && ~isempty(x);
0060 d_isprovided = exist('d', 'var') && ~isempty(d);
0061
0062 if ~x_isprovided && d_isprovided
0063 error('If d is provided, x must be too, since d is tangent at x.');
0064 end
0065
0066
0067 if ~x_isprovided
0068 x = problem.M.rand();
0069 end
0070 if ~d_isprovided
0071 d = problem.M.randvec(x);
0072 end
0073
0074
0075 storedb = StoreDB();
0076 xkey = storedb.getNewKey();
0077 f0 = getCost(problem, x, storedb, xkey);
0078
0079 if ~force_gradient
0080 df0 = getDirectionalDerivative(problem, x, d, storedb, xkey);
0081 else
0082 grad = getGradient(problem, x, storedb, xkey);
0083 df0 = problem.M.inner(x, grad, d);
0084 end
0085
0086
0087 if isfield(problem.M, 'exp')
0088 stepper = problem.M.exp;
0089 else
0090 stepper = problem.M.retr;
0091
0092
0093 end
0094
0095
0096
0097
0098 h = logspace(-8, 0, 51);
0099 value = zeros(size(h));
0100 for k = 1 : length(h)
0101 y = stepper(x, d, h(k));
0102 ykey = storedb.getNewKey();
0103 value(k) = getCost(problem, y, storedb, ykey);
0104 storedb.remove(ykey);
0105 end
0106
0107
0108
0109 model = polyval([df0 f0], h);
0110
0111
0112 err = abs(model - value);
0113
0114
0115 loglog(h, err);
0116 title(sprintf(['Directional derivative check.\nThe slope of the '...
0117 'continuous line should match that of the dashed\n'...
0118 '(reference) line over at least a few orders of '...
0119 'magnitude for h.']));
0120 xlabel('h');
0121 ylabel('Approximation error');
0122
0123 line('xdata', [1e-8 1e0], 'ydata', [1e-8 1e8], ...
0124 'color', 'k', 'LineStyle', '--', ...
0125 'YLimInclude', 'off', 'XLimInclude', 'off');
0126
0127
0128 if ~all( err < 1e-12 )
0129
0130
0131
0132 isModelExact = false;
0133 window_len = 10;
0134 [range, poly] = identify_linear_piece(log10(h), log10(err), window_len);
0135 else
0136
0137
0138 isModelExact = true;
0139 range = 1:numel(h);
0140 poly = polyfit(log10(h), err, 1);
0141
0142 poly(end) = log10(poly(end));
0143
0144 title(sprintf(...
0145 ['Directional derivative check.\n'...
0146 'It seems the linear model is exact:\n'...
0147 'Model error is numerically zero for all h.']));
0148 end
0149 hold all;
0150 loglog(h(range), 10.^polyval(poly, log10(h(range))), 'LineWidth', 3);
0151 hold off;
0152
0153 if ~isModelExact
0154 fprintf('The slope should be 2. It appears to be: %g.\n', poly(1));
0155 fprintf(['If it is far from 2, then directional derivatives ' ...
0156 'might be erroneous.\n']);
0157 else
0158 fprintf(['The linear model appears to be exact ' ...
0159 '(within numerical precision),\n'...
0160 'hence the slope computation is irrelevant.\n']);
0161 end
0162
0163 if ~(isreal(value) && isreal(f0))
0164 fprintf(['# The cost function appears to return complex values' ...
0165 '.\n# Please ensure real outputs.\n']);
0166 end
0167
0168 end