0001 function using_counters()
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015 if exist('OCTAVE_VERSION', 'builtin') == 0
0016 rng(0);
0017 end
0018
0019
0020 n = 1000;
0021 A = randn(n);
0022 A = .5*(A+A');
0023
0024 manifold = spherefactory(n);
0025 problem.M = manifold;
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037 function store = prepare(x, store)
0038 if ~isfield(store, 'Ax')
0039 store.Ax = A*x;
0040
0041
0042
0043
0044 store = incrementcounter(store, 'Aproducts');
0045 end
0046 end
0047
0048 problem.cost = @cost;
0049 function [f, store] = cost(x, store)
0050 t = tic();
0051 store = prepare(x, store);
0052 f = -.5*(x'*store.Ax);
0053
0054 store = incrementcounter(store, 'costcalls');
0055
0056
0057
0058 store = incrementcounter(store, 'functiontime', toc(t));
0059 end
0060
0061 problem.egrad = @egrad;
0062 function [g, store] = egrad(x, store)
0063 t = tic();
0064 store = prepare(x, store);
0065 g = -store.Ax;
0066
0067 store = incrementcounter(store, 'gradcalls');
0068
0069
0070 store = incrementcounter(store, 'functiontime', toc(t));
0071 end
0072
0073 problem.ehess = @ehess;
0074 function [h, store] = ehess(x, xdot, store)
0075 t = tic();
0076 h = -A*xdot;
0077
0078
0079 store = incrementcounter(store, 'hesscalls');
0080 store = incrementcounter(store, 'Aproducts');
0081
0082 store = incrementcounter(store, 'functiontime', toc(t));
0083 end
0084
0085
0086
0087
0088
0089
0090
0091
0092 stats = statscounters({'costcalls', 'gradcalls', 'hesscalls', ...
0093 'Aproducts', 'functiontime'});
0094 options.statsfun = statsfunhelper(stats);
0095
0096
0097
0098
0099
0100 options.stopfun = @stopfun;
0101 function [stop, reason] = stopfun(problem, x, info, last)
0102 reason = 'Exceeded Aproducts budget.';
0103 stop = (info(last).Aproducts > 250);
0104
0105
0106 end
0107
0108
0109 options.tolgradnorm = 1e-9;
0110 [x, xcost, infortr] = trustregions(problem, [], options);
0111 [x, xcost, inforcg] = conjugategradient(problem, [], options);
0112 [x, xcost, infobfg] = rlbfgs(problem, [], options);
0113
0114
0115
0116
0117
0118 figure(1);
0119 subplot(3, 3, 1);
0120 semilogy([infortr.iter], [infortr.gradnorm], '.-', ...
0121 [inforcg.iter], [inforcg.gradnorm], '.-', ...
0122 [infobfg.iter], [infobfg.gradnorm], '.-');
0123 legend('RTR', 'RCG', 'RLBFGS');
0124 xlabel('Iteration #');
0125 ylabel('Gradient norm');
0126 ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0127 subplot(3, 3, 2);
0128 semilogy([infortr.costcalls], [infortr.gradnorm], '.-', ...
0129 [inforcg.costcalls], [inforcg.gradnorm], '.-', ...
0130 [infobfg.costcalls], [infobfg.gradnorm], '.-');
0131 xlabel('# cost calls');
0132 ylabel('Gradient norm');
0133 ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0134 subplot(3, 3, 3);
0135 semilogy([infortr.gradcalls], [infortr.gradnorm], '.-', ...
0136 [inforcg.gradcalls], [inforcg.gradnorm], '.-', ...
0137 [infobfg.gradcalls], [infobfg.gradnorm], '.-');
0138 xlabel('# gradient calls');
0139 ylabel('Gradient norm');
0140 ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0141 subplot(3, 3, 4);
0142 semilogy([infortr.hesscalls], [infortr.gradnorm], '.-', ...
0143 [inforcg.hesscalls], [inforcg.gradnorm], '.-', ...
0144 [infobfg.hesscalls], [infobfg.gradnorm], '.-');
0145 xlabel('# Hessian calls');
0146 ylabel('Gradient norm');
0147 ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0148 subplot(3, 3, 5);
0149 semilogy([infortr.Aproducts], [infortr.gradnorm], '.-', ...
0150 [inforcg.Aproducts], [inforcg.gradnorm], '.-', ...
0151 [infobfg.Aproducts], [infobfg.gradnorm], '.-');
0152 xlabel('# matrix-vector products');
0153 ylabel('Gradient norm');
0154 ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0155 subplot(3, 3, 6);
0156 semilogy([infortr.time], [infortr.gradnorm], '.-', ...
0157 [inforcg.time], [inforcg.gradnorm], '.-', ...
0158 [infobfg.time], [infobfg.gradnorm], '.-');
0159 xlabel('Computation time [s]');
0160 ylabel('Gradient norm');
0161 ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0162 subplot(3, 3, 7);
0163 semilogy([infortr.functiontime], [infortr.gradnorm], '.-', ...
0164 [inforcg.functiontime], [inforcg.gradnorm], '.-', ...
0165 [infobfg.functiontime], [infobfg.gradnorm], '.-');
0166 xlabel('Time spent in cost/grad/hess [s]');
0167 ylabel('Gradient norm');
0168 ylim([1e-12, 1e2]); set(gca, 'YTick', [1e-12, 1e-6, 1e0]);
0169
0170
0171
0172
0173
0174 subplot(3, 3, 8);
0175 maxtime = max([[infortr.time], [inforcg.time], [infobfg.time]]);
0176 plot([infortr.time], [infortr.functiontime], '.-', ...
0177 [inforcg.time], [inforcg.functiontime], '.-', ...
0178 [infobfg.time], [infobfg.functiontime], '.-', ...
0179 [0, maxtime], [0, maxtime], 'k--');
0180 axis tight;
0181 xlabel('Total computation time [s]');
0182 ylabel(sprintf('Time spent in\ncost/grad/hess [s]'));
0183
0184 end