Home > examples > using_counters.m

using_counters

PURPOSE ^

Manopt example on how to use counters during optimization. Typical uses,

SYNOPSIS ^

function using_counters()

DESCRIPTION ^

 Manopt example on how to use counters during optimization. Typical uses,
 as demonstrated here, include counting calls to cost, gradient and
 Hessian functions. The example also demonstrates how to record total time
 spent in cost/grad/hess calls iteration by iteration.

 See also: statscounters incrementcounter statsfunhelper

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function using_counters()
0002 % Manopt example on how to use counters during optimization. Typical uses,
0003 % as demonstrated here, include counting calls to cost, gradient and
0004 % Hessian functions. The example also demonstrates how to record total time
0005 % spent in cost/grad/hess calls iteration by iteration.
0006 %
0007 % See also: statscounters incrementcounter statsfunhelper
0008 
0009 % This file is part of Manopt: www.manopt.org.
0010 % Original author: Nicolas Boumal, July 27, 2018.
0011 % Contributors:
0012 % Change log:
0013 
0014     % Fix random seed in Matlab (this particular syntax fails in Octave).
0015     if exist('OCTAVE_VERSION', 'builtin') == 0
0016         rng(0);
0017     end
0018 
0019     % Setup an optimization problem to illustrate the use of counters
0020     n = 1000;
0021     A = randn(n);
0022     A = .5*(A+A');
0023     
0024     manifold = spherefactory(n);
0025     problem.M = manifold;
0026     
0027     
0028     % Define the problem cost function and its gradient.
0029     %
0030     % Since the most expensive operation in computing the cost and the
0031     % gradient at x is the product A*x, and since this operation is the
0032     % same for both the cost and the gradient, we use the caching
0033     % functionalities of manopt for this product. This function ensures the
0034     % product A*x is available in the store structure. Remember that a
0035     % store structure is associated to a particular point x: if cost and
0036     % egrad are called on the same point x, they will see the same store.
0037     function store = prepare(x, store)
0038         if ~isfield(store, 'Ax')
0039             store.Ax = A*x;
0040             % Increment a counter for the number of matrix-vector products
0041             % involving A. The names of the counters (here, Aproducts) are
0042             % for us to choose: they only need to be valid structure field
0043             % names. They need not have been defined in advance.
0044             store = incrementcounter(store, 'Aproducts');
0045         end
0046     end
0047     %
0048     problem.cost = @cost;
0049     function [f, store] = cost(x, store)
0050         t = tic();
0051         store = prepare(x, store);
0052         f = -.5*(x'*store.Ax);
0053         % Increment a counter for the number of calls to the cost function.
0054         store = incrementcounter(store, 'costcalls');
0055         % We also increment a counter with the amount of time spent in this
0056         % function (all counters are stored as doubles; here we exploit
0057         % this to track a non-integer quantity.)
0058         store = incrementcounter(store, 'functiontime', toc(t));
0059     end
0060     %
0061     problem.egrad = @egrad;
0062     function [g, store] = egrad(x, store)
0063         t = tic();
0064         store = prepare(x, store);
0065         g = -store.Ax;
0066         % Count the number of calls to the gradient function.
0067         store = incrementcounter(store, 'gradcalls');
0068         % We also record time spent in this call, atop the same counter as
0069         % for the cost function.
0070         store = incrementcounter(store, 'functiontime', toc(t));
0071     end
0072     %
0073     problem.ehess = @ehess;
0074     function [h, store] = ehess(x, xdot, store) %#ok<INUSL>
0075         t = tic();
0076         h = -A*xdot;
0077         % Count the number of calls to the Hessian operator and also count
0078         % the matrix-vector product with A.
0079         store = incrementcounter(store, 'hesscalls');
0080         store = incrementcounter(store, 'Aproducts');
0081         % We also record time spent in this call atop the cost and gradient.
0082         store = incrementcounter(store, 'functiontime', toc(t));
0083     end
0084 
0085     
0086     % Setup a callback to log statistics. We use a combination of
0087     % statscounters and of statsfunhelper to indicate which counters we
0088     % want the optimization algorithm to log. Here, stats is a structure
0089     % where each field is a function handle corresponding to one of the
0090     % counters. Before passing stats to statsfunhelper, we could decide to
0091     % add more fields to stats to log other things as well.
0092     stats = statscounters({'costcalls', 'gradcalls', 'hesscalls', ...
0093                            'Aproducts', 'functiontime'});
0094     options.statsfun = statsfunhelper(stats);
0095 
0096     % As an example: we could set up a stopping criterion based on the
0097     % number of matrix-vector products. A short version:
0098     % options.stopfun = @(problem, x, info, last) info(last).Aproducts > 250;
0099     % A longer version that also returns a reason string:
0100     options.stopfun = @stopfun;
0101     function [stop, reason] = stopfun(problem, x, info, last) %#ok<INUSL>
0102         reason = 'Exceeded Aproducts budget.';
0103         stop = (info(last).Aproducts > 250);   % true if budget exceeded
0104         % Here, info(last) contains the stats of the latest iteration.
0105         % That includes all registered counters.
0106     end
0107     
0108     % Solve with different solvers to compare.
0109     options.tolgradnorm = 1e-9;
0110     [x, xcost, infortr] = trustregions(problem, [], options); %#ok<ASGLU>
0111     [x, xcost, inforcg] = conjugategradient(problem, [], options); %#ok<ASGLU>
0112     [x, xcost, infobfg] = rlbfgs(problem, [], options); %#ok<ASGLU>
0113     
0114     
0115     % Display some statistics. The logged data is available in the info
0116     % struct-arrays. Notice how the counters are available by their
0117     % corresponding field name.
0118     figure(1);
0119     subplot(3, 3, 1);
0120     semilogy([infortr.iter], [infortr.gradnorm], '.-', ...
0121              [inforcg.iter], [inforcg.gradnorm], '.-', ...
0122              [infobfg.iter], [infobfg.gradnorm], '.-');
0123     legend('RTR', 'RCG', 'RLBFGS');
0124     xlabel('Iteration #');
0125     ylabel('Gradient norm');
0126     ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0127     subplot(3, 3, 2);
0128     semilogy([infortr.costcalls], [infortr.gradnorm], '.-', ...
0129              [inforcg.costcalls], [inforcg.gradnorm], '.-', ...
0130              [infobfg.costcalls], [infobfg.gradnorm], '.-');
0131     xlabel('# cost calls');
0132     ylabel('Gradient norm');
0133     ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0134     subplot(3, 3, 3);
0135     semilogy([infortr.gradcalls], [infortr.gradnorm], '.-', ...
0136              [inforcg.gradcalls], [inforcg.gradnorm], '.-', ...
0137              [infobfg.gradcalls], [infobfg.gradnorm], '.-');
0138     xlabel('# gradient calls');
0139     ylabel('Gradient norm');
0140     ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0141     subplot(3, 3, 4);
0142     semilogy([infortr.hesscalls], [infortr.gradnorm], '.-', ...
0143              [inforcg.hesscalls], [inforcg.gradnorm], '.-', ...
0144              [infobfg.hesscalls], [infobfg.gradnorm], '.-');
0145     xlabel('# Hessian calls');
0146     ylabel('Gradient norm');
0147     ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0148     subplot(3, 3, 5);
0149     semilogy([infortr.Aproducts], [infortr.gradnorm], '.-', ...
0150              [inforcg.Aproducts], [inforcg.gradnorm], '.-', ...
0151              [infobfg.Aproducts], [infobfg.gradnorm], '.-');
0152     xlabel('# matrix-vector products');
0153     ylabel('Gradient norm');
0154     ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0155     subplot(3, 3, 6);
0156     semilogy([infortr.time], [infortr.gradnorm], '.-', ...
0157              [inforcg.time], [inforcg.gradnorm], '.-', ...
0158              [infobfg.time], [infobfg.gradnorm], '.-');
0159     xlabel('Computation time [s]');
0160     ylabel('Gradient norm');
0161     ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0162     subplot(3, 3, 7);
0163     semilogy([infortr.functiontime], [infortr.gradnorm], '.-', ...
0164              [inforcg.functiontime], [inforcg.gradnorm], '.-', ...
0165              [infobfg.functiontime], [infobfg.gradnorm], '.-');
0166     xlabel('Time spent in cost/grad/hess [s]');
0167     ylabel('Gradient norm');
0168     ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0169     % The following plot allows to investigate what fraction of the time is
0170     % spent inside user-supplied function (cost/grad/hess) versus the total
0171     % time spent by the solver. This gives a sense of the relative
0172     % importance of cost function-related computational costs vs a solver's
0173     % inner workings, retractions, and other solver-specific operations.
0174     subplot(3, 3, 8);
0175     maxtime = max([[infortr.time], [inforcg.time], [infobfg.time]]);
0176     plot([infortr.time], [infortr.functiontime], '.-', ...
0177          [inforcg.time], [inforcg.functiontime], '.-', ...
0178          [infobfg.time], [infobfg.functiontime], '.-', ...
0179          [0, maxtime], [0, maxtime], 'k--');
0180     axis tight;
0181     xlabel('Total computation time [s]');
0182     ylabel(sprintf('Time spent in\ncost/grad/hess [s]'));
0183     
0184 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005