Manopt example on how to use counters during optimization. Typical uses,


function using_counters()


0001 function using_counters()
0002 % Manopt example on how to use counters during optimization. Typical uses,
0003 % as demonstrated here, include counting calls to cost, gradient and
0004 % Hessian functions. The example also demonstrates how to record total time
0005 % spent in cost/grad/hess calls iteration by iteration.
0006 %
0007 % See also: statscounters incrementcounter statsfunhelper
0009 % This file is part of Manopt: www.manopt.org.
0010 % Original author: Nicolas Boumal, July 27, 2018.
0011 % Contributors:
0012 % Change log:
0014     % Fix random seed in Matlab (this particular syntax fails in Octave).
0015     if exist('OCTAVE_VERSION', 'builtin') == 0
0016         rng(0);
0017     end
0019     % Setup an optimization problem to illustrate the use of counters
0020     n = 1000;
0021     A = randn(n);
0022     A = .5*(A+A');
0024     manifold = spherefactory(n);
0025     problem.M = manifold;
0028     % Define the problem cost function and its gradient.
0029     %
0030     % Since the most expensive operation in computing the cost and the
0031     % gradient at x is the product A*x, and since this operation is the
0032     % same for both the cost and the gradient, we use the caching
0033     % functionalities of manopt for this product. This function ensures the
0034     % product A*x is available in the store structure. Remember that a
0035     % store structure is associated to a particular point x: if cost and
0036     % egrad are called on the same point x, they will see the same store.
0037     function store = prepare(x, store)
0038         if ~isfield(store, 'Ax')
0039             store.Ax = A*x;
0040             % Increment a counter for the number of matrix-vector products
0041             % involving A. The names of the counters (here, Aproducts) are
0042             % for us to choose: they only need to be valid structure field
0043             % names. They need not have been defined in advance.
0044             store = incrementcounter(store, 'Aproducts');
0045         end
0046     end
0047     %
0048     problem.cost = @cost;
0049     function [f, store] = cost(x, store)
0050         t = tic();
0051         store = prepare(x, store);
0052         f = -.5*(x'*store.Ax);
0053         % Increment a counter for the number of calls to the cost function.
0054         store = incrementcounter(store, 'costcalls');
0055         % We also increment a counter with the amount of time spent in this
0056         % function (all counters are stored as doubles; here we exploit
0057         % this to track a non-integer quantity.)
0058         store = incrementcounter(store, 'functiontime', toc(t));
0059     end
0060     %
0061     problem.egrad = @egrad;
0062     function [g, store] = egrad(x, store)
0063         t = tic();
0064         store = prepare(x, store);
0065         g = -store.Ax;
0066         % Count the number of calls to the gradient function.
0067         store = incrementcounter(store, 'gradcalls');
0068         % We also record time spent in this call, atop the same counter as
0069         % for the cost function.
0070         store = incrementcounter(store, 'functiontime', toc(t));
0071     end
0072     %
0073     problem.ehess = @ehess;
0074     function [h, store] = ehess(x, xdot, store) %#ok<INUSL>
0075         t = tic();
0076         h = -A*xdot;
0077         % Count the number of calls to the Hessian operator and also count
0078         % the matrix-vector product with A.
0079         store = incrementcounter(store, 'hesscalls');
0080         store = incrementcounter(store, 'Aproducts');
0081         % We also record time spent in this call atop the cost and gradient.
0082         store = incrementcounter(store, 'functiontime', toc(t));
0083     end
0086     % Setup a callback to log statistics. We use a combination of
0087     % statscounters and of statsfunhelper to indicate which counters we
0088     % want the optimization algorithm to log. Here, stats is a structure
0089     % where each field is a function handle corresponding to one of the
0090     % counters. Before passing stats to statsfunhelper, we could decide to
0091     % add more fields to stats to log other things as well.
0092     stats = statscounters({'costcalls', 'gradcalls', 'hesscalls', ...
0093                            'Aproducts', 'functiontime'});
0094     options.statsfun = statsfunhelper(stats);
0096     % As an example: we could set up a stopping criterion based on the
0097     % number of matrix-vector products. A short version:
0098     % options.stopfun = @(problem, x, info, last) info(last).Aproducts > 250;
0099     % A longer version that also returns a reason string:
0100     options.stopfun = @stopfun;
0101     function [stop, reason] = stopfun(problem, x, info, last) %#ok<INUSL>
0102         reason = 'Exceeded Aproducts budget.';
0103         stop = (info(last).Aproducts > 250);   % true if budget exceeded
0104         % Here, info(last) contains the stats of the latest iteration.
0105         % That includes all registered counters.
0106     end
0108     % Solve with different solvers to compare.
0109     options.tolgradnorm = 1e-9;
0110     [x, xcost, infortr] = trustregions(problem, [], options); %#ok<ASGLU>
0111     [x, xcost, inforcg] = conjugategradient(problem, [], options); %#ok<ASGLU>
0112     [x, xcost, infobfg] = rlbfgs(problem, [], options); %#ok<ASGLU>
0115     % Display some statistics. The logged data is available in the info
0116     % struct-arrays. Notice how the counters are available by their
0117     % corresponding field name.
0118     figure(1);
0119     subplot(3, 3, 1);
0120     semilogy([infortr.iter], [infortr.gradnorm], '.-', ...
0121              [inforcg.iter], [inforcg.gradnorm], '.-', ...
0122              [infobfg.iter], [infobfg.gradnorm], '.-');
0123     legend('RTR', 'RCG', 'RLBFGS');
0124     xlabel('Iteration #');
0125     ylabel('Gradient norm');
0126     ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0127     subplot(3, 3, 2);
0128     semilogy([infortr.costcalls], [infortr.gradnorm], '.-', ...
0129              [inforcg.costcalls], [inforcg.gradnorm], '.-', ...
0130              [infobfg.costcalls], [infobfg.gradnorm], '.-');
0131     xlabel('# cost calls');
0132     ylabel('Gradient norm');
0133     ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0134     subplot(3, 3, 3);
0135     semilogy([infortr.gradcalls], [infortr.gradnorm], '.-', ...
0136              [inforcg.gradcalls], [inforcg.gradnorm], '.-', ...
0137              [infobfg.gradcalls], [infobfg.gradnorm], '.-');
0138     xlabel('# gradient calls');
0139     ylabel('Gradient norm');
0140     ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0141     subplot(3, 3, 4);
0142     semilogy([infortr.hesscalls], [infortr.gradnorm], '.-', ...
0143              [inforcg.hesscalls], [inforcg.gradnorm], '.-', ...
0144              [infobfg.hesscalls], [infobfg.gradnorm], '.-');
0145     xlabel('# Hessian calls');
0146     ylabel('Gradient norm');
0147     ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0148     subplot(3, 3, 5);
0149     semilogy([infortr.Aproducts], [infortr.gradnorm], '.-', ...
0150              [inforcg.Aproducts], [inforcg.gradnorm], '.-', ...
0151              [infobfg.Aproducts], [infobfg.gradnorm], '.-');
0152     xlabel('# matrix-vector products');
0153     ylabel('Gradient norm');
0154     ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0155     subplot(3, 3, 6);
0156     semilogy([infortr.time], [infortr.gradnorm], '.-', ...
0157              [inforcg.time], [inforcg.gradnorm], '.-', ...
0158              [infobfg.time], [infobfg.gradnorm], '.-');
0159     xlabel('Computation time [s]');
0160     ylabel('Gradient norm');
0161     ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0162     subplot(3, 3, 7);
0163     semilogy([infortr.functiontime], [infortr.gradnorm], '.-', ...
0164              [inforcg.functiontime], [inforcg.gradnorm], '.-', ...
0165              [infobfg.functiontime], [infobfg.gradnorm], '.-');
0166     xlabel('Time spent in cost/grad/hess [s]');
0167     ylabel('Gradient norm');
0168     ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0169     % The following plot allows to investigate what fraction of the time is
0170     % spent inside user-supplied function (cost/grad/hess) versus the total
0171     % time spent by the solver. This gives a sense of the relative
0172     % importance of cost function-related computational costs vs a solver's
0173     % inner workings, retractions, and other solver-specific operations.
0174     subplot(3, 3, 8);
0175     maxtime = max([[infortr.time], [inforcg.time], [infobfg.time]]);
0176     plot([infortr.time], [infortr.functiontime], '.-', ...
0177          [inforcg.time], [inforcg.functiontime], '.-', ...
0178          [infobfg.time], [infobfg.functiontime], '.-', ...
0179          [0, maxtime], [0, maxtime], 'k--');
0180     axis tight;
0181     xlabel('Total computation time [s]');
0182     ylabel(sprintf('Time spent in\ncost/grad/hess [s]'));
0184 end

