Stochastic gradient (SG) minimization algorithm for Manopt. function [x, info, options] = stochasticgradient(problem) function [x, info, options] = stochasticgradient(problem, x0) function [x, info, options] = stochasticgradient(problem, x0, options) function [x, info, options] = stochasticgradient(problem, [], options) Apply the Riemannian stochastic gradient algorithm to the problem defined in the problem structure, starting at x0 if it is provided (otherwise, at a random point on the manifold). To specify options whilst not specifying an initial guess, give x0 as [] (the empty matrix). The problem structure must contain the following fields: problem.M: Defines the manifold to optimize over, given by a factory. problem.partialgrad or problem.partialegrad (or equivalent) Describes the partial gradients of the cost function. If the cost function is of the form f(x) = sum_{k=1}^N f_k(x), then partialegrad(x, K) = sum_{k \in K} grad f_k(x). As usual, partialgrad must define the Riemannian gradient, whereas partialegrad defines a Euclidean (classical) gradient which will be converted automatically to a Riemannian gradient. Use the tool checkgradient(problem) to check it. K is a /row/ vector, which makes it natural to write for k = K, ..., end. problem.ncostterms An integer specifying how many terms are in the cost function (in the example above, that would be N.) Importantly, the cost function itself needs not be specified. Some of the options of the solver are specific to this file. Please have a look inside the code. To record the value of the cost function or the norm of the gradient for example (which are statistics the algorithm does not require and hence does not compute by default), one can set the following options: metrics.cost = @(problem, x) getCost(problem, x); metrics.gradnorm = @(problem, x) problem.M.norm(x, getGradient(problem, x)); options.statsfun = statsfunhelper(metrics); Important caveat: stochastic algorithms usually return an average of the last few iterates. Computing averages on manifolds can be expensive. Currently, this solver does not compute averages and simply returns the last iterate. Using options.statsfun, it is possible for the user to compute averages manually. If you have ideas on how to do this generically, we welcome feedback. In particular, approximate means could be computed with M.pairmean which is available in many geometries. See also: steepestdescent
0001 function [x, info, options] = stochasticgradient(problem, x, options) 0002 % Stochastic gradient (SG) minimization algorithm for Manopt. 0003 % 0004 % function [x, info, options] = stochasticgradient(problem) 0005 % function [x, info, options] = stochasticgradient(problem, x0) 0006 % function [x, info, options] = stochasticgradient(problem, x0, options) 0007 % function [x, info, options] = stochasticgradient(problem, [], options) 0008 % 0009 % Apply the Riemannian stochastic gradient algorithm to the problem defined 0010 % in the problem structure, starting at x0 if it is provided (otherwise, at 0011 % a random point on the manifold). To specify options whilst not specifying 0012 % an initial guess, give x0 as [] (the empty matrix). 0013 % 0014 % The problem structure must contain the following fields: 0015 % 0016 % problem.M: 0017 % Defines the manifold to optimize over, given by a factory. 0018 % 0019 % problem.partialgrad or problem.partialegrad (or equivalent) 0020 % Describes the partial gradients of the cost function. If the cost 0021 % function is of the form f(x) = sum_{k=1}^N f_k(x), 0022 % then partialegrad(x, K) = sum_{k \in K} grad f_k(x). 0023 % As usual, partialgrad must define the Riemannian gradient, whereas 0024 % partialegrad defines a Euclidean (classical) gradient which will be 0025 % converted automatically to a Riemannian gradient. Use the tool 0026 % checkgradient(problem) to check it. K is a /row/ vector, which 0027 % makes it natural to write for k = K, ..., end. 0028 % 0029 % problem.ncostterms 0030 % An integer specifying how many terms are in the cost function (in 0031 % the example above, that would be N.) 0032 % 0033 % Importantly, the cost function itself needs not be specified. 0034 % 0035 % Some of the options of the solver are specific to this file. Please have 0036 % a look inside the code. 0037 % 0038 % To record the value of the cost function or the norm of the gradient for 0039 % example (which are statistics the algorithm does not require and hence 0040 % does not compute by default), one can set the following options: 0041 % 0042 % metrics.cost = @(problem, x) getCost(problem, x); 0043 % metrics.gradnorm = @(problem, x) problem.M.norm(x, getGradient(problem, x)); 0044 % options.statsfun = statsfunhelper(metrics); 0045 % 0046 % Important caveat: stochastic algorithms usually return an average of the 0047 % last few iterates. Computing averages on manifolds can be expensive. 0048 % Currently, this solver does not compute averages and simply returns the 0049 % last iterate. Using options.statsfun, it is possible for the user to 0050 % compute averages manually. If you have ideas on how to do this 0051 % generically, we welcome feedback. In particular, approximate means could 0052 % be computed with M.pairmean which is available in many geometries. 0053 % 0054 % See also: steepestdescent 0055 0056 % This file is part of Manopt: www.manopt.org. 0057 % Original authors: Bamdev Mishra <bamdevm@gmail.com>, 0058 % Hiroyuki Kasai <kasai@is.uec.ac.jp>, and 0059 % Hiroyuki Sato <hsato@ms.kagu.tus.ac.jp>, 22 April 2016. 0060 % Contributors: Nicolas Boumal 0061 % Change log: 0062 % 0063 % 06 July 2019 (BM): 0064 % Added preconditioner support. This allows to use adaptive algorithms. 0065 0066 0067 % Verify that the problem description is sufficient for the solver. 0068 if ~canGetPartialGradient(problem) 0069 warning('manopt:getPartialGradient', ... 0070 'No partial gradient provided. The algorithm will likely abort.'); 0071 end 0072 0073 0074 % Set local default 0075 localdefaults.maxiter = 1000; % Maximum number of iterations 0076 localdefaults.batchsize = 1; % Batchsize (# cost terms per iter) 0077 localdefaults.verbosity = 2; % Output verbosity (0, 1 or 2) 0078 localdefaults.storedepth = 20; % Limit amount of caching 0079 0080 % Check stopping criteria and save stats every checkperiod iterations. 0081 localdefaults.checkperiod = 100; 0082 0083 % stepsizefun is a function implementing a step size selection 0084 % algorithm. See that function for help with options, which can be 0085 % specified in the options structure passed to the solver directly. 0086 localdefaults.stepsizefun = @stepsize_sg; 0087 0088 % Merge global and local defaults, then merge w/ user options, if any. 0089 localdefaults = mergeOptions(getGlobalDefaults(), localdefaults); 0090 if ~exist('options', 'var') || isempty(options) 0091 options = struct(); 0092 end 0093 options = mergeOptions(localdefaults, options); 0094 0095 0096 assert(options.checkperiod >= 1, ... 0097 'options.checkperiod must be a positive integer (>= 1).'); 0098 0099 0100 % If no initial point x is given by the user, generate one at random. 0101 if ~exist('x', 'var') || isempty(x) 0102 x = problem.M.rand(); 0103 end 0104 0105 % Create a store database and get a key for the current x 0106 storedb = StoreDB(options.storedepth); 0107 key = storedb.getNewKey(); 0108 0109 0110 % Elapsed time for the current set of iterations, where a set of 0111 % iterations comprises options.checkperiod iterations. We do not 0112 % count time spent for such things as logging statistics, as these are 0113 % not relevant to the actual optimization process. 0114 elapsed_time = 0; 0115 0116 % Total number of completed steps 0117 iter = 0; 0118 0119 0120 % Total number of saved stats at this point. 0121 savedstats = 0; 0122 0123 % Collect and save stats in a struct array info, and preallocate. 0124 stats = savestats(); 0125 info(1) = stats; 0126 savedstats = savedstats + 1; 0127 if isinf(options.maxiter) 0128 % We trust that if the user set maxiter = inf, then they defined 0129 % another stopping criterion. 0130 preallocate = 1e5; 0131 else 0132 preallocate = ceil(options.maxiter / options.checkperiod) + 1; 0133 end 0134 info(preallocate).iter = []; 0135 0136 0137 % Display information header for the user. 0138 if options.verbosity >= 2 0139 fprintf(' iter time [s] step size\n'); 0140 end 0141 0142 0143 % Main loop. 0144 stop = false; 0145 while iter < options.maxiter 0146 0147 % Record start time. 0148 start_time = tic(); 0149 0150 % Draw the samples with replacement. 0151 idx_batch = randi(problem.ncostterms, options.batchsize, 1); 0152 0153 % Compute partial gradient on this batch. 0154 pgrad = getPartialGradient(problem, x, idx_batch, storedb, key); 0155 0156 % Apply preconditioner to the partial gradient. 0157 Ppgrad = getPrecon(problem, x, pgrad, storedb, key); 0158 0159 % Compute a step size and the corresponding new point x. 0160 [stepsize, newx, newkey, ssstats] = ... 0161 options.stepsizefun(problem, x, Ppgrad, iter, ... 0162 options, storedb, key); 0163 0164 % Make the step: transfer iterate, remove cache from previous x. 0165 storedb.removefirstifdifferent(key, newkey); 0166 x = newx; 0167 key = newkey; 0168 0169 % Make sure we do not use too much memory for the store database. 0170 storedb.purge(); 0171 0172 % Total number of completed steps. 0173 iter = iter + 1; 0174 0175 % Elapsed time doing actual optimization work so far in this 0176 % set of options.checkperiod iterations. 0177 elapsed_time = elapsed_time + toc(start_time); 0178 0179 0180 % Check stopping criteria and save stats every checkperiod iters. 0181 if mod(iter, options.checkperiod) == 0 0182 0183 % Log statistics for freshly executed iteration. 0184 stats = savestats(); 0185 info(savedstats+1) = stats; 0186 savedstats = savedstats + 1; 0187 0188 % Reset timer. 0189 elapsed_time = 0; 0190 0191 % Print output. 0192 if options.verbosity >= 2 0193 fprintf('%8d %10.2f %.3e\n', ... 0194 iter, stats.time, stepsize); 0195 end 0196 0197 % Run standard stopping criterion checks. 0198 [stop, reason] = stoppingcriterion(problem, x, ... 0199 options, info, savedstats); 0200 if stop 0201 if options.verbosity >= 1 0202 fprintf([reason '\n']); 0203 end 0204 break; 0205 end 0206 0207 end 0208 0209 end 0210 0211 0212 % Keep only the relevant portion of the info struct-array. 0213 info = info(1:savedstats); 0214 0215 0216 % Display a final information message. 0217 if options.verbosity >= 1 0218 if ~stop 0219 % We stopped not because of stoppingcriterion but because the 0220 % loop came to an end, which means maxiter triggered. 0221 msg = 'Max iteration count reached; options.maxiter = %g.\n'; 0222 fprintf(msg, options.maxiter); 0223 end 0224 fprintf('Total time is %f [s] (excludes statsfun)\n', ... 0225 info(end).time + elapsed_time); 0226 end 0227 0228 0229 % Helper function to collect statistics to be saved at 0230 % index checkperiodcount+1 in info. 0231 function stats = savestats() 0232 stats.iter = iter; 0233 if savedstats == 0 0234 stats.time = 0; 0235 stats.stepsize = NaN; 0236 stats.stepsize_stats = []; 0237 else 0238 stats.time = info(savedstats).time + elapsed_time; 0239 stats.stepsize = stepsize; 0240 stats.stepsize_stats = ssstats; 0241 end 0242 stats = applyStatsfun(problem, x, storedb, key, options, stats); 0243 end 0244 0245 end