## PURPOSE

Stochastic gradient (SG) minimization algorithm for Manopt.

## SYNOPSIS

function [x, info, options] = stochasticgradient(problem, x, options)

## DESCRIPTION

 Stochastic gradient (SG) minimization algorithm for Manopt.

function [x, info, options] = stochasticgradient(problem)
function [x, info, options] = stochasticgradient(problem, x0)
function [x, info, options] = stochasticgradient(problem, x0, options)
function [x, info, options] = stochasticgradient(problem, [], options)

Apply the Riemannian stochastic gradient algorithm to the problem defined
in the problem structure, starting at x0 if it is provided (otherwise, at
a random point on the manifold). To specify options whilst not specifying
an initial guess, give x0 as [] (the empty matrix).

The problem structure must contain the following fields:

problem.M:
Defines the manifold to optimize over, given by a factory.

Describes the partial gradients of the cost function. If the cost
function is of the form f(x) = sum_{k=1}^N f_k(x),
converted automatically to a Riemannian gradient. Use the tool
checkgradient(problem) to check it. K is a /row/ vector, which
makes it natural to write for k = K, ..., end.

problem.ncostterms
An integer specifying how many terms are in the cost function (in
the example above, that would be N.)

Importantly, the cost function itself needs not be specified.

Some of the options of the solver are specific to this file. Please have
a look inside the code.

To record the value of the cost function or the norm of the gradient for
example (which are statistics the algorithm does not require and hence
does not compute by default), one can set the following options:

metrics.cost = @(problem, x) getCost(problem, x);
options.statsfun = statsfunhelper(metrics);

Important caveat: stochastic algorithms usually return an average of the
last few iterates. Computing averages on manifolds can be expensive.
Currently, this solver does not compute averages and simply returns the
last iterate. Using options.statsfun, it is possible for the user to
compute averages manually. If you have ideas on how to do this
generically, we welcome feedback. In particular, approximate means could
be computed with M.pairmean which is available in many geometries.

See also: steepestdescent

## CROSS-REFERENCE INFORMATION

This function calls:
• StoreDB
• applyStatsfun Apply the statsfun function to a stats structure (for solvers).
• canGetPartialGradient Checks whether the partial gradient can be computed for a given problem.
• getGlobalDefaults Returns a structure with default option values for Manopt.
• getPartialGradient Computes the gradient of a subset of terms in the cost function at x.
• getPrecon Applies the preconditioner for the Hessian of the cost at x along d.
• mergeOptions Merges two options structures with one having precedence over the other.
• stoppingcriterion Checks for standard stopping criteria, as a helper to solvers.
• stepsize_sg Standard step size selection algorithm for the stochastic gradient method
This function is called by:
• PCA_stochastic Example of stochastic gradient algorithm in Manopt on a PCA problem.

## SOURCE CODE

