Home > manopt > solvers > neldermead > neldermead.m

neldermead

PURPOSE ^

Nelder Mead optimization algorithm for derivative-free minimization.

SYNOPSIS ^

function [x, cost, info, options] = neldermead(problem, x, options)

DESCRIPTION ^

 Nelder Mead optimization algorithm for derivative-free minimization.

 function [x, cost, info, options] = neldermead(problem)
 function [x, cost, info, options] = neldermead(problem, x0)
 function [x, cost, info, options] = neldermead(problem, x0, options)
 function [x, cost, info, options] = neldermead(problem, [], options)

 Apply a Nelder-Mead minimization algorithm to the problem defined in
 the problem structure, starting with the population x0 if it is provided
 (otherwise, a random population on the manifold is generated). A
 population is a cell containing points on the manifold. The number of
 elements in the cell must be dim+1, where dim is the dimension of the
 manifold: problem.M.dim().

 To specify options whilst not specifying an initial guess, give x0 as []
 (the empty matrix).

 This algorithm is a plain adaptation of the Euclidean Nelder-Mead method
 to the Riemannian setting. It comes with no convergence guarantees and
 there is room for improvement. In particular, we compute centroids as
 Karcher means, which seems overly expensive: cheaper forms of
 average-like quantities might work better.
 This solver is useful nonetheless for problems for which no derivatives
 are available, and it may constitute a starting point for the development
 of other Riemannian derivative-free methods.

 None of the options are mandatory. See in code for details.

 Requires problem.M.pairmean(x, y) to be defined (computes the average
 between two points, x and y).

 If options.statsfun is defined, it will receive a cell of points x (the
 current simplex being considered at that iteration), and, if required,
 one store structure corresponding to the best point, x{1}. The points are
 ordered by increasing cost: f(x{1}) <= f(x{2}) <= ... <= f(x{dim+1}),
 where dim = problem.M.dim().

 Based on http://www.optimization-online.org/DB_FILE/2007/08/1742.pdf.

 See also: manopt/solvers/pso/pso

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [x, cost, info, options] = neldermead(problem, x, options)
0002 % Nelder Mead optimization algorithm for derivative-free minimization.
0003 %
0004 % function [x, cost, info, options] = neldermead(problem)
0005 % function [x, cost, info, options] = neldermead(problem, x0)
0006 % function [x, cost, info, options] = neldermead(problem, x0, options)
0007 % function [x, cost, info, options] = neldermead(problem, [], options)
0008 %
0009 % Apply a Nelder-Mead minimization algorithm to the problem defined in
0010 % the problem structure, starting with the population x0 if it is provided
0011 % (otherwise, a random population on the manifold is generated). A
0012 % population is a cell containing points on the manifold. The number of
0013 % elements in the cell must be dim+1, where dim is the dimension of the
0014 % manifold: problem.M.dim().
0015 %
0016 % To specify options whilst not specifying an initial guess, give x0 as []
0017 % (the empty matrix).
0018 %
0019 % This algorithm is a plain adaptation of the Euclidean Nelder-Mead method
0020 % to the Riemannian setting. It comes with no convergence guarantees and
0021 % there is room for improvement. In particular, we compute centroids as
0022 % Karcher means, which seems overly expensive: cheaper forms of
0023 % average-like quantities might work better.
0024 % This solver is useful nonetheless for problems for which no derivatives
0025 % are available, and it may constitute a starting point for the development
0026 % of other Riemannian derivative-free methods.
0027 %
0028 % None of the options are mandatory. See in code for details.
0029 %
0030 % Requires problem.M.pairmean(x, y) to be defined (computes the average
0031 % between two points, x and y).
0032 %
0033 % If options.statsfun is defined, it will receive a cell of points x (the
0034 % current simplex being considered at that iteration), and, if required,
0035 % one store structure corresponding to the best point, x{1}. The points are
0036 % ordered by increasing cost: f(x{1}) <= f(x{2}) <= ... <= f(x{dim+1}),
0037 % where dim = problem.M.dim().
0038 %
0039 % Based on http://www.optimization-online.org/DB_FILE/2007/08/1742.pdf.
0040 %
0041 % See also: manopt/solvers/pso/pso
0042 
0043 % This file is part of Manopt: www.manopt.org.
0044 % Original author: Nicolas Boumal, Dec. 30, 2012.
0045 % Contributors:
0046 % Change log:
0047 %
0048 %   Apr.  4, 2015 (NB):
0049 %       Working with the new StoreDB class system.
0050 %       Clarified interactions with statsfun and store.
0051 %
0052 %   Nov. 11, 2016 (NB):
0053 %       If options.verbosity is < 2, prints minimal output.
0054 %
0055 %   Sep.  6, 2018 (NB):
0056 %       Using retraction instead of exponential.
0057 
0058     
0059     % Verify that the problem description is sufficient for the solver.
0060     if ~canGetCost(problem)
0061         warning('manopt:getCost', ...
0062                 'No cost provided. The algorithm will likely abort.');  
0063     end
0064     
0065     % Dimension of the manifold
0066     dim = problem.M.dim();
0067 
0068     % Set local defaults here
0069     localdefaults.storedepth = 0;                     % no need for caching
0070     localdefaults.maxiter = max(2000, 4*dim);
0071     
0072     localdefaults.reflection = 1;
0073     localdefaults.expansion = 2;
0074     localdefaults.contraction = .5;
0075     % forced to .5 to enable using pairmean functions in manifolds.
0076     % localdefaults.shrinkage = .5;
0077     
0078     % Merge global and local defaults, then merge w/ user options, if any.
0079     localdefaults = mergeOptions(getGlobalDefaults(), localdefaults);
0080     if ~exist('options', 'var') || isempty(options)
0081         options = struct();
0082     end
0083     options = mergeOptions(localdefaults, options);
0084     
0085     % Start timing for initialization.
0086     timetic = tic();
0087     
0088     % If no initial simplex x is given by the user, generate one at random.
0089     if ~exist('x', 'var') || isempty(x)
0090         x = cell(dim+1, 1);
0091         for i = 1 : dim+1
0092             x{i} = problem.M.rand();
0093         end
0094     end
0095     
0096     % Create a store database and a key for each point.
0097     storedb = StoreDB(options.storedepth);
0098     key = cell(size(x));
0099     for i = 1 : dim+1;
0100         key{i} = storedb.getNewKey();
0101     end
0102     
0103     % Compute objective-related quantities for x, and setup a
0104     % function evaluations counter.
0105     costs = zeros(dim+1, 1);
0106     for i = 1 : dim+1
0107         costs(i) = getCost(problem, x{i}, storedb, key{i});
0108     end
0109     costevals = dim+1;
0110     
0111     % Sort simplex points by cost.
0112     [costs, order] = sort(costs);
0113     x = x(order);
0114     key = key(order);
0115     
0116     % Iteration counter.
0117     % At any point, iter is the number of fully executed iterations so far.
0118     iter = 0;
0119     
0120     % Save stats in a struct array info, and preallocate.
0121     % savestats will be called twice for the initial iterate (number 0),
0122     % which is unfortunate, but not problematic.
0123     stats = savestats();
0124     info(1) = stats;
0125     info(min(10000, options.maxiter+1)).iter = [];
0126     
0127     % Start iterating until stopping criterion triggers.
0128     while true
0129         
0130         % Make sure we don't use to much memory for the store database.
0131         storedb.purge();
0132         
0133         stats = savestats();
0134         info(iter+1) = stats; %#ok<AGROW>
0135         iter = iter + 1;
0136         
0137         % Start timing this iteration.
0138         timetic = tic();
0139         
0140         % Sort simplex points by cost.
0141         [costs, order] = sort(costs);
0142         x = x(order);
0143         key = key(order);
0144 
0145         % Log / display iteration information here.
0146         if options.verbosity >= 2
0147             fprintf('Cost evals: %7d\tBest cost: %+.4e\t', ...
0148                     costevals, costs(1));
0149         end
0150         
0151         % Run standard stopping criterion checks.
0152         [stop, reason] = stoppingcriterion(problem, x, options, info, iter);
0153     
0154         if stop
0155             if options.verbosity >= 1
0156                 fprintf([reason '\n']);
0157             end
0158             break;
0159         end
0160         
0161         % Compute a centroid for the dim best points.
0162         xbar = centroid(problem.M, x(1:end-1));
0163         
0164         % Compute the direction for moving along the axis xbar - worst x.
0165         vec = problem.M.log(xbar, x{end});
0166         
0167         % Reflection step
0168         xr = problem.M.retr(xbar, vec, -options.reflection);
0169         keyr = storedb.getNewKey();
0170         costr = getCost(problem, xr, storedb, keyr);
0171         costevals = costevals + 1;
0172         
0173         % If the reflected point is honorable, drop the worst point,
0174         % replace it by the reflected point and start new iteration.
0175         if costr >= costs(1) && costr < costs(end-1)
0176             if options.verbosity >= 2
0177                 fprintf('Reflection\n');
0178             end
0179             costs(end) = costr;
0180             x{end} = xr;
0181             key{end} = keyr;
0182             continue;
0183         end
0184         
0185         % If the reflected point is better than the best point, expand.
0186         if costr < costs(1)
0187             xe = problem.M.retr(xbar, vec, -options.expansion);
0188             keye = storedb.getNewKey();
0189             coste = getCost(problem, xe, storedb, keye);
0190             costevals = costevals + 1;
0191             if coste < costr
0192                 if options.verbosity >= 2
0193                     fprintf('Expansion\n');
0194                 end
0195                 costs(end) = coste;
0196                 x{end} = xe;
0197                 key{end} = keye;
0198                 continue;
0199             else
0200                 if options.verbosity >= 2
0201                     fprintf('Reflection (failed expansion)\n');
0202                 end
0203                 costs(end) = costr;
0204                 x{end} = xr;
0205                 key{end} = keyr;
0206                 continue;
0207             end
0208         end
0209         
0210         % If the reflected point is worse than the second to worst point,
0211         % contract.
0212         if costr >= costs(end-1)
0213             if costr < costs(end)
0214                 % do an outside contraction
0215                 xoc = problem.M.retr(xbar, vec, -options.contraction);
0216                 keyoc = storedb.getNewKey();
0217                 costoc = getCost(problem, xoc, storedb, keyoc);
0218                 costevals = costevals + 1;
0219                 if costoc <= costr
0220                     if options.verbosity >= 2
0221                         fprintf('Outside contraction\n');
0222                     end
0223                     costs(end) = costoc;
0224                     x{end} = xoc;
0225                     key{end} = keyoc;
0226                     continue;
0227                 end
0228             else
0229                 % do an inside contraction
0230                 xic = problem.M.retr(xbar, vec, options.contraction);
0231                 keyic = storedb.getNewKey();
0232                 costic = getCost(problem, xic, storedb, keyic);
0233                 costevals = costevals + 1;
0234                 if costic <= costs(end)
0235                     if options.verbosity >= 2
0236                         fprintf('Inside contraction\n');
0237                     end
0238                     costs(end) = costic;
0239                     x{end} = xic;
0240                     key{end} = keyic;
0241                     continue;
0242                 end
0243             end
0244         end
0245         
0246         % If we get here, shrink the simplex around x{1}.
0247         if options.verbosity >= 2
0248             fprintf('Shrinkage\n');
0249         end
0250         for i = 2 : dim+1
0251             x{i} = problem.M.pairmean(x{1}, x{i});
0252             key{i} = storedb.getNewKey();
0253             costs(i) = getCost(problem, x{i}, storedb, key{i});
0254         end
0255         costevals = costevals + dim;
0256         
0257     end
0258     
0259     
0260     info = info(1:iter);
0261     
0262     % Iteration done: return only the best point found.
0263     cost = costs(1);
0264     x = x{1};
0265     key = key{1};
0266     
0267     
0268     
0269     % Routine in charge of collecting the current iteration stats.
0270     function stats = savestats()
0271         stats.iter = iter;
0272         stats.cost = costs(1);
0273         stats.costevals = costevals;
0274         if iter == 0
0275             stats.time = toc(timetic);
0276         else
0277             stats.time = info(iter).time + toc(timetic);
0278         end
0279         % The statsfun can only possibly receive one store structure. We
0280         % pass the key to the best point, so that the best point's store
0281         % will be passed. But the whole cell x of points is passed through.
0282         stats = applyStatsfun(problem, x, storedb, key{1}, options, stats);
0283     end
0284     
0285 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005