Home > manopt > solvers > stochasticgradient > stochasticgradient.m

stochasticgradient

PURPOSE ^

Stochastic gradient (SG) minimization algorithm for Manopt.

SYNOPSIS ^

function [x, info, options] = stochasticgradient(problem, x, options)

DESCRIPTION ^

 Stochastic gradient (SG) minimization algorithm for Manopt.

 function [x, info, options] = stochasticgradient(problem)
 function [x, info, options] = stochasticgradient(problem, x0)
 function [x, info, options] = stochasticgradient(problem, x0, options)
 function [x, info, options] = stochasticgradient(problem, [], options)

 Apply the Riemannian stochastic gradient algorithm to the problem defined
 in the problem structure, starting at x0 if it is provided (otherwise, at
 a random point on the manifold). To specify options whilst not specifying
 an initial guess, give x0 as [] (the empty matrix).

 The problem structure must contain the following fields:

  problem.M:
       Defines the manifold to optimize over, given by a factory.

  problem.partialgrad or problem.partialegrad (or equivalent)
       Describes the partial gradients of the cost function. If the cost
       function is of the form f(x) = sum_{k=1}^N f_k(x),
       then partialegrad(x, K) = sum_{k \in K} grad f_k(x).
       As usual, partialgrad must define the Riemannian gradient, whereas
       partialegrad defines a Euclidean (classical) gradient which will be
       converted automatically to a Riemannian gradient. Use the tool
       checkgradient(problem) to check it. K is a /row/ vector, which
       makes it natural to write for k = K, ..., end.

  problem.ncostterms
       An integer specifying how many terms are in the cost function (in
       the example above, that would be N.)

 Importantly, the cost function itself needs not be specified.

 Some of the options of the solver are specific to this file. Please have
 a look inside the code.

 To record the value of the cost function or the norm of the gradient for
 example (which are statistics the algorithm does not require and hence
 does not compute by default), one can set the following options:

   metrics.cost = @(problem, x) getCost(problem, x);
   metrics.gradnorm = @(problem, x) problem.M.norm(x, getGradient(problem, x));
   options.statsfun = statsfunhelper(metrics);

 Important caveat: stochastic algorithms usually return an average of the
 last few iterates. Computing averages on manifolds can be expensive.
 Currently, this solver does not compute averages and simply returns the
 last iterate. Using options.statsfun, it is possible for the user to
 compute averages manually. If you have ideas on how to do this
 generically, we welcome feedback. In particular, approximate means could
 be computed with M.pairmean which is available in many geometries.

 See also: steepestdescent

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [x, info, options] = stochasticgradient(problem, x, options)
0002 % Stochastic gradient (SG) minimization algorithm for Manopt.
0003 %
0004 % function [x, info, options] = stochasticgradient(problem)
0005 % function [x, info, options] = stochasticgradient(problem, x0)
0006 % function [x, info, options] = stochasticgradient(problem, x0, options)
0007 % function [x, info, options] = stochasticgradient(problem, [], options)
0008 %
0009 % Apply the Riemannian stochastic gradient algorithm to the problem defined
0010 % in the problem structure, starting at x0 if it is provided (otherwise, at
0011 % a random point on the manifold). To specify options whilst not specifying
0012 % an initial guess, give x0 as [] (the empty matrix).
0013 %
0014 % The problem structure must contain the following fields:
0015 %
0016 %  problem.M:
0017 %       Defines the manifold to optimize over, given by a factory.
0018 %
0019 %  problem.partialgrad or problem.partialegrad (or equivalent)
0020 %       Describes the partial gradients of the cost function. If the cost
0021 %       function is of the form f(x) = sum_{k=1}^N f_k(x),
0022 %       then partialegrad(x, K) = sum_{k \in K} grad f_k(x).
0023 %       As usual, partialgrad must define the Riemannian gradient, whereas
0024 %       partialegrad defines a Euclidean (classical) gradient which will be
0025 %       converted automatically to a Riemannian gradient. Use the tool
0026 %       checkgradient(problem) to check it. K is a /row/ vector, which
0027 %       makes it natural to write for k = K, ..., end.
0028 %
0029 %  problem.ncostterms
0030 %       An integer specifying how many terms are in the cost function (in
0031 %       the example above, that would be N.)
0032 %
0033 % Importantly, the cost function itself needs not be specified.
0034 %
0035 % Some of the options of the solver are specific to this file. Please have
0036 % a look inside the code.
0037 %
0038 % To record the value of the cost function or the norm of the gradient for
0039 % example (which are statistics the algorithm does not require and hence
0040 % does not compute by default), one can set the following options:
0041 %
0042 %   metrics.cost = @(problem, x) getCost(problem, x);
0043 %   metrics.gradnorm = @(problem, x) problem.M.norm(x, getGradient(problem, x));
0044 %   options.statsfun = statsfunhelper(metrics);
0045 %
0046 % Important caveat: stochastic algorithms usually return an average of the
0047 % last few iterates. Computing averages on manifolds can be expensive.
0048 % Currently, this solver does not compute averages and simply returns the
0049 % last iterate. Using options.statsfun, it is possible for the user to
0050 % compute averages manually. If you have ideas on how to do this
0051 % generically, we welcome feedback. In particular, approximate means could
0052 % be computed with M.pairmean which is available in many geometries.
0053 %
0054 % See also: steepestdescent
0055 
0056 % This file is part of Manopt: www.manopt.org.
0057 % Original authors: Bamdev Mishra <bamdevm@gmail.com>,
0058 %                   Hiroyuki Kasai <kasai@is.uec.ac.jp>, and
0059 %                   Hiroyuki Sato <hsato@ms.kagu.tus.ac.jp>, 22 April 2016.
0060 % Contributors: Nicolas Boumal
0061 % Change log:
0062 %
0063 %   06 July 2019 (BM):
0064 %      Added preconditioner support. This allows to use adaptive algorithms.
0065     
0066 
0067     % Verify that the problem description is sufficient for the solver.
0068     if ~canGetPartialGradient(problem)
0069         warning('manopt:getPartialGradient', ...
0070          'No partial gradient provided. The algorithm will likely abort.');
0071     end
0072     
0073    
0074     % Set local default
0075     localdefaults.maxiter = 1000;       % Maximum number of iterations
0076     localdefaults.batchsize = 1;        % Batchsize (# cost terms per iter)
0077     localdefaults.verbosity = 2;        % Output verbosity (0, 1 or 2)
0078     localdefaults.storedepth = 20;      % Limit amount of caching
0079     
0080     % Check stopping criteria and save stats every checkperiod iterations.
0081     localdefaults.checkperiod = 100;
0082     
0083     % stepsizefun is a function implementing a step size selection
0084     % algorithm. See that function for help with options, which can be
0085     % specified in the options structure passed to the solver directly.
0086     localdefaults.stepsizefun = @stepsize_sg;
0087     
0088     % Merge global and local defaults, then merge w/ user options, if any.
0089     localdefaults = mergeOptions(getGlobalDefaults(), localdefaults);
0090     if ~exist('options', 'var') || isempty(options)
0091         options = struct();
0092     end
0093     options = mergeOptions(localdefaults, options);
0094     
0095     
0096     assert(options.checkperiod >= 1, ...
0097                  'options.checkperiod must be a positive integer (>= 1).');
0098     
0099     
0100     % If no initial point x is given by the user, generate one at random.
0101     if ~exist('x', 'var') || isempty(x)
0102         x = problem.M.rand();
0103     end
0104     
0105     % Create a store database and get a key for the current x
0106     storedb = StoreDB(options.storedepth);
0107     key = storedb.getNewKey();
0108     
0109     
0110     % Elapsed time for the current set of iterations, where a set of
0111     % iterations comprises options.checkperiod iterations. We do not
0112     % count time spent for such things as logging statistics, as these are
0113     % not relevant to the actual optimization process.
0114     elapsed_time = 0;
0115     
0116     % Total number of completed steps
0117     iter = 0;
0118     
0119     
0120     % Total number of saved stats at this point.
0121     savedstats = 0;
0122     
0123     % Collect and save stats in a struct array info, and preallocate.
0124     stats = savestats();
0125     info(1) = stats;
0126     savedstats = savedstats + 1;
0127     if isinf(options.maxiter)
0128         % We trust that if the user set maxiter = inf, then they defined
0129         % another stopping criterion.
0130         preallocate = 1e5;
0131     else
0132         preallocate = ceil(options.maxiter / options.checkperiod) + 1;
0133     end
0134     info(preallocate).iter = [];
0135     
0136     
0137     % Display information header for the user.
0138     if options.verbosity >= 2
0139         fprintf('    iter       time [s]       step size\n');
0140     end
0141     
0142     
0143     % Main loop.
0144     stop = false;
0145     while iter < options.maxiter
0146         
0147         % Record start time.
0148         start_time = tic();
0149         
0150         % Draw the samples with replacement.
0151         idx_batch = randi(problem.ncostterms, options.batchsize, 1);
0152         
0153         % Compute partial gradient on this batch.
0154         pgrad = getPartialGradient(problem, x, idx_batch, storedb, key);
0155 
0156         % Apply preconditioner to the partial gradient.
0157         Ppgrad = getPrecon(problem, x, pgrad, storedb, key);
0158         
0159         % Compute a step size and the corresponding new point x.
0160         [stepsize, newx, newkey, ssstats] = ...
0161                            options.stepsizefun(problem, x, Ppgrad, iter, ...
0162                                                options, storedb, key);
0163         
0164         % Make the step: transfer iterate, remove cache from previous x.
0165         storedb.removefirstifdifferent(key, newkey);
0166         x = newx;
0167         key = newkey;
0168         
0169         % Make sure we do not use too much memory for the store database.
0170         storedb.purge();
0171         
0172         % Total number of completed steps.
0173         iter = iter + 1;
0174         
0175         % Elapsed time doing actual optimization work so far in this
0176         % set of options.checkperiod iterations.
0177         elapsed_time = elapsed_time + toc(start_time);
0178         
0179         
0180         % Check stopping criteria and save stats every checkperiod iters.
0181         if mod(iter, options.checkperiod) == 0
0182             
0183             % Log statistics for freshly executed iteration.
0184             stats = savestats();
0185             info(savedstats+1) = stats;
0186             savedstats = savedstats + 1;
0187             
0188             % Reset timer.
0189             elapsed_time = 0;
0190             
0191             % Print output.
0192             if options.verbosity >= 2
0193                 fprintf('%8d     %10.2f       %.3e\n', ...
0194                                                iter, stats.time, stepsize);
0195             end
0196             
0197             % Run standard stopping criterion checks.
0198             [stop, reason] = stoppingcriterion(problem, x, ...
0199                                                options, info, savedstats);
0200             if stop
0201                 if options.verbosity >= 1
0202                     fprintf([reason '\n']);
0203                 end
0204                 break;
0205             end
0206         
0207         end
0208 
0209     end
0210     
0211     
0212     % Keep only the relevant portion of the info struct-array.
0213     info = info(1:savedstats);
0214     
0215     
0216     % Display a final information message.
0217     if options.verbosity >= 1
0218         if ~stop
0219             % We stopped not because of stoppingcriterion but because the
0220             % loop came to an end, which means maxiter triggered.
0221             msg = 'Max iteration count reached; options.maxiter = %g.\n';
0222             fprintf(msg, options.maxiter);
0223         end
0224         fprintf('Total time is %f [s] (excludes statsfun)\n', ...
0225                 info(end).time + elapsed_time);
0226     end
0227     
0228     
0229     % Helper function to collect statistics to be saved at
0230     % index checkperiodcount+1 in info.
0231     function stats = savestats()
0232         stats.iter = iter;
0233         if savedstats == 0
0234             stats.time = 0;
0235             stats.stepsize = NaN;
0236             stats.stepsize_stats = [];
0237         else
0238             stats.time = info(savedstats).time + elapsed_time;
0239             stats.stepsize = stepsize;
0240             stats.stepsize_stats = ssstats;
0241         end
0242         stats = applyStatsfun(problem, x, storedb, key, options, stats);
0243     end
0244     
0245 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005