0001 function [x, info, options] = stochasticgradient(problem, x, options)
0002 % Stochastic gradient (SG) minimization algorithm for Manopt.
0003 %
0004 % function [x, info, options] = stochasticgradient(problem)
0005 % function [x, info, options] = stochasticgradient(problem, x0)
0006 % function [x, info, options] = stochasticgradient(problem, x0, options)
0007 % function [x, info, options] = stochasticgradient(problem, [], options)
0008 %
0009 % Apply the Riemannian stochastic gradient algorithm to the problem defined
0010 % in the problem structure, starting at x0 if it is provided (otherwise, at
0011 % a random point on the manifold). To specify options whilst not specifying
0012 % an initial guess, give x0 as [] (the empty matrix).
0013 %
0014 % The problem structure must contain the following fields:
0015 %
0016 %  problem.M:
0017 %       Defines the manifold to optimize over, given by a factory.
0018 %
0020 %       Describes the partial gradients of the cost function. If the cost
0021 %       function is of the form f(x) = sum_{k=1}^N f_k(x),
0025 %       converted automatically to a Riemannian gradient. Use the tool
0026 %       checkgradient(problem) to check it. K is a /row/ vector, which
0027 %       makes it natural to write for k = K, ..., end.
0028 %
0029 %  problem.ncostterms
0030 %       An integer specifying how many terms are in the cost function (in
0031 %       the example above, that would be N.)
0032 %
0033 % Importantly, the cost function itself needs not be specified.
0034 %
0035 % Some of the options of the solver are specific to this file. Please have
0036 % a look inside the code.
0037 %
0038 % To record the value of the cost function or the norm of the gradient for
0039 % example (which are statistics the algorithm does not require and hence
0040 % does not compute by default), one can set the following options:
0041 %
0042 %   metrics.cost = @(problem, x) getCost(problem, x);
0044 %   options.statsfun = statsfunhelper(metrics);
0045 %
0046 % Important caveat: stochastic algorithms usually return an average of the
0047 % last few iterates. Computing averages on manifolds can be expensive.
0048 % Currently, this solver does not compute averages and simply returns the
0049 % last iterate. Using options.statsfun, it is possible for the user to
0050 % compute averages manually. If you have ideas on how to do this
0051 % generically, we welcome feedback. In particular, approximate means could
0052 % be computed with M.pairmean which is available in many geometries.
0053 %
0055
0056 % This file is part of Manopt: www.manopt.org.
0057 % Original authors: Bamdev Mishra <bamdevm@gmail.com>,
0058 %                   Hiroyuki Kasai <kasai@is.uec.ac.jp>, and
0059 %                   Hiroyuki Sato <hsato@ms.kagu.tus.ac.jp>, 22 April 2016.
0060 % Contributors: Nicolas Boumal
0061 % Change log:
0062 %
0063 %   06 July 2019 (BM):
0065
0066
0067     % Verify that the problem description is sufficient for the solver.
0070          'No partial gradient provided. The algorithm will likely abort.');
0071     end
0072
0073
0074     % Set local default
0075     localdefaults.maxiter = 1000;       % Maximum number of iterations
0076     localdefaults.batchsize = 1;        % Batchsize (# cost terms per iter)
0077     localdefaults.verbosity = 2;        % Output verbosity (0, 1 or 2)
0078     localdefaults.storedepth = 20;      % Limit amount of caching
0079
0080     % Check stopping criteria and save stats every checkperiod iterations.
0081     localdefaults.checkperiod = 100;
0082
0083     % stepsizefun is a function implementing a step size selection
0084     % algorithm. See that function for help with options, which can be
0085     % specified in the options structure passed to the solver directly.
0086     localdefaults.stepsizefun = @stepsize_sg;
0087
0088     % Merge global and local defaults, then merge w/ user options, if any.
0089     localdefaults = mergeOptions(getGlobalDefaults(), localdefaults);
0090     if ~exist('options', 'var') || isempty(options)
0091         options = struct();
0092     end
0093     options = mergeOptions(localdefaults, options);
0094
0095
0096     assert(options.checkperiod >= 1, ...
0097                  'options.checkperiod must be a positive integer (>= 1).');
0098
0099
0100     % If no initial point x is given by the user, generate one at random.
0101     if ~exist('x', 'var') || isempty(x)
0102         x = problem.M.rand();
0103     end
0104
0105     % Create a store database and get a key for the current x
0106     storedb = StoreDB(options.storedepth);
0107     key = storedb.getNewKey();
0108
0109
0110     % Elapsed time for the current set of iterations, where a set of
0111     % iterations comprises options.checkperiod iterations. We do not
0112     % count time spent for such things as logging statistics, as these are
0113     % not relevant to the actual optimization process.
0114     elapsed_time = 0;
0115
0116     % Total number of completed steps
0117     iter = 0;
0118
0119
0120     % Total number of saved stats at this point.
0121     savedstats = 0;
0122
0123     % Collect and save stats in a struct array info, and preallocate.
0124     stats = savestats();
0125     info(1) = stats;
0126     savedstats = savedstats + 1;
0127     if isinf(options.maxiter)
0128         % We trust that if the user set maxiter = inf, then they defined
0129         % another stopping criterion.
0130         preallocate = 1e5;
0131     else
0132         preallocate = ceil(options.maxiter / options.checkperiod) + 1;
0133     end
0134     info(preallocate).iter = [];
0135
0136
0137     % Display information header for the user.
0138     if options.verbosity >= 2
0139         fprintf('    iter       time [s]       step size\n');
0140     end
0141
0142
0143     % Main loop.
0144     stop = false;
0145     while iter < options.maxiter
0146
0147         % Record start time.
0148         start_time = tic();
0149
0150         % Draw the samples with replacement.
0151         idx_batch = randi(problem.ncostterms, options.batchsize, 1);
0152
0153         % Compute partial gradient on this batch.
0155
0156         % Apply preconditioner to the partial gradient.
0158
0159         % Compute a step size and the corresponding new point x.
0160         [stepsize, newx, newkey, ssstats] = ...
0161                            options.stepsizefun(problem, x, Ppgrad, iter, ...
0162                                                options, storedb, key);
0163
0164         % Make the step: transfer iterate, remove cache from previous x.
0165         storedb.removefirstifdifferent(key, newkey);
0166         x = newx;
0167         key = newkey;
0168
0169         % Make sure we do not use too much memory for the store database.
0170         storedb.purge();
0171
0172         % Total number of completed steps.
0173         iter = iter + 1;
0174
0175         % Elapsed time doing actual optimization work so far in this
0176         % set of options.checkperiod iterations.
0177         elapsed_time = elapsed_time + toc(start_time);
0178
0179
0180         % Check stopping criteria and save stats every checkperiod iters.
0181         if mod(iter, options.checkperiod) == 0
0182
0183             % Log statistics for freshly executed iteration.
0184             stats = savestats();
0185             info(savedstats+1) = stats;
0186             savedstats = savedstats + 1;
0187
0188             % Reset timer.
0189             elapsed_time = 0;
0190
0191             % Print output.
0192             if options.verbosity >= 2
0193                 fprintf('%8d     %10.2f       %.3e\n', ...
0194                                                iter, stats.time, stepsize);
0195             end
0196
0197             % Run standard stopping criterion checks.
0198             [stop, reason] = stoppingcriterion(problem, x, ...
0199                                                options, info, savedstats);
0200             if stop
0201                 if options.verbosity >= 1
0202                     fprintf([reason '\n']);
0203                 end
0204                 break;
0205             end
0206
0207         end
0208
0209     end
0210
0211
0212     % Keep only the relevant portion of the info struct-array.
0213     info = info(1:savedstats);
0214
0215
0216     % Display a final information message.
0217     if options.verbosity >= 1
0218         if ~stop
0219             % We stopped not because of stoppingcriterion but because the
0220             % loop came to an end, which means maxiter triggered.
0221             msg = 'Max iteration count reached; options.maxiter = %g.\n';
0222             fprintf(msg, options.maxiter);
0223         end
0224         fprintf('Total time is %f [s] (excludes statsfun)\n', ...
0225                 info(end).time + elapsed_time);
0226     end
0227
0228
0229     % Helper function to collect statistics to be saved at
0230     % index checkperiodcount+1 in info.
0231     function stats = savestats()
0232         stats.iter = iter;
0233         if savedstats == 0
0234             stats.time = 0;
0235             stats.stepsize = NaN;
0236             stats.stepsize_stats = [];
0237         else
0238             stats.time = info(savedstats).time + elapsed_time;
0239             stats.stepsize = stepsize;
0240             stats.stepsize_stats = ssstats;
0241         end
0242         stats = applyStatsfun(problem, x, storedb, key, options, stats);
0243     end
0244
0245 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005