Home > manopt > solvers > pso > pso.m

pso

PURPOSE ^

Particle swarm optimization (PSO) for derivative-free minimization.

SYNOPSIS ^

function [xbest, fbest, info, options] = pso(problem, x, options)

DESCRIPTION ^

 Particle swarm optimization (PSO) for derivative-free minimization.

 function [x, cost, info, options] = pso(problem)
 function [x, cost, info, options] = pso(problem, x0)
 function [x, cost, info, options] = pso(problem, x0, options)
 function [x, cost, info, options] = pso(problem, [], options)

 Apply the Particle Swarm Optimization minimization algorithm to
 the problem defined in the problem structure, starting with the
 population x0 if it is provided (otherwise, a random population on the
 manifold is generated). A population is a cell containing points on the
 manifold. The number of elements in the cell must match the parameter
 options.populationsize.

 To specify options whilst not specifying an initial guess, give x0 as []
 (the empty matrix).

 None of the options are mandatory. See in code for details.

 Based on the original PSO description in
   http://particleswarm.info/nn951942.ps.

 See also: manopt/solvers/neldermead/neldermead

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [xbest, fbest, info, options] = pso(problem, x, options)
0002 % Particle swarm optimization (PSO) for derivative-free minimization.
0003 %
0004 % function [x, cost, info, options] = pso(problem)
0005 % function [x, cost, info, options] = pso(problem, x0)
0006 % function [x, cost, info, options] = pso(problem, x0, options)
0007 % function [x, cost, info, options] = pso(problem, [], options)
0008 %
0009 % Apply the Particle Swarm Optimization minimization algorithm to
0010 % the problem defined in the problem structure, starting with the
0011 % population x0 if it is provided (otherwise, a random population on the
0012 % manifold is generated). A population is a cell containing points on the
0013 % manifold. The number of elements in the cell must match the parameter
0014 % options.populationsize.
0015 %
0016 % To specify options whilst not specifying an initial guess, give x0 as []
0017 % (the empty matrix).
0018 %
0019 % None of the options are mandatory. See in code for details.
0020 %
0021 % Based on the original PSO description in
0022 %   http://particleswarm.info/nn951942.ps.
0023 %
0024 % See also: manopt/solvers/neldermead/neldermead
0025 
0026 % This file is part of Manopt: www.manopt.org.
0027 % Original author: Pierre Borckmans, Dec. 30, 2012.
0028 % Contributors: Bamdev Mishra, June 18, 2014.
0029 % Change log:
0030 %
0031 %   June 18, 2014 (BM) :
0032 %       Modified for handling product manifolds. Still need overall cleanup
0033 %       to avoid potential issues, in particular wrt logarithms.
0034 %
0035 %   June 23, 2014 (NB) :
0036 %       Added some logic for handling of the populationsize option.
0037 %
0038 %   April 5, 2015 (NB):
0039 %       Working with the new StoreDB class system. The code keeps track of
0040 %       storedb keys for all points, even though it is not strictly
0041 %       necessary. This extra bookkeeping should help maintaining the code.
0042     
0043     
0044     % Verify that the problem description is sufficient for the solver.
0045     if ~canGetCost(problem)
0046         warning('manopt:getCost', ...
0047             'No cost provided. The algorithm will likely abort.');
0048     end
0049     
0050     % Dimension of the manifold
0051     dim = problem.M.dim();
0052     
0053     % Set local defaults here
0054     localdefaults.storedepth = 0;                   % no need for caching
0055     localdefaults.maxiter = max(500, 4*dim);
0056     
0057     localdefaults.populationsize = min(40, 10*dim);
0058     localdefaults.nostalgia = 1.4;
0059     localdefaults.social = 1.4;
0060     
0061     % Merge global and local defaults, then merge w/ user options, if any.
0062     localdefaults = mergeOptions(getGlobalDefaults(), localdefaults);
0063     if ~exist('options', 'var') || isempty(options)
0064         options = struct();
0065     end
0066     options = mergeOptions(localdefaults, options);
0067     
0068     
0069     if ~isfield(problem.M, 'log') % BM
0070         error(['The manifold problem.M must provide a logarithmic map, ' ...
0071                'M.log(x, y). An approximate logarithm will do too.']);
0072     end
0073     
0074     % Start timing for initialization
0075     timetic = tic();
0076     
0077     % If no initial population x is given by the user,
0078     % generate one at random.
0079     if ~exist('x', 'var') || isempty(x)
0080         x = cell(options.populationsize, 1);
0081         for i = 1 : options.populationsize
0082             x{i} = problem.M.rand();
0083         end
0084     else
0085         if ~iscell(x)
0086             error('The initial guess x0 must be a cell (a population).');
0087         end
0088         if length(x) ~= options.populationsize
0089             options.populationsize = length(x);
0090             warning('manopt:pso:size', ...
0091                     ['The option populationsize was forced to the size' ...
0092                      ' of the given initial population x0.']);
0093         end
0094     end
0095     
0096     
0097     % Create a store database and a key for each point x{i}
0098     storedb = StoreDB(options.storedepth);
0099     xkey = cell(size(x));
0100     for i = 1 : numel(x)
0101         xkey{i} = storedb.getNewKey();
0102     end
0103     
0104     % Initialize personal best positions to the initial population
0105     y = x;
0106     ykey = xkey;
0107     
0108     % Save a copy of the swarm at the previous iteration
0109     xprev = x;
0110     xprevkey = xkey; %#ok<NASGU>
0111     
0112     % Initialize velocities for each particle
0113     v = cell(size(x));
0114     for i = 1 : numel(x)
0115         % random velocity to improve initial exploration
0116         v{i} = problem.M.randvec(x{i});
0117         % or null velocity
0118         % v{i} = problem.M.zerovec();
0119     end
0120     
0121     % Compute cost for each particle xi,
0122     % initialize personal best costs,
0123     % and setup a function evaluations counter.
0124     costs = zeros(size(x));
0125     for i = 1 : numel(x)
0126         costs(i) = getCost(problem, x{i}, storedb, xkey{i});
0127     end
0128     fy = costs;
0129     costevals = options.populationsize;
0130     
0131     % Identify the best particle and store its cost/position
0132     [fbest, imin] = min(costs);
0133     xbest = x{imin};
0134     xbestkey = xkey{imin}; %#ok<NASGU>
0135     
0136     % Iteration counter (at any point, iter is the number of fully executed
0137     % iterations so far)
0138     iter = 0;
0139     
0140     % Save stats in a struct array info, and preallocate.
0141     % savestats will be called twice for the initial iterate (number 0),
0142     % which is unfortunate, but not problematic.
0143     stats = savestats();
0144     info(1) = stats;
0145     info(min(10000, options.maxiter+1)).iter = [];
0146     
0147     % Start iterating until stopping criterion triggers
0148     while true
0149         
0150         stats = savestats();
0151         info(iter+1) = stats; %#ok<AGROW>
0152         iter = iter + 1;
0153         
0154         % Make sure we don't use too much memory for the store database
0155         storedb.purge();
0156         
0157         % Log / display iteration information here.
0158         if options.verbosity >= 2
0159             fprintf('Cost evals: %7d\tBest cost: %+.8e\n', costevals, fbest);
0160         end
0161         
0162         % Start timing this iteration
0163         timetic = tic();
0164         
0165         % BM: Run standard stopping criterion checks.
0166         % BM: Stop if any particle triggers a stopping criterion.
0167         for i = numel(x)
0168             [stop, reason] = stoppingcriterion(problem, x{i}, options, info, iter);
0169             if stop
0170                 break;
0171             end
0172         end
0173         
0174         if stop
0175             if options.verbosity >= 1
0176                 fprintf([reason '\n']);
0177             end
0178             break;
0179         end
0180         
0181         
0182         % Compute the inertia factor
0183         % (linearly decreasing from .9 to .4, from iter=0 to maxiter)
0184         w = 0.4 + 0.5*(1-iter/options.maxiter);
0185         
0186         % Compute velocities
0187         for i = 1 : numel(x)
0188             
0189             % Get the position and past best position of particle i
0190             xi = x{i};
0191             yi = y{i};
0192             
0193             % Get the previous position and velocity of particle i
0194             xiprev = xprev{i};
0195             vi = v{i};
0196             
0197             % Compute new velocity of particle i,
0198             % composed of 3 contributions
0199             inertia = problem.M.lincomb(xi, w , problem.M.transp(xiprev, xi, vi));
0200             nostalgia = problem.M.lincomb(xi, rand(1)*options.nostalgia, problem.M.log(xi, yi) );
0201             social = problem.M.lincomb(xi, rand(1) * options.social, problem.M.log(xi, xbest));
0202             
0203             v{i} = problem.M.lincomb(xi, 1, inertia, 1, problem.M.lincomb(xi, 1, nostalgia, 1, social));
0204             
0205         end
0206         
0207         % Backup the current swarm positions
0208         xprev = x;
0209         xprevkey = xkey; %#ok<NASGU>
0210         
0211         % Update positions, personal bests and global best
0212         for i = 1 : numel(x)
0213             % compute new position of particle i
0214             x{i} = problem.M.retr(x{i}, v{i});
0215             xkey{i} = storedb.getNewKey();
0216             % compute new cost of particle i
0217             fxi = getCost(problem, x{i}, storedb, xkey{i});
0218             costevals = costevals + 1;
0219             
0220             % update costs of the swarm
0221             costs(i) = fxi;
0222             % update self-best if necessary
0223             if fxi < fy(i)
0224                 % update self-best cost and position
0225                 fy(i) = fxi;
0226                 y{i} = x{i};
0227                 ykey{i} = xkey{i};
0228                 % update global-best if necessary
0229                 if fy(i) < fbest
0230                     fbest = fy(i);
0231                     xbest = y{i};
0232                     xbestkey = ykey{i}; %#ok<NASGU>
0233                 end
0234             end
0235         end
0236     end
0237     
0238     
0239     info = info(1:iter);
0240      
0241     % Routine in charge of collecting the current iteration stats
0242     function stats = savestats()
0243         stats.iter = iter;
0244         stats.cost = fbest;
0245         stats.costevals = costevals;
0246         stats.x = x;
0247         stats.v = v;
0248         stats.xbest = xbest;
0249         if iter == 0
0250             stats.time = toc(timetic);
0251         else
0252             stats.time = info(iter).time + toc(timetic);
0253         end
0254         
0255         % BM: Begin storing user defined stats for the entire population
0256         num_old_fields = size(fieldnames(stats), 1);
0257         trialstats = applyStatsfun(problem, x{1}, storedb, xkey{1}, options, stats);% BM
0258         new_fields = fieldnames(trialstats);
0259         num_new_fields = size(fieldnames(trialstats), 1);
0260         num_additional_fields =  num_new_fields - num_old_fields; % User has defined new fields
0261         for jj = 1 : num_additional_fields % New fields added
0262             tempfield = new_fields(num_old_fields + jj);
0263             stats.(char(tempfield)) = cell(options.populationsize, 1);
0264         end
0265         for ii = 1 : options.populationsize % Adding information for each element of the population
0266             tempstats = applyStatsfun(problem, x{ii}, storedb, xkey{ii}, options, stats);
0267             for jj = 1 : num_additional_fields
0268                 tempfield = new_fields(num_old_fields + jj);
0269                 tempfield_value = tempstats.(char(tempfield));
0270                 stats.(char(tempfield)){ii} = tempfield_value;
0271             end
0272         end
0273         % BM: End storing
0274        
0275     end
0276     
0277     
0278 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005