Standard step size selection algorithm for the stochastic gradient method Given a problem structure, a point x on the manifold problem.d and a tangent vector d at x, produces a stepsize (a positive real number) and a new point newx obtained by retraction -stepsize*d at x. Additional inputs include iter (the iteration number of x, where 0 marks the initial guess), an options structure, a storedb database and the key of point x in that database. Additional outputs include the key of newx in the database, newkey, as well as a structure ssstats collecting statistics about the work done during the call to this function. See in code for the role of available options: options.stepsize_type options.stepsize_init options.stepsize_lambda options.stepsize_decaysteps This function may create and maintain a structure called sssgmem inside storedb.internal. This gives the function the opportunity to remember what happened in previous calls. See also: stochasticgradient
0001 function [stepsize, newx, newkey, ssstats] = ... 0002 stepsize_sg(problem, x, d, iter, options, storedb, key) %#ok<INUSD> 0003 % Standard step size selection algorithm for the stochastic gradient method 0004 % 0005 % Given a problem structure, a point x on the manifold problem.d and a 0006 % tangent vector d at x, produces a stepsize (a positive real number) and a 0007 % new point newx obtained by retraction -stepsize*d at x. Additional inputs 0008 % include iter (the iteration number of x, where 0 marks the initial 0009 % guess), an options structure, a storedb database and the key of point x 0010 % in that database. Additional outputs include the key of newx in the 0011 % database, newkey, as well as a structure ssstats collecting statistics 0012 % about the work done during the call to this function. 0013 % 0014 % See in code for the role of available options: 0015 % options.stepsize_type 0016 % options.stepsize_init 0017 % options.stepsize_lambda 0018 % options.stepsize_decaysteps 0019 % 0020 % This function may create and maintain a structure called sssgmem inside 0021 % storedb.internal. This gives the function the opportunity to remember 0022 % what happened in previous calls. 0023 % 0024 % See also: stochasticgradient 0025 0026 % This file is part of Manopt: www.manopt.org. 0027 % Original authors: Bamdev Mishra and Nicolas Boumal, March 30, 2017. 0028 % Contributors: Hiroyuki Kasai and Hiroyuki Sato. 0029 % Change log: 0030 0031 0032 % Allow omission of the key, and even of storedb. 0033 if ~exist('key', 'var') 0034 if ~exist('storedb', 'var') 0035 storedb = StoreDB(); 0036 end 0037 key = storedb.getNewKey(); %#ok<NASGU> 0038 end 0039 0040 0041 % Initial stepsize guess. 0042 default_options.stepsize_init = 0.1; 0043 % Stepsize evolution type. Options are 'decay', 'fix' and 'hybrid'. 0044 default_options.stepsize_type = 'decay'; 0045 % If stepsize_type = 'decay' or 'hybrid', lambda is a weighting factor. 0046 default_options.stepsize_lambda = 0.1; 0047 % If stepsize_type = 'hybrid', decaysteps states for how many 0048 % iterations the step size decays before becoming constant. 0049 default_options.stepsize_decaysteps = 100; 0050 0051 if ~exist('options', 'var') || isempty(options) 0052 options = struct(); 0053 end 0054 options = mergeOptions(default_options, options); 0055 0056 0057 type = options.stepsize_type; 0058 init = options.stepsize_init; 0059 lambda = options.stepsize_lambda; 0060 decaysteps = options.stepsize_decaysteps; 0061 0062 0063 switch lower(type) 0064 0065 % Step size decays as O(1/iter). 0066 case 'decay' 0067 stepsize = init / (1 + init*lambda*iter); 0068 0069 % Step size is fixed. 0070 case {'fix', 'fixed'} 0071 stepsize = init; 0072 0073 % Step size decays only for the few initial iterations. 0074 case 'hybrid' 0075 if iter < decaysteps 0076 stepsize = init / (1 + init*lambda*iter); 0077 else 0078 stepsize = init / (1 + init*lambda*decaysteps); 0079 end 0080 0081 otherwise 0082 error(['Unknown options.stepsize_type. ' ... 0083 'Should be ''fix'', ''decay'' or ''hybrid''.']); 0084 0085 end 0086 0087 % Store some information. 0088 ssstats = struct(); 0089 ssstats.stepsize = stepsize; 0090 0091 % Compute the new point and give it a key. 0092 newx = problem.M.retr(x, d, -stepsize); 0093 newkey = storedb.getNewKey(); 0094 0095 end