Home > manopt > manifolds > multinomial > multinomialdoublystochasticgeneralfactory.m

multinomialdoublystochasticgeneralfactory

PURPOSE ^

Manifold of n-by-m postive matrices such that row sum is p and column sum is q.

SYNOPSIS ^

function M = multinomialdoublystochasticgeneralfactory(n, m, p, q)

DESCRIPTION ^

 Manifold of n-by-m postive matrices such that row sum is p and column sum is q.

 function M = multinomialdoublystochasticgeneralfactory(n, m, p, q)

  X > 0.
  X1 = p, p is a column positive vector of size n.
  X'1 = q, q is a column positive vector of size m.
 
 Ensure that p > 0 and q > 0. Also, ensure that sum(p) == sum(q).


 Please cite the Manopt paper as well as the research papers:


 @Techreport{mishra21a,
   Title   = {Manifold optimization for optimal transport},
   Author  = {Mishra, B. and Satya Dev, N. T. V., Kasai, H. and Jawanpuria, P.},
   Journal = {Arxiv preprint arXiv:2103.00902},
   Year    = {2021}
 }

 @article{douik2019manifold,
 title={Manifold optimization over the set of doubly stochastic matrices: A second-order geometry},
  author={Douik, A. and Hassibi, B.},
  journal={IEEE Transactions on Signal Processing},
  volume={67},
  number={22},
  pages={5761--5774},
  year={2019}
}


 @article{shi21a,
 title={Coupling matrix manifolds assisted optimization for optimal transport problems},
  author={Shi, D. and Gao, J. and Hong, X. and Choy, ST. B. and Wang, Z.},
  journal={Machine Learning},
  pages={1--26},
  year={2021}
b}


 The factory file extends the factory file
 multinomialdoublystochasticfactory 
 to handle general scaling of rows and columns.


 See also multinomialdoublystochastic multinomialsymmetricfactory multinomialfactory

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function M = multinomialdoublystochasticgeneralfactory(n, m, p, q)
0002 % Manifold of n-by-m postive matrices such that row sum is p and column sum is q.
0003 %
0004 % function M = multinomialdoublystochasticgeneralfactory(n, m, p, q)
0005 %
0006 %  X > 0.
0007 %  X1 = p, p is a column positive vector of size n.
0008 %  X'1 = q, q is a column positive vector of size m.
0009 %
0010 % Ensure that p > 0 and q > 0. Also, ensure that sum(p) == sum(q).
0011 %
0012 %
0013 % Please cite the Manopt paper as well as the research papers:
0014 %
0015 %
0016 % @Techreport{mishra21a,
0017 %   Title   = {Manifold optimization for optimal transport},
0018 %   Author  = {Mishra, B. and Satya Dev, N. T. V., Kasai, H. and Jawanpuria, P.},
0019 %   Journal = {Arxiv preprint arXiv:2103.00902},
0020 %   Year    = {2021}
0021 % }
0022 %
0023 % @article{douik2019manifold,
0024 % title={Manifold optimization over the set of doubly stochastic matrices: A second-order geometry},
0025 %  author={Douik, A. and Hassibi, B.},
0026 %  journal={IEEE Transactions on Signal Processing},
0027 %  volume={67},
0028 %  number={22},
0029 %  pages={5761--5774},
0030 %  year={2019}
0031 %}
0032 %
0033 %
0034 % @article{shi21a,
0035 % title={Coupling matrix manifolds assisted optimization for optimal transport problems},
0036 %  author={Shi, D. and Gao, J. and Hong, X. and Choy, ST. B. and Wang, Z.},
0037 %  journal={Machine Learning},
0038 %  pages={1--26},
0039 %  year={2021}
0040 %b}
0041 %
0042 %
0043 % The factory file extends the factory file
0044 % multinomialdoublystochasticfactory
0045 % to handle general scaling of rows and columns.
0046 %
0047 %
0048 % See also multinomialdoublystochastic multinomialsymmetricfactory multinomialfactory
0049 
0050 % This file is part of Manopt: www.manopt.org.
0051 % Original author: Bamdev Mishra, Oct 30, 2020.
0052 % Contributors:
0053 % Change log:
0054 
0055     e1 = ones(n, 1);
0056     e2 = ones(m, 1);
0057 
0058     maxDSiters = min(1000, n*m); % Ideally it should be supplid by user.
0059 
0060     if size(p, 1) ~= n
0061         error('p should be a column vector of size n.');
0062     end
0063 
0064     if size(q, 1) ~= m
0065         error('q should be a column vector of size m.');
0066     end
0067 
0068     function [alpha, beta] = mylinearsolve(X, b) % BM okay
0069         % zeta = sparse(A)\b; % sparse might not be better perf.-wise.
0070         % where A = [diag(p) X ; X' diag(q)];
0071         %
0072         % Even faster is to create a function handle
0073         % computing A*x (x is a given vector).
0074         % Make sure that A is not created, and X is only
0075         % passed with mylinearsolve and not A.
0076         [zeta, ~, ~, iter] = pcg(@mycompute, b, 1e-6, 100);
0077         function Ax = mycompute(x) % BM okay
0078             xtop = x(1:n,1); % vector of size n akin to alpha
0079             xbottom = x(n+1:end,1); % vector of size m akin to beta
0080             Axtop = xtop.*p + X*xbottom;
0081             Axbottom = X'*xtop + xbottom.*q;
0082             Ax = [Axtop; Axbottom];
0083         end
0084         alpha = zeta(1:n, 1);
0085         beta = zeta(n+1:end, 1);
0086     end
0087 
0088     M.name = @() sprintf('%dx%d matrices with positive entries such that row sum is p and column sum is q', n, n);
0089 
0090     M.dim = @() (n-1)*(m-1); % BM okay
0091 
0092     % Fisher metric
0093     M.inner = @iproduct; % BM okay
0094     function ip = iproduct(X, eta, zeta)
0095         ip = sum((eta(:).*zeta(:))./X(:));
0096     end
0097 
0098     M.norm = @(X, eta) sqrt(M.inner(X, eta, eta)); % BM okay
0099 
0100     M.dist = @(X, Y) error('multinomialdoublystochasticgeneralfactory.dist not implemented yet.');
0101 
0102     % The manifold is not compact as a result of the choice of the metric,
0103     % thus any choice here is arbitrary. This is notably used to pick
0104     % default values of initial and maximal trust-region radius in the
0105     % trustregions solver.
0106     M.typicaldist = @() m+n;
0107 
0108     % Pick a random point on the manifold
0109     M.rand = @random; % BM okay
0110     function X = random()
0111         Z = abs(randn(n, m));     % Random point in the ambient space
0112         X = doubly_stochastic_general(Z, p, q, maxDSiters); % Projection on the Manifold
0113     end
0114 
0115     % Pick a random vector in the tangent space at X.
0116     M.randvec = @randomvec; % BM okay
0117     function eta = randomvec(X) % A random vector in the tangent space
0118         % A random vector in the ambient space
0119         Z = randn(n, m);
0120         % Projection of the vector onto the tangent space
0121         b = [sum(Z, 2) ; sum(Z, 1)'];
0122         [alpha, beta] = mylinearsolve(X, b);
0123         eta = Z - (alpha*e2' + e1*beta').*X;
0124         % Normalizing the vector
0125         nrm = M.norm(X, eta);
0126         eta = eta / nrm;
0127     end
0128 
0129     % Projection of vector eta in the ambient space to the tangent space.
0130     M.proj = @projection;  % BM okay
0131     function etaproj = projection(X, eta) % Projection of the vector eta in the ambeint space onto the tangent space
0132         b = [sum(eta, 2) ; sum(eta, 1)'];
0133         [alpha, beta] = mylinearsolve(X, b);
0134         etaproj = eta - (alpha*e2' + e1*beta').*X;
0135     end
0136 
0137     M.tangent = M.proj;
0138     M.tangent2ambient = @(X, eta) eta; % BM okay
0139 
0140     % Conversion of Euclidean to Riemannian gradient
0141     M.egrad2rgrad = @egrad2rgrad; % BM okay
0142     function rgrad = egrad2rgrad(X, egrad) % projection of the euclidean gradient
0143         mu = (X.*egrad); 
0144         b = [sum(mu, 2) ; sum(mu, 1)'];
0145         [alpha, beta] = mylinearsolve(X, b);
0146         rgrad = mu - (alpha*e2' + e1*beta').*X;
0147     end
0148 
0149     % First-order retraction
0150     M.retr = @retraction;
0151     function Y = retraction(X, eta, t)
0152         if nargin < 3
0153             t = 1.0;
0154         end
0155         Y = X.*exp(t*(eta./X));
0156 
0157         Y = min(Y, 1e50); % For numerical stability
0158         Y = max(Y, 1e-50); % For numerical stability
0159 
0160         Y = doubly_stochastic_general(Y, p, q, maxDSiters);
0161     end
0162 
0163     % Conversion of Euclidean to Riemannian Hessian
0164     M.ehess2rhess = @ehess2rhess; % BM okay
0165     function rhess = ehess2rhess(X, egrad, ehess, eta)
0166 
0167         % Computing the directional derivative of the Riemannian
0168         % gradient
0169         gamma = egrad.*X;
0170         gammadot = ehess.*X + egrad.*eta;
0171         
0172         b = [sum(gamma, 2) ; sum(gamma, 1)'];
0173         bdot = [sum(gammadot, 2) ; sum(gammadot, 1)'];
0174         [alpha, beta] = mylinearsolve(X, b);
0175         [alphadot, betadot] = mylinearsolve(X, bdot- [eta*beta; eta'*alpha]);
0176         
0177         S = (alpha*e2' + e1*beta');
0178         deltadot = gammadot - (alphadot*e2' + e1*betadot').*X- S.*eta; % rgraddot
0179 
0180         % Computing Riemannian gradient
0181         delta = gamma - S.*X; % rgrad
0182 
0183         % Riemannian Hessian in the ambient space
0184         nabla = deltadot - 0.5*(delta.*eta)./X;
0185 
0186         % Riemannian Hessian on the tangent space
0187         rhess = projection(X, nabla);
0188     end
0189 
0190 
0191     % Miscellaneous manifold functions % BM okay
0192     M.hash = @(X) ['z' hashmd5(X(:))];
0193     M.lincomb = @matrixlincomb;
0194     M.zerovec = @(X) zeros(n, m);
0195     M.transp = @(X1, X2, d) projection(X2, d);
0196     M.vec = @(X, U) U(:);
0197     M.mat = @(X, u) reshape(u, n, m);
0198     M.vecmatareisometries = @() false;
0199     
0200 end

Generated on Fri 30-Sep-2022 13:18:25 by m2html © 2005